In [None]:
import torch
import torch.nn as nn
# from apex import amp
from torch.cuda import amp
import numpy as np
from tqdm import tqdm
import torch
from torch.nn import functional as F

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as transforms
from torch.autograd import Variable
from PIL import Image
from matplotlib.pyplot import imshow
from scipy import ndimage as ndi
from skimage.transform import resize
import pandas as pd
import os
import nibabel as nib


import matplotlib.pyplot as plt
%matplotlib inline

%gui qt

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
!nvidia-smi

Sat May  8 06:31:37 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   69C    P8    11W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# Import Data

In [None]:
dataDir = 'downloads/OpenDataset/'
# dataDir = 'drive/My Drive/B.Sc proj/OpenDataset/'
metadata = pd.read_csv(dataDir + '201014_M&Ms_Dataset_Information_-_opendataset.csv')
metadata

Unnamed: 0,External code,VendorName,Vendor,Centre,ED,ES
0,A0S9V9,Siemens,A,1,0,9
1,A1D0Q7,Philips,B,2,0,9
2,A1D9Z7,Siemens,A,1,22,11
3,A1E9Q1,Siemens,A,1,0,9
4,A1K2P5,Canon,D,5,33,11
...,...,...,...,...,...,...
340,T2Z1Z9,Canon,D,5,29,9
341,T9U9W2,Siemens,A,1,0,10
342,V4W8Z5,GE,C,4,19,9
343,W5Z4Z8,Philips,B,2,29,11


# Define Dataset

In [None]:
def normalize_01(inp: np.ndarray):
    inp_out = (inp - np.min(inp)) / np.ptp(inp)
    return inp_out

def change_sizes(img, diff_seq, is_mask):
    img = img[:,:, diff_seq//2:img.shape[2]-(diff_seq-diff_seq//2)]
    res = []
    for i in range(img.shape[2]):
        x = resize(img[:,:, i], (200,200))
        res.append(x)
    return np.asarray(res)
    
def change_size_and_pad(img, diff_seq):
    res = []
    temp = np.zeros((200, 200))
    for i in range(diff_seq//2):
        res.append(temp)
    for i in range(img.shape[2]):
        x = resize(img[:,:, i], (200,200))
        res.append(x)
    for i in range(diff_seq - diff_seq//2):
        res.append(temp)
    return np.asarray(res)


class SegDataset(Dataset):
    def __init__(self, all_dirs, metadata, addr):
        self.all_dirs = all_dirs 
        self.metadata = metadata
        self.addr = addr
        self.num_seq = 8

    def __len__(self):
        return 2*len(self.all_dirs)
    
    def __getitem__(self, idx):
        item = self.all_dirs[idx//2]
        path = self.addr + item

        mask, img, id = None, None, None

        ### lazy load data
        for root, dir, files in os.walk(path):
            for f in files:
                img_arr = nib.load(path + '/' + f).get_fdata()
#                 print(f"file name is {f}")
                if f[-8] == 't':
                    id = f.split('_')[0]
                    mask = img_arr
                else:
                    img = img_arr
        
        # preparing data 
#         print(f"df: {metadata[metadata['External code'] == id]['ES']}")
        if idx %2 == 0:
            t = metadata[metadata['External code'] == id]['ES'].iloc[0]
        else:
            t = metadata[metadata['External code'] == id]['ED'].iloc[0]
        
        img = img[:, :, :, t]
        gt = mask[:, :, :, t]


#         print('maximum number of mask', np.max(gt) , '\tmax of img', np.max(img))

        diff_seq = img.shape[2] - self.num_seq
        if diff_seq > 0: 
            img = change_sizes(img, diff_seq, False)
            gt = change_sizes(gt, diff_seq, True)
        else:
#             print(f'BEFORE shape of img is {img.shape} and shape of mask is {gt.shape}')
            img = change_size_and_pad(img, self.num_seq-img.shape[2])
            gt = change_size_and_pad(gt, self.num_seq-gt.shape[2])
        
        # print(f'AFTER shape of img is {img.shape} and shape of mask is {gt.shape}')

        img = normalize_01(img)
        
        
        # convert to tensors
        img = torch.FloatTensor([img])
        mask = torch.LongTensor(gt)

        return img, mask

# Load Dataset, DataLoader

In [None]:
batch_size = 4
num_epochs = 60
learning_rate = 0.001

In [None]:
train_dir = dataDir + '/Training/Labeled/'
train_all = os.listdir(train_dir)[:8]
train_dataset = SegDataset(train_all, metadata, train_dir)


val_dir = dataDir + '/Validation/'
val_all = os.listdir(val_dir)[:8]
val_dataset = SegDataset(val_all, metadata, val_dir)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

# Define Dice

In [None]:
import torch.nn.functional as F

def dice_score(inputs, targets, class_num, dims=[1,2,3], smooth=0.5):
    inputs = inputs.detach().clone()
    targets = targets.detach().clone()    
    
    inputs = torch.argmax(inputs, dim=1)
#     print(f"check dimensions {inputs.shape}, {targets.shape}")
    
    inputs[inputs == class_num] = -1
    inputs[inputs != -1] = 0
    inputs[inputs != 0] = 1

    targets[targets == class_num] = -1
    targets[targets != -1] = 0
    targets[targets != 0] = 1   

#     print(f"inputs:{inputs}\ntargets:{targets}")

    intersection = 2 * (inputs * targets)
    total = (inputs + targets)
#     print(f"intersection: {intersection}\ntotal: {total}")
    
    intersection = intersection.sum(dim=dims)
    total = total.sum(dim=dims)
#     print(f"intersection: {intersection}\ntotal: {total}")

    dice = torch.add(intersection, smooth)/torch.add(total, smooth)
    
#     print('shape of dice', dice.shape)
    dice = torch.mean(dice).item()
#     print(dice)
    return dice  

In [None]:
class CostumLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(CostumLoss, self).__init__()

    def forward(self, outs, targets, smooth=0.5):
        dice_1 = -2* np.log(dice_score(outs, targets, 1))
        dice_2 = -2* np.log(dice_score(outs, targets, 2))
        dice_3 = -2* np.log(dice_score(outs, targets, 3))

#         cross_entropy = F.cross_entropy(outs, targets, reduction='sum')
#         print(f"cross_entropy sum {cross_entropy}")

        cross_entropy = F.cross_entropy(outs, targets, reduction='mean')
#         print(f"cross_entropy mean {cross_entropy}")

        res = cross_entropy + dice_1 + dice_2 + dice_3 
        return res 

# Define 3D UNet

In [None]:
class UNet3D(nn.Module):
    def __init__(self, input_dim, output_dim, initial_filters=16, depth=4, dropout=0):
        # input_dim: number of input channles. 1 in our case.
        # output dim: number of output channels. 3 in our case. (since we have 3 labels)
        # initial_filter: kernel_size of the first conv. usually 16 or 32.
        # depth: depth of the U in unet. `image_slices` should be divisible by (2 ** depth)
        # dropout: the possibility of zeroing a node out.
        
        super(UNet3D, self).__init__()
        
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.initial_filters = initial_filters
        self.depth = depth
        self.dropout = dropout
        activation = nn.ReLU(inplace=True)
        
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.pools = nn.ModuleList()
        self.trans = nn.ModuleList()
        self.dropouts = nn.ModuleList()
        
        # Down
        for i in range(depth):
            in_dim = self.input_dim
            if i > 0:
                in_dim = self.initial_filters * 2 ** (i - 1)
            out_dim = self.initial_filters * 2 ** i
            self.downs.append(self.double_conv_block(in_dim, out_dim, activation))
            self.pools.append(self.max_pool())
            
        
        # Bridge
        self.bridge = self.double_conv_block(
            self.initial_filters * 2 ** (depth - 1),
            self.initial_filters * 2 ** depth,
            activation
        )
        
        # Dropout
        self.dropouts.append(self.dropout_layer())
        
        # Up
        for i in range(depth):
            trans_in_out_dim = self.initial_filters * 2 ** (depth - i)
            self.trans.append(self.conv_transpose_block(trans_in_out_dim, trans_in_out_dim, activation))
            
            up_in_dim = self.initial_filters * (2 ** (depth - i) + 2 ** (depth - i - 1))
            up_out_dim = self.initial_filters * 2 ** (depth - i - 1)
            self.ups.append(self.double_conv_block(up_in_dim, up_out_dim, activation))
            
        # Dropout
        self.dropouts.append(self.dropout_layer())
        
        # Output
        self.out = self.prediction_mask(initial_filters, self.output_dim)
        
    def single_conv_block(self, input_dim, output_dim, activation):
        return nn.Sequential(
            nn.Conv3d(input_dim, output_dim, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm3d(output_dim),
            activation
        )

    def double_conv_block(self, input_dim, output_dim, activation):
        return nn.Sequential(
            self.single_conv_block(input_dim, output_dim, activation),
            #nn.Dropout(p=self.dropout),
            self.single_conv_block(output_dim, output_dim, activation)
        )

    def conv_transpose_block(self, input_dim, output_dim, activation):
        return nn.Sequential(
            nn.ConvTranspose3d(input_dim, output_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm3d(output_dim),
            activation,
            #nn.Dropout(p=self.dropout)
        )

    def prediction_mask(self, input_dim, output_dim):
        # the activation is considered to in the loss.
        return nn.Conv3d(input_dim, output_dim, kernel_size=1, stride=1, padding=0) 

    def max_pool(self):
        return nn.MaxPool3d(kernel_size=2, stride=2, padding=0) 
    
    def dropout_layer(self):
        return nn.Dropout(p=self.dropout)
        
    def forward(self, x):
        downs = []
        ups = []
        pools = []
        trans = []
        concats = []
        dropouts = []
        
        # Down
        for i in range(self.depth):
            inp = x
            if i > 0:
                inp = pools[-1]
            downs.append(self.downs[i](inp))
            pools.append(self.pools[i](downs[-1]))
        
        # Bridge
        bridge = self.bridge(pools[-1])
        
        # Dropout
        dropouts.append(self.dropouts[0](bridge))
        
        # Up
        for i in range(self.depth):
            inp = dropouts[-1]
            if i > 0:
                inp = ups[-1]
            trans.append(self.trans[i](inp))
            concats.append(torch.cat([trans[-1], downs[self.depth - i - 1]], dim=1))
            ups.append(self.ups[i](concats[-1]))
        
        # Dropout
        dropouts.append(self.dropouts[1](ups[-1]))
        
        # Output
        out = self.out(dropouts[-1])
        return out


class SegmentationModel(UNet3D):
    def __init__(self, input_dim, output_dim, gpu, initial_filters=16, depth=4, dropout=0, name='Segmentation'):
        # input_dim: number of input channles. 1 in our case.
        # output dim: number of output channels. 3 in our case. (since we have 3 labels)
        # initial_filter: kernel_size of the first conv. usually 16 or 32.
        # depth: depth of the U in unet. `image_slices` should be divisible by (2 ** depth)
        # dropout: the possibility of zeroing a node out.
        
        super().__init__(input_dim, output_dim, initial_filters=16, depth=4, dropout=0)
        self.name = name
        self.gpu = gpu
        self.cuda(self.gpu)
    
    def train_epoch(self, loader, optimizer):
        class_weights = torch.Tensor(loader.dataset.get_class_weights()).cuda(self.gpu)
        
        running_loss = 0.
        running_scaled_loss = 0.
        running_dice = 0.
        running_pivot_dice = 0.

        for X, Y, _, P, _, VW, _ in tqdm(loader): # iterate over batches
            # shape of X: (batch_size, 1, image_slices, width, height)
            # shape of Y: (batch_size, 1, image_slices, width, height)
            # shape of index: (batch_size, )

            # send variables to `device`. they both should be float.
            X, Y = X.cuda(self.gpu).float(), Y.cuda(self.gpu).float()

            optimizer.zero_grad() # zero the gradient
            
            with amp.autocast():
                outputs = self(X) # get output of the model for input X

                # calculate the training loss
                # the UNLABELED voxels will be ignored in the process of calculating the loss.
                loss = soft_binray_cross_entropy( 
                    outputs,
                    Y, # target
                    class_weights=class_weights,
                    voxel_weights=VW.cuda(self.gpu)
                )

                running_loss += loss.item()
            
            # backward
            loss.backward()
#             with amp.scale_loss(loss, optimizer) as scaled_loss:
#                 running_scaled_loss += scaled_loss.item()
#                 scaled_loss.backward()
            optimizer.step() # update the parameters

            # calculate dice coefficient on this batch and add it to the running training dice.
            # the UNLABELED voxels will be ignored in the process of calculating the dice.
            running_dice += dice_coef(outputs, Y).item()

            # running pivot dice
            for i, p in enumerate(P): # iterate over samples in this batch and sum their pivot dices.
                running_pivot_dice += dice_coef(
                    outputs[i:i+1, :, p, :, :], # only over pivot slices
                    Y[i:i+1, :, p, :, :] # only over pivot slices
                ).item() # normalize over batch size
        
        return {
            'loss': running_loss / len(loader.dataset),
            'scaled_loss': running_scaled_loss / len(loader.dataset),
            'dice': running_dice / len(loader.dataset),
            'pivot_dice': running_pivot_dice / len(loader.dataset)
        }
    
    def val_epoch(self, loader):
        self.eval()
        class_weights = torch.Tensor(loader.dataset.get_class_weights()).cuda(self.gpu)
        
        running_loss = 0
        running_pivot_dice = 0

        with torch.no_grad(): # no gradients are required in the test time
            for X, Y, _, P, _, _, _ in loader: # iterate over test batches

                # send the batch to gpu
                X, Y = X.cuda(self.gpu).float(), Y.cuda(self.gpu).float()

                # forward pass
                outputs = self(X)
                
                # running pivot dice
                for i, p in enumerate(P): # iterate over samples in this batch and sum their pivot dices.
                    running_loss += soft_binray_cross_entropy(
                        outputs[i:i+1, :, p, :, :], # only over pivot slices
                        Y[i:i+1, :, p, :, :], # only over pivot slices
                        class_weights=class_weights
                    ).item() # normalize over batch size
                    
                    running_pivot_dice += dice_coef(
                        outputs[i:i+1, :, p, :, :], # only over pivot slices
                        Y[i:i+1, :, p, :, :] # only over pivot slices
                    ).item() # normalize over batch size

        self.train()
        return {
            'loss': running_loss / len(loader.dataset),
            'pivot_dice': running_pivot_dice / len(loader.dataset)
        }
    
    def generate_labels(self, loader, iteration, step_size):
        assert step_size == 1
        
        self.eval()
        
        labels = {}
        with torch.no_grad(): # no gradients are required in the test time
            for X, _, _, P, R, _, I in loader: # iterate over test batches
                assert len(X) == 1 # batch size should be 1 in this implementation
                
                p = P.item()
                r = R.item()
                i = I.item()
                
                new_slices = [s for s in range(p - iteration, p + iteration + 1) if s >= 0 and s < r and s != p]

                # send the batch to gpu
                X = X.cuda(self.gpu).float()

                # forward pass
                outputs = torch.sigmoid(self(X)).cpu().detach().numpy().squeeze()
                labels[i] = {s:outputs[s] for s in new_slices}

        self.train()
        return labels

# Define Trainer

In [None]:
class Trainer:
    def __init__(self,
                 model: torch.nn.Module,
                 device: torch.device,
                 criterion: torch.nn.Module,
                 optimizer: torch.optim.Optimizer,
                 training_DataLoader: torch.utils.data.Dataset,
                 validation_DataLoader: torch.utils.data.Dataset = None,
                 lr_scheduler: torch.optim.lr_scheduler = None,
                 epochs: int = 100,
                 epoch: int = 0,
                 notebook: bool = False
                 ):

        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.training_DataLoader = training_DataLoader
        self.validation_DataLoader = validation_DataLoader
        self.device = device
        self.epochs = epochs
        self.epoch = epoch
        self.notebook = notebook
        

        self.training_loss = []
        self.validation_loss = []
        self.learning_rate = []
        
        self.validation_dices_c0 = []
        self.validation_dices_c1 = []
        self.validation_dices_c2 = []
        self.validation_dices_c3 = []

        self.train_dices_c0 = []
        self.train_dices_c1 = []
        self.train_dices_c2 = []
        self.train_dices_c3 = []

        cols = ['epoch', 't_dice_0', 't_dice_1', 't_dice_2', 't_dice_3', 'v_dice_0', 'v_dice_1', 'v_dice_2', 'v_dice_3', 't_loss', 'v_loss']
        self.log = pd.DataFrame(columns=cols)

        self.dice_0_per_slice_log = pd.DataFrame(columns=[i for i in range(12)])
        self.dice_1_per_slice_log = pd.DataFrame(columns=[i for i in range(12)])
        self.dice_2_per_slice_log = pd.DataFrame(columns=[i for i in range(12)])
        self.dice_3_per_slice_log = pd.DataFrame(columns=[i for i in range(12)])

    
    def dice_per_slice(self, inputs, targets, class_num):
        print('out shapeee', inputs.shape, 'target shape', targets.shape)
        d0, d1, d2, d3 = [], [], [], []
        for i in range(12):
            inp = inputs[:, :, i, :, :]
            gt = targets[:, i, :, :]
            d0.append(dice_score(inp, gt, 0, dims=[1,2]))
            d1.append(dice_score(inp, gt, 1, dims=[1,2]))
            d2.append(dice_score(inp, gt, 2, dims=[1,2]))
            d3.append(dice_score(inp, gt, 3, dims=[1,2]))
        l = len(self.dice_0_per_slice_log)
        self.dice_0_per_slice_log.loc[l] = d0
        self.dice_1_per_slice_log.loc[l] = d1
        self.dice_2_per_slice_log.loc[l] = d2
        self.dice_3_per_slice_log.loc[l] = d3


    def run_trainer(self):

        if self.notebook:
            from tqdm.notebook import tqdm, trange
        else:
            from tqdm import tqdm, trange

        progressbar = trange(self.epochs, desc='Progress')
        for i in progressbar:
            """Epoch counter"""
            self.epoch += 1  # epoch counter

            """Training block"""
            self._train()

            """Validation block"""
            if self.validation_DataLoader is not None:
                self._validate()

            """Learning rate scheduler block"""
            if self.lr_scheduler is not None:
                if self.validation_DataLoader is not None and self.lr_scheduler.__class__.__name__ == 'ReduceLROnPlateau':
                    self.lr_scheduler.batch(self.validation_loss[i])  # learning rate scheduler step with validation loss
                else:
                    self.lr_scheduler.batch()  # learning rate scheduler step
            
            print(f"epoch number {self.epoch} finished with validation dices as follows:")
            print('\t calss 0:', self.validation_dices_c0[i], end='')
            print('\t calss 1:', self.validation_dices_c1[i], end='')
            print('\t calss 2:', self.validation_dices_c2[i], end='')
            print('\t calss 3:', self.validation_dices_c3[i], end='')

            log_dict = {'epoch' : self.epoch,
                        't_dice_0': self.train_dices_c0[i], 't_dice_1': self.train_dices_c1[i],
                        't_dice_2': self.train_dices_c2[i], 't_dice_3': self.train_dices_c3[i],
                        'v_dice_0':self.validation_dices_c0[i], 'v_dice_1':self.validation_dices_c1[i],
                        'v_dice_2':self.validation_dices_c2[i], 'v_dice_3':self.validation_dices_c3[i],
                        't_loss': self.training_loss[i], 'v_loss':self.validation_loss }
            
            self.log.append(log_dict, ignore_index=True)

            if i % 10 == 0 and i > 0:
                torch.save(self.model, dataDir + 'unet_3d.pt')
                self.log.to_csv('log.csv')
                self.dice_0_per_slice_log.to_csv('dice_0_per_slice_log.csv')
                self.dice_1_per_slice_log.to_csv('dice_1_per_slice_log.csv')
                self.dice_2_per_slice_log.to_csv('dice_2_per_slice_log.csv')
                self.dice_3_per_slice_log.to_csv('dice_3_per_slice_log.csv')
        
        self.log.to_csv('log.csv')
        return self.training_loss, self.validation_loss, self.validation_dices_c0, self.validation_dices_c1, self.validation_dices_c2, self.validation_dices_c3

    def _train(self):

        if self.notebook:
            from tqdm.notebook import tqdm, trange
        else:
            from tqdm import tqdm, trange

        self.model.train()  # train mode
        train_losses = []  # accumulate the losses here

        t_dices_c0 = []
        t_dices_c1 = []
        t_dices_c2 = []
        t_dices_c3 = []
        
        batch_iter = tqdm(enumerate(self.training_DataLoader), 'Training', total=len(self.training_DataLoader),
                          leave=False)

        for i, (x, y) in batch_iter:
            # print('shape of input:', x.shape)
            # print('shape of output:', y.shape)
            input, target = x.to(0), y.to(0)  # send to device (GPU or CPU)
            self.optimizer.zero_grad()  # zerograd the parameters
            out = self.model(input)  # one forward pass
            # print('shape of output:', out.shape)
            # print('shape of target:', target.shape)
            loss = self.criterion(out, target)  # calculate loss
            loss_value = loss.item()
            train_losses.append(loss_value)
            loss.backward()  # one backward pass
            self.optimizer.step()  # update the parameters
            
            t_dices_c0.append(dice_score(out, target, 0))
            t_dices_c1.append(dice_score(out, target, 1))
            t_dices_c2.append(dice_score(out, target, 2))
            t_dices_c3.append(dice_score(out, target, 3))

            batch_iter.set_description(f'Training: (loss {loss_value:.4f})')  # update progressbar

        self.training_loss.append(np.mean(train_losses))

        self.train_dices_c0.append(np.mean(t_dices_c0))
        self.train_dices_c1.append(np.mean(t_dices_c1))
        self.train_dices_c2.append(np.mean(t_dices_c2))
        self.train_dices_c3.append(np.mean(t_dices_c3))

        self.learning_rate.append(self.optimizer.param_groups[0]['lr'])

        batch_iter.close()

    def _validate(self):

        if self.notebook:
            from tqdm.notebook import tqdm, trange
        else:
            from tqdm import tqdm, trange

        self.model.eval()  # evaluation mode
        valid_losses = []  # accumulate the losses here
        valid_dices_c0 = []
        valid_dices_c1 = []
        valid_dices_c2 = []
        valid_dices_c3 = []
        
        batch_iter = tqdm(enumerate(self.validation_DataLoader), 'Validation', total=len(self.validation_DataLoader),
                          leave=False)

        for i, (x, y) in batch_iter:
            input, target = x.to(self.device), y.to(self.device)  # send to device (GPU or CPU)

            with torch.no_grad():
                out = self.model(input)
                loss = self.criterion(out, target)
                loss_value = loss.item()
                valid_losses.append(loss_value)
                
    
                valid_dices_c0.append(dice_score(out, target, 0))
                valid_dices_c1.append(dice_score(out, target, 1))
                valid_dices_c2.append(dice_score(out, target, 2))
                valid_dices_c3.append(dice_score(out, target, 3))

                batch_iter.set_description(f'Validation: (loss {loss_value:.4f})')

        self.validation_loss.append(np.mean(valid_losses))
        
        self.validation_dices_c0.append(np.mean(valid_dices_c0))
        self.validation_dices_c1.append(np.mean(valid_dices_c1))
        self.validation_dices_c2.append(np.mean(valid_dices_c2))
        self.validation_dices_c3.append(np.mean(valid_dices_c3))
        
        batch_iter.close()

# Train

In [None]:
model = UNet3D(input_dim=1,
             output_dim=4,
             depth=3,
             dropout=0.2).to(device)

# model = torch.load(dataDir + 'unet_3d.pt')

In [None]:
# criterion
# criterion = torch.nn.CrossEntropyLoss()
criterion = CostumLoss().to(device)

# optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# trainer
trainer = Trainer(model=model,
                  device=device,
                  criterion=criterion,
                  optimizer=optimizer,
                  training_DataLoader=train_loader,
                  validation_DataLoader=val_loader,
                  lr_scheduler=None,
                  epochs=num_epochs,
                  epoch=0,
                  notebook=True)
# start training
training_losses, validation_losses, val_dices_0, val_dices_1, val_dices_2, val_dices_3 = trainer.run_trainer()

HBox(children=(FloatProgress(value=0.0, description='Progress', max=60.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Training', max=4.0, style=ProgressStyle(description_width…

AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER 

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER 

HBox(children=(FloatProgress(value=0.0, description='Training', max=4.0, style=ProgressStyle(description_width…

AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER 

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER 

HBox(children=(FloatProgress(value=0.0, description='Training', max=4.0, style=ProgressStyle(description_width…

AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)
AFTER shape of img is (8, 200, 200) and shape of mask is (8, 200, 200)


KeyboardInterrupt: ignored

In [None]:
torch.save(model, dataDir + 'unet_3d.pt')

In [None]:
model = torch.load(dataDir + 'unet_3d.pt')

In [None]:
len(training_losses), len(validation_losses), num_epochs

In [None]:
val_dices_0, val_dices_1, val_dices_2, val_dices_3

# Loss Plots

In [None]:
x_axis = range(0, num_epochs)
fig, ax = plt.subplots()
ax.plot(x_axis, training_losses, label='Train')
ax.plot(x_axis, validation_losses, label='Validation')
ax.legend()
plt.ylabel('Loss')
plt.show()

# Show Results

In [None]:
from tqdm.notebook import tqdm, trange

model.eval()  # evaluation mode
batch_iter = tqdm(enumerate(val_loader), 'Validation', total=len(val_loader),
                    leave=False)

model_outputs = []
original_images = []
gts = []
dice_c0 = []
dice_c1 = []
dice_c2 = []
dice_c3 = []

for i, (x, y) in batch_iter:
    input, target = x.to(device), y.to(device)  # send to device (GPU or CPU)
    
    original_images.append(input.detach().clone())
    gts.append(target.detach().clone())
    
    with torch.no_grad():
        out = model(input)
        print(out.shape)
    
    out_soft = torch.argmax(out, dim=1) 
    model_outputs.append(out_soft)

    
    dice_c0.append(dice_score(out.detach().clone(), target, 0))
    dice_c1.append(dice_score(out.detach().clone(), target, 1))
    dice_c2.append(dice_score(out.detach().clone(), target, 2))
    dice_c3.append(dice_score(out.detach().clone(), target, 3))


batch_iter.close()

In [None]:
model_outputs = np.asarray(model_outputs)
original_images = np.asarray(original_images)
gts = np.asarray(gts)

In [None]:
model_outputs[0].shape, original_images[0].shape, gts[0].shape

In [None]:
plt.imshow(original_images[0][2][0][6].cpu(), cmap='gray')

In [None]:
plt.imshow(gts[0][2][6].cpu(), cmap='gray')

In [None]:
plt.imshow(model_outputs[0][2][6].cpu(), cmap='gray')

# Check

In [None]:
metadata[metadata['External code'] == 'A5C2D2']

In [None]:
p = dataDir + 'Validation/A5C2D2/A5C2D2_sa_gt.nii.gz'
a = nib.load(p).get_fdata()

In [None]:
a.shape

In [None]:
a = a[:,:,:, 24]

In [None]:
a = change_sizes(a, 0, True)
a = normalize_01(a)
a.shape

In [None]:
plt.imshow(a[2], cmap='gray')