## Libraries & Cuda

In [None]:
from tqdm import tqdm
import torch,torchvision
import torch.nn as nn
import matplotlib.pyplot as plt
from monai.data import Dataset, ArrayDataset, create_test_image_3d, DataLoader
from monai.data import CacheDataset
from torch.utils.data import random_split
import torchvision.transforms as transform
from torch.utils.tensorboard import SummaryWriter

In [None]:
import torch
print(torch.cuda.is_available())

In [None]:
%load_ext autoreload

## Importing Data

In [None]:
train_data = torch.load('Data/UpdatedFullDataV_2_0')

In [None]:
val = int((len(train_data)/100)*20) # 30%
org = len(train_data)-val
print(val,org)
train_ds,val_ds = random_split(train_data,[org,val])

In [None]:
batch_size = 5
train_loder = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loder = DataLoader(val_ds, batch_size=batch_size, shuffle=True)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

## Model

In [None]:
device = 'cpu'

In [None]:

%autoreload 2
import Models.pix2pix as Model
PRINTLOG_G = False
PRINTLOG_D = True
PRINTLOG = False
print(Model.Logs(PRINTLOG))
print(Model.G_Logs(PRINTLOG_G))
print(Model.D_Logs(PRINTLOG_D))
model = Model.Pix2Pix(1,1,device, save_after= 5)
model.train(train_loder,val_loder,1)
PRINTLOG = False
print(Model.Logs(PRINTLOG))
print(Model.G_Logs(PRINTLOG))
print(Model.D_Logs(PRINTLOG))

In [None]:
def show(img,output,label,denorm = False):
    img,output,label = img.cpu(),output.cpu(),label.cpu()
    if(len(output) != 1):
      fig,ax = plt.subplots(len(output),3,figsize=(15,15))
    else:
      fig,ax = plt.subplots(len(output),3,figsize=(30,10))
    cols = ['Input Image','Actual Output','Predicted Output']
    for i in range(len(output)):
        if(len(output) != 1):
          Img,Lab,act = img[i],output[i],label[i]
          Img,Lab,act = Img.detach().numpy()[0,:,:],Lab.detach().numpy()[0,:,:],act.detach().numpy()[0,:,:]
          ax[i][0].imshow(Img,cmap='gray')
          ax[i][2].imshow(Lab,cmap='gray')
          ax[i][1].imshow(act,cmap='gray')
        else:
          Img,Lab,act = img[i],output[i],label[i]
          Img,Lab,act = Img.detach().numpy()[0,:,:],Lab.detach().numpy()[0,:,:],act.detach().numpy()[0,:,:]
          ax[0].imshow(Img,cmap='gray')
          ax[2].imshow(Lab,cmap='gray')
          ax[1].imshow(act,cmap='gray')
    plt.show()

In [None]:
def get_lr(optimizer):
  for param_group in optimizer.param_groups:
      return param_group['lr']

## Parameteric Tuning

In [None]:
# model = UNet(1).float().to(device)

In [None]:
lr = 0.0001
lossfunc = nn.L1Loss()
LOSS_FUNC = 'L1LOSS'

optimizer = torch.optim.RAdam(model.parameters(), lr=lr)
OPTIM = 'RAdam'

scheduler = None#torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience = 5, factor = 0.9)
SCHEDULER = None#'RedOnPlatu'

In [None]:
epochs = 50

In [None]:
VERSION = 1
DATA = 'UpdatedFullData_V_2'
log_dir = 'logs'

METHOD = f'{DATA}_Lr{lr}_optim{OPTIM}_loss{LOSS_FUNC}_schedular{SCHEDULER}_epoch{epochs}_ver{VERSION}_randtest'
path = f'{log_dir}/{METHOD}'
path2 = path + '_validation'
METHOD

In [None]:
!rm -rf /logs

In [None]:
writter_train = SummaryWriter(path)
writter_validate = SummaryWriter(path2)

## Traning

In [None]:
train_acc = []
val_acc = []
train_loss = []
val_loss = []

In [None]:
train_count = 0
valid_count = 0
train_count_epi = 0
valid_count_epi = 0

In [None]:
train_loder

In [None]:
for i in range(epochs):
    # if i<300:
    #   clear_output()
    trainloss = 0
    valloss = 0
    
    train_loss_log = 0
    c = 0
    for d in tqdm(train_loder):
        '''
            Traning the Model.
        '''
        optimizer.zero_grad()
        img = d['MR']
        label = d['CT']
        
        img = img.to(device).float()
        label = label.to(device).float()
        output = model(img)
        loss = lossfunc(output,label)
        loss.backward()
        optimizer.step()
        trainloss += loss.item()
        writter_train.add_scalar('Loss/Per Step Loss',loss.item(),train_count)
        writter_train.add_scalar('Learning Rate/Per Step LR',get_lr(optimizer),train_count)
        
        train_count+=1
    

    if i==0:
      prev_loss = trainloss/len(train_loder)
    else:
      diff_in_loss = abs(trainloss/len(train_loder) - prev_loss ) 
      prev_loss = trainloss/len(train_loder)
      writter_train.add_scalar('Loss/Rate of change',diff_in_loss,train_count_epi)

    train_loss.append(trainloss/len(train_loder))  
    writter_train.add_scalar('Loss/Avg Loss',trainloss/len(train_loder),train_count_epi)  
    writter_train.add_scalar('Learning Rate/Per Batch LR',get_lr(optimizer),train_count_epi)
    train_count_epi += 1

    for d in tqdm(val_loder):
        '''
            Validation of Model.
        '''
        img = d['MR']
        label = d['CT']
        img = img.to(device).float()
        label = label.to(device).float()
        output = model(img)
        loss = lossfunc(output,label)
        valloss += loss.item()
        writter_validate.add_scalar('Loss/Per Step Loss',loss.item(),valid_count)
        valid_count+=1
    
    if SCHEDULER:
      scheduler.step(valloss/len(val_loder))
    
    if i%20 == 0:
      show(img,output,label)
    val_loss.append(valloss/len(val_loder))  

    writter_validate.add_scalar('Loss/Avg Loss',valloss/len(val_loder),valid_count_epi)  
    valid_count_epi += 1

    writter_validate.add_scalar('Loss/Diff in Loss',(valloss/len(val_loder) - trainloss/len(train_loder)) ,valid_count_epi)  
    img_grid = torchvision.utils.make_grid(output)
    writter_validate.add_image('Images/Validation',img_grid,valid_count_epi)  
    
    print("epoch : {} ,train loss : {:.6f} ,valid loss : {:.6f} ".format(i,train_loss[-1],val_loss[-1]))

## Cuda Clear

In [None]:
if device == 'cuda':
    torch.cuda.empty_cache()
del model

In [None]:
del train_loder

In [None]:
if device == 'cuda':
    torch.cuda.empty_cache()