In [1]:
import torch
import torch.nn as nn
# from apex import amp
from torch.cuda import amp
import numpy as np
from tqdm import tqdm
import torch
from torch.nn import functional as F

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as transforms
from torch.autograd import Variable
from PIL import Image
from matplotlib.pyplot import imshow
from scipy import ndimage as ndi
from skimage.transform import resize
import pandas as pd
import os
import nibabel as nib


import matplotlib.pyplot as plt
%matplotlib inline

%gui qt

In [2]:
# from google.colab import drive
# drive.mount('/content/drive', force_remount=True)

In [3]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
!nvidia-smi

Thu May 13 06:53:13 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 430.50       Driver Version: 430.50       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  GeForce GTX 108...  Off  | 00000000:00:05.0 Off |                  N/A |
|  0%   31C    P0    57W / 250W |      0MiB / 11178MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage    

In [4]:
# device
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device= torch.device('cpu')
device

device(type='cuda')

# Import Data

In [5]:
dataDir = 'downloads/OpenDataset/'
# dataDir = 'drive/My Drive/B.Sc proj/OpenDataset/'
metadata = pd.read_csv(dataDir + '201014_M&Ms_Dataset_Information_-_opendataset.csv')
metadata

Unnamed: 0,External code,VendorName,Vendor,Centre,ED,ES
0,A0S9V9,Siemens,A,1,0,9
1,A1D0Q7,Philips,B,2,0,9
2,A1D9Z7,Siemens,A,1,22,11
3,A1E9Q1,Siemens,A,1,0,9
4,A1K2P5,Canon,D,5,33,11
...,...,...,...,...,...,...
340,T2Z1Z9,Canon,D,5,29,9
341,T9U9W2,Siemens,A,1,0,10
342,V4W8Z5,GE,C,4,19,9
343,W5Z4Z8,Philips,B,2,29,11


# Define Dataset

In [6]:
def normalize_01(inp: np.ndarray):
    inp_out = (inp - np.min(inp)) / np.ptp(inp)
    return inp_out

def change_sizes(img, diff_seq, is_mask):
    img = img[:,:, diff_seq//2:img.shape[2]-(diff_seq-diff_seq//2)]
    res = []
    for i in range(img.shape[2]):
        x = resize(img[:,:, i], (200,200))
        res.append(x)
    return np.asarray(res)
    
def change_size_and_pad(img, diff_seq):
    res = []
    temp = np.zeros((200, 200))
    for i in range(diff_seq//2):
        res.append(temp)
    for i in range(img.shape[2]):
        x = resize(img[:,:, i], (200,200))
        res.append(x)
    for i in range(diff_seq - diff_seq//2):
        res.append(temp)
    return np.asarray(res)


def convert_two_class(x, inp_classes, target_class):
    for c in inp_classes:
        x[x == c] = target_class
    return x


p = dataDir + 'Validation/A5C2D2/A5C2D2_sa_gt.nii.gz'
a = nib.load(p).get_fdata()


class SegDataset(Dataset):
    def __init__(self, all_dirs, metadata, addr):
        self.all_dirs = all_dirs 
        self.metadata = metadata
        self.addr = addr
        self.num_seq = 8

    def get_class_weights(self):
        eps = 1e-6
        ones = np.sum(a == 1)
        zeros = np.sum(a == 0) + np.sum(a == 2) + np.sum(a == 3)
        
        return [ones / (ones + zeros + eps), zeros / (ones + zeros + eps)]

    def __len__(self):
        return 2*len(self.all_dirs)
    
    def __getitem__(self, idx):
        item = self.all_dirs[idx//2]
        path = self.addr + item

        mask, img, id = None, None, None

        ### lazy load data
        for root, dir, files in os.walk(path):
            for f in files:
                img_arr = nib.load(path + '/' + f).get_fdata()
#                 print(f"file name is {f}")
                if f[-8] == 't':
                    id = f.split('_')[0]
                    mask = img_arr
                else:
                    img = img_arr
        
        # preparing data 
#         print(f"df: {metadata[metadata['External code'] == id]['ES']}")
        if idx %2 == 0:
            t = metadata[metadata['External code'] == id]['ES'].iloc[0]
        else:
            t = metadata[metadata['External code'] == id]['ED'].iloc[0]
        
        img = img[:, :, :, t]
        gt = mask[:, :, :, t]


#         print('maximum number of mask', np.max(gt) , '\tmax of img', np.max(img))

        diff_seq = img.shape[2] - self.num_seq
        if diff_seq > 0: 
            img = change_sizes(img, diff_seq, False)
            gt = change_sizes(gt, diff_seq, True)
        else:
#             print(f'BEFORE shape of img is {img.shape} and shape of mask is {gt.shape}')
            img = change_size_and_pad(img, self.num_seq-img.shape[2])
            gt = change_size_and_pad(gt, self.num_seq-gt.shape[2])
        
        # print(f'AFTER shape of img is {img.shape} and shape of mask is {gt.shape}')

        img = normalize_01(img)
                

        # convert to tensors
        img = torch.FloatTensor([img])
#         mask = torch.LongTensor([gt])
        mask = torch.LongTensor(gt)

        # convert to 2 class format
#         mask = convert_two_class(mask, [1,3], 0)
#         mask[mask == 2] = 1
#         mask = mask.float()# we only focus on class 1 
        # print('maxxxx minnnnn', torch.max(mask), torch.min(mask))

        return img, mask

# Load Dataset, DataLoader

In [7]:
batch_size = 4
num_epochs = 50
learning_rate = 0.001

UNLABELED = 4 

In [8]:
train_dir = dataDir + '/Training/Labeled/'
train_all = os.listdir(train_dir)
train_dataset = SegDataset(train_all, metadata, train_dir)


val_dir = dataDir + '/Validation/'
val_all = os.listdir(val_dir)[:8]
val_dataset = SegDataset(val_all, metadata, val_dir)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

# Define Loss


In [9]:
def dice_score(inputs, targets, class_num, dims=[1,2,3], smooth=0.5):
    inputs = inputs.detach().clone()
    targets = targets.detach().clone()    
    
    inputs = torch.argmax(inputs, dim=1)
#     print(f"check dimensions {inputs.shape}, {targets.shape}")



#     inputs = torch.sigmoid(inputs)
    
    # round the probabilities to the closest integer
    # target = torch.round(target)
#     inputs = torch.round(inputs)
    
#     print('min and max', torch.min(inputs), torch.max(inputs))
    
    inputs[inputs == class_num] = -1
    inputs[inputs != -1] = 0
    inputs[inputs != 0] = 1

    targets[targets == class_num] = -1
    targets[targets != -1] = 0
    targets[targets != 0] = 1   

#     print(f"inputs:{inputs}\ntargets:{targets}")

    intersection = 2 * (inputs * targets)
    total = (inputs + targets)
#     print(f"intersection: {intersection}\ntotal: {total}")
    
    intersection = intersection.sum(dim=dims)
    total = total.sum(dim=dims)
#     print(f"intersection: {intersection}\ntotal: {total}")

    dice = torch.add(intersection, smooth)/torch.add(total, smooth)
    
#     print('shape of dice', dice.shape)
    dice = torch.mean(dice).item()
#     print(dice)
    return dice  

In [10]:
import torch
from torch.nn import functional as F

smooth = 1e-2 # smooth value for dice_coef calculation

def dice_coef(outputs, target):

    # target: labels of current batch. can contain UNLABELED or probabilities.
    #     shape: (batch_size, 1, image_slices, width, height)
    
    # outputs: outputs of the model for current batch, without any activations.
    #     shape: (batch_size, 1, image_slices, width, height)
    
    # get the batch size
    batch_size = target.shape[0]
    
    # we calcaulate dice for each sample separately and return their sum.
    if batch_size > 1:
        result = dice_coef(outputs[0:1], target[0:1]) # dice of the first sample in batch
        for i in range(1, len(target)): # add dice of the other samples
            result = result + dice_coef(outputs[i:i+1], target[i:i+1])
        return result
    
    # convert `outputs` to probability
    outputs = torch.sigmoid(outputs)
    
    # round the probabilities to the closest integer
    # target = torch.round(target)
    outputs = torch.round(outputs)
    
    # shape of target and outputs will be (batch_size, 1, image_slices, width, height) here.
    #     containing the class of each voxel.
    
    # find the actual labeled voxels. 
    # labeled_voxels = target != UNLABELED
    # target = target[labeled_voxels]
    # outputs = outputs[labeled_voxels]
    
    # so now both target and outputs have only values 0 and 1 for background and liver respectively.

    
    # calculate intersection of target and outputs for each sample in batch
    intersection = (target * outputs).view(batch_size, -1).sum(-1).float()
    
    # calculate sum of target and outputs for each sample in batch
    union = (target + outputs).view(batch_size, -1).sum(-1).float()

    # numerator of dice_coef
    numerator = 2. * intersection + smooth

    # denominator of dice_coef
    denominator = union + smooth

    # calculate dice for each sample in batch
    coef = numerator / denominator
    
    # sum over samples in batch
    return coef.sum(0)


def soft_binray_cross_entropy(outputs, target, class_weights, voxel_weights=None):
    # target: labels of current batch. can contain UNLABELED or probabilities.
    #     shape: (batch_size, 1, image_slices, width, height)
    
    # outputs: outputs of the model for current batch, without any activations.
    #     shape: (batch_size, 1, image_slices, width, height)
    
    # class_weights: weights of classes. a tensor with shape (2, ). can't be None.
    
    # voxel_weights: a tensor with the exact shape of target's, containing a weight for each voxel. can be set to None.
    #     shape: (batch_size, 1, image_slices, width, height)
    
    # get the batch size
    batch_size = target.shape[0]
    
    # we calcaulate loss for each sample separately and return their sum.
    if batch_size > 1:
        # loss of the first sample in batch
        result = soft_binray_cross_entropy(
            outputs[0:1], 
            target[0:1], 
            class_weights, 
            None if voxel_weights is None else voxel_weights[0:1]
        )
        
        for i in range(1, len(target)): # add loss of the other samples
            result = result + soft_binray_cross_entropy(
                outputs[i:i+1], 
                target[i:i+1], 
                class_weights, 
                None if voxel_weights is None else voxel_weights[i:i+1]
            )
        return result
    
    # max(x, 0) - x * z + log(1 + exp(-abs(x)))
    outputs_positives = outputs.clone()
    outputs_positives[outputs_positives < 0] = 0.
    losses = outputs_positives - outputs * target + torch.log(1 + torch.exp(-torch.abs(outputs)))
    # print('shape of loss', loss.shape)
    
    # calculate strict target (round the probabilities to closest integers)
    # strict_target = torch.round(target)
    strict_target = target
    
    # initialize weights with zeros
    weights = torch.zeros_like(losses).to(target.device).float()
    
    # we set the weights of voxels according to their class
    # the weight of voxels with class `UNLABELED` will remain zero
    weights[strict_target == 0] = class_weights[0]
    weights[strict_target == 1] = class_weights[1]
    
    # apply voxel_weights if not None
    if voxel_weights is not None:
        weights = weights * voxel_weights
    
    # weighted mean
    return (losses * weights).sum() / weights.sum()

# Define UNet


In [11]:
import torch
import torch.nn as nn
# from apex import amp
from torch.cuda import amp
import numpy as np
from tqdm import tqdm

class UNet3D(nn.Module):
    def __init__(self, input_dim, output_dim, initial_filters=16, depth=4, dropout=0):
        # input_dim: number of input channles. 1 in our case.
        # output dim: number of output channels. 3 in our case. (since we have 3 labels)
        # initial_filter: kernel_size of the first conv. usually 16 or 32.
        # depth: depth of the U in unet. `image_slices` should be divisible by (2 ** depth)
        # dropout: the possibility of zeroing a node out.
        
        super(UNet3D, self).__init__()
        
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.initial_filters = initial_filters
        self.depth = depth
        self.dropout = dropout
        activation = nn.ReLU(inplace=True)
        
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.pools = nn.ModuleList()
        self.trans = nn.ModuleList()
        self.dropouts = nn.ModuleList()
        
        # Down
        for i in range(depth):
            in_dim = self.input_dim
            if i > 0:
                in_dim = self.initial_filters * 2 ** (i - 1)
            out_dim = self.initial_filters * 2 ** i
            self.downs.append(self.double_conv_block(in_dim, out_dim, activation))
            self.pools.append(self.max_pool())
            
        
        # Bridge
        self.bridge = self.double_conv_block(
            self.initial_filters * 2 ** (depth - 1),
            self.initial_filters * 2 ** depth,
            activation
        )
        
        # Dropout
        self.dropouts.append(self.dropout_layer())
        
        # Up
        for i in range(depth):
            trans_in_out_dim = self.initial_filters * 2 ** (depth - i)
            self.trans.append(self.conv_transpose_block(trans_in_out_dim, trans_in_out_dim, activation))
            
            up_in_dim = self.initial_filters * (2 ** (depth - i) + 2 ** (depth - i - 1))
            up_out_dim = self.initial_filters * 2 ** (depth - i - 1)
            self.ups.append(self.double_conv_block(up_in_dim, up_out_dim, activation))
            
        # Dropout
        self.dropouts.append(self.dropout_layer())
        
        # Output
        self.out = self.prediction_mask(initial_filters, self.output_dim)
        
        self.initialize_parameters()
        
        
    @staticmethod
    def weight_init(module, method, **kwargs):
        if isinstance(module, (nn.Conv3d, nn.Conv2d, nn.ConvTranspose3d, nn.ConvTranspose2d)):
            method(module.weight, **kwargs)  # weights

    @staticmethod
    def bias_init(module, method, **kwargs):
        if isinstance(module, (nn.Conv3d, nn.Conv2d, nn.ConvTranspose3d, nn.ConvTranspose2d)):
            method(module.bias, **kwargs)  # bias

    def initialize_parameters(self,
                              method_weights=nn.init.xavier_uniform_,
                              method_bias=nn.init.zeros_,
                              kwargs_weights={},
                              kwargs_bias={}
                              ):
        for module in self.modules():
            self.weight_init(module, method_weights, **kwargs_weights)  # initialize weights
            self.bias_init(module, method_bias, **kwargs_bias)  # initialize bias
        
    def single_conv_block(self, input_dim, output_dim, activation):
        return nn.Sequential(
            nn.Conv3d(input_dim, output_dim, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm3d(output_dim),
            activation
        )

    def double_conv_block(self, input_dim, output_dim, activation):
        return nn.Sequential(
            self.single_conv_block(input_dim, output_dim, activation),
            #nn.Dropout(p=self.dropout),
            self.single_conv_block(output_dim, output_dim, activation)
        )

    def conv_transpose_block(self, input_dim, output_dim, activation):
        return nn.Sequential(
            nn.ConvTranspose3d(input_dim, output_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm3d(output_dim),
            activation,
            #nn.Dropout(p=self.dropout)
        )

    def prediction_mask(self, input_dim, output_dim):
        # the activation is considered to in the loss.
        return nn.Conv3d(input_dim, output_dim, kernel_size=1, stride=1, padding=0) 

    def max_pool(self):
        return nn.MaxPool3d(kernel_size=2, stride=2, padding=0) 
    
    def dropout_layer(self):
        return nn.Dropout(p=self.dropout)
        
    def forward(self, x):
        downs = []
        ups = []
        pools = []
        trans = []
        concats = []
        dropouts = []
        
        # Down
        for i in range(self.depth):
            inp = x
            if i > 0:
                inp = pools[-1]
            downs.append(self.downs[i](inp))
            pools.append(self.pools[i](downs[-1]))
        
        # Bridge
        bridge = self.bridge(pools[-1])
        
        # Dropout
        dropouts.append(self.dropouts[0](bridge))
        
        # Up
        for i in range(self.depth):
            inp = dropouts[-1]
            if i > 0:
                inp = ups[-1]
            trans.append(self.trans[i](inp))
            concats.append(torch.cat([trans[-1], downs[self.depth - i - 1]], dim=1))
            ups.append(self.ups[i](concats[-1]))
        
        # Dropout
        dropouts.append(self.dropouts[1](ups[-1]))
        
        # Output
        out = self.out(dropouts[-1])
        return out


# Define Trainer

In [12]:
class Trainer:
    def __init__(self,
                 model: torch.nn.Module,
                 device: torch.device,
                 criterion: torch.nn.Module,
                 optimizer: torch.optim.Optimizer,
                 training_DataLoader: torch.utils.data.Dataset,
                 validation_DataLoader: torch.utils.data.Dataset = None,
                 lr_scheduler: torch.optim.lr_scheduler = None,
                 epochs: int = 100,
                 epoch: int = 0,
                 notebook: bool = False
                 ):

        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.training_DataLoader = training_DataLoader
        self.validation_DataLoader = validation_DataLoader
        self.device = device
        self.epochs = epochs
        self.epoch = epoch
        self.notebook = notebook
        

        self.training_loss = []
        self.validation_loss = []
        self.learning_rate = []
        
        self.validation_dices_c0 = []
        self.validation_dices_c1 = []
        self.validation_dices_c2 = []
        self.validation_dices_c3 = []


        self.train_dices_c0 = []
        self.train_dices_c1 = []
        self.train_dices_c2 = []
        self.train_dices_c3 = []
        
        cols = ['epoch', 't_dice_0', 't_dice_1', 't_dice_2', 't_dice_3', 'v_dice_0', 'v_dice_1', 'v_dice_2', 'v_dice_3', 't_loss', 'v_loss']
        self.log = pd.DataFrame(columns=cols)

        self.class_weights = torch.Tensor(training_DataLoader.dataset.get_class_weights()).cuda(0)
    


    def run_trainer(self):

        if self.notebook:
            from tqdm.notebook import tqdm, trange
        else:
            from tqdm import tqdm, trange

        progressbar = trange(self.epochs, desc='Progress')
        for i in progressbar:
            """Epoch counter"""
            self.epoch += 1  # epoch counter

            """Training block"""
            self._train()

            """Validation block"""
            if self.validation_DataLoader is not None:
                self._validate()

            """Learning rate scheduler block"""
            if self.lr_scheduler is not None:
                if self.validation_DataLoader is not None and self.lr_scheduler.__class__.__name__ == 'ReduceLROnPlateau':
                    self.lr_scheduler.batch(self.validation_loss[i])  # learning rate scheduler step with validation loss
                else:
                    self.lr_scheduler.batch()  # learning rate scheduler step
            
            print(f"epoch number {self.epoch} finished with following results:")
            print('\ttrain dice c0', self.train_dices_c0[i], '\ttrain dice c1', self.train_dices_c1[i])
            print('\ttrain dice c2', self.train_dices_c2[i], '\ttrain dice c3', self.train_dices_c3[i])
            print('\tvalidation dice c0', self.validation_dices_c0[i], '\tvalidation dice c1', self.validation_dices_c1[i])
            print('\tvalidation dice c2', self.validation_dices_c2[i], '\tvalidation dice c4', self.validation_dices_c3[i])
            print('\ttrain loss ', self.training_loss[i], '\tvalidation loss ', self.validation_loss[i])
            

            log_dict = {'epoch' : self.epoch,
                        't_dice_0': self.train_dices_c0[i], 't_dice_1': self.train_dices_c1[i],
                        't_dice_2': self.train_dices_c2[i], 't_dice_3': self.train_dices_c3[i],
                        'v_dice_0':self.validation_dices_c0[i], 'v_dice_1':self.validation_dices_c1[i], 
                        'v_dice_2':self.validation_dices_c2[i], 'v_dice_3':self.validation_dices_c3[i], 
                        't_loss': self.training_loss[i], 'v_loss':self.validation_loss }
            
            self.log.append(log_dict, ignore_index=True)

            if i % 10 == 0 and i > 0:
                torch.save(self.model, dataDir + 'unet_3d_4class.pt')
                self.log.to_csv('log_2class_c2.csv')
                # self.dice_0_per_slice_log.to_csv('dice_0_per_slice_log.csv')
        
        
        self.log.to_csv('log_4class.csv')
        return self.training_loss, self.validation_loss, self.validation_dices_c0

    def _train(self):

        if self.notebook:
            from tqdm.notebook import tqdm, trange
        else:
            from tqdm import tqdm, trange

        self.model.train()  # train mode
        train_losses = []  # accumulate the losses here

        t_dices_c0 = []
        t_dices_c1 = []
        t_dices_c2 = []
        t_dices_c3 = []
        
        batch_iter = tqdm(enumerate(self.training_DataLoader), 'Training', total=len(self.training_DataLoader),
                          leave=False)

        for i, (x, y) in batch_iter:
            # print('shape of input:', x.shape)
            # print('shape of output:', y.shape)
            input, target = x.to(0), y.to(0)  # send to device (GPU or CPU)
            self.optimizer.zero_grad()  # zerograd the parameters
            out = self.model(input)  # one forward pass
            # out = torch.argmax(out, dim=1)
            # out = torch.unsqueeze(out, 1)
            # print('**** shape of output:', out.shape)
            # print('**** shape of target:', target.shape)
#             loss = soft_binray_cross_entropy( 
#                     out,
#                     target, # target
#                     class_weights=self.class_weights,
#                     # voxel_weights=VW.cuda(self.gpu)
#                 )  # calculate loss
#             loss = Variable(loss.data, requires_grad=True)
            loss = criterion(out, target)
            loss_value = loss.item()
            train_losses.append(loss_value)
            loss.backward()  # one backward pass
            self.optimizer.step()  # update the parameters
            
#             t_dices_c0.append(dice_coef(out, target).item())
            t_dices_c0.append(dice_score(out, target,0))
            t_dices_c1.append(dice_score(out, target,1))
            t_dices_c2.append(dice_score(out, target,2))
            t_dices_c3.append(dice_score(out, target,3))
                              
            batch_iter.set_description(f'Training: (loss {loss_value:.4f})')  # update progressbar

        self.training_loss.append(np.mean(train_losses))

        self.train_dices_c0.append(np.mean(t_dices_c0))
        self.train_dices_c1.append(np.mean(t_dices_c1))
        self.train_dices_c2.append(np.mean(t_dices_c2))
        self.train_dices_c3.append(np.mean(t_dices_c3))

        self.learning_rate.append(self.optimizer.param_groups[0]['lr'])

        batch_iter.close()

    def _validate(self):

        if self.notebook:
            from tqdm.notebook import tqdm, trange
        else:
            from tqdm import tqdm, trange

        self.model.eval()  # evaluation mode
        valid_losses = []  # accumulate the losses here
        valid_dices_c0 = []
        valid_dices_c1 = []
        valid_dices_c2 = []
        valid_dices_c3 = []
        
        batch_iter = tqdm(enumerate(self.validation_DataLoader), 'Validation', total=len(self.validation_DataLoader),
                          leave=False)

        for i, (x, y) in batch_iter:
            input, target = x.to(self.device), y.to(self.device)  # send to device (GPU or CPU)

            with torch.no_grad():
                out = self.model(input)
                # out = torch.argmax(out, dim=1)
                # out = torch.unsqueeze(out, 1)
#                 loss = soft_binray_cross_entropy( 
#                     out,
#                     target, # target
#                     class_weights=self.class_weights,
#                     # voxel_weights=VW.cuda(self.gpu)
#                 )  # calculate loss
                loss = criterion(out, target)
                loss_value = loss.item()
                valid_losses.append(loss_value)
                
    
#                 valid_dices_c0.append(dice_coef(out, target).item())
                valid_dices_c0.append(dice_score(out, target,0))
                valid_dices_c1.append(dice_score(out, target,1))
                valid_dices_c2.append(dice_score(out, target,2))
                valid_dices_c3.append(dice_score(out, target,3))


                batch_iter.set_description(f'Validation: (loss {loss_value:.4f})')

        self.validation_loss.append(np.mean(valid_losses))
        
        self.validation_dices_c0.append(np.mean(valid_dices_c0))
        self.validation_dices_c1.append(np.mean(valid_dices_c1))
        self.validation_dices_c2.append(np.mean(valid_dices_c2))
        self.validation_dices_c3.append(np.mean(valid_dices_c3))
        
        batch_iter.close()

# Train

In [13]:
model = UNet3D(input_dim=1,
             output_dim=4,
             depth=3,
             dropout=0.2).to(device)

# model = torch.load(dataDir + 'unet_3d.pt')

In [None]:
# criterion
# criterion = torch.nn.BCEWithLogitsLoss()
criterion = torch.nn.CrossEntropyLoss().to(device)
# criterion = CostumLoss().to(device)

# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# trainer
trainer = Trainer(model=model,
                  device=device,
                  criterion=criterion,
                  optimizer=optimizer,
                  training_DataLoader=train_loader,
                  validation_DataLoader=val_loader,
                  lr_scheduler=None,
                  epochs=num_epochs,
                  epoch=0,
                  notebook=True)
# start training
training_losses, validation_losses, val_dices_0 = trainer.run_trainer()

HBox(children=(FloatProgress(value=0.0, description='Progress', max=50.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

In [None]:
torch.save(model, dataDir + 'unet_3d_4class.pt')

In [None]:
model = torch.load(dataDir + 'unet_3d_4class.pt')

# Loss Plots

In [None]:
x_axis = range(0, num_epochs)
fig, ax = plt.subplots()
ax.plot(x_axis, training_losses, label='Train')
ax.plot(x_axis, validation_losses, label='Validation')
ax.legend()
plt.ylabel('Loss')
plt.show()

# Show Results

In [None]:
from tqdm.notebook import tqdm, trange

model.eval()  # evaluation mode
batch_iter = tqdm(enumerate(val_loader), 'Validation', total=len(val_loader),
                    leave=False)

model_outputs = []
original_images = []
gts = []
dice_c0 = []
dice_c1 = []
dice_c2 = []
dice_c3 = []

for i, (x, y) in batch_iter:
    input, target = x.to(device), y.to(device)  # send to device (GPU or CPU)
    
    original_images.append(input.detach().clone())
    gts.append(target.detach().clone())
    
    with torch.no_grad():
        out = model(input)
        out = torch.sigmoid(out)
        out = torch.round(out)
        print(out.shape)
    
#     out_soft = torch.argmax(out, dim=1) 
#     model_outputs.append(out_soft)
    model_outputs.append(out)

    
#     dice_c0.append(dice_score(out.detach().clone(), target, 0))
    dice_c1.append(dice_score(out.detach().clone(), target, 1))
#     dice_c2.append(dice_score(out.detach().clone(), target, 2))
#     dice_c3.append(dice_score(out.detach().clone(), target, 3))


batch_iter.close()

In [None]:
model_outputs = np.asarray(model_outputs)
original_images = np.asarray(original_images)
gts = np.asarray(gts)

In [None]:
model_outputs[0].shape, original_images[0].shape, gts[0].shape

In [None]:
plt.imshow(original_images[0][2][0][6].cpu(), cmap='gray')

In [None]:
plt.imshow(gts[0][2][0][6].cpu(), cmap='gray')

In [None]:
plt.imshow(model_outputs[0][1][0][6].cpu(), cmap='gray')

In [None]:
dice_c1