In [None]:
import torch
import importlib
import os
import sys
import time


sys.path.append(os.getcwd()+'/model/')
sys.path.append(os.getcwd()+'/losses/')
sys.path.append(os.getcwd()+'/Dataloader/')


from Encoder_model import *
from Decoder_model import *
from Comparator_model import *
from Dense_Spatial_Transformation import * 
from MorphSSL_dataloader import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
'''
The perceptual Loss: MSE in the feature space using the comparator network
'''
############## Compute Perceptual Loss #############
def update_perceptual_model():
    # uses global perceptual_model, encoder_model.
    # The exponential moving average of first few layers of encoder weigths is used to update the comparator network
    beta=0.99
    for perc_params, enc_params in zip(comparator_model.parameters(), encoder_model.parameters()):
        old_wt, new_wt=perc_params.data, enc_params.data
        perc_params.data=(old_wt*beta)+(1.0-beta)*new_wt
        
    

def compute_perceptual_loss(I_AB, I_B):
    
    ####################################################################################
    # Update discriminator as a momentum encoder of first few layers of the encoder model
    update_perceptual_model()
    
    #####################################################################################
    # Now compute the loss for training our architecture. Architecture must be able to fool, so reverse gt
    ftr_AB=comparator_model(I_AB)
    ftr_B=comparator_model(I_B)
    loss=(F.mse_loss(ftr_AB[0], ftr_B[0])+ F.mse_loss(ftr_AB[1], ftr_B[1]) + F.mse_loss(ftr_AB[2], ftr_B[2]))/3
    return loss


In [None]:
def inference(I_A, I_B, IA_msk, IB_msk):
    
    with torch.no_grad():
        ftr_A=encoder_model(I_A) # time t
        ftr_B=encoder_model(I_B) # time t+k
    
    D_AB, C_AB=decoder_model(ftr_A, ftr_B)
    
    I_out, _=spatial_transform(I_A, D_AB, grid_4,'nearest')
    I_out=I_out+C_AB
    I_AB_msk,_=spatial_transform(IA_msk, D_AB, grid_4,'nearest')
    
    loss=F.mse_loss((I_out*I_AB_msk),(I_B*IB_msk))# Registration/Metamorphosis loss
    loss=loss.detach().cpu().numpy()
    return loss

#####################################################################

def complete_inference(val_loader):
    t = time.time()

    decoder_model.eval()
    encoder_model.eval()
    loss_lst=[]
    lbl_lst=[]
    for i, sample in enumerate(val_loader):
                
        I_A=sample['I_A']
        I_B=sample['I_B']
        IA_msk=sample['IA_msk']
        IB_msk=sample['IB_msk']
        
        lbl=sample['lbl']
        lbl_lst.append(lbl[0])
        
        nm_A=sample['nm_A'][0]
        nm_B=sample['nm_B'][0]
                
        I_A=I_A.to(device)
        I_B=I_B.to(device)
        IA_msk=IA_msk.to(device)
        IB_msk=IB_msk.to(device)
        
        loss=inference(I_A, I_B,IA_msk,IB_msk)
        loss_lst.append(loss)
        del I_A, I_B,loss, IA_msk, IB_msk
    
    decoder_model.train()
    encoder_model.train()
    
    loss_lst=np.array(loss_lst)
    lbl_lst=np.array(lbl_lst)
    
    run_loss=0
    for k in range(0,16):
        idx=np.where(lbl_lst==k)
        tmp=loss_lst[idx]
        del idx
        tmp=np.mean(tmp)
        run_loss=run_loss+tmp
        del tmp
        
    run_loss=run_loss/16
        
    print('Validation loss: '+str(run_loss))
    print('Validation time is '+str(time.time() - t) +' seconds')
    return -run_loss # metric is neg of loss

In [None]:
def train_one_batch(I_A, I_B, IA_msk, IB_msk, opt_enc, opt_dec, sched_enc, sched_dec): 
    
    # Forward pass: get feature representation of the images
    ftr_A=encoder_model(I_A) # time t
    ftr_B=encoder_model(I_B) # time t+k
    
    # Predict deformation 
    D_AB, C_AB=decoder_model(ftr_A, ftr_B)
    
    
    ## Apply deformation to image
    I_AB_dfrm, fld_loss=spatial_transform(I_A, D_AB, grid_4,'bilinear')
    I_AB_dfrm_nearest, _=spatial_transform(I_A, D_AB, grid_4,'nearest')
    
    ## Apply deformation to ROI mask
    I_AB_msk,_=spatial_transform(IA_msk, D_AB, grid_4,'bilinear')
    
    
    
    
    ########## Compute loss ####################
    ## reg error for D_AB
    reg_loss_dfrm=F.mse_loss((I_AB_dfrm*I_AB_msk),(I_B*IB_msk))
    
    ## reg error of C_AB
    residual_gt=(I_B-I_AB_dfrm_nearest)*IB_msk
    reg_loss_add=F.mse_loss(C_AB, residual_gt)
    
    ## Cyclic/Discriminator/Perceptual
    I_AB=(I_AB_dfrm+C_AB)    
    dscrm_loss_img=compute_perceptual_loss(I_AB*I_AB_msk, I_B*IB_msk)
    
    # Compute smoothness loss
    dfrm_smth_loss=comp_smooth_loss(D_AB)# Smoothness of deformation field.
    
    # Compute L1 sparsity loss for C_AB
    add_l1_loss=torch.mean(torch.abs(torch.flatten(C_AB)))
    
   
    
    
    ########################## Weight all losses and aggregate ##########
    reg_loss_dfrm=(10.0**1)*(reg_loss_dfrm) # 1 
    reg_loss_add=(10.0**2)*(reg_loss_add)   # 2 
    dscrm_loss_img=dscrm_loss_img*(10**1)   # 1 # 0
    ############
    fld_loss=(10.0**6)*fld_loss             # 6 # 5
    dfrm_smth_loss=(10**-1)*dfrm_smth_loss   # 0 #-1
    add_l1_loss=(10.0**-5)*add_l1_loss      # -5
    
    loss=(reg_loss_dfrm+reg_loss_add+dscrm_loss_img)+(fld_loss+dfrm_smth_loss+add_l1_loss)
    
    
    ###################### Backpropagation ###########################
    # remove previous gradients
    opt_enc.zero_grad() 
    opt_dec.zero_grad()
    
    # compute the gradients
    loss.backward() 
    
    # update the weights
    opt_enc.step() 
    opt_dec.step()
    
    # Update learning rate scheduler
    sched_enc.step() 
    sched_dec.step()
    
    ########## Return loss values to Log & Display ##############
    loss=loss.detach().cpu().numpy()
    reg_loss_dfrm=reg_loss_dfrm.detach().cpu().numpy()
    reg_loss_add=reg_loss_add.detach().cpu().numpy()
    dscrm_loss_img=dscrm_loss_img.detach().cpu().numpy()
    
    fld_loss=fld_loss.detach().cpu().numpy()
    dfrm_smth_loss=dfrm_smth_loss.detach().cpu().numpy()
    add_l1_loss=add_l1_loss.detach().cpu().numpy()
    
    return loss, reg_loss_dfrm, reg_loss_add,dscrm_loss_img,fld_loss,dfrm_smth_loss,add_l1_loss,\
opt_enc, opt_dec, sched_enc, sched_dec

In [None]:
def train_complete():
    ###################  Learning Rate optimizer and Scheduler  ####################################
    mn_lr=10**(-5.0)
    mx_lr=10**(-4.0)
    
    opt_enc = torch.optim.Adam(encoder_model.parameters(), lr=mn_lr, 
                                 betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
    opt_dec = torch.optim.Adam(decoder_model.parameters(), lr=mn_lr, 
                                 betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
    
    
    cycle_length=100
    sched_enc=torch.optim.lr_scheduler.CyclicLR(opt_enc, base_lr=mn_lr, max_lr=mx_lr, cycle_momentum=False,
                                            step_size_up=cycle_length, step_size_down=None, mode='triangular')
    
    sched_dec=torch.optim.lr_scheduler.CyclicLR(opt_dec, base_lr=mn_lr, max_lr=mx_lr, cycle_momentum=False,
                                            step_size_up=cycle_length, step_size_down=None, mode='triangular')
    
    #################################################################################################
    total_batches=len(train_loader) 
    max_epochs=300
    tot_batch_updates_val= 2*cycle_length # check at the end of "scheduler" triangles where lr is min.
    patience=40 # For Early Stopping
    
    
    ##################### Initialize #################################################################
    ptnc_cnt=0 # for Early stopping
    flag=0 # 0=> continue, 1=> Early Stopping has occured to break the outer loop for epoch
    max_metric=-999 # best val accuracy encountered so far
    
    cnt=0
    
    run_loss=0
    run_reg_dfrm=0
    run_reg_add=0
    run_dscimg=0
    
    run_fld=0
    run_smth=0
    run_l1=0
    
    
    t2 = time.time()
    
    for epoch in range(0, max_epochs):
        for i, sample in enumerate(train_loader):
            cnt=cnt+1
            
            I_A=sample['I_A']
            I_B=sample['I_B']        
            IA_msk=sample['IA_msk']
            IB_msk=sample['IB_msk']
            del sample
            
            I_A=I_A.to(device)
            I_B=I_B.to(device)
            IA_msk=IA_msk.to(device)
            IB_msk=IB_msk.to(device)
            
            
            loss, reg_loss_dfrm, reg_loss_add,dscrm_loss_img,fld_loss,dfrm_smth_loss,add_l1_loss,opt_enc, opt_dec,\
            sched_enc, sched_dec=train_one_batch(I_A, I_B, IA_msk, IB_msk, opt_enc, opt_dec, sched_enc, sched_dec)
            
            run_loss=run_loss+loss
            run_reg_dfrm=run_reg_dfrm+reg_loss_dfrm
            run_reg_add=run_reg_add+reg_loss_add
            run_dscimg=run_dscimg+dscrm_loss_img
            
            run_fld=run_fld+fld_loss
            run_smth=run_smth+dfrm_smth_loss 
            run_l1=run_l1+add_l1_loss
            
            
            del loss,reg_loss_dfrm, reg_loss_add, dscrm_loss_img,fld_loss,dfrm_smth_loss,add_l1_loss
            del I_A,I_B,IA_msk, IB_msk
            
            
            if (cnt+1) % 10== 0: # displays after every 10 batch updates
                print ("Epoch [{}/{}], Batch [{}/{}], cnt {}, Train Loss: {:.4f}, REG_DFRM: {:.4f}, REG_ADD: {:.4f}, DSC_REG: {:.4f}, FLD: {:.4f}, L1: {:.4f}, SMTH: {:.4f}"
                       .format(epoch+1, max_epochs, i+1, total_batches, cnt, (run_loss/cnt),(run_reg_dfrm/cnt),(run_reg_add/cnt),
                        (run_dscimg/cnt),(run_fld/cnt), (run_l1/cnt),(run_smth/cnt)), end ="\r")
                
            
            ############# Monitor Validation Acc and Early Stopping ############
            if cnt>=tot_batch_updates_val:
                torch.save({'model_state_dict_encoder_model': encoder_model.state_dict(),
                            'model_state_dict_decoder_model': decoder_model.state_dict()}, 'last_weight.pt')
                
                print('\n Training time for 1 cycle is: '+str(time.time() - t2) +' seconds')
                
                # Monitor val auc
                metric=complete_inference(val_loader)    
                
                ########## Reinitialize for next cycle the running loss, optimizer and scheduler
                cnt=0
    
                run_loss=0
                run_reg_dfrm=0
                run_reg_add=0
                run_dscimg=0
    
                run_fld=0
                run_smth=0
                run_l1=0
                
                t2 = time.time()
            
                # Early Stopping
                if metric>max_metric:
                    torch.save({'model_state_dict_encoder_model': encoder_model.state_dict(),
                            'model_state_dict_decoder_model': decoder_model.state_dict()}
                            , 'best_weight'+str(metric)+'.pt')
                    max_metric=metric
                    ptnc_cnt=0
                else:
                    ptnc_cnt=ptnc_cnt+1
                    print('\n Validation metric has not improved in last '+str(ptnc_cnt)+' batch updates')
                    if ptnc_cnt>=patience:
                        print("\n Early Stopping ! from current epoch \n \n")
                        flag=1 # this will be used to break out of the outer loop for epochs.
                        break # this breaks out of inner loop of Dataloader
        if flag==1: # Early Stopping has ocurred
            break # break out of the outer loop for epochs.   

In [None]:
def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)

In [None]:
eps=10**(-14)
grid_4=create_mesh_grid(192,192,32) # Spatial index used to apply the spatial deformation

############### Dataloader #####################
# see comments in /Dataloader/MorphSSL_dataloader.py  for img_pth, img_pairs
train_data=train_dataset(img_pth='/msc/home/achakr83/PINNACLE/SSL_training/May30/final_full_training/preprocessed_SSL_images2/',
                         img_pairs='/msc/home/achakr83/PINNACLE/SSL_training/May30/final_full_training/step4_final_train_ssl_data.npz')
train_loader=DataLoader(dataset=train_data, batch_size=1, shuffle=True, num_workers=2, 
                             pin_memory=False, drop_last=False, worker_init_fn=worker_init_fn)

val_data=val_dataset(img_pth='/msc/home/achakr83/PINNACLE/SSL_training/May30/final_full_training/preprocessed_SSL_images2/',
                     img_pairs='/msc/home/achakr83/PINNACLE/SSL_training/May30/final_full_training/step4_final_val_ssl_data.npz')
val_loader=DataLoader(dataset=val_data, batch_size=1, shuffle=False, num_workers=2, 
                             pin_memory=False, drop_last=False, worker_init_fn=worker_init_fn)


########### Instantiate the Models ##############
#### Encoder ###
encoder_model=Encoder_Architecture(base_chnls=16, out_ftr_dim=(64*2))
encoder_model.to(device)
#### Decoder ###
decoder_model=Decoder_Architecture(in_dim=64, first_dim=512)
decoder_model.to(device)
#### Comparator ###
comparator_model=Comparator_Architecture(base_chnls=16, out_ftr_dim=(64*2))   
comparator_model.to(device)

In [None]:
######### Begin Training #########
train_complete()