# **Drive Mounting**

In [None]:
"""from google.colab import drive
drive.mount('/content/drive', force_remount=True)
%cd 'drive/My Drive/Colab Notebooks/MISA PROJECT/'
#%cd 'drive/My Drive/MISA PROJECT/'"""

# Libraries Import:


In [None]:
import os
import numpy as np
import nibabel as nib
!pip install antspyx
import ants
import copy
import pandas as pd
import time
import warnings
warnings.filterwarnings('ignore')

In [None]:
#for interactive plots
%matplotlib notebook
#for static images
%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
import torch
import torchvision
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim import Adam

In [None]:
from helper import *
from datagen import *
from model import *
from eval_helper import *
from metrics import *

# Initialize variables

In [None]:
training_path = 'dataset/Training_Set'
validation_path = 'dataset/Validation_Set'
test_path = 'dataset/Test_Set'

num_epochs = 100
multi = True
device=set_device()

patch_size = (32,32,32)# 256/32 = 8, 128/32 = 4
sampling_step = (16,16,16)#256/16 = 16, 128/16 = 8
batchsize=32

# Make general mask of brain images (ROI) (To be run once)

In [None]:
for scan_id in os.listdir(training_path):
  scan = ants.image_read(os.path.join(training_path, scan_id, '{}.nii.gz'.format(scan_id)))
  #ants.plot(scan)
  brainmask = ants.image_clone(scan).apply(mask_image)
  #ants.plot(brainmask)
  brainmask.to_filename(os.path.join(training_path, scan_id, '{}_brainmask.nii.gz'.format(scan_id)))

for scan_id in os.listdir(validation_path):
  scan = ants.image_read(os.path.join(validation_path, scan_id, '{}.nii.gz'.format(scan_id)))
  brainmask = ants.image_clone(scan).apply(mask_image)
  brainmask.to_filename(os.path.join(validation_path, scan_id, '{}_brainmask.nii.gz'.format(scan_id)))

In [None]:
for scan_id in os.listdir(test_path):
  scan = ants.image_read(os.path.join(test_path, scan_id, '{}.nii.gz'.format(scan_id)))
  brainmask = ants.image_clone(scan).apply(mask_image)
  brainmask.to_filename(os.path.join(test_path, scan_id, '{}_brainmask.nii.gz'.format(scan_id)))

# loading paths of data

In [None]:
#Training
input_train_data={}
input_train_labels={}
input_train_rois={}
for scan_id in os.listdir(training_path):
  input_train_data[scan_id]=[os.path.join(training_path, scan_id, '{}.nii.gz'.format(scan_id))]
  input_train_labels[scan_id]=[os.path.join(training_path, scan_id, '{}_seg.nii.gz'.format(scan_id))]
  input_train_rois[scan_id]=[os.path.join(training_path, scan_id, '{}_brainmask.nii.gz'.format(scan_id))]

#Validation
input_val_data={}
input_val_labels={}
input_val_rois={}
for scan_id in os.listdir(validation_path):
  input_val_data[scan_id]=[os.path.join(validation_path, scan_id, '{}.nii.gz'.format(scan_id))]
  input_val_labels[scan_id]=[os.path.join(validation_path, scan_id, '{}_seg.nii.gz'.format(scan_id))]
  input_val_rois[scan_id]=[os.path.join(validation_path, scan_id, '{}_brainmask.nii.gz'.format(scan_id))]

In [None]:
#test
input_test_data={}
input_test_rois={}
for scan_id in os.listdir(test_path):
  input_test_data[scan_id]=[os.path.join(test_path, scan_id, '{}.nii.gz'.format(scan_id))]
  input_test_rois[scan_id]=[os.path.join(test_path, scan_id, '{}_brainmask.nii.gz'.format(scan_id))]
print(len(input_test_data))

# Build datagenerators

In [None]:
#train
training_dataset = MRI_DataPatchLoader(input_data=input_train_data, labels=input_train_labels, rois=input_train_rois,patch_size=patch_size,
                                       apply_padding=True, normalize=True, sampling_type='mask',sampling_step=sampling_step)
training_dataloader = DataLoader(training_dataset,batch_size=batchsize,shuffle=True)
#Validation
validation_dataset = MRI_DataPatchLoader(input_data=input_val_data,labels=input_val_labels,rois=input_val_rois,patch_size=patch_size,
                                       apply_padding=True, normalize=True, sampling_type='mask',sampling_step=sampling_step)
validation_dataloader = DataLoader(validation_dataset,batch_size=batchsize,shuffle=True)

In [None]:
"""inputs, classes = next(iter(training_dataloader))# Get a batch of training data
out = torchvision.utils.make_grid(inputs[:,:,:,0])# Make a grid from batch
imshow(out)"""

# Training

In [None]:
tmpdir = "saved"
if not (os.path.exists(tmpdir)):
    os.mkdir(tmpdir)

In [None]:
#MODEL
model = Unet(input_size=1, output_size=4)
model = model.to(device)
# define the optimizer
optimizer = Adam(model.parameters())
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

In [None]:
train_loss_all = []
train_acc_all = []
val_loss_all = []
val_acc_all = []
best_acc = 0.0
best_loss = 1e+5
early_count = 0
dice=True
training = True
epoch = 1

since = time.time()
while training:
    train_loss = 0
    train_accuracy = 0
    val_loss = 0
    val_accuracy = 0
    model.train()
    for b, batch in enumerate(training_dataloader):
        x = batch[0].to(device)
        y = batch[1].to(device)
        optimizer.zero_grad()
        pred = model(x)
        if multi:
          loss = calc_loss(pred,y)
          train_loss += loss.item()
        else:
          loss = F.cross_entropy(torch.log(torch.clamp(pred, 1E-7, 1.0)),y.squeeze(dim=1).long())
          train_loss += loss.item()
        loss.backward()
        optimizer.step()
        # compute the accuracy
        pred = pred.max(1, keepdim=True)[1]
        batch_accuracy = pred.eq(y.view_as(pred).long())
        train_accuracy += (batch_accuracy.sum().item() / np.prod(y.shape))
    model.eval()
    for a, batch in enumerate(validation_dataloader):
        x = batch[0].to(device)
        y = batch[1].to(device)
        with torch.no_grad():
            pred = model(x)
            if multi:
              loss = calc_loss(pred, y)
              val_loss += loss.item()
            else:
              loss = F.cross_entropy(torch.log(torch.clamp(pred, 1E-7, 1.0)),
                                y.squeeze(dim=1).long())
              val_loss += loss.item()
            pred = pred.max(1, keepdim=True)[1]
            batch_accuracy = pred.eq(y.view_as(pred).long())
            val_accuracy += batch_accuracy.sum().item() / np.prod(y.shape)
    train_loss /= (b + 1)
    train_accuracy /= (b + 1)
    val_loss /= (a + 1)
    val_accuracy /= (a + 1)

    train_loss_all.append(train_loss)
    train_acc_all.append(train_accuracy)
    val_loss_all.append(val_loss)
    val_acc_all.append(val_accuracy)
    print('Epoch {:d} train_loss {:.4f} train_acc {:.4f} val_loss {:.4f} val_acc {:.4f}'.format(
        epoch,
        train_loss,
        train_accuracy,
        val_loss,
        val_accuracy))

    if val_loss < best_loss:
          # save weights
          best_loss = val_loss
          best_acc = val_accuracy
          print("val loss decreased...saving model")
          best_model_wts = copy.deepcopy(model.state_dict()) #copy its weights
          model_path = "{}/model.pt".format(tmpdir)
          torch.save(model.state_dict(),model_path)
          early_count = 0
    else:
          early_count += 1
    epoch += 1
    scheduler.step(val_loss)

    if early_count == 20:
        print("Early stopping")
        training = False
    if epoch >= num_epochs:
        training = False
model.load_state_dict(best_model_wts)
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

# Plotting Results

In [None]:
# training and val accuracy
plt.figure()
plt.plot(train_acc_all)
plt.plot(val_acc_all)
plt.title('Accuracy')
plt.ylabel('accuracy')
plt.xlabel('epochs')
plt.legend(['train', 'test'], loc='upper left')
plt.savefig("{}/accuracy.png".format(tmpdir))

# training and val loss
plt.figure()
plt.plot(train_loss_all)
plt.plot(val_loss_all)
plt.title('Loss')
plt.ylabel('loss')
plt.xlabel('epochs')
plt.legend(['train', 'test'], loc='upper left')
plt.savefig("{}/loss.png".format(tmpdir))

# Evalutaion

In [None]:
model_path = "{}/model.pt".format(tmpdir)
#MODEL
model = Unet(input_size=1, output_size=4)
model = model.to(device)
model.load_state_dict(torch.load(model_path))

In [None]:
m_mean = evaluation(validation_path)
m_mean

In [None]:
m_mean = evaluation(training_path)
print(m_mean)

In [None]:
evaluation_test(test_path)