In [1]:
import torch
import torch.nn as nn
# from apex import amp
from torch.cuda import amp
import numpy as np
from tqdm import tqdm
import torch
from torch.nn import functional as F

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as transforms
from torch.autograd import Variable
from PIL import Image
from matplotlib.pyplot import imshow
from scipy import ndimage as ndi
from skimage.transform import resize
import pandas as pd
import os
import nibabel as nib


import matplotlib.pyplot as plt
%matplotlib inline

%gui qt

In [2]:
# from google.colab import drive
# drive.mount('/content/drive', force_remount=True)

In [3]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
!nvidia-smi

Fri May 28 05:45:42 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 430.50       Driver Version: 430.50       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  GeForce GTX 108...  Off  | 00000000:00:05.0 Off |                  N/A |
|  0%   19C    P0    53W / 250W |      0MiB / 11178MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage    

In [4]:
# device
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device= torch.device('cpu')
device

device(type='cuda')

# Import Data

In [5]:
dataDir = 'downloads/OpenDataset/'
# dataDir = 'drive/My Drive/B.Sc proj/OpenDataset/'
metadata = pd.read_csv(dataDir + '201014_M&Ms_Dataset_Information_-_opendataset.csv')
metadata

Unnamed: 0,External code,VendorName,Vendor,Centre,ED,ES
0,A0S9V9,Siemens,A,1,0,9
1,A1D0Q7,Philips,B,2,0,9
2,A1D9Z7,Siemens,A,1,22,11
3,A1E9Q1,Siemens,A,1,0,9
4,A1K2P5,Canon,D,5,33,11
...,...,...,...,...,...,...
340,T2Z1Z9,Canon,D,5,29,9
341,T9U9W2,Siemens,A,1,0,10
342,V4W8Z5,GE,C,4,19,9
343,W5Z4Z8,Philips,B,2,29,11


In [6]:
# train_folder = 'train/train'

# for root, dir, files in os.walk(train_folder):
#             for f in files:
#                 a = pd.read_csv(train_folder)
# #                 

# Define Dataset

In [7]:
def normalize_01(inp: np.ndarray):
    inp_out = (inp - np.min(inp)) / np.ptp(inp)
    return inp_out

def change_sizes(img, diff_seq, is_mask):
    img = img[:,:, diff_seq//2:img.shape[2]-(diff_seq-diff_seq//2)]
    res = []
    for i in range(img.shape[2]):
        x = resize(img[:,:, i], (224,224))
        res.append(x)
    return np.asarray(res)
    
def change_size_and_pad(img, diff_seq):
    res = []
    temp = np.zeros((224, 224))
    for i in range(diff_seq//2):
        res.append(temp)
    for i in range(img.shape[2]):
        x = resize(img[:,:, i], (224,224))
        res.append(x)
    for i in range(diff_seq - diff_seq//2):
        res.append(temp)
    return np.asarray(res)


def convert_two_class(x, inp_classes, target_class):
    for c in inp_classes:
        x[x == c] = target_class
    return x


p = dataDir + 'Validation/A5C2D2/A5C2D2_sa_gt.nii.gz'
a = nib.load(p).get_fdata()


class SegDataset(Dataset):
    def __init__(self, all_dirs, metadata, addr):
        self.all_dirs = all_dirs 
        self.metadata = metadata
        self.addr = addr
        self.num_seq = 16


    def __len__(self):
        return 2*len(self.all_dirs)
    
    def __getitem__(self, idx):
        item = self.all_dirs[idx//2]
        path = self.addr + item

        mask, img, id = None, None, None

        ### lazy load data
        for root, dir, files in os.walk(path):
            for f in files:
                img_arr = nib.load(path + '/' + f).get_fdata()
#                 print(f"file name is {f}")
                if f[-8] == 't':
                    id = f.split('_')[0]
                    mask = img_arr
                else:
                    img = img_arr
        
        # preparing data 
#         print(f"df: {metadata[metadata['External code'] == id]['ES']}")
        if idx %2 == 0:
            t = metadata[metadata['External code'] == id]['ES'].iloc[0]
        else:
            t = metadata[metadata['External code'] == id]['ED'].iloc[0]
        
        img = img[:, :, :, t]
        gt = mask[:, :, :, t]


#         print('maximum number of mask', np.max(gt) , '\tmax of img', np.max(img))

        diff_seq = img.shape[2] - self.num_seq
        if diff_seq > 0: 
            img = change_sizes(img, diff_seq, False)
            gt = change_sizes(gt, diff_seq, True)
        else:
#             print(f'BEFORE shape of img is {img.shape} and shape of mask is {gt.shape}')
            img = change_size_and_pad(img, self.num_seq-img.shape[2])
            gt = change_size_and_pad(gt, self.num_seq-gt.shape[2])
        
        # print(f'AFTER shape of img is {img.shape} and shape of mask is {gt.shape}')

        img = normalize_01(img)
                

        # convert to tensors
        img = torch.FloatTensor([img])
#         mask = torch.LongTensor([gt])
        mask = torch.LongTensor(gt)

        # convert to 2 class format
#         mask = convert_two_class(mask, [1,3], 0)
#         mask[mask == 2] = 1
#         mask = mask.float()# we only focus on class 1 
        # print('maxxxx minnnnn', torch.max(mask), torch.min(mask))

        return img, mask

# Load Dataset, DataLoader

In [8]:
batch_size = 4
num_epochs = 80
learning_rate = 0.0001

UNLABELED = 4 

In [9]:
train_dir = dataDir + '/Training/Labeled/'
train_all = os.listdir(train_dir)
train_dataset = SegDataset(train_all, metadata, train_dir)


val_dir = dataDir + '/Validation/'
val_all = os.listdir(val_dir)[:8]
val_dataset = SegDataset(val_all, metadata, val_dir)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

# Define Loss


In [10]:
def dice_score(inputs, targets, class_num, dims=[1,2,3], smooth=0.5):
    inputs = inputs.detach().clone()
    targets = targets.detach().clone()    
    
    inputs = torch.argmax(inputs, dim=1)
#     print(f"check dimensions {inputs.shape}, {targets.shape}")



#     inputs = torch.sigmoid(inputs)
    
    # round the probabilities to the closest integer
    # target = torch.round(target)
#     inputs = torch.round(inputs)
    
#     print('min and max', torch.min(inputs), torch.max(inputs))
    
    inputs[inputs == class_num] = -1
    inputs[inputs != -1] = 0
    inputs[inputs != 0] = 1

    targets[targets == class_num] = -1
    targets[targets != -1] = 0
    targets[targets != 0] = 1   

#     print(f"inputs:{inputs}\ntargets:{targets}")

    intersection = 2 * (inputs * targets)
    total = (inputs + targets)
#     print(f"intersection: {intersection}\ntotal: {total}")
    
    intersection = intersection.sum(dim=dims)
    total = total.sum(dim=dims)
#     print(f"intersection: {intersection}\ntotal: {total}")

    dice = torch.add(intersection, smooth)/torch.add(total, smooth)
    
#     print('shape of dice', dice.shape)
    dice = torch.mean(dice).item()
#     print(dice)
    return dice  

In [11]:
import torch
from torch.nn import functional as F

smooth = 1e-2 # smooth value for dice_coef calculation

def dice_coef(outputs, target):

    # target: labels of current batch. can contain UNLABELED or probabilities.
    #     shape: (batch_size, 1, image_slices, width, height)
    
    # outputs: outputs of the model for current batch, without any activations.
    #     shape: (batch_size, 1, image_slices, width, height)
    
    # get the batch size
    batch_size = target.shape[0]
    
    # we calcaulate dice for each sample separately and return their sum.
    if batch_size > 1:
        result = dice_coef(outputs[0:1], target[0:1]) # dice of the first sample in batch
        for i in range(1, len(target)): # add dice of the other samples
            result = result + dice_coef(outputs[i:i+1], target[i:i+1])
        return result
    
    # convert `outputs` to probability
    outputs = torch.sigmoid(outputs)
    
    # round the probabilities to the closest integer
    # target = torch.round(target)
    outputs = torch.round(outputs)
    
    # shape of target and outputs will be (batch_size, 1, image_slices, width, height) here.
    #     containing the class of each voxel.
    
    # find the actual labeled voxels. 
    # labeled_voxels = target != UNLABELED
    # target = target[labeled_voxels]
    # outputs = outputs[labeled_voxels]
    
    # so now both target and outputs have only values 0 and 1 for background and liver respectively.

    
    # calculate intersection of target and outputs for each sample in batch
    intersection = (target * outputs).view(batch_size, -1).sum(-1).float()
    
    # calculate sum of target and outputs for each sample in batch
    union = (target + outputs).view(batch_size, -1).sum(-1).float()

    # numerator of dice_coef
    numerator = 2. * intersection + smooth

    # denominator of dice_coef
    denominator = union + smooth

    # calculate dice for each sample in batch
    coef = numerator / denominator
    
    # sum over samples in batch
    return coef.sum(0)


def soft_binray_cross_entropy(outputs, target, class_weights, voxel_weights=None):
    # target: labels of current batch. can contain UNLABELED or probabilities.
    #     shape: (batch_size, 1, image_slices, width, height)
    
    # outputs: outputs of the model for current batch, without any activations.
    #     shape: (batch_size, 1, image_slices, width, height)
    
    # class_weights: weights of classes. a tensor with shape (2, ). can't be None.
    
    # voxel_weights: a tensor with the exact shape of target's, containing a weight for each voxel. can be set to None.
    #     shape: (batch_size, 1, image_slices, width, height)
    
    # get the batch size
    batch_size = target.shape[0]
    
    # we calcaulate loss for each sample separately and return their sum.
    if batch_size > 1:
        # loss of the first sample in batch
        result = soft_binray_cross_entropy(
            outputs[0:1], 
            target[0:1], 
            class_weights, 
            None if voxel_weights is None else voxel_weights[0:1]
        )
        
        for i in range(1, len(target)): # add loss of the other samples
            result = result + soft_binray_cross_entropy(
                outputs[i:i+1], 
                target[i:i+1], 
                class_weights, 
                None if voxel_weights is None else voxel_weights[i:i+1]
            )
        return result
    
    # max(x, 0) - x * z + log(1 + exp(-abs(x)))
    outputs_positives = outputs.clone()
    outputs_positives[outputs_positives < 0] = 0.
    losses = outputs_positives - outputs * target + torch.log(1 + torch.exp(-torch.abs(outputs)))
    # print('shape of loss', loss.shape)
    
    # calculate strict target (round the probabilities to closest integers)
    # strict_target = torch.round(target)
    strict_target = target
    
    # initialize weights with zeros
    weights = torch.zeros_like(losses).to(target.device).float()
    
    # we set the weights of voxels according to their class
    # the weight of voxels with class `UNLABELED` will remain zero
    weights[strict_target == 0] = class_weights[0]
    weights[strict_target == 1] = class_weights[1]
    
    # apply voxel_weights if not None
    if voxel_weights is not None:
        weights = weights * voxel_weights
    
    # weighted mean
    return (losses * weights).sum() / weights.sum()

# Define UNet


In [12]:
import torch
import torch.nn as nn
# from apex import amp
from torch.cuda import amp
import numpy as np
from tqdm import tqdm

class UNet3D(nn.Module):
    def __init__(self, input_dim, output_dim, initial_filters=16, depth=4, dropout=0):
        # input_dim: number of input channles. 1 in our case.
        # output dim: number of output channels. 3 in our case. (since we have 3 labels)
        # initial_filter: kernel_size of the first conv. usually 16 or 32.
        # depth: depth of the U in unet. `image_slices` should be divisible by (2 ** depth)
        # dropout: the possibility of zeroing a node out.
        
        super(UNet3D, self).__init__()
        
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.initial_filters = initial_filters
        self.depth = depth
        self.dropout = dropout
        activation = nn.ReLU(inplace=True)
        
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.pools = nn.ModuleList()
        self.trans = nn.ModuleList()
        self.dropouts = nn.ModuleList()
        
        # Down
        for i in range(depth):
            in_dim = self.input_dim
            if i > 0:
                in_dim = self.initial_filters * 2 ** (i - 1)
            out_dim = self.initial_filters * 2 ** i
            self.downs.append(self.double_conv_block(in_dim, out_dim, activation))
            self.pools.append(self.max_pool())
            
        
        # Bridge
        self.bridge = self.double_conv_block(
            self.initial_filters * 2 ** (depth - 1),
            self.initial_filters * 2 ** depth,
            activation
        )
        
        # Dropout
        self.dropouts.append(self.dropout_layer())
        
        # Up
        for i in range(depth):
            trans_in_out_dim = self.initial_filters * 2 ** (depth - i)
            self.trans.append(self.conv_transpose_block(trans_in_out_dim, trans_in_out_dim, activation))
            
            up_in_dim = self.initial_filters * (2 ** (depth - i) + 2 ** (depth - i - 1))
            up_out_dim = self.initial_filters * 2 ** (depth - i - 1)
            self.ups.append(self.double_conv_block(up_in_dim, up_out_dim, activation))
            
        # Dropout
        self.dropouts.append(self.dropout_layer())
        
        # Output
        self.out = self.prediction_mask(initial_filters, self.output_dim)
        
        self.initialize_parameters()
        
        
    @staticmethod
    def weight_init(module, method, **kwargs):
        if isinstance(module, (nn.Conv3d, nn.Conv2d, nn.ConvTranspose3d, nn.ConvTranspose2d)):
            method(module.weight, **kwargs)  # weights

    @staticmethod
    def bias_init(module, method, **kwargs):
        if isinstance(module, (nn.Conv3d, nn.Conv2d, nn.ConvTranspose3d, nn.ConvTranspose2d)):
            method(module.bias, **kwargs)  # bias

    def initialize_parameters(self,
                              method_weights=nn.init.xavier_uniform_,
                              method_bias=nn.init.zeros_,
                              kwargs_weights={},
                              kwargs_bias={}
                              ):
        for module in self.modules():
            self.weight_init(module, method_weights, **kwargs_weights)  # initialize weights
            self.bias_init(module, method_bias, **kwargs_bias)  # initialize bias
        
    def single_conv_block(self, input_dim, output_dim, activation):
        return nn.Sequential(
            nn.Conv3d(input_dim, output_dim, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm3d(output_dim),
            activation
        )

    def double_conv_block(self, input_dim, output_dim, activation):
        return nn.Sequential(
            self.single_conv_block(input_dim, output_dim, activation),
            #nn.Dropout(p=self.dropout),
            self.single_conv_block(output_dim, output_dim, activation)
        )

    def conv_transpose_block(self, input_dim, output_dim, activation):
        return nn.Sequential(
            nn.ConvTranspose3d(input_dim, output_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm3d(output_dim),
            activation,
            #nn.Dropout(p=self.dropout)
        )

    def prediction_mask(self, input_dim, output_dim):
        # the activation is considered to in the loss.
        return nn.Conv3d(input_dim, output_dim, kernel_size=1, stride=1, padding=0) 

    def max_pool(self):
        return nn.MaxPool3d(kernel_size=2, stride=2, padding=0) 
    
    def dropout_layer(self):
        return nn.Dropout(p=self.dropout)
        
    def forward(self, x):
        downs = []
        ups = []
        pools = []
        trans = []
        concats = []
        dropouts = []
        
        # Down
        for i in range(self.depth):
            inp = x
            if i > 0:
                inp = pools[-1]
            downs.append(self.downs[i](inp))
            pools.append(self.pools[i](downs[-1]))
        
        # Bridge
        bridge = self.bridge(pools[-1])
        
        # Dropout
        dropouts.append(self.dropouts[0](bridge))
        
        # Up
        for i in range(self.depth):
            inp = dropouts[-1]
            if i > 0:
                inp = ups[-1]
            trans.append(self.trans[i](inp))
            concats.append(torch.cat([trans[-1], downs[self.depth - i - 1]], dim=1))
            ups.append(self.ups[i](concats[-1]))
        
        # Dropout
        dropouts.append(self.dropouts[1](ups[-1]))
        
        # Output
        out = self.out(dropouts[-1])
        return out


# Define Trainer

In [13]:
class Trainer:
    def __init__(self,
                 model: torch.nn.Module,
                 device: torch.device,
                 criterion: torch.nn.Module,
                 optimizer: torch.optim.Optimizer,
                 training_DataLoader: torch.utils.data.Dataset,
                 validation_DataLoader: torch.utils.data.Dataset = None,
                 lr_scheduler: torch.optim.lr_scheduler = None,
                 epochs: int = 100,
                 epoch: int = 0,
                 notebook: bool = False
                 ):

        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.training_DataLoader = training_DataLoader
        self.validation_DataLoader = validation_DataLoader
        self.device = device
        self.epochs = epochs
        self.epoch = epoch
        self.notebook = notebook
        

        self.training_loss = []
        self.validation_loss = []
        self.learning_rate = []
        
        self.validation_dices_c0 = []
        self.validation_dices_c1 = []
        self.validation_dices_c2 = []
        self.validation_dices_c3 = []


        self.train_dices_c0 = []
        self.train_dices_c1 = []
        self.train_dices_c2 = []
        self.train_dices_c3 = []
        
        cols = ['epoch', 't_dice_0', 't_dice_1', 't_dice_2', 't_dice_3', 'v_dice_0', 'v_dice_1', 'v_dice_2', 'v_dice_3', 't_loss', 'v_loss']
        self.log = pd.DataFrame(columns=cols)
        
        self.log_file_name = 'log_initial_filter_8.csv'
        self.model_name = 'unet_3d_initial_filter_8.pt'

#         self.class_weights = torch.Tensor(training_DataLoader.dataset.get_class_weights()).cuda(0)
    


    def run_trainer(self):

        if self.notebook:
            from tqdm.notebook import tqdm, trange
        else:
            from tqdm import tqdm, trange

        progressbar = trange(self.epochs, desc='Progress')
        for i in progressbar:
            """Epoch counter"""
            self.epoch += 1  # epoch counter

            """Training block"""
            self._train()

            """Validation block"""
            if self.validation_DataLoader is not None:
                self._validate()

            """Learning rate scheduler block"""
            if self.lr_scheduler is not None:
                if self.validation_DataLoader is not None and self.lr_scheduler.__class__.__name__ == 'ReduceLROnPlateau':
                    self.lr_scheduler.batch(self.validation_loss[i])  # learning rate scheduler step with validation loss
                else:
                    self.lr_scheduler.batch()  # learning rate scheduler step
            
            print(f"epoch number {self.epoch} finished with following results:")
            print('\ttrain dice c0', self.train_dices_c0[i], '\ttrain dice c1', self.train_dices_c1[i])
            print('\ttrain dice c2', self.train_dices_c2[i], '\ttrain dice c3', self.train_dices_c3[i])
            print('\tvalidation dice c0', self.validation_dices_c0[i], '\tvalidation dice c1', self.validation_dices_c1[i])
            print('\tvalidation dice c2', self.validation_dices_c2[i], '\tvalidation dice c4', self.validation_dices_c3[i])
            print('\ttrain loss ', self.training_loss[i], '\tvalidation loss ', self.validation_loss[i])
            

            log_dict = {'epoch' : self.epoch,
                        't_dice_0': self.train_dices_c0[i], 't_dice_1': self.train_dices_c1[i],
                        't_dice_2': self.train_dices_c2[i], 't_dice_3': self.train_dices_c3[i],
                        'v_dice_0':self.validation_dices_c0[i], 'v_dice_1':self.validation_dices_c1[i], 
                        'v_dice_2':self.validation_dices_c2[i], 'v_dice_3':self.validation_dices_c3[i], 
                        't_loss': self.training_loss[i], 'v_loss':self.validation_loss[i] }
            
            self.log = self.log.append(log_dict, ignore_index=True)

            if i % 10 == 0 and i > 0:
                torch.save(self.model, dataDir + self.model_name)
                self.log.to_csv(self.log_file_name)
                # self.dice_0_per_slice_log.to_csv('dice_0_per_slice_log.csv')
        
        
        self.log.to_csv(self.log_file_name)
        return self.training_loss, self.validation_loss, self.validation_dices_c0

    def _train(self):

        if self.notebook:
            from tqdm.notebook import tqdm, trange
        else:
            from tqdm import tqdm, trange

        self.model.train()  # train mode
        train_losses = []  # accumulate the losses here

        t_dices_c0 = []
        t_dices_c1 = []
        t_dices_c2 = []
        t_dices_c3 = []
        
        batch_iter = tqdm(enumerate(self.training_DataLoader), 'Training', total=len(self.training_DataLoader),
                          leave=False)

        for i, (x, y) in batch_iter:
#             print('shape of input:', x.shape)
#             print('shape of output:', y.shape)
            input, target = x.to(0), y.to(0)  # send to device (GPU or CPU)
            self.optimizer.zero_grad()  # zerograd the parameters
            out = self.model(input)  # one forward pass
            # out = torch.argmax(out, dim=1)
            # out = torch.unsqueeze(out, 1)
            # print('**** shape of output:', out.shape)
            # print('**** shape of target:', target.shape)
#             loss = soft_binray_cross_entropy( 
#                     out,
#                     target, # target
#                     class_weights=self.class_weights,
#                     # voxel_weights=VW.cuda(self.gpu)
#                 )  # calculate loss
#             loss = Variable(loss.data, requires_grad=True)
            loss = criterion(out, target)
            loss_value = loss.item()
            train_losses.append(loss_value)
            loss.backward()  # one backward pass
            self.optimizer.step()  # update the parameters
            
#             t_dices_c0.append(dice_coef(out, target).item())
            t_dices_c0.append(dice_score(out, target,0))
            t_dices_c1.append(dice_score(out, target,1))
            t_dices_c2.append(dice_score(out, target,2))
            t_dices_c3.append(dice_score(out, target,3))
                              
            batch_iter.set_description(f'Training: (loss {loss_value:.4f})')  # update progressbar

        self.training_loss.append(np.mean(train_losses))

        self.train_dices_c0.append(np.mean(t_dices_c0))
        self.train_dices_c1.append(np.mean(t_dices_c1))
        self.train_dices_c2.append(np.mean(t_dices_c2))
        self.train_dices_c3.append(np.mean(t_dices_c3))

        self.learning_rate.append(self.optimizer.param_groups[0]['lr'])

        batch_iter.close()

    def _validate(self):

        if self.notebook:
            from tqdm.notebook import tqdm, trange
        else:
            from tqdm import tqdm, trange

        self.model.eval()  # evaluation mode
        valid_losses = []  # accumulate the losses here
        valid_dices_c0 = []
        valid_dices_c1 = []
        valid_dices_c2 = []
        valid_dices_c3 = []
        
        batch_iter = tqdm(enumerate(self.validation_DataLoader), 'Validation', total=len(self.validation_DataLoader),
                          leave=False)

        for i, (x, y) in batch_iter:
            input, target = x.to(self.device), y.to(self.device)  # send to device (GPU or CPU)

            with torch.no_grad():
                out = self.model(input)
                # out = torch.argmax(out, dim=1)
                # out = torch.unsqueeze(out, 1)
#                 loss = soft_binray_cross_entropy( 
#                     out,
#                     target, # target
#                     class_weights=self.class_weights,
#                     # voxel_weights=VW.cuda(self.gpu)
#                 )  # calculate loss
                loss = criterion(out, target)
                loss_value = loss.item()
                valid_losses.append(loss_value)
                
    
#                 valid_dices_c0.append(dice_coef(out, target).item())
                valid_dices_c0.append(dice_score(out, target,0))
                valid_dices_c1.append(dice_score(out, target,1))
                valid_dices_c2.append(dice_score(out, target,2))
                valid_dices_c3.append(dice_score(out, target,3))


                batch_iter.set_description(f'Validation: (loss {loss_value:.4f})')

        self.validation_loss.append(np.mean(valid_losses))
        
        self.validation_dices_c0.append(np.mean(valid_dices_c0))
        self.validation_dices_c1.append(np.mean(valid_dices_c1))
        self.validation_dices_c2.append(np.mean(valid_dices_c2))
        self.validation_dices_c3.append(np.mean(valid_dices_c3))
        
        batch_iter.close()

# Train

In [14]:
model = UNet3D(input_dim=1,
             output_dim=4,
             depth=4, dropout=0, initial_filters=8).to(device)

# model = torch.load(dataDir + 'unet_3d_4class_dropout.pt')

In [15]:
p = dataDir + 'Validation/A5C2D2/A5C2D2_sa_gt.nii.gz'
a = nib.load(p).get_fdata()
nSamples = [np.sum(a == 0), np.sum(a == 1), np.sum(a == 2), np.sum(a == 3)]
print(nSamples, np.unique(a))
weights = [1 - (x / sum(nSamples)) for x in nSamples]

#############
weights = [0.1, 0.3, 0.3 , 0.3]
#############

weights = torch.FloatTensor(weights).to(device)
weights

[20711519, 8830, 6467, 9184] [0. 1. 2. 3.]


tensor([0.1000, 0.3000, 0.3000, 0.3000], device='cuda:0')

In [None]:
# criterion
# criterion = torch.nn.BCEWithLogitsLoss()
criterion = torch.nn.CrossEntropyLoss(weight=weights).to(device)
# criterion = CostumLoss().to(device)

# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# trainer
trainer = Trainer(model=model,
                  device=device,
                  criterion=criterion,
                  optimizer=optimizer,
                  training_DataLoader=train_loader,
                  validation_DataLoader=val_loader,
                  lr_scheduler=None,
                  epochs=num_epochs,
                  epoch=0, #################todo check itttttt
                  notebook=True)
# start training
training_losses, validation_losses, val_dices_0 = trainer.run_trainer()

HBox(children=(FloatProgress(value=0.0, description='Progress', max=80.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 1 finished with following results:
	train dice c0 0.7870184675852457 	train dice c1 0.04027067912121614
	train dice c2 0.016635506693273783 	train dice c3 0.018812849985394373
	validation dice c0 0.8797599226236343 	validation dice c1 0.029332368168979883
	validation dice c2 0.014486528700217605 	validation dice c4 0.01737809064798057
	train loss  1.109695833524068 	validation loss  0.9909138083457947


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 2 finished with following results:
	train dice c0 0.9255771104494731 	train dice c1 0.09157377464075883
	train dice c2 0.019935716409236193 	train dice c3 0.0227933563098001
	validation dice c0 0.9439001977443695 	validation dice c1 0.04463878087699413
	validation dice c2 0.011079032206907868 	validation dice c4 0.0255808113142848
	train loss  0.9460289080937704 	validation loss  0.901009812951088


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 3 finished with following results:
	train dice c0 0.9611104130744934 	train dice c1 0.1687760828435421
	train dice c2 0.025585389366994303 	train dice c3 0.04741551437046534
	validation dice c0 0.964658334851265 	validation dice c1 0.08569386787712574
	validation dice c2 0.010496286675333977 	validation dice c4 0.006034955644281581
	train loss  0.8573711363474528 	validation loss  0.8226979821920395


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 4 finished with following results:
	train dice c0 0.9762494587898254 	train dice c1 0.25094528347253797
	train dice c2 0.02937782025585572 	train dice c3 0.06060693428541223
	validation dice c0 0.9807600677013397 	validation dice c1 0.13674421794712543
	validation dice c2 0.018351634964346886 	validation dice c4 0.0281780909281224
	train loss  0.7909470637639363 	validation loss  0.7469772547483444


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 5 finished with following results:
	train dice c0 0.98638436794281 	train dice c1 0.3543785709142685
	train dice c2 0.07309235220154127 	train dice c3 0.048255409048482155
	validation dice c0 0.9892620742321014 	validation dice c1 0.24159176275134087
	validation dice c2 0.052206539548933506 	validation dice c4 0.02870829962193966
	train loss  0.7243493970235189 	validation loss  0.6949487328529358


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 7 finished with following results:
	train dice c0 0.9926735870043437 	train dice c1 0.5103318870067597
	train dice c2 0.128379336198171 	train dice c3 0.05881605253865321
	validation dice c0 0.9918368607759476 	validation dice c1 0.3393283262848854
	validation dice c2 0.09426957927644253 	validation dice c4 0.055736192502081394
	train loss  0.6192969783147176 	validation loss  0.6197923868894577


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 8 finished with following results:
	train dice c0 0.9933757535616556 	train dice c1 0.5258812077840169
	train dice c2 0.14720539182424544 	train dice c3 0.06682904841999213
	validation dice c0 0.9941049367189407 	validation dice c1 0.3503505662083626
	validation dice c2 0.07664915733039379 	validation dice c4 0.045902212616056204
	train loss  0.5823039388656617 	validation loss  0.5603630244731903


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 9 finished with following results:
	train dice c0 0.9938343079884847 	train dice c1 0.5424126489957174
	train dice c2 0.1736303237080574 	train dice c3 0.0720168050006032
	validation dice c0 0.9929644614458084 	validation dice c1 0.37653395533561707
	validation dice c2 0.0788955525495112 	validation dice c4 0.045093526132404804
	train loss  0.548238860766093 	validation loss  0.5356578677892685


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 10 finished with following results:
	train dice c0 0.9939759047826131 	train dice c1 0.5579040531317393
	train dice c2 0.2029946442445119 	train dice c3 0.07686553651622186
	validation dice c0 0.994631826877594 	validation dice c1 0.36610667407512665
	validation dice c2 0.13875185698270798 	validation dice c4 0.050282849464565516
	train loss  0.5177715849876404 	validation loss  0.5087489560246468


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 11 finished with following results:
	train dice c0 0.9940216628710429 	train dice c1 0.5730298062165579
	train dice c2 0.2676149464646975 	train dice c3 0.07492679469908277
	validation dice c0 0.9944742769002914 	validation dice c1 0.39408012852072716
	validation dice c2 0.2373376041650772 	validation dice c4 0.03820742230163887
	train loss  0.489989842971166 	validation loss  0.48205312341451645


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 12 finished with following results:
	train dice c0 0.9941290807723999 	train dice c1 0.5895186698436737
	train dice c2 0.3617490797241529 	train dice c3 0.11244191761439046
	validation dice c0 0.9941563755273819 	validation dice c1 0.3745235688984394
	validation dice c2 0.19251557812094688 	validation dice c4 0.08953182946424931
	train loss  0.4613295110066732 	validation loss  0.4573062136769295


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 13 finished with following results:
	train dice c0 0.9943488256136577 	train dice c1 0.6202277545134226
	train dice c2 0.4388219992319743 	train dice c3 0.2940850932523608
	validation dice c0 0.9944931119680405 	validation dice c1 0.4278876893222332
	validation dice c2 0.21912332996726036 	validation dice c4 0.2963857725262642
	train loss  0.434276917775472 	validation loss  0.4276108220219612


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 14 finished with following results:
	train dice c0 0.9945560765266418 	train dice c1 0.6353644625345866
	train dice c2 0.46168199261029563 	train dice c3 0.4259170127908389
	validation dice c0 0.9945053905248642 	validation dice c1 0.4178594499826431
	validation dice c2 0.26196854189038277 	validation dice c4 0.16575562980142422
	train loss  0.41019771416982015 	validation loss  0.4110295996069908


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 15 finished with following results:
	train dice c0 0.9947814361254375 	train dice c1 0.6566304429372152
	train dice c2 0.49608655969301857 	train dice c3 0.5082036489248276
	validation dice c0 0.9946064352989197 	validation dice c1 0.39137761294841766
	validation dice c2 0.2644750662147999 	validation dice c4 0.24767822213470936
	train loss  0.3866536871592204 	validation loss  0.3841656967997551


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 16 finished with following results:
	train dice c0 0.9949384641647339 	train dice c1 0.6708981823921204
	train dice c2 0.5119982413450876 	train dice c3 0.5365508975585301
	validation dice c0 0.9931674599647522 	validation dice c1 0.382291242480278
	validation dice c2 0.21660156548023224 	validation dice c4 0.2542428640444996
	train loss  0.36452065348625184 	validation loss  0.38412580639123917


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 17 finished with following results:
	train dice c0 0.9951773365338643 	train dice c1 0.6783260186513265
	train dice c2 0.5356305360794067 	train dice c3 0.6079465083281199
	validation dice c0 0.993145763874054 	validation dice c1 0.34863342344760895
	validation dice c2 0.26636506989598274 	validation dice c4 0.3518775701522827
	train loss  0.3440809881687164 	validation loss  0.3657902926206589


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 18 finished with following results:
	train dice c0 0.995262439250946 	train dice c1 0.6899120807647705
	train dice c2 0.5459750807285308 	train dice c3 0.6420532651742299
	validation dice c0 0.9940585941076279 	validation dice c1 0.3854065537452698
	validation dice c2 0.2947377786040306 	validation dice c4 0.2695274204015732
	train loss  0.32528167406717934 	validation loss  0.3407963141798973


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 19 finished with following results:
	train dice c0 0.9953812185923259 	train dice c1 0.6963885219891867
	train dice c2 0.5485183628400166 	train dice c3 0.6456981120506923
	validation dice c0 0.9942955076694489 	validation dice c1 0.3816524110734463
	validation dice c2 0.29008013382554054 	validation dice c4 0.2846896652918076
	train loss  0.30829193512598674 	validation loss  0.32255133241415024


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 20 finished with following results:
	train dice c0 0.9955662147204081 	train dice c1 0.7025220553080241
	train dice c2 0.5548663751284282 	train dice c3 0.6482149012883505
	validation dice c0 0.9936392158269882 	validation dice c1 0.39990154653787613
	validation dice c2 0.255139771848917 	validation dice c4 0.3214284721761942
	train loss  0.2916911244392395 	validation loss  0.31672707945108414


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 21 finished with following results:
	train dice c0 0.9957323789596557 	train dice c1 0.7138168557484945
	train dice c2 0.5758918730417887 	train dice c3 0.6781519385178884
	validation dice c0 0.9943694621324539 	validation dice c1 0.4060436710715294
	validation dice c2 0.3038521409034729 	validation dice c4 0.32397129700984806
	train loss  0.2752937177817027 	validation loss  0.29555704444646835


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 22 finished with following results:
	train dice c0 0.9957806658744812 	train dice c1 0.7219712654749553
	train dice c2 0.5882141558329265 	train dice c3 0.70160497824351
	validation dice c0 0.9949767142534256 	validation dice c1 0.4455854222178459
	validation dice c2 0.31743136048316956 	validation dice c4 0.2683595269918442
	train loss  0.26059388935565947 	validation loss  0.2769482880830765


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 23 finished with following results:
	train dice c0 0.9958822464942932 	train dice c1 0.7207971453666687
	train dice c2 0.5818714606761932 	train dice c3 0.6884509364763896
	validation dice c0 0.9949154853820801 	validation dice c1 0.4332287386059761
	validation dice c2 0.3008887991309166 	validation dice c4 0.2957347743213177
	train loss  0.24855977257092793 	validation loss  0.26304956153035164


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 24 finished with following results:
	train dice c0 0.9960452651977539 	train dice c1 0.7312789209683737
	train dice c2 0.5962648479143778 	train dice c3 0.7263921479384104
	validation dice c0 0.9951745718717575 	validation dice c1 0.4340420179069042
	validation dice c2 0.3099249079823494 	validation dice c4 0.3361584199592471
	train loss  0.23442568798859914 	validation loss  0.25176000967621803


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 25 finished with following results:
	train dice c0 0.9962505308787027 	train dice c1 0.7393546326955159
	train dice c2 0.6153311562538147 	train dice c3 0.7405756576855977
	validation dice c0 0.9949050545692444 	validation dice c1 0.4042827934026718
	validation dice c2 0.3212167099118233 	validation dice c4 0.3534872345626354
	train loss  0.22142780860265096 	validation loss  0.2444426566362381


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 26 finished with following results:
	train dice c0 0.9964599792162577 	train dice c1 0.7486987177530925
	train dice c2 0.626291795571645 	train dice c3 0.7497127803166708
	validation dice c0 0.9952020347118378 	validation dice c1 0.4262388199567795
	validation dice c2 0.3461432009935379 	validation dice c4 0.35531216114759445
	train loss  0.20960562348365783 	validation loss  0.23102326691150665


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 27 finished with following results:
	train dice c0 0.9965140509605408 	train dice c1 0.75052814245224
	train dice c2 0.6264972400665283 	train dice c3 0.7566410350799561
	validation dice c0 0.9954268336296082 	validation dice c1 0.4297341853380203
	validation dice c2 0.326394222676754 	validation dice c4 0.364594042301178
	train loss  0.19962095896402995 	validation loss  0.22353874146938324


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 28 finished with following results:
	train dice c0 0.9966486461957296 	train dice c1 0.7521719837188721
	train dice c2 0.6350902438163757 	train dice c3 0.7658266631762186
	validation dice c0 0.9951543807983398 	validation dice c1 0.43420542776584625
	validation dice c2 0.3334128186106682 	validation dice c4 0.33719434986414853
	train loss  0.18961629370848337 	validation loss  0.21819305047392845


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 29 finished with following results:
	train dice c0 0.9966757782300313 	train dice c1 0.7557528003056844
	train dice c2 0.6358238991101582 	train dice c3 0.7620629811286926
	validation dice c0 0.994935154914856 	validation dice c1 0.3907152786850929
	validation dice c2 0.2979198358952999 	validation dice c4 0.351827759295702
	train loss  0.18091748813788097 	validation loss  0.21622388437390327


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 30 finished with following results:
	train dice c0 0.9967383193969727 	train dice c1 0.7578481086095175
	train dice c2 0.6458534487088521 	train dice c3 0.7663489945729574
	validation dice c0 0.9953247457742691 	validation dice c1 0.4519655406475067
	validation dice c2 0.3206322081387043 	validation dice c4 0.3296353258192539
	train loss  0.1725973435242971 	validation loss  0.2036532275378704


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 31 finished with following results:
	train dice c0 0.996857176621755 	train dice c1 0.7586306039492289
	train dice c2 0.6445671188831329 	train dice c3 0.7700876303513845
	validation dice c0 0.9950649589300156 	validation dice c1 0.45009778439998627
	validation dice c2 0.3154607079923153 	validation dice c4 0.3383576311171055
	train loss  0.16458440641562144 	validation loss  0.1988922879099846


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 32 finished with following results:
	train dice c0 0.9970079747835795 	train dice c1 0.7647480177879333
	train dice c2 0.6517540621757507 	train dice c3 0.7922780148188273
	validation dice c0 0.9953616410493851 	validation dice c1 0.42625585943460464
	validation dice c2 0.2922193296253681 	validation dice c4 0.27227075956761837
	train loss  0.156199479897817 	validation loss  0.1938903108239174


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 33 finished with following results:
	train dice c0 0.9970432758331299 	train dice c1 0.7705736327171325
	train dice c2 0.6584371888637542 	train dice c3 0.7956838822364807
	validation dice c0 0.9953265190124512 	validation dice c1 0.4484763443470001
	validation dice c2 0.3333531394600868 	validation dice c4 0.3417927846312523
	train loss  0.149524267911911 	validation loss  0.18167676776647568


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 34 finished with following results:
	train dice c0 0.9970851691563924 	train dice c1 0.7719523564974466
	train dice c2 0.6621103088061014 	train dice c3 0.8132900794347128
	validation dice c0 0.9942860454320908 	validation dice c1 0.460900716483593
	validation dice c2 0.3007280007004738 	validation dice c4 0.29258930310606956
	train loss  0.14278920928637187 	validation loss  0.18296918272972107


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 35 finished with following results:
	train dice c0 0.9970975240071615 	train dice c1 0.7690968068440756
	train dice c2 0.6597998483975729 	train dice c3 0.8016705234845479
	validation dice c0 0.9944058507680893 	validation dice c1 0.4411824494600296
	validation dice c2 0.29920122027397156 	validation dice c4 0.33600951358675957
	train loss  0.1371776129802068 	validation loss  0.17726266756653786


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 36 finished with following results:
	train dice c0 0.9971655813852945 	train dice c1 0.7703967920939128
	train dice c2 0.6682984153429667 	train dice c3 0.824314292271932
	validation dice c0 0.9950936734676361 	validation dice c1 0.4433388337492943
	validation dice c2 0.31752726435661316 	validation dice c4 0.3285215198993683
	train loss  0.13096978306770324 	validation loss  0.17224489152431488


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 37 finished with following results:
	train dice c0 0.9972726066907247 	train dice c1 0.7793884038925171
	train dice c2 0.6696864104270935 	train dice c3 0.8253774627049764
	validation dice c0 0.9949050396680832 	validation dice c1 0.4587501659989357
	validation dice c2 0.32640910893678665 	validation dice c4 0.3320278115570545
	train loss  0.1255198100209236 	validation loss  0.16463062725961208


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 38 finished with following results:
	train dice c0 0.9971808981895447 	train dice c1 0.7767703986167908
	train dice c2 0.6690380318959555 	train dice c3 0.8010048508644104
	validation dice c0 0.9949678033590317 	validation dice c1 0.47201429307460785
	validation dice c2 0.2952220179140568 	validation dice c4 0.38317935168743134
	train loss  0.12098688105742136 	validation loss  0.16016997210681438


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Validation', max=4.0, style=ProgressStyle(description_wid…

epoch number 39 finished with following results:
	train dice c0 0.9974446972211202 	train dice c1 0.7860892566045126
	train dice c2 0.6835036253929139 	train dice c3 0.8287782057126363
	validation dice c0 0.9949727952480316 	validation dice c1 0.4468256086111069
	validation dice c2 0.3126034662127495 	validation dice c4 0.3976138308644295
	train loss  0.11480402568976085 	validation loss  0.15811968222260475


HBox(children=(FloatProgress(value=0.0, description='Training', max=75.0, style=ProgressStyle(description_widt…

In [None]:
torch.save(model, dataDir + 'unet_3d_initial_filter_8.pt')

In [None]:
model = torch.load(dataDir + 'unet_3d_initial_filter_8.pt')

In [None]:
print(model)

# Loss Plots

In [None]:
log = pd.read_csv('log_4class_dropout.csv')
log

In [None]:
x_axis = range(0, len(log))
fig, ax = plt.subplots()
ax.plot(x_axis, log['t_loss'], label='Train')
ax.plot(x_axis, log['v_loss'], label='Validation')
ax.legend()
plt.ylabel('Loss')
plt.show()

In [None]:
from torchsummary import summary

summary(model, (1, 16, 224, 224))

# Show Results

In [None]:
from tqdm.notebook import tqdm, trange



val_dir = dataDir + '/Validation/'
val_all = os.listdir(val_dir)
val_dataset = SegDataset(val_all, metadata, val_dir)

val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)



model.eval()  # evaluation mode
batch_iter = tqdm(enumerate(val_loader), 'Validation', total=len(val_loader),
                    leave=False)

model_outputs = []
original_images = []
gts = []
dice_c0 = []
dice_c1 = []
dice_c2 = []
dice_c3 = []

for i, (x, y) in batch_iter:
    input, target = x.to(device), y.to(device)  # send to device (GPU or CPU)
    
    original_images.append(input.detach().clone())
    gts.append(target.detach().clone())
    
    with torch.no_grad():
        out = model(input)
        out = torch.sigmoid(out)
        out = torch.round(out)
        print(out.shape)
    
    out_soft = torch.argmax(out, dim=1) 
    model_outputs.append(out_soft)
#     model_outputs.append(out)

    
    dice_c0.append(dice_score(out.detach().clone(), target, 0))
    dice_c1.append(dice_score(out.detach().clone(), target, 1))
    dice_c2.append(dice_score(out.detach().clone(), target, 2))
    dice_c3.append(dice_score(out.detach().clone(), target, 3))


batch_iter.close()

print(np.mean(dice_c0), np.mean(dice_c1), np.mean(dice_c2), np.mean(dice_c3))

In [None]:
model_outputs = np.asarray(model_outputs)
original_images = np.asarray(original_images)
gts = np.asarray(gts)

In [None]:
model_outputs[0].shape, original_images[0].shape, gts[0].shape

In [None]:
plt.imshow(original_images[1][2][0][7].cpu(), cmap='gray')

In [None]:
plt.imshow(gts[1][2][7].cpu(), cmap='gray')

In [None]:
plt.imshow(model_outputs[1][2][7].cpu(), cmap='gray')

In [None]:
model_outputs[1].shape