In [1]:
import numpy as np
import sys
import glob, os

import pydicom
import random 
import math
import time
import pickle
import gc

# https://scikit-survival.readthedocs.io/en/stable/api/generated/sksurv.metrics.concordance_index_censored.html#
from sksurv.metrics import concordance_index_censored, cumulative_dynamic_auc, integrated_brier_score

from torchvision.models import resnet50, ResNet50_Weights

from scipy import ndimage
from scipy.ndimage import zoom

from skimage import filters
from skimage import io
from skimage import transform 

from einops import rearrange#, reduce, repeat


import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision
import torch.nn as nn
import torch.nn.functional as F


sys.path.append("torchdiffeq-master") # go to parent dir'
from torchdiffeq import odeint_adjoint as odeint
#from torchdiffeq import odeint

from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score

#import warnings
#warnings.filterwarnings("ignore")

######################## configure device ###############
#os.environ["CUDA_VISIBLE_DEVICES"]="5" # 1 is GT 740
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# verify that gpu is recognized
print(device)



################## Set random seem for reproducibility ##########
manualSeed = 9432
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)


########## interactive mode for plots ###################
plt.ion()   
%matplotlib inline

print(torch.__version__)

cuda
Random Seed:  9432
2.0.0


In [2]:
pwd=os.getcwd()
sys.path.append(pwd+'/backprop/')

from mtadam import MTAdam

In [3]:
get_ipython().run_line_magic('run', pwd+"/backprop/losses.ipynb")
get_ipython().run_line_magic('run', pwd+"/model/Encoder.ipynb")
get_ipython().run_line_magic('run', pwd+"/model/Neural_ODE.ipynb")
get_ipython().run_line_magic('run', pwd+"/backprop/losses.ipynb")
get_ipython().run_line_magic('run', pwd+"/Dataloader/Dataloader.ipynb")

In [4]:
def complete_inference(val_loader):
    encoder_model.eval()
    classifier_model.eval()
    ode_func.eval()
    
    t_inp=torch.from_numpy(np.array([0.0, 6.0, 12.0, 18.0, 24.0, 30.0, 36.0])).to(device)
    t_inp=(t_inp/36.0)
    
    gt_lst=[] # 2D array:  sample, 6 time-points
    pred_lst=[]
    risk_scr_lst=[]
    tcnv_lst=[]
    indctr_lst=[]
    nm_lst=[]
    
    for i, sample in enumerate(val_loader):
        img=sample['img'].to(device)
        gt=sample['gt'] # 6-D  conversion within 0/6/12/18/24/30/36 month time-points
        tcnv=sample['tcnv']
        indctr=sample['indctr']
        nm=sample['nm']
        
        img=rearrange(img, 'b 1 c h w -> b c h w')
        
        ### Forward Pass
        with torch.no_grad():
            ftr=encoder_model(img)         
            rsk_curr=classifier_model(ftr)
            
        ##### ODE to get future risk predictions #####
        ftr_ode=odeint(ode_func, ftr, t_inp, 
                          atol=1e-8, rtol=1e-8, method=ode_solver, options=dict(step_size=stp_sz)) # ftr_ode: t=7,B,768
        
        ftr_ode=ftr_ode[1:,:,:] # remove t=0    t=6,B,768
        B=ftr_ode.shape[1]
        #print(B)
        #print('ftr_ode: '+str(ftr_ode.shape))
        ftr_ode=rearrange(ftr_ode, 't b d -> (b t) d') # (B*6), 768
        rsk_ode=classifier_model(ftr_ode) # (B*6,1)
        #print('ftr_ode: '+str(ftr_ode.shape))
        #print('rsk_ode: '+str(rsk_ode.shape))
        
        rsk_ode=rearrange(rsk_ode, '(b t) 1 -> b t', b=B, t=6) # B,6 
        
        p=F.sigmoid(rsk_ode) # B,6
        p=p.detach().cpu().numpy()
        ################################################
        
        pred_lst.append(p)
        risk_scr_lst.append(rsk_curr.detach().cpu().numpy()) # risk predicted for the current input scan
        nm_lst.append(nm)
        gt_lst.append(gt)
        tcnv_lst.append(tcnv)
        indctr_lst.append(indctr)
        
        del img, gt, tcnv, indctr, nm, ftr,rsk_curr,ftr_ode, rsk_ode, p
        
    encoder_model.train()
    classifier_model.train()
    ode_func.train()
    
    
    pred_lst=np.concatenate(pred_lst, axis=0) # or stack?  # B,6
    risk_scr_lst=np.concatenate(risk_scr_lst, axis=0)      # B,1
    nm_lst=np.concatenate(nm_lst, axis=0)                  # B,
    gt_lst=np.concatenate(gt_lst, axis=0)                  # B,6
    tcnv_lst=np.concatenate(tcnv_lst, axis=0)              # B,
    indctr_lst=np.concatenate(indctr_lst, axis=0)          # B,
    
    ### Sort the list which is used to define the index for the samples of each bootstrap re-sampling.
    idx=np.argsort(nm_lst, axis=0)
    pred_lst=pred_lst[idx,:]           # B,6
    risk_scr_lst=risk_scr_lst[idx,:]   # B,1
    nm_lst=nm_lst[idx]                 # B,
    gt_lst=gt_lst[idx,:]               # B,6
    tcnv_lst=tcnv_lst[idx]             # B,
    indctr_lst=indctr_lst[idx]         # B,
    del idx
    
    ### Check that the sorted nm_lst is same as the validation dataloader
    flag=np.array_equal(nm_lst, val_data.nm_lst) # this is the order used to define bootstrap samplings
    if flag==False:
        print('the nm_lst doesnot match !')
    
    
    ###### Now everything is sorted by name. So now, we can use the pre-saved indices ###
    indices=val_data.sampling_index
    c_lst=[]
    auc_lst=[]
    for k in range(0, len(indices)): # No. of bootstrap re-samplings
        idx=indices[k]
        tmp_rsk=np.squeeze(risk_scr_lst[idx,:], axis=1) # B,
        tmp_tcnv=tcnv_lst[idx]
        tmp_indctr=indctr_lst[idx]
        c_index = concordance_index_censored(tmp_indctr.astype(bool), tmp_tcnv, tmp_rsk)
        c_index = c_index[0]
        c_lst.append(c_index)
        del c_index, tmp_indctr, tmp_tcnv, tmp_rsk
        
        tmp_gt=gt_lst[idx,:]
        tmp_pred=pred_lst[idx,:]
        auc=[]
        for cls in range(0,6):
            gt=tmp_gt[:, cls]
            pred=tmp_pred[:, cls]
            idx2=np.where(gt !=-1) # -1 implies GT is unavailable (eg. time after censoring has occured)
            gt=gt[idx2]
            pred=pred[idx2]
            # roc_auc_score(y_true, y_score
            auc.append(roc_auc_score(gt, pred))
            del gt, pred, idx2
            # suppose within timepoint 36 months but image is censored at 24 months. then lbl is -1, needs to be avoided
        
        auc=np.expand_dims(np.array(auc), axis=0)
        auc_lst.append(auc) # list of 6-dim arrays
        del auc, idx
        
    
    auc_lst=np.concatenate(auc_lst, axis=0) # B,6
    c_lst=np.array(c_lst)
    
    mean_concordance=np.mean(c_lst)
    avg_auc=np.mean(auc_lst, axis=1) # avg across 6 time-points
    mn_avg_auc=np.mean(avg_auc, axis=0) # avg across each sampling.
    # confidence intervals np.percentile(c_lst, 0.95)   and 0.05
    
    print('\n CI: '+str(mean_concordance)+'  AUC: '+str(mn_avg_auc))
    metric=mean_concordance+mn_avg_auc # this has to be maximized
    return metric

In [5]:
def train_one_batch(sample, optimizer, scheduler):
    
    img1=sample['img1'].to(device)   # B,3,224,224
    img2=sample['img2'].to(device)   # B,3,224,224
    
    img1=rearrange(img1, '1 b 1 c h w -> b c h w')
    img2=rearrange(img2, '1 b 1 c h w -> b c h w')
    
    gt1=torch.unsqueeze(torch.squeeze(sample['gt1'], dim=0), dim=1).to(device)   # B,1
    gt2=torch.unsqueeze(torch.squeeze(sample['gt2'], dim=0), dim=1).to(device) 
    
    tcnv1=torch.unsqueeze(torch.squeeze(sample['tcnv1'], dim=0), dim=1).to(device)  # B,1
    tcnv2=torch.unsqueeze(torch.squeeze(sample['tcnv2'], dim=0), dim=1).to(device) 
    
    indctr1=torch.unsqueeze(torch.squeeze(sample['indctr1'], dim=0), dim=1).to(device)  # B,1
    indctr2=torch.unsqueeze(torch.squeeze(sample['indctr2'], dim=0), dim=1).to(device) 
    
    tintrvl=torch.unsqueeze(torch.squeeze(sample['tintrvl'], dim=0), dim=1).to(device)  # B,1
    
    ###############################################################################################    
    img=torch.cat((img1, img2), dim=0)         # 2B,3,H,W
    gt=torch.cat((gt1, gt2), dim=0)            # 2B,1
    indctr=torch.cat((indctr1, indctr2), dim=0)# 2B,1
    tcnv=torch.cat((tcnv1, tcnv2), dim=0)      # 2B,1    
    
    #print('img: '+str(img.shape))
    #print('gt: '+str(gt.shape))
    #print('indctr: '+str(indctr.shape))
    #print('tcnv: '+str(tcnv.shape))
    ######################################## FORWARD PASS #########################################
    B=img.shape[0]//2
    ftr=encoder_model(img)         # 2B, 768
    rsk=classifier_model(ftr)
    
    ftr1=ftr[0:B,:] 
    ftr2=ftr[B:,:]
    
    #print('ftr: '+str(ftr.shape)+'  ftr1: '+str(ftr1.shape)+'  ftr2: '+str(ftr2.shape))
    #print('rsk: '+str(rsk.shape))
    
    ###############################################################################################
    
    t_inp, idx=torch.unique(tintrvl, sorted=True, return_inverse=True, return_counts=False, dim=0)
    t_inp=t_inp.to(device)
    
    idx=torch.unsqueeze(idx, dim=1) # B,1
    idx_ftr=idx.repeat(1,768) # B,768
    idx_ftr=torch.unsqueeze(idx_ftr, dim=0) # 1,B,768
    
    
    ### add t=0 to t_inp
    t_inp=torch.squeeze(t_inp, dim=1) 
    t_inp=torch.cat((torch.from_numpy(np.array([0.0])).to(device), t_inp), dim=0)
    
    ftr_ode=odeint(ode_func, ftr1, t_inp, 
                          atol=1e-8, rtol=1e-8, method=ode_solver, options=dict(step_size=stp_sz))
    
    ### Remove the first t=0 time-point
    ftr_ode=ftr_ode[1:,:,:] # 12,12,768
    ftr_ode=torch.squeeze(torch.gather(input=ftr_ode, dim=0, index=idx_ftr), dim=0) # B,768
    
    ###############################################################################################
    
    rsk_ode=classifier_model(ftr_ode)
    #print('ftr_ode: '+str(ftr_ode.shape)+'  rsk_ode: '+str(rsk_ode.shape))
    
    
    ####################################### COMPUTE LOSSES ########################################
    # the rsk uses loss with a weight of 3 for + class.
    cls_loss=bce_loss_logits(rsk, gt)
    
    ##################### Consistency (ODE) losses ###########
    ftr_ode_loss=F.mse_loss(ftr_ode, ftr2, reduction='mean')
    # with logits still required the GT to be [0,1]
    rsk_ode_loss=F.binary_cross_entropy_with_logits(rsk_ode, F.sigmoid(rsk[B:,:]), reduction='mean')+ F.binary_cross_entropy_with_logits(rsk_ode, gt[B:,:])
    
    ###################### Risk score ranking Loss ###########
    rank_loss=my_risk_concordance_loss(torch.cat((rsk, rsk_ode), dim=0),
                                       torch.cat((indctr, indctr2), dim=0),
                                       torch.cat((tcnv, tcnv2), dim=0))
    
    
    
    ################# Backpropagation ###########################
    # remove previously stored gradients
    optimizer.zero_grad()
    # Compute Gradients
    
    # Update weights
    optimizer.step([rank_loss, rsk_ode_loss, cls_loss, ftr_ode_loss], [1, 1, 1, 1], None)
    #optimizer.step()
    # Update learning rate scheduler
    scheduler.step()
    
    ############  Return Loss for Displaying #####
    cls_loss=cls_loss.detach().cpu().numpy()
    ftr_ode_loss=ftr_ode_loss.detach().cpu().numpy()
    rsk_ode_loss=rsk_ode_loss.detach().cpu().numpy()
    rank_loss=rank_loss.detach().cpu().numpy()
    
    loss=cls_loss+ftr_ode_loss+rsk_ode_loss+rank_loss
    
    return loss, cls_loss, ftr_ode_loss, rsk_ode_loss, rank_loss, optimizer, scheduler
     
'''
the use of MTAdam is as simple as using Adam, and requires the following steps: 
(a) initiating the MTAdam optimizer (in a similar way to Adam). 
(b) keeping the multi-term loss objective decomposed as a sequence of single terms, 
    and sending the sequence as an argument to MTAdam.step(). 
(c) avoid calling the function loss.backward(), since it is done internally in MTAdam.step().
'''

'\nthe use of MTAdam is as simple as using Adam, and requires the following steps: \n(a) initiating the MTAdam optimizer (in a similar way to Adam). \n(b) keeping the multi-term loss objective decomposed as a sequence of single terms, \n    and sending the sequence as an argument to MTAdam.step(). \n(c) avoid calling the function loss.backward(), since it is done internally in MTAdam.step().\n'

In [6]:
def train_complete():
    mn_lr=10**(-6.0)
    mx_lr=10**(-4.0)
    
    nupdates=300#1000  200 # no of batch updates in each epoch
    max_epochs=200
    max_patience=50 # 100 # Early stopping if validation metric doesnot improve in this many consecutive epochs
    
    ####
    max_metric=0
    
    ############################################################################################################
    optimizer = MTAdam(list(encoder_model.parameters()) + list(classifier_model.parameters()) + list(ode_func.parameters()) + list(rank_model.parameters()), lr=mn_lr, amsgrad=True)
    scheduler=torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=mn_lr, max_lr=mx_lr, cycle_momentum=False,
                                            step_size_up=nupdates//2, step_size_down=None, mode='triangular')
    ############################################################################################################
    
    patience=0
    data_iter = iter(train_loader) 
    for epochs in range(0, max_epochs):
        run_loss=0 # total loss  
        run_cls_loss=0 #  BCE loss(mixup) with GT (both Encoder and ODE)
        run_ftr_ode_loss=0 # MSE loss of ODE feature pred
        run_rsk_ode_loss=0  # BCE loss of consistency in pred for ODE
        run_rnk_loss=0  # Concordance Index loss
        
        
        tic=time.time()
        for i in range(0, nupdates): # batch updates in each round of training (epoch) 
            try:
                sample = next(data_iter) 
            except StopIteration:
                # StopIteration is thrown if dataset ends
                # reinitialize data loader 
                data_iter = iter(train_loader)
                sample = next(data_iter)
                
            
            
            loss, cls_loss, ftr_ode_loss, rsk_ode_loss, rank_loss, optimizer, scheduler=train_one_batch(sample, optimizer, scheduler)
            del sample
            
                
            run_loss=run_loss+loss
            run_cls_loss=run_cls_loss+cls_loss
            run_ftr_ode_loss=run_ftr_ode_loss+ftr_ode_loss
            
            run_rsk_ode_loss=run_rsk_ode_loss+rsk_ode_loss
            run_rnk_loss=run_rnk_loss+rank_loss
            
            del loss, cls_loss, ftr_ode_loss, rsk_ode_loss, rank_loss
                
            if (i+1) % 10== 0: # displays after every 10 batch updates
                print ("Epoch [{}/{}], Batch [{}/{}], Train Loss: {:.4f}, Classification: {:.4f}, FTR_ODE: {:.4f}, RSK_ODE: {:.4f}, RANKING: {:.4f}"
                       .format(epochs+1, max_epochs, i+1, nupdates, (run_loss/i), (run_cls_loss/i), (run_ftr_ode_loss/i), (run_rsk_ode_loss/i), (run_rnk_loss/i)), end ="\r")
        
            
        ### End of an epoch. Check validation loss
        metric=complete_inference(val_loader) 
        toc=time.time()
        print('\n Val Metric: '+str(metric)+'  Last Epoch took '+str(toc-tic)+' seconds')
        
        
        run_loss=0 # total loss  
        run_cls_loss=0 #  BCE loss(mixup) with GT (both Encoder and ODE)
        run_ftr_ode_loss=0 # MSE loss of ODE feature pred
        run_r_ode_loss=0  # BCE loss of consistency in pred for ODE
        run_risk_loss=0  # Concordance Index loss
        
        #### Early stopping
        if metric>max_metric:
            max_metric=metric
            patience=0
            print('Validation metric improved !')
            torch.save({
                        'ode_state_dict': ode_func.state_dict(),
                        'encoder_state_dict': encoder_model.state_dict(),
                        'rank_state_dict': rank_model.state_dict(),
                        'classifier_state_dict': classifier_model.state_dict()
            },'best_weight_fld'+str(fld)+'_metric'+str(metric)+'.pt')
        else:
            patience=patience+1
            print('\n Validation metric has not improved in last '+str(patience)+' epochs')
            if patience>max_patience:
                print('Early Stopping !')
                break
    
    

In [7]:
def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)
    
def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [8]:
# 12 X 2 samples:   
# 6 X 2 non-converters + 6 converters(first image has not converted)=18 non converters
# 6 converters.
bce_loss_logits=nn.BCEWithLogitsLoss(pos_weight=torch.FloatTensor([3.0]).to(device))


# for ODE: 6 converter + 6 non-converter

In [None]:
fld=1
stp_sz=0.06 # Neural-ODE 36 month=1  so 1 mnth step_sz=0.03, 2 mnth=0.06, 3mnth=0.08, 4 mnth=0.10
ode_solver='euler' # rk4     euler   midpoint


############### Data Loader #######################################

val_data=validation_dataset(fold=fld)
val_loader=DataLoader(dataset=val_data, batch_size=16, shuffle=False, num_workers=2, 
                             pin_memory=False, drop_last=False, worker_init_fn=worker_init_fn)


train_data=train_dataset(fold=fld, discard_converted=False)
train_loader=DataLoader(dataset=train_data, batch_size=1, shuffle=True, num_workers=2, 
                             pin_memory=False, drop_last=False, worker_init_fn=worker_init_fn)



################## Prepare the model  ############################### 
encoder_model=Encoder_Network2D()
classifier_model=Classification_Network()
ode_func=ODE_network(n_head=12, depth=3, ftr_dim=768)
rank_model=Temporal_Order_Classification()

encoder_model.to(device)
classifier_model.to(device)
ode_func.to(device)
rank_model.to(device)

#######################################################################
train_complete()

607
Epoch [1/150], Batch [200/200], Train Loss: 2.2978, Classification: 0.7626, FTR_ODE: 0.0169, RSK_ODE: 1.1687, RANKING: 0.3497
 CI: 0.7457176243690165  AUC: 0.7971576391228378

 Val Metric: 1.542875263491854  Last Epoch took 722.8822522163391 seconds
Validation metric improved !
Epoch [2/150], Batch [200/200], Train Loss: 1.5130, Classification: 0.4648, FTR_ODE: 0.0107, RSK_ODE: 0.8476, RANKING: 0.1899
 CI: 0.7458335205721834  AUC: 0.7950954196304759

 Val Metric: 1.5409289402026594  Last Epoch took 806.8543956279755 seconds

 Validation metric has not improved in last 1 epochs
Epoch [3/150], Batch [200/200], Train Loss: 1.2183, Classification: 0.3636, FTR_ODE: 0.0090, RSK_ODE: 0.7010, RANKING: 0.1448
 CI: 0.7500369592649101  AUC: 0.8072814229731856

 Val Metric: 1.5573183822380958  Last Epoch took 812.2652823925018 seconds
Validation metric improved !
Epoch [4/150], Batch [200/200], Train Loss: 0.9961, Classification: 0.2755, FTR_ODE: 0.0072, RSK_ODE: 0.6000, RANKING: 0.1133
 CI: 0

Epoch [55/150], Batch [200/200], Train Loss: 0.2974, Classification: 0.0512, FTR_ODE: 0.0010, RSK_ODE: 0.2093, RANKING: 0.0360
 CI: 0.6945123013036941  AUC: 0.7359945222065628

 Val Metric: 1.430506823510257  Last Epoch took 801.3268096446991 seconds

 Validation metric has not improved in last 42 epochs
Epoch [56/150], Batch [200/200], Train Loss: 0.2548, Classification: 0.0322, FTR_ODE: 0.0010, RSK_ODE: 0.1901, RANKING: 0.0315
 CI: 0.7303335376632042  AUC: 0.7644273314334022

 Val Metric: 1.4947608690966065  Last Epoch took 800.6249837875366 seconds

 Validation metric has not improved in last 43 epochs
Epoch [57/150], Batch [200/200], Train Loss: 0.3276, Classification: 0.0646, FTR_ODE: 0.0010, RSK_ODE: 0.2252, RANKING: 0.0368
 CI: 0.7096611261121806  AUC: 0.7546389153481909

 Val Metric: 1.4643000414603715  Last Epoch took 809.5641093254089 seconds

 Validation metric has not improved in last 44 epochs
Epoch [58/150], Batch [200/200], Train Loss: 0.3218, Classification: 0.0453, FTR

In [None]:
.769