In [None]:
import torch
import importlib
import os
import sys
import time
from sklearn.metrics import roc_auc_score

sys.path.append(os.getcwd()+'/model/')
sys.path.append(os.getcwd()+'/losses/')
sys.path.append(os.getcwd()+'/Dataloader/')


from Encoder_model import *
from Classifier_Model import *
from Decoder_model import *
from Comparator_model import *
from Dense_Spatial_Transformation import * 
from MorphSSL_dataloader import *
from TTC_dataloader import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def complete_inference(val_loader):
    classifier_model.eval()
    feature_model.eval()
    
    gt_lst=[]
    pred_lst=[]
    a_lst=[]
    b_lst=[]
    
    
    cnt=0
    for i, sample in enumerate(val_loader):     
        
        img=sample['I']
        gt=sample['gt'] # It has 2 points
        t=sample['t']
        t_cnv=sample['t_cnv']
        t_bfr=sample['t_bfr']
        
        img=img.to(device)
        gt=gt.to(device)
        t=t.to(device)
        t_cnv=t_cnv.to(device)
        t_bfr=t_bfr.to(device)
        
        with torch.no_grad():
            ftr=feature_model(img)
            a,b,_=classifier_model(ftr)
            pred=torch.sigmoid((t-b)/(a + 0.05))
        
        
        pred=pred.detach().cpu().numpy()
        gt=gt.detach().cpu().numpy()
        a=a.detach().cpu().numpy()
        b=b.detach().cpu().numpy()
        
        pred_lst.append(pred)
        gt_lst.append(gt)
        a_lst.append(a)
        b_lst.append(b)
        del pred, gt,a,b
        
    
    pred_lst=np.concatenate(pred_lst, axis=0) # B,4
    gt_lst=np.concatenate(gt_lst, axis=0) # B,4
    a_lst=np.concatenate(a_lst, axis=0) # B,1
    b_lst=np.concatenate(b_lst, axis=0) # B,1
    
    
    
    
    ####### ROC_AUC for 0 month, 6 month, 12 month, 18 month
    tmp_gt=np.squeeze(gt_lst[:,0])
    tmp_pred=np.squeeze(pred_lst[:,0])
    idx=np.where(tmp_gt!=-1)
    tmp_gt=tmp_gt[idx]
    tmp_pred=tmp_pred[idx]
    roc_auc0=roc_auc_score(tmp_gt, tmp_pred)
    
    tmp_gt=np.squeeze(gt_lst[:,1])
    tmp_pred=np.squeeze(pred_lst[:,1])
    idx=np.where(tmp_gt!=-1)
    tmp_gt=tmp_gt[idx]
    tmp_pred=tmp_pred[idx]
    roc_auc1=roc_auc_score(tmp_gt, tmp_pred)
    
    tmp_gt=np.squeeze(gt_lst[:,2])
    tmp_pred=np.squeeze(pred_lst[:,2])
    idx=np.where(tmp_gt!=-1)
    tmp_gt=tmp_gt[idx]
    tmp_pred=tmp_pred[idx]
    roc_auc2=roc_auc_score(tmp_gt, tmp_pred)
    
    tmp_gt=np.squeeze(gt_lst[:,3])
    tmp_pred=np.squeeze(pred_lst[:,3])
    idx=np.where(tmp_gt!=-1)
    tmp_gt=tmp_gt[idx]
    tmp_pred=tmp_pred[idx]
    roc_auc3=roc_auc_score(tmp_gt, tmp_pred)
    
    del tmp_gt, tmp_pred
    
    
    
    metric=(roc_auc0+roc_auc1+roc_auc2+roc_auc3)/4
    
    print("ROC AUC, 0 mnth: {:.4f}, 6 mnth: {:.4f}, 12 mnth: {:.4f}, 18 mnth: {:.4f}"
          .format(roc_auc0, roc_auc1, roc_auc2, roc_auc3))
    
    classifier_model.train()
    if trn_type=='finetune':
        feature_model.train()
    return metric

In [None]:
def compute_loss(a,b,t,gt,t_cnv,t_bfr, roi_msk, cam):
    
    # GT : 0 for not converted, 1 for converted. Inverse of Survival
    # Get the prediction
    pred=torch.sigmoid((t-b)/(a + 0.05)) 
    loss_pred=F.binary_cross_entropy(pred, gt,reduction='none') 
    loss_pred=torch.mean(loss_pred, dim=1)
    loss_pred=torch.squeeze(torch.mean(loss_pred))
    
    # Cam minimize region outside the ROI
    loss_cam=cam*(1-roi_msk)# B,1,H,W,D
    loss_cam=loss_cam.view(-1) # (BHWD)
    loss_cam=torch.squeeze(torch.mean((loss_cam**2),dim=0))
    
    ########
    loss_uncrtn=torch.squeeze(torch.mean((a.view(-1)**2),dim=0))
    
    loss=1.0*loss_pred+0.1*loss_uncrtn+0.1*loss_cam
    return loss, loss_pred, loss_uncrtn, loss_cam
    
    

def train_one_batch(img,t,gt,t_cnv,t_bfr,roi_msk, optimizer1, scheduler1,optimizer2, scheduler2):
    if trn_type=='finetune':
        ftr=feature_model(img)
    elif trn_type=='freeze':
        with torch.no_grad():
            ftr=feature_model(img)
    
    a,b,cam=classifier_model(ftr) # unnormalized score without sigmoid
    loss, loss_pred, loss_uncrtn, loss_cam=compute_loss(a,b,t,gt,t_cnv,t_bfr, roi_msk, cam)
    
    ###################### Backpropagation ###########################
    # remove previous gradients
    optimizer1.zero_grad()
    if trn_type=='finetune':
        optimizer2.zero_grad()
    # compute the gradients
    loss.backward() 
    # Update the weights
    optimizer1.step()
    # Update Learning rate scheduler
    scheduler1.step()
    
    if trn_type=='finetune':
        optimizer2.step()
        scheduler2.step()
    
    loss=loss.detach().cpu().numpy()
    loss_pred=loss_pred.detach().cpu().numpy()
    loss_uncrtn=loss_uncrtn.detach().cpu().numpy()
    loss_cam=loss_cam.detach().cpu().numpy()
    
    return loss, loss_pred, loss_uncrtn, loss_cam,optimizer1,scheduler1,optimizer2, scheduler2

In [None]:
def train_complete():
    mx_lr=10**(-5.0) 
    mn_lr=10**(-6.0)     
    total_batches=len(train_loader) 
    ################################################################################################# 
    cycle_length=240
    max_epochs=1000 # First No. is no. of complete triangles to train
    tot_batch_updates_val= 2*cycle_length # check at the end of "scheduler" triangles where lr is min.
    
    patience=100 # Stop training if val metric doesnot improve for "patience" successive validation checks
    
    #################################################################################################
    
    optimizer1=torch.optim.AdamW(classifier_model.parameters(), lr=mn_lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=wt_decay)
    
    scheduler1=torch.optim.lr_scheduler.CyclicLR(optimizer1, base_lr=mn_lr, max_lr=mx_lr, cycle_momentum=False,
                                            step_size_up=cycle_length, step_size_down=None, mode='triangular')
    
    if trn_type=='finetune':
        optimizer2=torch.optim.AdamW(feature_model.parameters(), lr=mn_lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=wt_decay)
    
        scheduler2=torch.optim.lr_scheduler.CyclicLR(optimizer2, base_lr=mn_lr, max_lr=mx_lr, cycle_momentum=False,
                                            step_size_up=cycle_length, step_size_down=None, mode='triangular')
    elif trn_type=='freeze':
        optimizer2=None
        scheduler2=None
    
    ##### Initial metric ########
    metric=complete_inference(val_loader)
    
    #### Initialize ######
    ptnc_cnt=0 # for Early stopping
    flag=0 # 0=> continue, 1=> Early Stopping has occured to break the outer loop for epoch
    max_metric=-999 # best val accuracy encountered so far
    
    cnt=0
    run_loss=0
    run_loss_pred=0
    run_loss_uncrtn=0
    run_loss_cam=0
    
    t2 = time.time()
    
    for epoch in range(0, max_epochs):
        for i, sample in enumerate(train_loader):
            cnt=cnt+1
            
            img=sample['I']
            t=sample['t']
            gt=sample['gt']
            t_cnv=sample['t_cnv']
            t_bfr=sample['t_bfr']
            roi_msk=sample['roi_msk']
            del sample
            
            img=img.to(device)
            t=t.to(device)
            gt=gt.to(device)
            roi_msk=roi_msk.to(device)
            
            loss, loss_pred, loss_uncrtn, loss_cam,optimizer1,scheduler1,optimizer2,scheduler2=train_one_batch(img,t,gt,t_cnv,t_bfr, roi_msk, optimizer1, scheduler1,optimizer2, scheduler2)
            
            run_loss=run_loss+loss
            run_loss_pred=run_loss_pred+loss_pred
            run_loss_uncrtn=run_loss_uncrtn+loss_uncrtn
            run_loss_cam=run_loss_cam+loss_cam
            
            del img, gt, t, loss, loss_pred, loss_uncrtn, loss_cam
            
            if (i+1) % 10== 0: # displays after every 10 batch updates
                print ("Epoch [{}/{}], Batch [{}/{}], cnt {}, Train Loss: {:.4f}, Loss Pred: {:.4f}, Loss Uncrtn: {:.4f}, Loss CAM: {:.4f}"
                       .format(epoch+1, max_epochs, i+1, total_batches, cnt, (run_loss/cnt),(run_loss_pred/cnt),(run_loss_uncrtn/cnt),(run_loss_cam/cnt)), end ="\r")
            
            ############# Monitor Validation Acc and Early Stopping ############
            if cnt>=tot_batch_updates_val:
                print('\n Training time for 1 cycle is: '+str(time.time() - t2) +' seconds')
                # Monitor val auc
                metric=complete_inference(val_loader)
                
                ########## Reinitialize for next cycle the running loss, optimizer and scheduler
                cnt=0
                run_loss=0
                run_loss_pred=0
                run_loss_uncrtn=0
                run_loss_cam=0
                t2 = time.time()
            
                # Early Stopping
                if metric>max_metric:
                    torch.save({'classifier_model_state_dict_model': classifier_model.state_dict(),
                                'feature_model_state_dict_model': feature_model.state_dict()},
                               wt_nm)
                    max_metric=metric
                    ptnc_cnt=0
                else:
                    ptnc_cnt=ptnc_cnt+1
                    print('\n Validation metric has not improved in last '+str(ptnc_cnt)+' batch updates')
                    if ptnc_cnt>=patience:
                        print("\n Early Stopping ! \n")
                        flag=1 # this will be used to break out of the outer loop for epochs.
                        break # this breaks out of inner loop of Dataloader
        if flag==1: # Early Stopping has ocurred
            break # break out of the outer loop for epochs.   

In [None]:
def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)

In [None]:
"""
This code was modified to perform both training with the encoder model frozen and end-to-end finetuning.
Originally, the frozen experiments were performed separately after pre-extracting 10 different data-augmented versions
of each sample with the encoder and using these features during training the classifier. This precomputation is faster 
and requires less computation as the forward pass through the encoder was precomputed and was not needed to be performed
while training.
"""

eps=10**(-9)
fld=3 # the fold for which to train
wt_decay=0.01
trn_type='freeze' # or 'finetune'. 'freeze' should always be run before running 'finetune'

#################### Dataloaders #####################################
fld_pth='/msc/home/achakr83/PINNACLE/SSL_training/June30/cross_validation_splits_new/fold'+str(fld)+'.npz' 
# fld_pth is the pth to a npz file containing the list of scans in the current fold and GT. 
# See /Dataloader/TTC_dataloader/  for details.

# directory containing the preprocessed OCT scans in npz files.
img_pth='/msc/home/achakr83/PINNACLE/preprocessed_Downstream_images2/' 

train_data=train_dataset(fold=fld_pth, img_pth=img_pth)
train_loader=DataLoader(dataset=train_data, batch_size=3, shuffle=True, num_workers=2, 
                             pin_memory=False, drop_last=False)
    
val_data=val_dataset(fold=fld, img_pth=img_pth)
val_loader=DataLoader(dataset=val_data, batch_size=3, shuffle=False, num_workers=2, 
                             pin_memory=False, drop_last=False, worker_init_fn=worker_init_fn)


################## Prepare the model  ############################### 
feature_model=Encoder_Architecture(base_chnls=16, out_ftr_dim=(64*2))
feature_model.to(device)

classifier_model=Classification_Network()
classifier_model.to(device)

################## Load the feature_model ##############################
checkpoint = torch.load('best_weight-0.007980789116118103.pt')
feature_model.load_state_dict(checkpoint['model_state_dict_encoder_model'])
del checkpoint

if trn_type=='finetune':
    # Additionally load initial weights for the classifier, obtained by training with the frozen model wts first.
    checkpoint = torch.load(str(fld)+'_best_weight_frozen.pt')
    classifier_model.load_state_dict(checkpoint['classifier_model_state_dict_model'])
    del checkpoint    
    wt_nm=str(fld)+'_best_weight_finetune.pt'
elif trn_type=='freeze':
    wt_nm=str(fld)+'_best_weight_frozen.pt'
    # freeze the weights of the feature_model(encoder)
    for param in feature_model.parameters():
        param.requires_grad = False
    
    feature_model.eval() # its features are frozen
    


train_complete()