In [None]:
import torch
import torch.nn as nn 
import torch.optim as optim 
import torch.nn.functional as F

import seaborn as sn 
import numpy as np
from sklearn.metrics import confusion_matrix 

from torch.utils.data import TensorDataset 
from torch.utils.data import DataLoader 

from load_warwick import load_warwick 

from matplotlib import pyplot as plt 

In [None]:
# Loading data 
# Training data: 85 images of 128x128x3 
# Training labels: 85 images of 128x128 
# Testing data: 60 images of 128x128x3 
# Testing labels: 60 images of 128x128 
xtra, ytra, xtes, ytes = load_warwick() 

In [None]:
# Resetting model function
# Credits: https://discuss.pytorch.org/t/reset-model-weights/19180/4
def reset_model(model):
    for layer in model.children(): 
       if hasattr(layer, 'reset_parameters'): 
           layer.reset_parameters() 

In [None]:
# Sorensen-Dice coefficient
def dsc(A, B):
  
    a = A.bool()
    b = B.bool()
    
    intersect = torch.logical_and(a,b) 

    dice_coeff = (2.0 * intersect.sum()) / (a.sum() + b.sum()) 
     
    return dice_coeff 

In [None]:
# Preprocessing of train/test data by dividing with maximum intensity 
xtra = xtra / 255
ytra = ytra / 255

xtes = xtes / 255 
ytes = ytes / 255 

In [None]:
# Get some dimensions
pixels = 128 
ntrain = xtra.shape[0] 
ntest = xtes.shape[0]  

In [None]:
# Convert to Torch tensors
xtra_torch = torch.from_numpy(xtra).permute(0, 3, 1, 2).float()
ytra_torch = torch.from_numpy(ytra) 
xtes_torch = torch.from_numpy(xtes).permute(0, 3, 1, 2).float()
ytes_torch = torch.from_numpy(ytes)

In [None]:
# Check if CUDA is available
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", torch.cuda.get_device_name(device)) 

In [None]:
# Creating data loader for training data 
num_batch = 10 
training_set = TensorDataset(xtra_torch.to(device), ytra_torch.to(device))  
training_loader = DataLoader(training_set, shuffle = True, batch_size = num_batch) 

xtest = xtes_torch.to(device) 
ytest = ytes_torch.to(device) 

In [None]:
# Convolutional network
segnet = nn.Sequential(nn.Conv2d(in_channels = 3, out_channels = 8, kernel_size = (3,3), stride = 1, padding = 1), nn.ReLU(), nn.MaxPool2d(kernel_size = 2, stride = 2), 
                        nn.Conv2d(in_channels = 8, out_channels = 16, kernel_size = (3,3), stride = 1, padding = 1), nn.ReLU(), nn.MaxPool2d(kernel_size = 2, stride = 2), 
                        nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size = (3,3), stride = 1, padding = 1), nn.ReLU(), nn.MaxPool2d(kernel_size = 2, stride = 2),
                        nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = (3,3), stride = 1, padding = 1),  nn.ReLU(), 
                        nn.ConvTranspose2d(in_channels = 64, out_channels = 32, kernel_size = (4,4), stride = 2, padding = 1), 
                        nn.ConvTranspose2d(in_channels = 32, out_channels = 16, kernel_size = (4,4), stride = 2, padding = 1), 
                        nn.ConvTranspose2d(in_channels = 16, out_channels = 8, kernel_size = (4,4), stride = 2, padding = 1), 
                        nn.Conv2d(in_channels = 8, out_channels = 2, kernel_size = (1,1), stride = 1, padding = 0)) 

segnetbn = nn.Sequential(nn.Conv2d(in_channels = 3, out_channels = 8, kernel_size = (3,3), stride = 1, padding = 1), nn.BatchNorm2d(8), nn.ReLU(), nn.MaxPool2d(kernel_size = 2, stride = 2), 
                        nn.Conv2d(in_channels = 8, out_channels = 16, kernel_size = (3,3), stride = 1, padding = 1), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(kernel_size = 2, stride = 2), 
                        nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size = (3,3), stride = 1, padding = 1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(kernel_size = 2, stride = 2), 
                        nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = (3,3), stride = 1, padding = 1), nn.BatchNorm2d(64), nn.ReLU(), 
                        nn.ConvTranspose2d(in_channels = 64, out_channels = 32, kernel_size = (4,4), stride = 2, padding = 1), 
                        nn.ConvTranspose2d(in_channels = 32, out_channels = 16, kernel_size = (4,4), stride = 2, padding = 1), 
                        nn.ConvTranspose2d(in_channels = 16, out_channels = 8, kernel_size = (4,4), stride = 2, padding = 1), 
                        nn.Conv2d(in_channels = 8, out_channels = 2, kernel_size = (1,1), stride = 1, padding = 0), nn.Sigmoid()) 

m_seg_selu = nn.Sequential(nn.Conv2d(in_channels = 3, out_channels = 8, kernel_size = (3,3), stride = 1, padding = 1), nn.SELU(), nn.MaxPool2d(kernel_size = 2, stride = 2), 
                        nn.Conv2d(in_channels = 8, out_channels = 16, kernel_size = (3,3), stride = 1, padding = 1), nn.SELU(), nn.MaxPool2d(kernel_size = 2, stride = 2), 
                        nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size = (3,3), stride = 1, padding = 1), nn.SELU(), nn.MaxPool2d(kernel_size = 2, stride = 2), 
                        nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = (3,3), stride = 1, padding = 1), nn.SELU(),
                        nn.ConvTranspose2d(in_channels = 64, out_channels = 32, kernel_size = (4,4), stride = 2, padding = 1), 
                        nn.ConvTranspose2d(in_channels = 32, out_channels = 16, kernel_size = (4,4), stride = 2, padding = 1), 
                        nn.ConvTranspose2d(in_channels = 16, out_channels = 8, kernel_size = (4,4), stride = 2, padding = 1), 
                        nn.Conv2d(in_channels = 8, out_channels = 2, kernel_size = (1,1), stride = 1, padding = 0)) 

m_seg_k_big = nn.Sequential(nn.Conv2d(in_channels = 3, out_channels = 8, kernel_size = (5,5), stride = 1, padding = 2), nn.BatchNorm2d(8), nn.ReLU(), nn.MaxPool2d(kernel_size = 2, stride = 2), 
                        nn.Conv2d(in_channels = 8, out_channels = 16, kernel_size = (5,5), stride = 1, padding = 2), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(kernel_size = 2, stride = 2),
                        nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size = (5,5), stride = 1, padding = 2), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(kernel_size = 2, stride = 2),
                        nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = (5,5), stride = 1, padding = 2), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(kernel_size = 2, stride = 2),
                        nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = (5,5), stride = 1, padding = 2), nn.BatchNorm2d(128), nn.ReLU(), 
                        nn.ConvTranspose2d(in_channels = 128, out_channels = 64, kernel_size = (4,4), stride = 2, padding = 1), 
                        nn.ConvTranspose2d(in_channels = 64, out_channels = 32, kernel_size = (4,4), stride = 2, padding = 1), 
                        nn.ConvTranspose2d(in_channels = 32, out_channels = 16, kernel_size = (4,4), stride = 2, padding = 1), 
                        nn.ConvTranspose2d(in_channels = 16, out_channels = 8, kernel_size = (4,4), stride = 2, padding = 1), 
                        nn.Conv2d(in_channels = 8, out_channels = 2, kernel_size = (1,1), stride = 1, padding = 0)) 


m_seg_k_big_selu = nn.Sequential(nn.Conv2d(in_channels = 3, out_channels = 8, kernel_size = (5,5), stride = 1, padding = 2), nn.SELU(), nn.MaxPool2d(kernel_size = 2, stride = 2), 
                        nn.Conv2d(in_channels = 8, out_channels = 16, kernel_size = (5,5), stride = 1, padding = 2), nn.SELU(), nn.MaxPool2d(kernel_size = 2, stride = 2),
                        nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size = (5,5), stride = 1, padding = 2), nn.SELU(), nn.MaxPool2d(kernel_size = 2, stride = 2),
                        nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = (5,5), stride = 1, padding = 2), nn.SELU(), nn.MaxPool2d(kernel_size = 2, stride = 2),
                        nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = (5,5), stride = 1, padding = 2), nn.SELU(), 
                        nn.ConvTranspose2d(in_channels = 128, out_channels = 64, kernel_size = (4,4), stride = 2, padding = 1), 
                        nn.ConvTranspose2d(in_channels = 64, out_channels = 32, kernel_size = (4,4), stride = 2, padding = 1), 
                        nn.ConvTranspose2d(in_channels = 32, out_channels = 16, kernel_size = (4,4), stride = 2, padding = 1), 
                        nn.ConvTranspose2d(in_channels = 16, out_channels = 8, kernel_size = (4,4), stride = 2, padding = 1), 
                        nn.Conv2d(in_channels = 8, out_channels = 2, kernel_size = (1,1), stride = 1, padding = 0), nn.Sigmoid()) 

# Selecting model 
model = segnetbn 

# Resetting model when we switch architectures 
reset_model(model) 

# Moving model to GPU 
model.to(device) 

# Check dimensions through network 
# Credits: https://d2l.ai/chapter_convolutional-neural-networks/lenet.html
check_dims = False       
if check_dims == True: 
    X = torch.rand(size = (num_batch, 3, 128, 128)).to(device) 
    for layer in model: 
        X = layer(X) 
        print(layer.__class__.__name__, 'output shape:\t', X.shape) 

In [None]:
# Loss function(s)
bce = F.binary_cross_entropy 
bce_logits = F.binary_cross_entropy_with_logits 

it = 0 
num_epochs = 300 
l_rate = 0.001 # 0.001 works good 

sgd_opt = optim.SGD(model.parameters(), lr = l_rate, weight_decay = 0, momentum = 0.0) 

rmsp_opt = optim.RMSprop(model.parameters(), lr = l_rate, alpha = 0.99, eps = 1e-08, weight_decay = 0, momentum = 0.1, centered = False) 

adam_opt = optim.Adam(model.parameters(), lr = l_rate, betas = (0.9, 0.999), eps = 1e-08, weight_decay = 0.0, amsgrad = False) 

adad_opt = optim.Adadelta(model.parameters(), lr = 1.0, rho = 0.9, eps = 1e-06, weight_decay = 0) 
 
adag_opt = optim.Adagrad(model.parameters(), lr = 0.01, lr_decay = 0, weight_decay = 0, initial_accumulator_value = 0, eps = 1e-10) 

adamax_opt = optim.Adamax(model.parameters(), lr = 0.002, betas = (0.9, 0.999), eps = 1e-08, weight_decay = 0) 

asgd_opt = optim.ASGD(model.parameters(), lr = 0.01, lambd = 0.0001, alpha = 0.75, t0 = 1000000.0, weight_decay = 0) 

#sched = torch.optim.lr_scheduler.ExponentialLR(optr, gamma, last_epoch = -1, verbose = False) 

In [None]:
optr = adam_opt 
dscore = [] 
bcetrain = [] 
bcetest = [] 

for epoch in range(num_epochs): 
    model.train() 
    for xbatch, ybatch in training_loader: 
        optr.zero_grad() 
        
        prediction = model(xbatch) 
        pred = prediction.mean(1) # float32, ybatch float64 
        #pred = prediction.amax(1) 
        
        bce_l = bce_logits(pred, ybatch) 
        bcetrain.append(bce_l.item()) 
        
        bce_l.backward() 
        optr.step() 
        #sched.step() 
        
        model.eval() 
        with torch.no_grad(): 
            testpreds = model(xtest) 
            testpred = testpreds.mean(1) 
            #estpred = testpreds.amax(1) 
            
            bce_test = bce_logits(testpred, ytest) 
            bcetest.append(bce_test.item()) # ytest is float64, testpred is float32
        
            dice_score = dsc(testpred.double(), ytest) 
            dscore.append(dice_score.item()) 
                        
            if epoch == (num_epochs - 1): 
                te_pred = testpred 
                test_batch = ytest 
                
                tra_pred = pred 
                train_batch = ybatch 
        
        it += 1 
    print("Epoch %s/%s" % (epoch + 1, num_epochs)) 

In [None]:
plt.figure(1) 
train, = plt.plot(bcetrain, 'r') 
test, = plt.plot(bcetest, 'b') 
plt.xlabel("Iteration") 
plt.ylabel("Loss") 
plt.title("Training and testing loss") 
plt.legend([train, test], ['Train loss', 'Test loss']) 
plt.annotate("Final train loss: %s" % (bcetrain[-1]) ,xycoords = 'figure fraction', xy = (0.25,0.7)) 
plt.annotate("Final test loss: %s" % (bcetest[-1]), xycoords = 'figure fraction', xy = (0.25,0.75)) 
print("Final training loss: %s." % bcetrain[-1]) 
print("Final testing loss: %s." % bcetest[-1]) 
plt.savefig("bce_loss_plots", dpi = 500) 

num = 0
plt.figure(2) 
plt.imshow(tra_pred[num, :, :].detach().cpu(), cmap = 'jet') 
plt.title('Training prediction')
plt.colorbar() 
plt.savefig("train_pred", dpi = 500) 
plt.figure(3) 
plt.imshow(train_batch[num, :, :].detach().cpu(), cmap = 'jet') 
plt.title('Training mask')
plt.colorbar() 
plt.savefig("train_mask", dpi = 500) 
plt.figure(4) 
plt.plot(dscore, 'b') 
plt.xlabel("Iteration") 
plt.ylabel("DSC") 
plt.savefig("dice_plot", dpi = 500) 
print("Final dice coefficient: %s." % dscore[-1]) 

ind = 10
plt.figure(5) 
plt.imshow(te_pred[ind, :, :].detach().cpu(), cmap = 'jet') 
plt.title('Testing prediction')
plt.colorbar() 
plt.savefig("test_pred_1", dpi = 500) 
plt.figure(6) 
plt.imshow(test_batch[ind, :, :].detach().cpu(), cmap = 'jet') 
plt.title('Testing mask')
plt.colorbar() 
plt.savefig("test_mask_1", dpi = 500) 

ind = 20
plt.figure(7) 
plt.imshow(te_pred[ind, :, :].detach().cpu(), cmap = 'jet') 
plt.title('Testing prediction')
plt.colorbar() 
plt.savefig("test_pred_2", dpi = 500) 
plt.figure(8) 
plt.imshow(test_batch[ind, :, :].detach().cpu(), cmap = 'jet') 
plt.title('Testing mask')
plt.colorbar() 
plt.savefig("test_mask_2", dpi = 500) 