### Imports

In [1]:
import torch 
import torchvision
import torch.nn as nn
from torch.utils.data import Dataset , DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np

import matplotlib.pyplot as plt
import os

from typing import List,Tuple,Callable,Dict

from utility_functions import *
from dataset_dataloader import *
from variable_unet import *
from augmentation_class import *
from train_utility_functions import *


import pickle

### Initialize

In [60]:

from argparse import Namespace

# Configurations
name = 'basline_sift'
args = {
    'root' : '',
    'samples':1,
    'log_root' : 'log',
    'ckpt_root' : 'checkpoints',
    'sift' : True,
    'batch_size' : 16,
    'crop_size' : 256,
    'epochs' : 100,
    'lr' : 1e-2,
    'device' : 'cuda:0',

    # U-Net configs
    'init_filters': 4,
    'block_count': 7,
    'upsample': False,
    'conv_block' : 'unet',

    # Loss function
    'dice_gam': 1,
    'th': 0.5,

    # Visualizer
    'display_interval': 100,

    # Checkpoint
    'save_interval': 5,
}

configs = Namespace(**args)



In [42]:
#### Dataloader
data_loader = heart_dataloader( root='train',
                                batch_size=configs.batch_size,
                                sift=configs.sift,
                                test=False)
    
train_dataloader,validation_dataloader= data_loader.train_dl , data_loader.val_dl

# check the loaded files paths
# print(data_loader.images_paths)
# print(data_loader.labels_paths)


  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

### Build model 

In [61]:

# Define some peripheral variables that are often used
epoch = configs.epochs

dice_gam = torch.tensor(configs.dice_gam, dtype=torch.float32, device=configs.device)


# Input transforms including: (preprocessing, data augmentation)
input_transform = PreprocessAug(device=configs.device, 
                                batch_size=configs.batch_size,
                                crop_size=configs.crop_size)


# Define U-Net
model = variable_unet_class(in_channels=1,
                            out_channels=1,
                            init_filters=configs.init_filters,
                            block_count=configs.block_count,
                            upsample=configs.upsample,
                            conv_block=configs.conv_block)


# Define loss

def dice_coeff(input, label):
    smooth = 1.
    iflat = input.contiguous().view(-1)
    lflat = label.contiguous().view(-1)
    intersection = (iflat * lflat).sum()
    
    return ((2. * intersection + smooth) / (iflat.sum() + lflat.sum() + smooth))

class dice_loss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input, label):
        dice_val = dice_coeff(input, label)
        return 1 - dice_val

BCE_LOSS = nn.BCEWithLogitsLoss()
DICE_LOSS = dice_loss()

sigmoid = nn.Sigmoid()

# Define optimizers
optimizer = optim.Adam(model.parameters(), lr=configs.lr)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, cooldown=1, min_lr=1e-8, eps=1e-08, verbose=True)

# Visualizer
writer = SummaryWriter(str(configs.log_root))

count_params = lambda x : sum(p.numel() for p in x.parameters() if p.requires_grad)

total_params = count_params(model)

print(f'UNET param count = {total_params}')

UNET param count = 7786325


### Train

In [62]:
# Define forward pass used for both training/validation
def forward_pass( input : torch.Tensor, 
                  label : torch.Tensor, 
                  model : torch.nn.Module,
                  dice_gam):
    
    # forward pass
    output = model(input).float()

    # calculate loss
    step_bceloss = BCE_LOSS(output, label)
    step_diceloss = DICE_LOSS(output,label)
    step_loss = step_bceloss
    step_metric = step_diceloss

    return output, step_loss, step_metric

def normalize_im(tensor: torch.Tensor)->torch.Tensor:
    
    large = torch.max(tensor).cpu().data
    small = torch.min(tensor).cpu().data
    diff = large - small
    
    normalized_tensor = (tensor.clamp(min=small, max=large) - small) * (torch.tensor(1) / diff)
    
    return normalized_tensor

def postprocess_im(tensor: torch.Tensor ,th: float = 0.5):
    """
    1. Applies sigmoid layer to U-Net output
    2. Thresholds to certain value. 
    i.e.) Values lower than th = 0 / Values higher than th = 1
    """
    
    tensor_th = tensor.clone()
    tensor_th = sigmoid(tensor_th)
    tensor_th[tensor_th <= th] = 0
    tensor_th[tensor_th > th] = 1
    
    return tensor_th


# logger
def log_epoch(epoch, epoch_loss, epoch_metrics, writer, elapsed_secs,lr ,training=True):
    
    mode = 'Training' if training else 'Validation'
    
    string = \
    f'''Epoch {epoch:03d} {mode}.\tloss: {epoch_loss[i-1]:.4e}.\tTime: {elapsed_secs // 60} min {elapsed_secs % 60} sec\tDice={epoch_metrics:.4e}\tLR ={lr:.3e}'''
    
    print(string)

# metric calculator
def get_step_metric(output, label):
    
    dice = dice_coeff(output, label)   
    
    return dice

In [63]:

train_losses = list()
val_losses = list()
train_metric = list()
val_metric = list()
learning_rate = list()

#move model to gpu
if configs.device.startswith('cuda') :
    model.cuda()



for i in range(1, epoch+1):
    
    '''
     Training
    '''
    

    epoch_train_loss    = list()
    epoch_train_metrics = list()
    tic_train = time()

    model.train()
    torch.autograd.set_grad_enabled(True)
    
    
    
    for step,(input, label) in enumerate(train_dataloader):
        
        print(f'epoch={i-1:003d}/{epoch:003d},\tstep={step:003d}',end='\r')
        
        input, label = input_transform(input.float(), label)
        
        # Flush gradients
        optimizer.zero_grad()

        # forward pass
        output, step_loss, step_metric = forward_pass(input, label, model, dice_gam)

        # backward pass
        step_loss.backward()
        optimizer.step()

        # Step metric, Aggregate to epoch_metrics
        output_th = postprocess_im(output, th=configs.th)
        step_metric = get_step_metric(output_th, label)

        # Aggregate to epoch_loss and epoch metric
        epoch_train_loss.append(step_loss.detach().cpu().numpy().astype('float32'))
        epoch_train_metrics.append(step_metric.detach().cpu().numpy().astype('float32'))

    
    # log epoch
    toc_train = int(time() - tic_train)
    epoch_train_loss = sum(epoch_train_loss)
    epoch_train_metrics = np.mean(epoch_train_metrics)
    train_losses.append(epoch_train_loss)
    train_metric.append(epoch_train_metrics)
    

    #-----------------------------------------------------------------------------#
    
    '''
    Validation
    '''
    
    epoch_val_loss    = list()
    epoch_val_metrics = list()

    tic_val = time()
    model.eval()
    torch.autograd.set_grad_enabled(False)

    for step,(input, label) in enumerate(validation_dataloader):
        
        print(f'epoch={i-1:003d}/{epoch:003d},\tstep={step:003d}',end='\r')  
        
        input, label = input_transform(input.float(), label, training=False)

        # forward pass
        output, step_loss, step_metric = forward_pass(input, label, model, dice_gam)
        
        
        # Step metric, Aggregate to epoch_metrics
        output_th = postprocess_im(output, th=configs.th)
        step_metric = get_step_metric(output_th, label) 
        
        # Aggregate to epoch_loss
        epoch_val_loss.append(step_loss.detach().cpu().numpy().astype('float32'))
        epoch_val_metrics.append(step_metric.detach().cpu().numpy().astype('float32'))
        

        
        
    # log epoch
    toc_val = int(time() - tic_val)
    epoch_val_loss = sum(epoch_val_loss)
    epoch_val_metrics = np.mean(epoch_val_metrics)
    
    val_losses.append(epoch_val_loss)
    val_metric.append(epoch_val_metrics)

    scheduler.step(epoch_val_loss)
    learning_rate.append(scheduler.state_dict()['_last_lr'][0])    
    
    # moving train logging here , since last lr is updated after validation step . otherwise it will throw an error
    log_epoch(epoch          = i, 
          epoch_loss     = train_losses, 
          epoch_metrics  = epoch_train_metrics,
          writer         = writer, 
          elapsed_secs   = toc_train, 
          lr             = learning_rate[-1],
          training       = True
         )
    
    log_epoch(epoch          = i, 
              epoch_loss     = val_losses, 
              epoch_metrics  = epoch_val_metrics,
              writer         = writer, 
              elapsed_secs   = toc_val, 
              lr             = learning_rate[-1],
              training       = False
         )
    
    
    

    
        
    print('-'*100,)
    
    


Epoch 001 Training.	loss: 1.9369e+01.	Time: 0 min 2 sec	Dice=1.4482e-03	LR =1.000e-02
Epoch 001 Validation.	loss: 8.4777e-01.	Time: 0 min 0 sec	Dice=9.1696e-05	LR =1.000e-02
----------------------------------------------------------------------------------------------------
Epoch 002 Training.	loss: 3.3675e+00.	Time: 0 min 2 sec	Dice=9.0189e-05	LR =1.000e-02
Epoch 002 Validation.	loss: 4.5858e-01.	Time: 0 min 0 sec	Dice=9.1696e-05	LR =1.000e-02
----------------------------------------------------------------------------------------------------
Epoch 003 Training.	loss: 2.0448e+00.	Time: 0 min 2 sec	Dice=8.9422e-05	LR =1.000e-02
Epoch 003 Validation.	loss: 2.9949e-01.	Time: 0 min 0 sec	Dice=9.1696e-05	LR =1.000e-02
----------------------------------------------------------------------------------------------------
Epoch 004 Training.	loss: 1.6939e+00.	Time: 0 min 2 sec	Dice=9.0327e-05	LR =1.000e-02
Epoch 004 Validation.	loss: 4.4753e-01.	Time: 0 min 0 sec	Dice=9.1696e-05	LR =1.000e-02
-

#### save info

In [64]:

name = f'block_count_{configs.block_count}_params_{total_params}'

save_path = os.path.join('models',name)

#save model weights
torch.save(model.state_dict(), f'{save_path}_model.tar')

#save training history
history = {
            'train_loss':np.array(train_losses),
            'val_loss': np.array(val_losses),
            'train_metric':np.array(train_metric),
            'val_metric':np.array(val_metric),
            'lr':np.array(learning_rate)
           }

with open(f'{save_path}_history.pickle', 'wb') as file:
    pickle.dump(history, file, protocol=pickle.HIGHEST_PROTOCOL)

### Testing

#### Get dice score

In [None]:
# dice_score = []

# test_dataset = SegDataset(root='test',sift=False)

# test_input_transform = PreprocessAug(device=0, batch_size=1,crop_size=320,full_size=320)
# loaded_model = variable_unet_class(in_channels=1,out_channels=1,init_filters=64,upsample=False)
# loaded_model.load_state_dict(torch.load('20_oct.tar'))
# torch.autograd.set_grad_enabled(False)
# loaded_model.eval()
# loaded_model.cuda();


# for test_image,test_label in test_dataset:
    
#     test_image_cuda,label_image_cuda = test_input_transform(test_image.float(),test_label.float(),training=False)
    
#     label_prediction = postprocess_im(loaded_model(test_image_cuda))

#     dice_score.append(dice_coeff(label_prediction,label_image_cuda).detach().cpu().numpy())
    
# print(f'Tested samples\t={len(dice_score)}\ndice score\t={np.mean(dice_score)}')

#### Visualize test results

In [None]:
# from ipywidgets import *
# imshow_4dtensor = lambda tensor : plt.imshow(tensor[0,0].cpu().numpy())
# test_dataset = SegDataset(root='test',sift=False)

# test_input_transform = PreprocessAug(device=0, batch_size=1,crop_size=320,full_size=320)
# loaded_model = variable_unet_class(in_channels=1,out_channels=1,init_filters=64,upsample=False)
# loaded_model.load_state_dict(torch.load('20_oct.tar'))

# loaded_model.cuda()

# @interact(index = IntSlider(min=0,max=len(test_dataset)-1,continuous_update=False,))
# def show_slice(index):

#     plt.figure(figsize=(15,15))
#     plt.subplot(1,3,1)
#     plt.title('Input image')
#     plt.imshow(test_dataset[index][0][0])
#     plt.subplot(1,3,2)
#     plt.title('True Mask')
#     plt.imshow(test_dataset[index][1][0])
#     plt.subplot(1,3,3)
    
#     plt.title('Predicted Mask')
    
#     test_image_cuda,label_image_cuda = test_input_transform(test_dataset[index][0].float(),test_dataset[index][1].float())
    
#     label_prediction = postprocess_im(loaded_model(test_image_cuda))
    
#     plt.imshow(label_prediction.detach().cpu().numpy()[0,0])