#Start

In [None]:
import os
current_path = os.getcwd()
!nvidia-smi

In [None]:
import os
current_path = os.getcwd()
from google.colab import drive
drive.mount('/content/drive')
!pip install simpleITK==1.2.4;
!pip install -q "monai-weekly[gdown, nibabel, tqdm, itk]"
import monai
import os
import shutil
import tempfile
import matplotlib.pyplot as plt
from PIL import Image
import torch
import numpy as np
from sklearn.metrics import classification_report
from monai.apps import download_and_extract
from monai.config import print_config
from monai.metrics import ROCAUCMetric
from monai.networks.nets import DenseNet121
from monai.transforms import AddChannel, Compose, RandAffine, RandRotate90, RandFlip, apply_transform, ToTensor
from monai.data import Dataset, DataLoader
from monai.utils import set_determinism
from monai.transforms import \
    LoadImageD, EnsureChannelFirstD, AddChannelD, ScaleIntensityD, ToTensorD, Compose, \
    AsDiscreteD, SpacingD, OrientationD, ResizeD, RandAffineD
KEYS = ("img", "seg")

xform = Compose([
    LoadImageD(KEYS),
    RandAffineD(KEYS,
                rotate_range=(np.pi/2, np.pi/2, np.pi/2),
                # translate_range=(20, 20, 20),
                scale_range=(0.1, 0.1, 0.1),
                mode=('bilinear', 'nearest'),
                prob=0.2),
    # RandFlipD(KEYS,prob=0.5),
    ToTensorD(KEYS)
])
KEYS_G = ("img_global", "seg_global")
xform_G = Compose([
    LoadImageD(KEYS_G),
    RandAffineD(KEYS_G,
                rotate_range=(np.pi/2, np.pi/2, np.pi/2),
                # translate_range=(20, 20, 20),
                scale_range=(0.1, 0.1, 0.1),
                mode=('bilinear', 'nearest'),
                prob=0.2),
    # RandFlipD(KEYS,prob=0.5),
    ToTensorD(KEYS_G)
])

#Pre-propressing

In [None]:
import pandas as pd
import numpy as np
import SimpleITK as sitk
import matplotlib.pyplot as plt
from ipywidgets import interact, fixed
from IPython.display import display
import gc
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import gc
# Calculate parameters low and high from window and level
def wl_to_lh(window, level):
    low = level - window/2
    high = level + window/2
    return low,high

def display_image(img, x=None, y=None, z=None, window=None, level=None, colormap='gray', crosshair=False,figure_name=''):
    # Convert SimpleITK image to NumPy array
    img_array = sitk.GetArrayFromImage(img)

    # Get image dimensions in millimetres
    size = img.GetSize()
    spacing = img.GetSpacing()
    width  = size[0] * spacing[0]
    height = size[1] * spacing[1]
    depth  = size[2] * spacing[2]

    if x is None:
        x = np.floor(size[0]/2).astype(int)
    if y is None:
        y = np.floor(size[1]/2).astype(int)
    if z is None:
        z = np.floor(size[2]/2).astype(int)

    if window is None:
        window = np.max(img_array) - np.min(img_array)

    if level is None:
        level = window / 2 + np.min(img_array)

    low,high = wl_to_lh(window,level)

    # Display the orthogonal slices
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(10, 4))

    ax1.imshow(img_array[z,:,:], cmap=colormap, clim=(low, high), extent=(0, width, height, 0))
    ax2.imshow(img_array[:,y,:], origin='lower', cmap=colormap, clim=(low, high), extent=(0, width,  0, depth))
    ax3.imshow(img_array[:,:,x], origin='lower', cmap=colormap, clim=(low, high), extent=(0, height, 0, depth))

    # Additionally display crosshairs
    if crosshair:
        ax1.axhline(y * spacing[1], lw=1)
        ax1.axvline(x * spacing[0], lw=1)
        ax2.axhline(z * spacing[2], lw=1)
        ax2.axvline(x * spacing[0], lw=1)
        ax3.axhline(z * spacing[2], lw=1)
        ax3.axvline(y * spacing[1], lw=1)
    if figure_name != '':
      plt.savefig(os.path.join(figure_dir,str(figure_name)+'.jpg'))
    plt.show()
    plt.close()

def interactive_view(img):
    size = img.GetSize()
    img_array = sitk.GetArrayFromImage(img)
    interact(display_image,img=fixed(img),
             x=(0, size[0] - 1),
             y=(0, size[1] - 1),
             z=(0, size[2] - 1),
             window=(0,np.max(img_array) - np.min(img_array)),
             level=(np.min(img_array),np.max(img_array)));

def make_one_hot(vol, mask):
    lens = len(mask)
    shape = np.array(vol.shape)
    shape[1] = lens
    shape = tuple(shape)
    result = torch.zeros(shape,dtype = vol.dtype).to(vol.device)

    for idx, label in enumerate(mask):
        tmp = vol == label
        result[:,idx] = tmp.squeeze(1)

    return result

In [None]:
def zero_mean_unit_var(image, mask):
    """Normalizes an image to zero mean and unit variance."""

    img_array = sitk.GetArrayFromImage(image)
    img_array = img_array.astype(np.float32)

    msk_array = sitk.GetArrayFromImage(mask)

    mean = np.mean(img_array[msk_array>0])
    std = np.std(img_array[msk_array>0])

    if std > 0:
        img_array = (img_array - mean) / std
        img_array[msk_array==0] = 0

    image_normalised = sitk.GetImageFromArray(img_array)
    image_normalised.CopyInformation(image)

    return image_normalised


def resample_image(image, out_spacing=(1.0, 1.0, 1.0), out_size=None, is_label=False, pad_value=0):
    """Resamples an image to given element spacing and output size."""

    original_spacing = np.array(image.GetSpacing())
    original_size = np.array(image.GetSize())

    if out_size is None:
        out_size = np.round(np.array(original_size * original_spacing / np.array(out_spacing))).astype(int)
    else:
        out_size = np.array(out_size)

    original_direction = np.array(image.GetDirection()).reshape(len(original_spacing),-1)
    original_center = (np.array(original_size, dtype=float) - 1.0) / 2.0 * original_spacing
    out_center = (np.array(out_size, dtype=float) - 1.0) / 2.0 * np.array(out_spacing)

    original_center = np.matmul(original_direction, original_center)
    out_center = np.matmul(original_direction, out_center)
    out_origin = np.array(image.GetOrigin()) + (original_center - out_center)

    resample = sitk.ResampleImageFilter()
    resample.SetOutputSpacing(out_spacing)
    resample.SetSize(out_size.tolist())
    resample.SetOutputDirection(image.GetDirection())
    resample.SetOutputOrigin(out_origin.tolist())
    resample.SetTransform(sitk.Transform())
    resample.SetDefaultPixelValue(pad_value)

    if is_label:
        resample.SetInterpolator(sitk.sitkNearestNeighbor)
    else:
        resample.SetInterpolator(sitk.sitkBSpline)
    import os
    import sys
    import psutil

    del original_size

    gc.collect()
    return resample.Execute(image)

def prepadding(data,kernel_c,kernel_h,kernel_w,step):
    #############input size:
    ####################batch_size, depth(n_channels), n_rows, n_cols
    import torch
    import torch.nn
    batch_size, depth, height, width = data.shape[0],data.shape[1],data.shape[2],data.shape[3]
    data = data.unsqueeze(1)

    # calculate the padding for two edges
    def cal_padding(n,k,stride,p):
        #the two boundry: left  right
        total_padding = stride - (n - (k+(int((n-k+2*p)/stride))*stride))
        if total_padding == stride:
          total_padding = 0
        left = int(total_padding/2)
        right = total_padding - left
        return left,right

    padding_left,padding_right = cal_padding(n=width,k=kernel_w,stride=step,p=0)
    padding_top,padding_bottom = cal_padding(n=height,k=kernel_h,stride=step,p=0)
    padding_front,padding_back = cal_padding(n=depth,k=kernel_c,stride=step,p=0)

    output = torch.zeros((batch_size,depth+padding_front+padding_back,height+padding_top+padding_bottom,width+padding_left+padding_right),dtype = data.dtype).to(device)
    output[:,padding_front:depth+padding_front,padding_top:height+padding_top,padding_left:width+padding_left]= data
    print(output.shape)
    del data
    gc.collect()
    return output



def get_patches(data,kernel_c=32, kernel_h=32, kernel_w=32,step=32):
    data = prepadding(data,kernel_c,kernel_h,kernel_w,step)
    gc.collect()
    batch_size, n_channels, n_rows, n_cols = data.shape[0],data.shape[1],data.shape[2],data.shape[3]
    data = data.unfold(1, kernel_c, step).unfold(2, kernel_h, step).unfold(3, kernel_w, step)
    data = data.permute(1, 2, 3, 0, 4, 5, 6).reshape(-1, kernel_c, kernel_h, kernel_w)
    
    return data

def get_patches_and_bound(data,kernel_c=32, kernel_h=32, kernel_w=32,step=32):
    data = prepadding(data,kernel_c,kernel_h,kernel_w,step)
    gc.collect()
    batch_size, n_channels, n_rows, n_cols = data.shape[0],data.shape[1],data.shape[2],data.shape[3]
    data = data.unfold(1, kernel_c, step).unfold(2, kernel_h, step).unfold(3, kernel_w, step)
    shapes = data.shape[0:4]
    data = data.permute(1, 2, 3, 0, 4, 5, 6).reshape(-1, kernel_c, kernel_h, kernel_w)
    
    return data,shapes

#find patch locations
def find_location(index_patches,shape1,step):
    count = 0
    for b in range(1):
        for c in range(shape1[1]):
            for h in range(shape1[2]):
                for w in range(shape1[3]):

                    if index_patches == count:
                        ld = c*step
                        lh = h*step
                        lw = w*step
                        return (ld,lh,lw)
                    count+=1
def crop_and_pad_2x(x,ld,lh,lw,len_d,len_h,len_w):

  shape1 = x.shape 
  len_d_total,len_h_total,len_w_total =len_d*2,len_h*2, len_w*2 


  padding_front,padding_back,padding_top,padding_bottom,padding_left,padding_right =0,0,0,0,0,0

  expand_front,expand_back = int(len_d/2),len_d - int(len_d/2)
  expand_top,expand_bottom = int(len_h/2),len_h - int(len_h/2)
  expand_left,expand_right = int(len_w/2),len_w - int(len_w/2)

  
  ld_crop,lh_crop,lw_crop = ld - expand_front ,lh - expand_top, lw - expand_left
  ld_crop_,lh_crop_,lw_crop_ = ld+len_d+expand_back,lh+len_h+expand_bottom,lw+len_w+expand_right
  #
  if ld_crop<0:
    padding_front  = -ld_crop; ld_crop = 0
  if lh_crop<0:
    padding_top = -lh_crop; lh_crop = 0
  if lw_crop<0:
    padding_left = -lw_crop; lw_crop = 0
  #
  if ld_crop_>shape1[2]:
    padding_back = ld_crop_ - shape1[2]; ld_crop_= shape1[2]
  if lh_crop_>shape1[3]:
    padding_bottom = lh_crop_ - shape1[3]; lh_crops_= shape1[3]
  if lw_crop_>shape1[4]:
    padding_right = lw_crop_ - shape1[4]; lw_crop_= shape1[4]

  #crop
  crop = x[:,:,ld_crop:ld_crop_,lh_crop:lh_crop_,lw_crop:lw_crop_]
  #padding
  pad_dims = (padding_left,padding_right,padding_top,padding_bottom,padding_front,padding_back)
  return F.pad(crop,pad_dims,"constant")


def crop_and_pad_1_5x(x,ld,lh,lw,len_d,len_h,len_w):

  shape1 = x.shape 
  len_d_total,len_h_total,len_w_total =int(len_d*1.5),int(len_h*1.5), int(len_w*1.5) #


  padding_front,padding_back,padding_top,padding_bottom,padding_left,padding_right =0,0,0,0,0,0

  expand_front,expand_back = int(len_d/4),int(len_d/2) - int(len_d/4)
  expand_top,expand_bottom = int(len_h/4),int(len_h/2) - int(len_h/4)
  expand_left,expand_right = int(len_w/4),int(len_w/2) - int(len_w/4)

  #
  ld_crop,lh_crop,lw_crop = ld - expand_front ,lh - expand_top, lw - expand_left
  ld_crop_,lh_crop_,lw_crop_ = ld+len_d+expand_back,lh+len_h+expand_bottom,lw+len_w+expand_right
  #
  if ld_crop<0:#
    padding_front  = -ld_crop; ld_crop = 0
  if lh_crop<0:
    padding_top = -lh_crop; lh_crop = 0
  if lw_crop<0:
    padding_left = -lw_crop; lw_crop = 0
  #
  if ld_crop_>shape1[2]:#
    padding_back = ld_crop_ - shape1[2]; ld_crop_= shape1[2]
  if lh_crop_>shape1[3]:#
    padding_bottom = lh_crop_ - shape1[3]; lh_crops_= shape1[3]
  if lw_crop_>shape1[4]:#
    padding_right = lw_crop_ - shape1[4]; lw_crop_= shape1[4]

  #crop
  crop = x[:,:,ld_crop:ld_crop_,lh_crop:lh_crop_,lw_crop:lw_crop_]
  #padding
  pad_dims = (padding_left,padding_right,padding_top,padding_bottom,padding_front,padding_back)
  return F.pad(crop,pad_dims,"constant")

### dataset 

In [None]:
class ImageSegmentationDataset_separate(Dataset):
    """Dataset for image segmentation."""

    def __init__(self, file_list_img, file_list_seg, img_spacing, img_size,
                 patch=True, kernel_c=64, kernel_h=64, kernel_w=64, step=32, TEST=False):
        self.imgs = []
        self.imgs_global = []
        self.segs = []
        self.segs_global = []
        self.img_names = []
        self.seg_names = []
        # count the number of class items in patches for weights
        self.counts = torch.zeros(9)
        self.test = TEST
        self.datasetname = 'train' if not TEST else 'val_test'
        torch.manual_seed(1928)
        roulette_fix = list(torch.rand(50000))
 
        with torch.no_grad():
            for idx, _ in enumerate(tqdm(range(len(file_list_img)), desc='Loading Data')):
                valid_index = []
                seg_path = file_list_seg[idx]

                seg = sitk.ReadImage(seg_path, sitk.sitkInt8)


                # ##resample image
                # seg = resample_image(seg, img_spacing, img_size, is_label=True)

                # from image  numpy to torch
                seg = torch.from_numpy(sitk.GetArrayFromImage(seg)).unsqueeze(0).to(device)
                seg_whole = seg #原图
                # print('shape: ',img.shape,seg.shape)
                # patches
                # seg = get_patches(seg, kernel_c, kernel_h, kernel_w, step)
                seg,shape1 = get_patches_and_bound(seg, kernel_c, kernel_h, kernel_w, step)
                if not self.test:
                  seg_whole_pad = prepadding(seg_whole,kernel_c,kernel_h,kernel_w,step).unsqueeze(0) 
                del seg_whole

                for index_patches in range(seg.shape[0]):
                    # print(img.shape,seg.shape)
                    sum_0 = torch.sum(seg[index_patches] == 0)
                    sum_numel = seg[index_patches].numel()
                    background = sum_0 / sum_numel
                    print(sum_0, sum_numel, background)





                    # if background>0.93 and not TEST:
                    if background >= threshold and not TEST:
                      if roulette_fix.pop()<roulette:
                        pass
                      else:
                        continue
                    #
                    ld,lh,lw = find_location(index_patches,shape1,step)
                    if not self.test:
                      seg_global = crop_and_pad_1_5x(seg_whole_pad,ld,lh,lw,kernel_c, kernel_h, kernel_w).char() 


                    # uni = torch.unique(seg[index_patches])
                    uni = [x for x in range(num_classes)]
                    for i in uni:
                      sum_item = torch.sum(seg[index_patches] == i)
                      self.counts[i] += sum_item.cpu()


                    valid_index.append(index_patches)
                    sample =  seg[index_patches].unsqueeze(0).cpu()
                    #****
                    seg_saving_path = os.path.join(working_path,self.datasetname,'seg',str(idx))
                    if not os.path.exists(seg_saving_path):
                      os.makedirs(seg_saving_path)

                    seg_saving_path = os.path.join(seg_saving_path,str(index_patches)+'.npy')
                    np.save(seg_saving_path,sample)

                    self.segs.append(seg_saving_path)
                    #****
                    
                    if not self.test:
                      #seg global
                      sample =  seg_global.squeeze(0).cpu()
                      del seg_global
                      #****
                      seg_saving_path = os.path.join(working_path,self.datasetname,'seg_global',str(idx))
                      if not os.path.exists(seg_saving_path):
                        os.makedirs(seg_saving_path)

                      seg_saving_path = os.path.join(seg_saving_path,str(index_patches)+'.npy')
                      np.save(seg_saving_path,sample)

                      self.segs_global.append(seg_saving_path)
                      # #****
                    
                      # self.seg_names.append(os.path.basename(seg_path))
                      import gc
                      del sample
                    import gc
                    gc.collect()
                del seg;            
                # del seg_whole_pad
                import gc
                gc.collect()


###img
                img_path = file_list_img[idx]
                img = sitk.ReadImage(img_path, sitk.sitkUInt8)
                   
                # pre=processing
                # img = zero_mean_unit_var(img, msk)
            
                # ##resample image
                # img = resample_image(img, img_spacing, img_size, is_label=False)
                # seg = resample_image(seg, img_spacing, img_size, is_label=True)
            
                # from image  numpy to torch
                img = torch.from_numpy(sitk.GetArrayFromImage(img)).unsqueeze(0).to(device)

                # # print('shape: ',img.shape,seg.shape)
                # # patches
                # torch.cuda.empty_cache()
                # img = get_patches(img, kernel_c, kernel_h, kernel_w, step)
                img_whole = img 
                # print('shape: ',img.shape,seg.shape)
                # patches
                # seg = get_patches(seg, kernel_c, kernel_h, kernel_w, step)
                img,shape1 = get_patches_and_bound(img, kernel_c, kernel_h, kernel_w, step)
                if not self.test:
                  img_whole_pad = prepadding(img_whole,kernel_c,kernel_h,kernel_w,step).unsqueeze(0) 
                del img_whole
                gc.collect()
                for index_patches in range(img.shape[0]):
                    if index_patches  not in valid_index:
                        continue
                    # print(img.shape,seg.shape)

                    ld,lh,lw = find_location(index_patches,shape1,step)
                    if not self.test:
                      img_global = crop_and_pad_1_5x(img_whole_pad,ld,lh,lw,kernel_c, kernel_h, kernel_w).byte() 

                    sample = img[index_patches].unsqueeze(0).cpu()
                    #****
                    img_saving_path = os.path.join(working_path,self.datasetname,'img',str(idx))
                    if not os.path.exists(img_saving_path):
                      os.makedirs(img_saving_path)
                    img_saving_path = os.path.join(img_saving_path,str(index_patches)+'.npy')
                    np.save(img_saving_path,sample)
             
                    self.imgs.append(img_saving_path)
                    #*****

                    if not self.test:
                      sample = img_global.squeeze(0).cpu()
                      del img_global
                      #****
                      img_saving_path = os.path.join(working_path,self.datasetname,'img_global',str(idx))
                      if not os.path.exists(img_saving_path):
                        os.makedirs(img_saving_path)
                      img_saving_path = os.path.join(img_saving_path,str(index_patches)+'.npy')
                      np.save(img_saving_path,sample)
              
                      self.imgs_global.append(img_saving_path)
                      #*****
                      self.img_names.append(os.path.basename(img_path))
                      import gc
                      del sample
                      gc.collect()
                del img;
                try:
                  del img_whole_pad
                except:
                  pass
                gc.collect()
            
            
    def __len__(self):
        return len(self.imgs)



    def __getitem__(self, item):

        if not self.test:
          img_saving_path = self.imgs[item]
          seg_saving_path = self.segs[item]
          img_saving_path_global = self.imgs_global[item]
          seg_saving_path_global = self.segs_global[item]

          data_dict = xform({'img': img_saving_path, 'seg': seg_saving_path})
          img = data_dict['img'].byte() #uint8   
          seg = data_dict['seg'].char() #int8  
          data_dict_global = xform_G({'img_global': img_saving_path_global,'seg_global': seg_saving_path_global})
          img_global = data_dict_global['img_global'].byte() #uint8   
          seg_global = data_dict_global['seg_global']
          seg_global = F.interpolate(seg_global.unsqueeze(0), scale_factor=0.5, mode='nearest',recompute_scale_factor=False).squeeze(0).char()
          # print('four shapes',img.shape,seg.shape,img_global.shape,seg_global.shape)
          return {'img': img, 'seg': seg, 'img_global': img_global,'seg_global': seg_global}
      
        else:

          img_saving_path = self.imgs[item]
          img = np.load(img_saving_path)
          img=torch.from_numpy(img)

          seg_saving_path = self.segs[item]
          seg = np.load(seg_saving_path)
          seg=torch.from_numpy(seg)

          # img_saving_path_global = self.imgs_global[item]
          # img_global = np.load(img_saving_path_global)
          # img_global=torch.from_numpy(img_global)
          img_global = torch.zeros(1)
          # seg_saving_path_global = self.segs_global[item]
          # seg_global = np.load(seg_saving_path_global)
          # seg_global=torch.from_numpy(seg_global)
          seg_global = torch.zeros(1)
          return {'img': img, 'seg': seg, 'img_global': img_global,'seg_global': seg_global}


    def get_sample(self, item):
        img_saving_path = self.imgs[item]
        img = np.load(img_saving_path)
        img=torch.from_numpy(img)

        seg_saving_path = self.segs[item]
        seg = np.load(seg_saving_path)
        seg=torch.from_numpy(seg)

        img_saving_path_global = self.imgs_global[item]
        img_global = np.load(img_saving_path_global)
        img_global=torch.from_numpy(img_global)

        seg_saving_path_global = self.segs_global[item]
        seg_global = np.load(seg_saving_path_global)
        seg_global=torch.from_numpy(seg_global)
        return {'img': img, 'seg': seg, 'img_global': img_global,'seg_global': seg_global}
    def get_img_name(self, item):
        return self.img_names[item]

    def get_seg_name(self, item):
        return self.seg_names[item]


In [None]:

cuda_dev = '0' #GPU device 0 (can be changed if multiple GPUs are available)

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:" + cuda_dev if use_cuda else "cpu")

print('Device: ' + str(device))
if use_cuda:
    print('GPU: ' + str(torch.cuda.get_device_name(int(cuda_dev))))


### hyperparameter for training and validation


In [None]:
import os
import sys
import argparse
rnd_seed = 42 #fixed random seed
load_checkpoint = True
img_size = [470, 470, 1072]
img_spacing = [1,1,1]
val_interval = 1
num_classes = 3
use_amp = True 
parser = argparse.ArgumentParser()
parser.add_argument("--load_checkpoint", type=str,choices=['True','False'],default='False')
parser.add_argument("--num_epochs", type=int,default=400)
parser.add_argument("--learning_rate", type=float,default=0.0009)
parser.add_argument("--batch_size", type=int,default=4)
parser.add_argument("--criterion_type", type=str,choices=['CrossEntropy', 'WeightedCrossEntropy', 'WeightedCrossEntropyAdaptive', 'FocalLoss','DiceLoss','W_D'],default='W_D')
parser.add_argument("--use_amp", type=str,choices=['True','False'],default='True')
parser.add_argument("--out_dir", type=str,default='/content/drive/MyDrive/Computing_Project/publish/meUnet')

In [None]:
parser.add_argument("--kernel_c", type=int,default=160)
parser.add_argument("--kernel_h", type=int,default=160)
parser.add_argument("--kernel_w", type=int,default=160)
parser.add_argument("--step", type=int,default=80)
parser.add_argument("--threshold", type=float,default=0.98)
parser.add_argument("--roulette", type=float,default=0.05)
parser.add_argument("--label_type", type=str,choices=['all'],default='all')
parser.add_argument("--data_type", type=str,choices=['all_data', 'male', 'mated', 'virgin'],default='all_data')
parser.add_argument("--labelset", type=str,choices=['original', '1072_470', 'gut'],default='gut')

In [None]:
parser.add_argument("--data_dir_root", type=str,default='/content/drive/MyDrive/Computing_Project/fruit_fly')

##parse

In [None]:
try:
#for py
  args = parser.parse_args()  
  print('This is py')
  data_dir_root='/rds/general/user/yw720/home/fruit_fly_google_drive'
  current_path = '/rds/general/user/yw720/ephemeral'
  import SimpleITK as sitk
  MAX_THREADS = 8
  sitk.ProcessObject.SetGlobalDefaultNumberOfThreads(MAX_THREADS) 
except:
#for jupyter
  print('This is jupyter')
  args = parser.parse_args(args=[])
  # data_dir_combination_root = args.data_dir_combination_root
  data_dir_root = args.data_dir_root
print(args)
if args.load_checkpoint == 'True':
    load_checkpoint = True
num_epochs = args.num_epochs
learning_rate = args.learning_rate
batch_size = args.batch_size
criterion_type =args.criterion_type
if args.use_amp == 'False':
  use_amp = False
out_dir = args.out_dir
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
figure_dir = out_dir+'/Figure/'
figure_dir = os.path.join(out_dir, 'Figure_separate2')
if not os.path.exists(figure_dir):
    os.makedirs(figure_dir)
kernel_c = args.kernel_c; kernel_h=args.kernel_h; kernel_w = args.kernel_w
step = args.step;
threshold = args.threshold
roulette = args.roulette
label_type = args.label_type
data_type = args.data_type
labelset = args.labelset

In [None]:
working_path = os.path.join(current_path, 'save'+out_dir[-8:-1])

#Load data

##load dir

Load male data dir

In [None]:
if labelset == 'original':
  data_dir_combination = data_dir_root+'/male_seg/' 
if labelset == '1072_470':
  data_dir_combination =  data_dir_root+'/male_seg_1072_470/'
if labelset == 'gut':
  data_dir_combination =data_dir_root+ '/male_seg_gut/'
# data_dir_combination = data_dir_combination_root+'male_seg{}/'.format(labelset)
#segmentation

list_male_seg = [data_dir_combination + 'M1{}.nii.gz'.format(chr(ord('A')+x)) for x in range(0,5)]
list_male_seg +=[data_dir_combination + 'M2{}.nii.gz'.format(chr(ord('A')+x)) for x in range(0,5)]
#data
if labelset == 'original':
  data_dir = data_dir_root+'/male_microCT/'
if labelset == '1072_470':
  data_dir = data_dir_root+'/male_1072_470/'
if labelset == 'gut':
  data_dir = data_dir_root+'/male_gut/'
# data_dir = data_dir_root++'male{}/'.format(labelset)
list_male = [data_dir + 'M1{}.nii.gz'.format(chr(ord('A')+x)) for x in range(0,5)]
list_male +=[data_dir + 'M2{}.nii.gz'.format(chr(ord('A')+x)) for x in range(0,5)]


load mated data dir

In [None]:
if labelset == 'original':
  data_dir_combination = data_dir_root+'/mated_seg/'
if labelset == '1072_470':
  data_dir_combination = data_dir_root+'/mated_seg_1072_470/'
if labelset == 'gut':
  data_dir_combination = data_dir_root+'/mated_seg_gut/'
#segmentation

list_mated_seg =  [data_dir_combination + 'F1{}.nii.gz'.format(chr(ord('A')+x)) for x in range(0,5)]
list_mated_seg +=[data_dir_combination + 'F2{}.nii.gz'.format(chr(ord('A')+x)) for x in range(0,5)]

#data
if labelset == 'original':
  data_dir = data_dir_root+'/mated_microCT/'
if labelset == '1072_470':
  data_dir = data_dir_root+'/mated_1072_470/'
if labelset == 'gut':
  data_dir = data_dir_root+'/mated_gut/'
list_mated = [data_dir + 'F1{}.nii.gz'.format(chr(ord('A')+x)) for x in range(0,5)]
list_mated +=[data_dir + 'F2{}.nii.gz'.format(chr(ord('A')+x)) for x in range(0,5)]


load virgin data dir

In [None]:
if labelset == 'original':
  data_dir_combination = data_dir_root+'/virgin_seg/'
if labelset == '1072_470':
  data_dir_combination = data_dir_root+'/virgin_seg_1072_470/'
if labelset == 'gut':
  data_dir_combination = data_dir_root+'/virgin_seg_gut/'
#segmentation

list_virgin_seg = [data_dir_combination + 'V1{}.nii.gz'.format(chr(ord('A')+x)) for x in range(0,5)]
list_virgin_seg +=[data_dir_combination + 'V2{}.nii.gz'.format(chr(ord('A')+x)) for x in range(0,5)]

# #data
if labelset == 'original':
  data_dir = data_dir_root+'/virgin_microCT/'
if labelset == '1072_470':
  data_dir = data_dir_root+'/virgin_1072_470/'
if labelset == 'gut':
  data_dir = data_dir_root+'/virgin_gut/'
list_virgin = [data_dir + 'V1{}.nii.gz'.format(chr(ord('A')+x)) for x in range(0,5)]
list_virgin +=[data_dir + 'V2{}.nii.gz'.format(chr(ord('A')+x)) for x in range(0,5)]


load train and validation data file

In [None]:
if data_type == 'all_data':
    #img
    files_seg_img_train = list_male[0:5]+ list_male[7:9] + list_mated[0:5] + list_mated[7:9] + list_virgin[0:5] + list_virgin[7:9]
    #seg
    files_seg_seg_train = list_male_seg[0:5]+ list_male_seg[7:9] + list_mated_seg[0:5] + list_mated_seg[7:9] + list_virgin_seg[0:5] + list_virgin_seg[7:9]
    #img
    files_seg_img_val = list_male[9:10] + list_mated[9:10] + list_virgin[9:10]
    #seg
    files_seg_seg_val = list_male_seg[9:10] + list_mated_seg[9:10] + list_virgin_seg[9:10]

##load train and val data

In [None]:
# LOAD TRAINING DATA separate
dataset_train = ImageSegmentationDataset_separate(files_seg_img_train, files_seg_seg_train, img_spacing, img_size,True,kernel_c, kernel_h, kernel_w,step,False)
dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers = 4, prefetch_factor=2,pin_memory = True)

# LOAD VALIDATION DATA
#let  step = kernel_c for validation set
dataset_val = ImageSegmentationDataset_separate(files_seg_img_val, files_seg_seg_val, img_spacing,img_size,True,int(kernel_c*1.5), int(kernel_h*1.5), int(kernel_w*1.5),int(kernel_c*1.5)  ,True)
dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_size, shuffle=False, num_workers = 4, prefetch_factor=2,pin_memory = True)


#Train

##model: meU-net

In [None]:
import torch.utils.checkpoint as cp
class DummyLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.dummy = nn.Parameter(torch.ones(1, dtype=torch.float32))
    def forward(self,x):
        return x + self.dummy - self.dummy #(also tried x+self.dummy)

def conv_block(in_chan, out_chan, stride=1):
    return nn.Sequential(
        nn.Conv3d(in_chan, out_chan, kernel_size=3, padding=1, stride=stride),
        nn.BatchNorm3d(out_chan),
        nn.ReLU(inplace=True)
    )

def conv_stage(in_chan, out_chan):
    return nn.Sequential(
        conv_block(in_chan, out_chan),
        conv_block(out_chan, out_chan),
    )


class UNet3d(nn.Module):

    def __init__(self):
        super().__init__()
        self.dummy_layer = DummyLayer()
        self.enc1 = conv_stage(1, 16)
        self.enc2 = conv_stage(16, 32)
        self.enc3 = conv_stage(32, 64)
        self.enc4 = conv_stage(64, 128)
        self.enc5 = conv_stage(128, 128)
        self.pool = nn.MaxPool3d(2, 2)

        self.dec4 = conv_stage(256, 64)
        self.dec3 = conv_stage(128, 32)
        self.dec2 = conv_stage(64, 16)
        self.dec1 = conv_stage(32, 16)
        self.conv_out = nn.Conv3d(16, 3, 1)
        self.global_out = nn.Conv3d(16, 3, 1)

    def forward(self, x, x_global):
      #x: standard patch x_global: expanded patch
      
      if self.training:
        #training mode

        with torch.no_grad():
          enc1_global = cp.checkpoint(self.enc1,x_global)
          pool1 = self.pool(enc1_global)
        
        #dummy_layer required by checkpointing 
        pool1 = self.dummy_layer(pool1) 
        enc2_global = cp.checkpoint(self.enc2,pool1)
        enc3_global = cp.checkpoint(self.enc3,self.pool(enc2_global))
        enc4_global = cp.checkpoint(self.enc4,self.pool(enc3_global))
        enc5_global = cp.checkpoint(self.enc5,self.pool(enc4_global))          

        dec4_global = cp.checkpoint(self.dec4,torch.cat((enc4_global, F.interpolate(enc5_global, enc4_global.size()[2:], mode='nearest')), 1))
        dec3_global = cp.checkpoint(self.dec3,torch.cat((enc3_global, F.interpolate(dec4_global, enc3_global.size()[2:], mode='nearest')), 1))
        dec2_global = cp.checkpoint(self.dec2,torch.cat((enc2_global, F.interpolate(dec3_global, enc2_global.size()[2:], mode='nearest')), 1))
        dec2_global = self.global_out(dec2_global)
        
        
        x = self.dummy_layer(x)
        enc1 = cp.checkpoint(self.enc1, x)
        enc2 = cp.checkpoint(self.enc2,self.pool(enc1))
        enc3 = cp.checkpoint(self.enc3,self.pool(enc2))
        enc4 = cp.checkpoint(self.enc4,self.pool(enc3))
        enc5 = cp.checkpoint(self.enc5,self.pool(enc4))

        dec4 = cp.checkpoint(self.dec4,torch.cat((enc4, F.interpolate(enc5, enc4.size()[2:], mode='nearest')), 1))
        dec3 = cp.checkpoint(self.dec3,torch.cat((enc3, F.interpolate(dec4, enc3.size()[2:], mode='nearest')), 1))
        dec2 = cp.checkpoint(self.dec2,torch.cat((enc2, F.interpolate(dec3, enc2.size()[2:], mode='nearest')), 1))
        dec1 = cp.checkpoint(self.dec1,torch.cat((enc1, F.interpolate(dec2, enc1.size()[2:], mode='nearest')), 1))    

        out = self.conv_out(dec1)
        return (out,dec2_global)
      else:
        #validation mode

        dec2_global = None 
        x = self.dummy_layer(x)
        enc1 = cp.checkpoint(self.enc1, x)
        enc2 = cp.checkpoint(self.enc2,self.pool(enc1))
        enc3 = cp.checkpoint(self.enc3,self.pool(enc2))
        enc4 = cp.checkpoint(self.enc4,self.pool(enc3))
        enc5 = cp.checkpoint(self.enc5,self.pool(enc4))

        dec4 = cp.checkpoint(self.dec4,torch.cat((enc4, F.interpolate(enc5, enc4.size()[2:], mode='nearest')), 1))
        dec3 = cp.checkpoint(self.dec3,torch.cat((enc3, F.interpolate(dec4, enc3.size()[2:], mode='nearest')), 1))
        dec2 = cp.checkpoint(self.dec2,torch.cat((enc2, F.interpolate(dec3, enc2.size()[2:], mode='nearest')), 1))
        dec1 = cp.checkpoint(self.dec1,torch.cat((enc1, F.interpolate(dec2, enc1.size()[2:], mode='nearest')), 1))    

        out = self.conv_out(dec1)
        return (out,dec2_global)


##criterion

###weight

In [None]:
import numpy as np
def soft_max(weight):
    weight = weight/np.sum(weight)
    print(weight)
    expo = np.exp(1*weight)
    sum_expo = np.sum(expo)
    return expo/sum_expo

In [None]:
# count the number in patches for weight instead of in total data
seg_path = files_seg_seg_train[0]
seg = sitk.ReadImage(seg_path,sitk.sitkInt8)
seg = torch.from_numpy(sitk.GetArrayFromImage(seg))
count_list=[]
uni = torch.unique(seg)
counts = dataset_train.counts
for i in uni:
    n = counts[i]
    print('number of {}: '.format(i),n)
    count_list.append((n.numpy()))

weight = 1/(count_list/sum(count_list))
weight = soft_max(weight)
weight = weight.tolist()
weight_= []
weight_max = max(weight)
for i in range(num_classes):
  if i in uni:
    weight_.append(weight.pop(0))
  else:
    weight_.append(0)
del seg
gc.collect()

In [None]:
weight_ = torch.FloatTensor(weight_).to(device)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
CrossEntropy = nn.CrossEntropyLoss()
WeightedCrossEntropy = nn.CrossEntropyLoss(weight = weight_)

focal loss

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss_class(nn.modules.loss._WeightedLoss):
    def __init__(self, weight=None, gamma=2,reduction='mean'):
        super(FocalLoss_class, self).__init__(weight,reduction=reduction)
        self.gamma = gamma
        self.weight = weight #weight parameter will act as the alpha parameter to balance class weights

    def forward(self, input, target):

        ce_loss = F.cross_entropy(input, target,reduction=self.reduction,weight=self.weight)
        pt = torch.exp(-ce_loss)
        focal_loss = (torch.pow((1 - pt), self.gamma) * ce_loss).mean()
        return focal_loss
# from https://github.com/gokulprasadthekkel/pytorch-multi-class-focal-loss/blob/master/focal_loss.py
# FocalLoss = FocalLoss_class(weight = weight_)
FocalLoss = FocalLoss_class(weight = None)

In [None]:

import torch
import torch.nn as nn
import numpy as np

def make_one_hot(vol, mask):

    lens = len(mask)
    shape = np.array(vol.shape)
    shape[1] = lens
    shape = tuple(shape)
    result = torch.zeros(shape,dtype = vol.dtype).to(vol.device)

    for idx, label in enumerate(mask):
        tmp = vol == label
        result[:,idx] = tmp.squeeze(1)

    return result

class dice_loss(nn.Module):
    '''
    vol1,vol2: need to make one hot first
    '''
    def __init__(self, epsilon=1e-5):
        super(dice_loss, self).__init__()
        self.epsilon = epsilon

    def forward(self, vol1, vol2):
        shape = vol1.shape
        total_loss = 0
        uni = [x for x in range(num_classes)]
        vol2 = make_one_hot(vol2,uni)
        vol2.requires_grad = False
        import time

        for i in range(shape[1]):

            top = 2 * torch.sum(torch.mul(vol1[:,i],vol2[:,i] ))

            bottom = torch.sum(vol1[:,i]) + torch.sum(vol2[:,i])
            bottom +=self.epsilon
            loss_tmp = 1 -  (top / bottom)
            total_loss += loss_tmp
            import time


        return total_loss / shape[1]
DiceLoss = dice_loss()

###selection

In [None]:
if criterion_type not in ['CrossEntropy', 'WeightedCrossEntropy' , 'FocalLoss','DiceLoss','W_D']:
  print('please input valid criterion_type')

In [None]:
def criterion(prediction,target,weight):

  if criterion_type == 'CrossEntropy':
    return CrossEntropy(prediction,target)
  elif criterion_type == 'WeightedCrossEntropy':
    return WeightedCrossEntropy(prediction,target)
  elif criterion_type == 'WeightedCrossEntropyAdaptive':
    prediction = F.softmax(prediction,dim = 1)
    weight_list = []
    N = target.numel()
    uni = [x for x in range(num_classes)]
    target1 = make_one_hot(target.unsqueeze(1),uni).detach()
    res = (target1 * prediction).detach()
    
    for i in range(num_classes):
      sum_pn = torch.sum(res[:,i])
      weight_list.append((N-sum_pn)/sum_pn)
    weight = torch.FloatTensor(weight_list).to(device).detach()
    return F.cross_entropy(prediction, target, weight = weight)

  elif criterion_type == 'FocalLoss':
    return FocalLoss(prediction,target)
  elif criterion_type == 'DiceLoss':
    return DiceLoss(F.softmax(prediction,dim = 1),target.unsqueeze(1).byte())
  elif criterion_type == 'W_D':
    return DiceLoss(F.softmax(prediction,dim = 1),target.unsqueeze(1).byte()) + WeightedCrossEntropy(prediction,target)

In [None]:
def save_checkpoint(epoch,model,optimizer,scaler,best_val,path):
  torch.save({
              'epoch': epoch,
              'model_state_dict': model.state_dict(),
              'optimizer_state_dict': optimizer.state_dict(),
              'scaler': scaler.state_dict(),
              'best_val': best_val,
              }, path)

## training

In [None]:
from sklearn import metrics
import sklearn.metrics
model_dir = os.path.join(out_dir, 'model')
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

torch.manual_seed(rnd_seed) #fix random seed

loss_train_log = []
loss_val_log = []
epoch_val_log = []
best_val = 9999
seg_val_idx = 1018
seg_val = None
start_epoch = 0
early_stopping_threshold = 30
early_stopping_count = 0


if load_checkpoint == False:
  model = UNet3d().to(device)
  model.train()
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
  torch.backends.cudnn.benchmark = True

else:
  print('loading the checkpoint')
  model = UNet3d().to(device)
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

  PATH = save_path = os.path.join(model_dir, 'model.pt')
  checkpoint = torch.load(PATH)
  model.load_state_dict(checkpoint['model_state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  scaler.load_state_dict(checkpoint['scaler'])
  start_epoch = checkpoint['epoch']
  best_val = checkpoint['best_val']
  model.train()
  print('start_epoch = ',start_epoch)
  torch.backends.cudnn.benchmark = True



print('START TRAINING...')
import time

time_start=time.time()
for epoch in range(1+start_epoch, num_epochs + 1+start_epoch):

    # Training
    model.train()
    for batch_idx, batch_samples in enumerate(dataloader_train):
        img, seg = batch_samples['img'].to(device).float(), batch_samples['seg'].to(device)
        img_global = batch_samples['img_global'].to(device).float()
        # print(batch_idx)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast(enabled=use_amp):
          prd = model(img,img_global)
          del img;del img_global;gc.collect()
          seg_global = batch_samples['seg_global'].to(device)
          loss = criterion(prd[0], seg.squeeze(1).long(),weight = weight_) +  criterion(prd[1], seg_global.squeeze(1).long(),weight = weight_)
          del seg; del seg_global
          del prd;gc.collect()

        scaler.scale(loss).backward()
        lossitem = loss.item()
        del loss; gc.collect()
        # optimizer.step()
        if load_checkpoint and epoch == 1+start_epoch:
          optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
          start_epoch = 0
          print('reload optimizer')
          del checkpoint
          gc.collect()
          torch.cuda.empty_cache()


        scaler.step(optimizer)
        scaler.update()
        gc.collect()
        torch.cuda.empty_cache()

    loss_train_log.append(lossitem)
        
    print('+ TRAINING \tEpoch: {} \tLoss: {:.6f}'.format(epoch, lossitem))



    # Validation
    model.eval()
    val_interval = 1
    if epoch == 1 or epoch == 2 or (epoch % val_interval == 0 and epoch>20):
        loss_val = 0
        sum_pts = 0
        torch.cuda.empty_cache() 
        with torch.no_grad():
          with torch.cuda.amp.autocast(enabled=use_amp):
            # criterion_val = FocalLoss(weight = criterion.weight,gamma =criterion.gamma,reduction= 'sum' )
            for idx,data_sample in enumerate(dataloader_val):
                img, seg = data_sample['img'].to(device).half(), data_sample['seg'].to(device)
                # img_global = data_sample['img_global'].to(device).half()
                # prd = model(img)
                img_global = None
                prd = model(img,img_global)

                loss_val +=  DiceLoss(F.softmax(prd[0],dim = 1), seg).item()
                # sum_pts += np.prod(img_size)
                sum_pts += 1
        loss_val /= sum_pts

        loss_val_log.append(loss_val)
        epoch_val_log.append(epoch)

        if loss_val < best_val:
          best_val = loss_val
          save_path = os.path.join(model_dir, 'model.pt')
          save_checkpoint(epoch,model,optimizer,scaler,best_val,save_path)
          early_stopping_count = 0
          
        else:
          early_stopping_count += val_interval
        
        if early_stopping_count>early_stopping_threshold:
          print('no improvement, triggering early stopping')
          break


        print('--------------------------------------------------')
        print('+ VALIDATE \tEpoch: {} \tLoss: {:.6f}'.format(epoch, loss_val))
        print('--------------------------------------------------')
        del img;del img_global
        del seg;
        gc.collect()
        torch.cuda.empty_cache()

time_end=time.time()

print('\nFinished TRAINING.')
print('Time cost for training: ',time_end-time_start,'s')



In [None]:
save_path = os.path.join(model_dir, 'model_last.pt')
save_checkpoint(epoch,model,optimizer,scaler,best_val,save_path)

clean RAM for further testing


In [None]:
try:del prd 
except:pass
try:del img 
except:pass
try:del seg 
except:pass
try:del loss 
except:pass
try:del img_global
except:pass
del optimizer
gc.collect()
torch.cuda.empty_cache()
del dataset_train 
del dataloader_train
del dataset_val 
del dataloader_val 
gc.collect()

#Test

##model test

In [None]:
import torch.utils.checkpoint as cp
class DummyLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.dummy = nn.Parameter(torch.ones(1, dtype=torch.float32))
    def forward(self,x):
        return x + self.dummy - self.dummy #(also tried x+self.dummy)

def conv_block(in_chan, out_chan, stride=1):
    return nn.Sequential(
        nn.Conv3d(in_chan, out_chan, kernel_size=3, padding=1, stride=stride),
        nn.BatchNorm3d(out_chan),
        nn.ReLU(inplace=True)
    )


def conv_stage(in_chan, out_chan):
    return nn.Sequential(
        conv_block(in_chan, out_chan),
        conv_block(out_chan, out_chan),
    )


class UNet3d(nn.Module):

    def __init__(self):
        super().__init__()
        self.dummy_layer = DummyLayer()
        self.enc1 = conv_stage(1, 16)
        self.enc2 = conv_stage(16, 32)
        self.enc3 = conv_stage(32, 64)
        self.enc4 = conv_stage(64, 128)
        self.enc5 = conv_stage(128, 128)
        self.pool = nn.MaxPool3d(2, 2)

        self.dec4 = conv_stage(256, 64)
        self.dec3 = conv_stage(128, 32)
        self.dec2 = conv_stage(64, 16)
        self.dec1 = conv_stage(32, 16)
        self.conv_out = nn.Conv3d(16, 3, 1)

        self.global_out = nn.Conv3d(16, 3, 1)

    def forward(self, x, x_global):
       

        dec2_global = None
        
        x = self.dummy_layer(x)
        enc1 = cp.checkpoint(self.enc1, x)
        enc2 = cp.checkpoint(self.enc2,self.pool(enc1))
        enc3 = cp.checkpoint(self.enc3,self.pool(enc2))
        enc4 = cp.checkpoint(self.enc4,self.pool(enc3))
        enc5 = cp.checkpoint(self.enc5,self.pool(enc4))

        dec4 = cp.checkpoint(self.dec4,torch.cat((enc4, F.interpolate(enc5, enc4.size()[2:], mode='nearest')), 1))
        dec3 = cp.checkpoint(self.dec3,torch.cat((enc3, F.interpolate(dec4, enc3.size()[2:], mode='nearest')), 1))
        dec2 = cp.checkpoint(self.dec2,torch.cat((enc2, F.interpolate(dec3, enc2.size()[2:], mode='nearest')), 1))
        dec1 = cp.checkpoint(self.dec1,torch.cat((enc1, F.interpolate(dec2, enc1.size()[2:], mode='nearest')), 1))    

        out = self.conv_out(dec1)
        return (out,dec2_global)


## hyperparameter for testing

In [None]:
img_size_original = [470,470,1072]
kernel_c = 240; kernel_h=240; kernel_w = 240
step = 120;

In [None]:
H,W,C = img_size
def cal_shape_padding(n,k,stride,p=0):
    #the two boundry: left  right
    total_padding = stride - (n - (k+(int((n-k+2*p)/stride))*stride))
    if total_padding == stride:
      total_padding = 0        
    return  n + total_padding

C_ = cal_shape_padding(C,kernel_c,step)
H_ = cal_shape_padding(H,kernel_h,step)
W_ = cal_shape_padding(W,kernel_w,step)

In [None]:
def cal_shape(batch_size, C, H, W, kernel_c, kernel_h, kernel_w,
                    step):
    def cal(n, k, s):
        return int((n - k) / s) + 1

    i0 = cal(C, kernel_c, step)
    i1 = cal(H, kernel_h, step)
    i2 = cal(W, kernel_w, step)
    i3 = batch_size
    i4 = kernel_c
    i5 = kernel_h
    i6 = kernel_w
    return i0*i1*i2*i3
num_patch_per_block = cal_shape(1, C_, H_, W_, kernel_c, kernel_h, kernel_w,
                    step)


##load test data

In [None]:
if data_type == 'all_data':
  if label_type == 'all':
    #img
    files_seg_img_test = list_male[5:7] +list_mated[5:7] + list_virgin[5:7]
    #seg
    files_seg_seg_test = list_male_seg[5:7]+list_mated_seg[5:7] + list_virgin_seg[5:7]


#Evaluation

##utils

In [None]:
def display_image(img, x=None, y=None, z=None, window=None, level=None, colormap='gray', crosshair=False,figure_name=''):
    # Convert SimpleITK image to NumPy array
    img_array = sitk.GetArrayFromImage(img)

    # Get image dimensions in millimetres
    size = img.GetSize()
    spacing = img.GetSpacing()
    width  = size[0] * spacing[0]
    height = size[1] * spacing[1]
    depth  = size[2] * spacing[2]

    if x is None:
        x = np.floor(size[0]/2).astype(int)
    if y is None:
        y = np.floor(size[1]/2).astype(int)
    if z is None:
        z = np.floor(size[2]/2).astype(int)

    if window is None:
        window = np.max(img_array) - np.min(img_array)

    if level is None:
        level = window / 2 + np.min(img_array)

    low,high = wl_to_lh(window,level)

    # Display the orthogonal slices
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(10, 4))

    ax1.imshow(img_array[z,:,:], cmap=colormap, clim=(low, high), extent=(0, width, height, 0))
    ax2.imshow(img_array[:,y,:], origin='lower', cmap=colormap, clim=(low, high), extent=(0, width,  0, depth))
    ax3.imshow(img_array[:,:,x], origin='lower', cmap=colormap, clim=(low, high), extent=(0, height, 0, depth))
    ax1.set_xticks([])
    ax1.set_yticks([])
    ax2.set_xticks([])
    ax2.set_yticks([])  
    ax3.set_xticks([])
    ax3.set_yticks([]) 
    # Additionally display crosshairs
    if crosshair:
        ax1.axhline(y * spacing[1], lw=1)
        ax1.axvline(x * spacing[0], lw=1)
        ax2.axhline(z * spacing[2], lw=1)
        ax2.axvline(x * spacing[0], lw=1)
        ax3.axhline(z * spacing[2], lw=1)
        ax3.axvline(y * spacing[1], lw=1)
    if figure_name != '':
      # plt.savefig(os.path.join(figure_dir,str(figure_name)+'.jpg'))
      extent = ax1.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
      fig.savefig(os.path.join(figure_dir,str(figure_name)+'_1.jpg'), bbox_inches=extent)
      extent = ax2.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
      fig.savefig(os.path.join(figure_dir,str(figure_name)+'_2.jpg'), bbox_inches=extent)
      extent = ax3.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
      fig.savefig(os.path.join(figure_dir,str(figure_name)+'_3.jpg'), bbox_inches=extent)
    plt.show()

    plt.close()


def load_original_image_torch(idx):

  # img_path = file_list_img[idx]
  # seg_path = file_list_seg[idx]
  print(idx)
  img_path = files_seg_img_test[idx]
  # msk_path = file_list_msk[idx]
  # img = sitk.ReadImage(img_path, sitk.sitkFloat32)
  img_original = sitk.ReadImage(img_path,sitk.sitkUInt8)
  # seg = sitk.GetArrayFromImage(seg)
  img = torch.from_numpy(sitk.GetArrayFromImage(img_original))
  # seg = prepadding(seg.unsqueeze(0),kernel_c,kernel_h,kernel_w,step)

  
  #display
  img_display = sitk.GetImageFromArray(img.squeeze(0).cpu().numpy().astype(np.uint8))
  # img_display = sitk.LabelToRGB(img_display)
  display_image(img_display,figure_name = 'original_image'+str(idx))
  del img_display

  return img_original

def weight_patch(batch_size,num_classes,C,H,W):
    import torch

    weight_mask = torch.ones((batch_size,num_classes,C,H,W),dtype = torch.int8)
    len_c = C//2
    len_h = H//2
    len_w = W//2
    front,back = int(len_c/2),len_c + int(len_c/2)
    top,bottom = int(len_h/2),len_h + int(len_h/2)
    left,right = int(len_w/2),len_w + int(len_w/2)

    weight_mask[:,:,front:back,top:bottom,left:right] = 2
    return weight_mask

def patch_combine(windows,batch_size=1, C=96, H=96, W=96, kernel_c = 32, kernel_h=32, kernel_w = 32,step = 32 ):

    import torch

    def calculate_shape(batch_size, C, H, W, kernel_c, kernel_h, kernel_w,
                        step):
        def cal(n, k, s):
            return int((n - k) / s) + 1

        i0 = cal(C, kernel_c, step)
        i1 = cal(H, kernel_h, step)
        i2 = cal(W, kernel_w, step)
        i3 = batch_size
        i4 = kernel_c
        i5 = kernel_h
        i6 = kernel_w
        return (i0, i1, i2, i3, i4, i5, i6)

    shapes = calculate_shape(batch_size, C, H, W,kernel_c,kernel_h,kernel_w,step)

    windows = windows.reshape(shapes)


    windows  = windows.permute(3, 0, 1, 2, 4, 5, 6)

    shape1 = windows.shape

    print(C,H,W)

    repatch_com = torch.zeros((batch_size,num_classes,C,H,W),dtype = torch.int8).to(windows.device)
    print(repatch_com.shape)


    kernel_c, kernel_h, kernel_w
    for b in range(batch_size):
        for c in range(shape1[1]):
            for h in range(shape1[2]):
                for w in range(shape1[3]):
                    # print(c,h,w)
                    # repatch_com[:,c*kernel_c:(c+1)*kernel_c,h*kernel_h:(h+1)*kernel_h,w*kernel_w:(w+1)*kernel_w] = repatch_reshape[:,c,h,w,:,:,:]
                    repatch_com[b,:,c*step:c*step+kernel_c,h*step:h*step+kernel_h,w*step:w*step+kernel_w] += make_one_hot(windows[b,c,h,w,:,:,:].unsqueeze(0).unsqueeze(0),[x for x in range(num_classes)]).squeeze() *weight_patch(batch_size, num_classes, kernel_c, kernel_h, kernel_w).squeeze().to(windows.device)
    repatch_com = torch.argmax(repatch_com,dim=1)               
    del windows
    gc.collect()
    return repatch_com


def cal_padding(n,k,stride,p):
        #the two boundry: left  right
        total_padding = stride - (n - (k+(int((n-k+2*p)/stride))*stride))
        if total_padding == stride:
          total_padding = 0                
        left = int(total_padding/2)
        right = total_padding - left
        return left,right



def patch_combine_one_by_one(data,num_image=1, C=96, H=96, W=96, kernel_c = 32, kernel_h=32, kernel_w = 32,step = 32 ):
    
    #downsampling size
    batch_size, depth, height, width = num_image,img_size[2],img_size[0],img_size[1]
    padding_left,padding_right = cal_padding(width,kernel_w,step,0)
    padding_top,padding_bottom = cal_padding(height,kernel_h,step,0)
    padding_front,padding_back = cal_padding(depth,kernel_c,step,0)
    print(padding_left,padding_right,W,kernel_w,step)

    recon_image_array =   torch.zeros((num_image,C,H,W),dtype = torch.int8).to(data.device)
    interval = int(data.shape[0]/num_image)
    for i in range(num_image):
        recon_image = patch_combine(data[interval*(i):interval*(i+1)],1, C, H, W, kernel_c , kernel_h, kernel_w ,step )

        recon_image_array[i] = recon_image
    del recon_image
    gc.collect()
    # return recon_image_array
    return recon_image_array[:,padding_front:depth+padding_front,padding_top:height+padding_top,padding_left:width+padding_left]

# return reg after padding
def load_reference_torch(idx):

  print(idx)
  seg_path = files_seg_seg_test[idx]
  # msk_path = file_list_msk[idx]
  # img = sitk.ReadImage(img_path, sitk.sitkFloat32)
  seg = sitk.ReadImage(seg_path,sitk.sitkInt8)
  # seg = sitk.GetArrayFromImage(seg)
  seg = torch.from_numpy(sitk.GetArrayFromImage(seg))
  # seg = prepadding(seg.unsqueeze(0),kernel_c,kernel_h,kernel_w,step)

  
  #display
  seg_display = sitk.GetImageFromArray(seg.squeeze(0).cpu().numpy().astype(np.uint8))
  seg_display = sitk.LabelToRGB(seg_display)
  display_image(seg_display,figure_name = 'reference'+str(idx))
  del seg_display

  return seg.squeeze(0)


def confusion_matrix_torch(img,seg):

  # for class 0, TP
  TP = torch.zeros(num_classes)
  TN = torch.zeros(num_classes)
  FN = torch.zeros(num_classes)
  FP = torch.zeros(num_classes)

  for i in range(num_classes):
    img  =  img.to(device).char()
    seg = seg.to(device)
    tp =torch.sum(img[img == seg] == i)
    TP[i] = tp
    # print(tp)
    #for class 0, FN
    fn = torch.sum(img[seg == i] !=i)
    # print(fn)
    FN[i] = fn
    #for class 0, FP
    fp=torch.sum(seg[img == i] !=i)
    FP[i] = fp
    # print(fp)
    #for class 0, TN
    tn =torch.sum(img[seg != i] !=i)
    TN[i] = tn
    # print(tn)
    
  return TP,TN,FN,FP


def dice_coefficient(TP,TN,FN,FP):
    #   2 * TP / (FN + (2 * TP) + FP)
  smooth = 1.
  return 2 * TP / (FN + (2 * TP) + FP + smooth)

def iou(TP,TN,FN,FP):
  smooth = 1
  return TP/(TP+FP+FN+smooth)



In [None]:

def zero_mean_unit_var(image, mask):
    """Normalizes an image to zero mean and unit variance."""

    img_array = sitk.GetArrayFromImage(image)
    img_array = img_array.astype(np.float32)

    msk_array = sitk.GetArrayFromImage(mask)

    mean = np.mean(img_array[msk_array>0])
    std = np.std(img_array[msk_array>0])

    if std > 0:
        img_array = (img_array - mean) / std
        img_array[msk_array==0] = 0

    image_normalised = sitk.GetImageFromArray(img_array)
    image_normalised.CopyInformation(image)

    return image_normalised


def resample_image(image, out_spacing=(1.0, 1.0, 1.0), out_size=None, is_label=False, pad_value=0):
    """Resamples an image to given element spacing and output size."""

    original_spacing = np.array(image.GetSpacing())
    original_size = np.array(image.GetSize())

    if out_size is None:
        out_size = np.round(np.array(original_size * original_spacing / np.array(out_spacing))).astype(int)
    else:
        out_size = np.array(out_size)

    original_direction = np.array(image.GetDirection()).reshape(len(original_spacing),-1)
    original_center = (np.array(original_size, dtype=float) - 1.0) / 2.0 * original_spacing
    out_center = (np.array(out_size, dtype=float) - 1.0) / 2.0 * np.array(out_spacing)

    original_center = np.matmul(original_direction, original_center)
    out_center = np.matmul(original_direction, out_center)
    out_origin = np.array(image.GetOrigin()) + (original_center - out_center)

    resample = sitk.ResampleImageFilter()
    resample.SetOutputSpacing(out_spacing)
    resample.SetSize(out_size.tolist())
    resample.SetOutputDirection(image.GetDirection())
    resample.SetOutputOrigin(out_origin.tolist())
    resample.SetTransform(sitk.Transform())
    resample.SetDefaultPixelValue(pad_value)

    if is_label:
        resample.SetInterpolator(sitk.sitkNearestNeighbor)
    else:
        resample.SetInterpolator(sitk.sitkBSpline)
    import os
    import sys
    import psutil

    del original_size

    gc.collect()
    return resample.Execute(image)

def prepadding(data,kernel_c,kernel_h,kernel_w,step):
    #batch_size, depth(n_channels), n_rows, n_cols
    import torch
    import torch.nn
    batch_size, depth, height, width = data.shape[0],data.shape[1],data.shape[2],data.shape[3]
    data = data.unsqueeze(1)

    # calculate the padding for two edges
    def cal_padding(n,k,stride,p):
        #the two boundry: left  right
        total_padding = stride - (n - (k+(int((n-k+2*p)/stride))*stride))
        if total_padding == stride:
          total_padding = 0
        left = int(total_padding/2)
        right = total_padding - left
        return left,right

    padding_left,padding_right = cal_padding(n=width,k=kernel_w,stride=step,p=0)
    padding_top,padding_bottom = cal_padding(n=height,k=kernel_h,stride=step,p=0)
    padding_front,padding_back = cal_padding(n=depth,k=kernel_c,stride=step,p=0)

    output = torch.zeros((batch_size,depth+padding_front+padding_back,height+padding_top+padding_bottom,width+padding_left+padding_right),dtype = data.dtype).to(device)
    output[:,padding_front:depth+padding_front,padding_top:height+padding_top,padding_left:width+padding_left]= data
    print(output.shape)
    del data
    gc.collect()
    return output


def get_patches(data,kernel_c=32, kernel_h=32, kernel_w=32,step=32):
    data = prepadding(data,kernel_c,kernel_h,kernel_w,step)
    gc.collect()
    batch_size, n_channels, n_rows, n_cols = data.shape[0],data.shape[1],data.shape[2],data.shape[3]
    data = data.unfold(1, kernel_c, step).unfold(2, kernel_h, step).unfold(3, kernel_w, step)
    data = data.permute(1, 2, 3, 0, 4, 5, 6).reshape(-1, kernel_c, kernel_h, kernel_w)
    
    return data


def get_patches_stage1(data,kernel_c=32, kernel_h=32, kernel_w=32,step=32):
    
    data = prepadding(data,kernel_c,kernel_h,kernel_w,step)
    
    gc.collect()
    batch_size, n_channels, n_rows, n_cols = data.shape[0],data.shape[1],data.shape[2],data.shape[3]
    data = data.unfold(1, kernel_c, step).unfold(2, kernel_h, step).unfold(3, kernel_w, step)
    data = data.permute(1, 2, 3, 0, 4, 5, 6).reshape(-1, kernel_c, kernel_h, kernel_w)
    data = make_one_hot(data.unsqueeze(1),[x for x in range(num_classes)]) #num, channel(3), D,H,W
    print('stage1 seg shape',data.shape)
    return data


In [None]:
import torch
import gc
class ImageSegmentationDataset_stage2(Dataset):
    """Dataset for image segmentation."""

    def __init__(self, stage1_prediction, file_list_img, file_list_seg, img_spacing, img_size,
                 patch=True, kernel_c=64, kernel_h=64, kernel_w=64, step=32, TEST=False):
        self.imgs = []
        self.segs = []
        self.img_names = []
        self.seg_names = []
        # count the number of class items in patches for weights
        self.counts = torch.zeros(9)
        
        with torch.no_grad():
            for idx, _ in enumerate(tqdm(range(len(file_list_img)), desc='Loading Data')):
                valid_index = []
                seg_path = file_list_seg[idx]

                seg = sitk.ReadImage(seg_path, sitk.sitkInt8)

                seg = torch.from_numpy(sitk.GetArrayFromImage(seg)).unsqueeze(0).to(device)

                # patches
                print('seg.dtype',seg.dtype)
                print('seg.dtype',seg.shape)
                seg = get_patches(seg, kernel_c, kernel_h, kernel_w, step)

                for index_patches in range(seg.shape[0]):

                    sum_0 = torch.sum(seg[index_patches] == 0)
                    sum_numel = seg[index_patches].numel()
                    background = sum_0 / sum_numel
                    print(sum_0, sum_numel, background)


                    if background >= threshold and not TEST:
                      if torch.rand(1)<roulette:
                        pass
                      else:
                        continue

                    uni = [x for x in range(num_classes)]
                    for i in uni:
                      sum_item = torch.sum(seg[index_patches] == i)
                      self.counts[i] += sum_item.cpu()

                  

                    valid_index.append(index_patches)
                    sample =  seg[index_patches].unsqueeze(0).cpu()
                    self.segs.append(sample)

                    self.seg_names.append(os.path.basename(seg_path))
                    import gc
                    del sample
                    gc.collect()
                del seg;
                import gc
                gc.collect()


                ###img
                img_path = file_list_img[idx]
                img = sitk.ReadImage(img_path, sitk.sitkUInt8)
                   
            
                # from image  numpy to torch
                img = torch.from_numpy(sitk.GetArrayFromImage(img)).unsqueeze(0).to(device)

                # patches
                torch.cuda.empty_cache()
                img = get_patches(img, kernel_c, kernel_h, kernel_w, step).unsqueeze(1) 

                stage1_seg = get_patches(stage1_prediction[idx].to(device),kernel_c, kernel_h, kernel_w, step) 

                for index_patches in range(img.shape[0]):
                    if index_patches  not in valid_index:
                        continue


                    sample = img[index_patches]
                    stage1_sample = stage1_seg[index_patches]

                    sample = torch.cat((sample,stage1_sample.unsqueeze(0)),0).byte()
                    print('sample.shape:',sample.shape)
                    del stage1_sample
                    
                    self.imgs.append(sample.cpu())
                 
                    self.img_names.append(os.path.basename(img_path))
                    import gc
                    del sample
                    gc.collect()
                del img;
                del stage1_seg
                gc.collect()
            
            
    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, item):
        img = self.imgs[item]
        seg = self.segs[item]
        return {'img': img, 'seg': seg}

    def get_sample(self, item):
        img = self.imgs[item]
        seg = self.segs[item]
        return {'img': img, 'seg': seg}

    def get_img_name(self, item):
        return self.img_names[item]

    def get_seg_name(self, item):
        return self.seg_names[item]
torch.cuda.empty_cache()

##metrics

In [None]:
dataset_test = ImageSegmentationDataset_separate(files_seg_img_test, files_seg_seg_test, img_spacing, img_size, True,kernel_c,kernel_h,kernel_w,step, True)
dataloader_testall = []
for index in range(len(files_seg_img_test)):
  dataloader_test = torch.utils.data.DataLoader(torch.utils.data.Subset(dataset_test,list(range(int(num_patch_per_block*index),int(num_patch_per_block*(index+1)), 1))), batch_size=1, shuffle=False, num_workers = 2, prefetch_factor=16,pin_memory=True)
  dataloader_testall.append(dataloader_test)  

In [None]:
uni_labels = [x for x in range(num_classes)]
vs_scores = torch.zeros(len(files_seg_img_test),len(uni_labels))
TP = torch.zeros((vs_scores.shape[0],num_classes))
TN = torch.zeros((vs_scores.shape[0],num_classes))
FN = torch.zeros((vs_scores.shape[0],num_classes))
FP = torch.zeros((vs_scores.shape[0],num_classes))


for index in range(len(files_seg_img_test)):

  num_image = 1
  model_dir = os.path.join(out_dir, 'model')
  uni_labels = [x for x in range(num_classes)]
  print(num_patch_per_block)
  print(step)
  #start testing

  patch_store = torch.zeros(num_image*num_patch_per_block,kernel_c,kernel_h,kernel_w,dtype = torch.int8).to(device)


  pred_dir = os.path.join(out_dir, 'pred')
  if not os.path.exists(pred_dir):
      os.makedirs(pred_dir)

  model = UNet3d()

  PATH = save_path = os.path.join(model_dir, 'model.pt')


  if device  == torch.device("cpu"):
    checkpoint = torch.load(PATH,map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['model_state_dict'])
  else:
    checkpoint = torch.load(PATH)
    model.load_state_dict(checkpoint['model_state_dict'])


  model.to(device)
  model.eval()
      
  print('START TESTING...')

  loss_test = 0
  sum_pts = 0
  idx_test = 0

  with torch.no_grad():
      for index_patch,data_sample in enumerate(dataloader_testall[index]):
        with torch.cuda.amp.autocast(enabled=use_amp):
          img, seg = data_sample['img'].to(device).half(), data_sample['seg'].to(device)
          img_global = data_sample['img_global'].to(device).half()
          prd = model(img,img_global)
          prd = torch.argmax(prd[0], dim=1)

          # print(index_patch)

          patch_store[index_patch] = prd
          
  del prd;gc.collect()
  print('\nFinished TESTING.')
  print(C,H,W)

  #from patch to whole data
  recon_image_original = patch_combine_one_by_one(patch_store, num_image, C_,H_,W_, kernel_c, kernel_h, kernel_w,step).cpu()


  #evaluation
  
  #input: recon_image_original   
  for idx in range(num_image):

    # print('i')
    tp,tn,fn,fp = 0,0,0,0 
    
    #original label 
    seg_torch = load_reference_torch(index)
    seg_shape = seg_torch.shape
    #shape for one image
    C,H,W = seg_shape
    # print('index {},c {},h {},w {}'.format(index,C,H,W))
    seg_recon = recon_image_original[idx]

    # display
    prediction = sitk.GetImageFromArray(seg_recon.numpy().astype(np.uint8))
    prediction2 = sitk.LabelToRGB(prediction)
    display_image(prediction2,figure_name ='prediction'+str(index))  
    
    #overlay
    green = [0,255,0] ;gold = [255,215,0];colo =[238, 64, 0]
    image_test = load_original_image_torch(index)
    overlay_prd = sitk.LabelOverlay(image_test,prediction,colormap=gold+green+colo,opacity=0.55)
    display_image(overlay_prd,figure_name ='overlay_prediction'+str(index))  
    del prediction
    del prediction2
    del image_test
    del overlay_prd
    gc.collect()

    #prediction
    img = seg_recon.squeeze().flatten()

    #label
    seg = seg_torch.squeeze().flatten()

    tp,tn,fn,fp = confusion_matrix_torch(img,seg)
    vs = 1-abs(fn-fp)/(2*tp+fp+fn+1e-9)
    #make one hot
    img = make_one_hot(seg_recon.unsqueeze(0).unsqueeze(0),[x for x in range(num_classes)])
    seg = make_one_hot(seg_torch.unsqueeze(0).unsqueeze(0),[x for x in range(num_classes)])

    TP[index] = (tp)
    TN[index] = (tn)
    FN[index] = (fn)
    FP[index] = (fp)
    del seg_torch
    del img;del seg;del seg_recon;gc.collect()

  del patch_store
  del recon_image_original
  gc.collect()



iou_score = iou(TP,TN,FN,FP)
dice_scores = dice_coefficient(TP,TN,FN,FP)
print('dice_scores:',dice_scores)
print('IoU:', iou_score)

dice_std = torch.std(dice_scores, dim=0)
dice_mean = torch.mean(dice_scores,dim = 0)
IoU_std = torch.std(iou_score,dim=0)
IoU_mean = torch.mean(iou_score,dim=0)


In [None]:
print('dice_std',dice_std)
print('dice_mean',dice_mean)
print('IoU_std',IoU_std)
print('IoU_mean',IoU_mean)

<!-- torch.Size([1, 1, 160, 160, 160]) torch.Size([1, 1, 160, 160, 160])
dice_scores: tensor([[0.9998, 0.7049, 0.6927],
        [0.9998, 0.8169, 0.5936],
        [0.9991, 0.9381, 0.7051],
        [0.9988, 0.9097, 0.8005],
        [0.9998, 0.9486, 0.5345],
        [0.9998, 0.9640, 0.8184]])
IoU: tensor([[0.9995, 0.5442, 0.5298],
        [0.9997, 0.6905, 0.4220],
        [0.9983, 0.8835, 0.5445],
        [0.9976, 0.8343, 0.6674],
        [0.9997, 0.9021, 0.3647],
        [0.9996, 0.9306, 0.6926]])
dice_std tensor([0.0004, 0.1007, 0.1117])
dice_mean tensor([0.9995, 0.8804, 0.6908])
IoU_std tensor([0.0009, 0.1503, 0.1298])
IoU_mean tensor([0.9991, 0.7975, 0.5368]) -->

<!-- torch.Size([1, 1, 160, 160, 160]) torch.Size([1, 1, 160, 160, 160])
dice_scores: tensor([[0.9998, 0.7049, 0.6927],
        [0.9998, 0.8169, 0.5936],
        [0.9991, 0.9381, 0.7051],
        [0.9988, 0.9097, 0.8005],
        [0.9998, 0.9486, 0.5345],
        [0.9998, 0.9640, 0.8184]])
IoU: tensor([[0.9995, 0.5442, 0.5298],
        [0.9997, 0.6905, 0.4220],
        [0.9983, 0.8835, 0.5445],
        [0.9976, 0.8343, 0.6674],
        [0.9997, 0.9021, 0.3647],
        [0.9996, 0.9306, 0.6926]])
dice_std tensor([0.0004, 0.1007, 0.1117])
dice_mean tensor([0.9995, 0.8804, 0.6908])
IoU_std tensor([0.0009, 0.1503, 0.1298])
IoU_mean tensor([0.9991, 0.7975, 0.5368]) -->

<!-- torch.Size([1, 1, 160, 160, 160]) torch.Size([1, 1, 160, 160, 160])
dice_scores: tensor([[0.9998, 0.7049, 0.6927],
        [0.9998, 0.8169, 0.5936],
        [0.9991, 0.9381, 0.7051],
        [0.9988, 0.9097, 0.8005],
        [0.9998, 0.9486, 0.5345],
        [0.9998, 0.9640, 0.8184]])
IoU: tensor([[0.9995, 0.5442, 0.5298],
        [0.9997, 0.6905, 0.4220],
        [0.9983, 0.8835, 0.5445],
        [0.9976, 0.8343, 0.6674],
        [0.9997, 0.9021, 0.3647],
        [0.9996, 0.9306, 0.6926]])
dice_std tensor([0.0004, 0.1007, 0.1117])
dice_mean tensor([0.9995, 0.8804, 0.6908])
IoU_std tensor([0.0009, 0.1503, 0.1298])
IoU_mean tensor([0.9991, 0.7975, 0.5368]) -->