In [None]:
import os
import numpy as np
import cv2
import pandas as pd
import time
from tqdm import tqdm
# from PIL import Image
import pydicom
import torch
import torch.nn as nn
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
import torch.nn
import torchvision.models as models  
import matplotlib.pyplot as plt
import pre_function as pre
import SimpleITK as sitk
# import nibabel as nib
import joblib
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.tensorboard import SummaryWriter

summaryWriter = SummaryWriter("./logs/")



In [None]:
path = './vali'


In [None]:
normMeanPET = 516.5095631343381
normStdPET = 1782.5365338158515


In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

normalize_pet = transforms.Normalize(
   mean=normMeanPET,
   std=normStdPET
)

transform_pet = transforms.Compose([transforms.ToTensor(), normalize_pet])  
transform_mask = transforms.Compose([transforms.ToTensor()])  


In [None]:
class PLEDataset(Dataset):
    def __init__(self, data_list,shape,transform_pet,transform_mask):
    # def __init__(self, img_paths, img_labels):
        self.list = data_list
        # self.labels = labels
        self.shape = shape
        # self.size_of_images = size_of_images
        self.transform_mask = transform_mask
        self.transform_pet = transform_pet

    def __len__(self):
        return len(self.list)
    def __getitem__(self, index):
        # img = cv2.resize(self.list[index],self.size_of_images)
        path = self.list[index]
        data = joblib.load(path)
        pet = data[0]
        mask = data[1]
        num_layer = pet.shape[0]
        if num_layer>200:
            pet = pet[num_layer-178:num_layer-44,:,:]
            mask = mask[num_layer-178:num_layer-44,:,:]
        # shape = [134,100,100]
        p1 = []
        p2 = []
        for i in range(3):
            if pet.shape[i] < shape[i]:
                a = int((shape[i]-pet.shape[i])/2)
                p1.append(a)
                p2.append(shape[i] - pet.shape[i] - a)
            else:
                p1.append(0)
                p2.append(0)
        img_pet = np.pad(pet,pad_width=((p1[0], p2[0]),
                                        (p1[1], p2[1]),
                                        (p1[2],p2[2])),mode="constant",constant_values=(0,0))
        img_mask = np.pad(mask,pad_width=((p1[0], p2[0]),
                                        (p1[1], p2[1]),
                                        (p1[2],p2[2])),mode="constant",constant_values=(0,0))
        c1 = abs(int((img_pet.shape[0]-shape[0])/2))
        c2 = abs(int((img_pet.shape[0]+shape[0])/2))
        h1 = abs(int((img_pet.shape[1]-shape[1])/2))
        h2 = abs(int((img_pet.shape[1]+shape[1])/2))
        w1 = abs(int((img_pet.shape[2]-shape[2])/2))
        w2 = abs(int((img_pet.shape[2]+shape[2])/2))
        img_pet = img_pet[c1:c2,h1:h2,w1:w2]
        img_mask = img_mask[c1:c2,h1:h2,w1:w2]

        img_pet = img_pet.astype(np.float32)
        img_mask = img_mask.astype(np.float32)
        # img_pet = torch.tensor(img_pet)
        TENSOR_pet = self.transform_pet(img_pet)
        TENSOR_mask = self.transform_mask(img_mask)
        # TENSOR_pet = torch.unsqueeze(TENSOR_pet,dim=0)
        # TENSOR_mask = torch.unsqueeze(TENSOR_mask,dim=0)
        # TENSOR_pet = torch.tensor(img_pet)
        # TENSOR_mask = torch.tensor(img_mask)
        # label = self.labels[index]
        return TENSOR_pet, TENSOR_mask
shape = [128,128,128]
test_set = PLEDataset(path,shape,transform_pet,transform_mask)
test_loader = torch.utils.data.DataLoader(test_set,batch_size=2,shuffle=False)

In [None]:
import monai
model = monai.networks.nets.UNet(spatial_dims=3, in_channels=1, out_channels=2,channels=(32,64,128,256),strides=(2,2,2)).to(device)
# model = monai.networks.nets.BasicUNet().to(device)  # out of memory
# model = monai.networks.nets.AttentionUnet(spatial_dims=3, in_channels=1, out_channels=2, channels=(32,64,128,256),strides=(2,2,2)).to(device)  # out of memory
# model = monai.networks.nets.SwinUNETR(img_size=(128,128,128), in_channels=1, out_channels=2).to(device) # out of memory
pretrain_path = './output/unet_pre60_meandice/weights/model120.pth'
no_cuda = False
gpu_id = [0]
if not no_cuda:
    if len(gpu_id) > 1:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=gpu_id)
        net_dict = model.state_dict()
    else:
        import os
        os.environ["CUDA_VISIBLE_DEVICES"]=str(gpu_id[0])
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)
        net_dict = model.state_dict()
else:
    net_dict = model.state_dict()

print('loading pretrained model {}'.format(pretrain_path))
pretrain = torch.load(pretrain_path)
# pretrain_dict = {k: v for k, v in pretrain['state_dict'].items() if k in net_dict.keys()}
pretrain_dict = {k: v for k, v in pretrain.items() if k in net_dict.keys()}
net_dict.update(pretrain_dict) 
model.load_state_dict(net_dict) 
print("-------- pre-train model load successfully --------")



In [None]:
from utils import logger, weights_init, metrics, common, loss

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# criterion = torch.nn.BCELoss().to(device)
criterion = loss.DiceLoss().to(device)
scheduler = ExponentialLR(optimizer, gamma=0.99)
num_epochs = 1

In [None]:
from sklearn.metrics import roc_curve

from sklearn.metrics import auc
for epoch in range(num_epochs):
    start = time.time()
    per_epoch_loss = 0
    score_list = [] 
    label_list = []
    per_epoch_metrics = 0
    
    val_per_epoch_loss = 0
    val_per_epoch_metrics = 0
    model.eval()
    with torch.no_grad():
        for x,label in tqdm(test_loader):
            x = x.float() 
            x = x.to(device)
            label = label.long()
            label = common.to_one_hot_3d(label, 2)
            label = label.to(device)
            #label_n = label.cpu().numpy()
            
            # val_label_list.extend(label.cpu().numpy())
            x=torch.unsqueeze(x, dim=1)
           
            # Forward pass
            logits = model(x)
            loss = criterion(logits, label)
            # print(loss)
            val_per_epoch_loss += loss.item()
            # metrics = monai.metrics.compute_iou(logits,label,include_background=False)[0]
            score, not_nans = monai.metrics.DiceHelper(include_background=False, sigmoid=True, softmax=True)(logits,label)
            val_per_epoch_metrics += score.item()
        # print("val Epoch: {}\t Acc: {:.6f}".format(epoch,val_num_correct/len(loader.dataset)))
        
        # summaryWriter.add_scalars('acc', {"val_acc":val_num_correct/len(loader.dataset)}, epoch)
        # summaryWriter.add_scalars('time', {"time":(time.time() - start)}, epoch)

            # pro_list = prob_out.detach().cpu().numpy()
            
            #print(pro_list)
            # for i in range(pro_list.shape[0]):
            #     if (pro_list[i] > 0.5) == label.cpu().numpy()[i]:
            #         val_num_correct += 1
            
            # val_score_list.extend(pro_list)
            

        # score_array = np.array(val_score_list)
        # label_array = np.array(val_label_list)
        # fpr_keras_1, tpr_keras_1, thresholds_keras_1 = roc_curve(label_array, score_array)
        # auc_keras_1 = auc(fpr_keras_1,tpr_keras_1)        

        # print("val Epoch: {}\t Acc: {:.6f} AUC: {:.6f} ".format(epoch,val_num_correct/len(test_loader.dataset),auc_keras_1))
        # summaryWriter.add_scalars('acc', {"val_acc":val_num_correct/len(test_loader.dataset)}, epoch)
        # summaryWriter.add_scalars('auc', {"val_auc":auc_keras_1}, epoch)
        # summaryWriter.add_scalars('time', {"time":(time.time() - start)}, epoch)
        
    scheduler.step()

    #filepath = "./weights"
    #folder = os.path.exists(filepath)
    #if not folder:
    #    os.makedirs(filepath)
    #path = './weights/model' + str(epoch) + '.pth'
    #torch.save(model.state_dict(), path)

In [None]:
output = logits.cpu().numpy()
target = label.cpu().numpy()

In [None]:
output1 = output[0,0,:,:,:]
output2 = output[0,1,:,:,:]
target1 = target[0,0,:,:,:]
target2 = target[0,1,:,:,:]

In [None]:
img1 = sitk.GetImageFromArray(output1)
img2 = sitk.GetImageFromArray(output2)
img3 = sitk.GetImageFromArray(target1)
img4 = sitk.GetImageFromArray(target2)
sitk.WriteImage(img1, './output1.nii.gz')
sitk.WriteImage(img2, './output2.nii.gz')
sitk.WriteImage(img3, './target1.nii.gz')
sitk.WriteImage(img4, './target2.nii.gz')
