In [1]:
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from sksurv.metrics import concordance_index_censored
from sklearn.metrics import roc_auc_score
import time


from model_architecture import *
from losses_paired import *  # we use the ranking loss but only on intra-subject pairs + consistency loss
from Dataloader.Dataloader import *



######################## configure device ###############
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# verify that gpu is recognized
print(device)

################## Set random seem for reproducibility ##########
manualSeed = 9432
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

########## interactive mode for plots ###################
plt.ion()   
%matplotlib inline

print(torch.__version__)

cuda
Random Seed:  9432
2.0.0


In [2]:
def complete_inference(val_loader):
    encoder_model.eval()
    classifier_model.eval()
    
    t_inp=torch.from_numpy(np.array([6.0, 12.0, 18.0, 24.0, 30.0, 36.0])).to(device)
    t_inp=(t_inp/36.0)
    
    gt_lst=[] # 2D array:  sample, 6 time-points
    pred_lst=[]
    risk_scr_lst=[]
    tcnv_lst=[]
    indctr_lst=[]
    nm_lst=[]
    
    for i, sample in enumerate(val_loader):
        img=sample['img'].to(device)
        gt=sample['gt'] # 6-D  conversion within 0/6/12/18/24/30/36 month time-points
        tcnv=sample['tcnv']
        indctr=sample['indctr']
        nm=sample['nm']
        
        img=rearrange(img, 'b 1 c h w -> b c h w')
        
        ### Forward Pass
        with torch.no_grad():
            ftr=encoder_model(img)                    # 2B,768
            
            tmp=t_inp.unsqueeze(dim=0).repeat(img.shape[0],1).to(dtype=torch.float32)
            rsk, pred_logits=classifier_model(ftr, tmp) # 3B, 1  and 3B,1
            pred=F.sigmoid(pred_logits)
            del tmp
        
        pred=pred.detach().cpu().numpy()
        ################################################
        
        pred_lst.append(pred)
        risk_scr_lst.append(rsk.detach().cpu().numpy()) # risk predicted for the current input scan
        nm_lst.append(nm)
        gt_lst.append(gt)
        tcnv_lst.append(tcnv)
        indctr_lst.append(indctr)
        
        del img, gt, tcnv, indctr, nm,rsk, pred
    
    encoder_model.train()
    classifier_model.train()
    
    
    pred_lst=np.concatenate(pred_lst, axis=0) # or stack?  # B,6
    risk_scr_lst=np.concatenate(risk_scr_lst, axis=0)      # B,1
    nm_lst=np.concatenate(nm_lst, axis=0)                  # B,
    gt_lst=np.concatenate(gt_lst, axis=0)                  # B,6
    tcnv_lst=np.concatenate(tcnv_lst, axis=0)              # B,
    indctr_lst=np.concatenate(indctr_lst, axis=0)          # B,
    
    ### Sort the list which is used to define the index for the samples of each bootstrap re-sampling.
    idx=np.argsort(nm_lst, axis=0)
    pred_lst=pred_lst[idx,:]           # B,6
    risk_scr_lst=risk_scr_lst[idx,:]   # B,1
    nm_lst=nm_lst[idx]                 # B,
    gt_lst=gt_lst[idx,:]               # B,6
    tcnv_lst=tcnv_lst[idx]             # B,
    indctr_lst=indctr_lst[idx]         # B,
    del idx
    
    ### Check that the sorted nm_lst is same as the validation dataloader
    flag=np.array_equal(nm_lst, val_data.nm_lst) # this is the order used to define bootstrap samplings
    if flag==False:
        print('the nm_lst doesnot match !')
    
    
    ###### Now everything is sorted by name. So now, we can use the pre-saved indices ###
    indices=val_data.sampling_index
    c_lst=[]
    auc_lst=[]
    for k in range(0, len(indices)): # No. of bootstrap re-samplings
        idx=indices[k]
        tmp_rsk=np.squeeze(risk_scr_lst[idx,:], axis=1) # B,
        tmp_tcnv=tcnv_lst[idx]
        tmp_indctr=indctr_lst[idx]
        c_index = concordance_index_censored(tmp_indctr.astype(bool), tmp_tcnv, tmp_rsk)
        c_index = c_index[0]
        c_lst.append(c_index)
        del c_index, tmp_indctr, tmp_tcnv, tmp_rsk
        
        tmp_gt=gt_lst[idx,:]
        tmp_pred=pred_lst[idx,:]
        auc=[]
        for cls in range(0,6):
            gt=tmp_gt[:, cls]
            pred=tmp_pred[:, cls]
            idx2=np.where(gt !=-1) # -1 implies GT is unavailable (eg. time after censoring has occured)
            gt=gt[idx2]
            pred=pred[idx2]
            # roc_auc_score(y_true, y_score
            auc.append(roc_auc_score(gt, pred))
            del gt, pred, idx2
            # suppose within timepoint 36 months but image is censored at 24 months. then lbl is -1, needs to be avoided
        
        auc=np.expand_dims(np.array(auc), axis=0)
        auc_lst.append(auc) # list of 6-dim arrays
        del auc, idx
        
    
    auc_lst=np.concatenate(auc_lst, axis=0) # B,6
    c_lst=np.array(c_lst)
    
    mean_concordance=np.mean(c_lst)
    avg_auc=np.mean(auc_lst, axis=1) # avg across 6 time-points
    mn_avg_auc=np.mean(avg_auc, axis=0) # avg across each sampling.
    # confidence intervals np.percentile(c_lst, 0.95)   and 0.05
    
    print('\n CI: '+str(mean_concordance)+'  AUC: '+str(mn_avg_auc))
    metric=mean_concordance+mn_avg_auc # this has to be maximized
    return metric

In [3]:
def train_one_batch(sample, optimizer, scheduler):
    
    img1=sample['img1'].to(device)   # B,3,224,224
    img2=sample['img2'].to(device)   # B,3,224,224
    
    img1=rearrange(img1, '1 b 1 c h w -> b c h w')
    img2=rearrange(img2, '1 b 1 c h w -> b c h w')
    
    tintrvl=torch.unsqueeze(torch.squeeze(sample['tintrvl'], dim=0), dim=1).to(device)  # B,1
    
    
    ########################  Forward Pass ########################
    B=img1.shape[0]
    img=torch.cat((img1, img2), dim=0)                  # 2B,3,H,W
    
    #### prediction for img1,img2,img1_to_ftr ####
    ftr=encoder_model(img)                    # 2B,768
    
    # Now add 
    t=torch.cat((torch.zeros((2*B, 1)).to(device), tintrvl), dim=0) # 3B,1
    # for first 2B time, evaluate current risk and predictions so t=0 for last B, predict risk from img1
    ftr=torch.cat((ftr, ftr[0:B, :]), dim=0)
    rsk, pred_logits=classifier_model(ftr, t) # 3B, 1  and 3B,1
        
    ####################################### COMPUTE LOSSES ########################################
    # gt1(O:N, 1:0);  gt2(0:N/2, 1:N/2);   tot 0:1.5N, 1:.5N
    # gt3(0:N/2, 1:N/2)   
    
    #### Consistency Loss ####
    cnstncy_loss=F.binary_cross_entropy_with_logits(pred_logits[(2*B):, 0:1], F.sigmoid(pred_logits[B:(2*B), 0:1]))
    
    ###################### Risk score ranking Loss ###########
    rank_loss=my_risk_concordance_loss_paired(rsk[0:B, 0:1], rsk[B:(2*B),0:1])
                                       
        
    loss=cnstncy_loss+rank_loss
    ################# Backpropagation ###########################
    # remove previously stored gradients
    optimizer.zero_grad()
    # Compute Gradients
    loss.backward()
    # Update weights
    #optimizer.step([cls_loss, cnstncy_loss, rank_loss], [1, 1, 1], None)
    optimizer.step()
    # Update learning rate scheduler
    scheduler.step()
    
    ############  Return Loss for Displaying #####
    #cls_loss=cls_loss.detach().cpu().numpy()
    cnstncy_loss=cnstncy_loss.detach().cpu().numpy()
    rank_loss=rank_loss.detach().cpu().numpy()
    
    loss=loss.detach().cpu().numpy()
    
    return loss, cnstncy_loss, rank_loss, optimizer, scheduler

In [4]:
def train_complete():
    mn_lr=10**(-6.0)
    mx_lr=10**(-4.0)
    
    nupdates=200#1000 # no of batch updates in each epoch
    max_epochs=200
    max_patience=100 # Early stopping if validation metric doesnot improve in this many consecutive epochs
    
    ####
    max_metric=complete_inference(val_loader)
    
    ############################################################################################################
    optimizer = torch.optim.AdamW(list(encoder_model.parameters()) + list(classifier_model.parameters()), lr=mn_lr, amsgrad=True)
    scheduler=torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=mn_lr, max_lr=mx_lr, cycle_momentum=False,
                                            step_size_up=nupdates//2, step_size_down=None, mode='triangular')
    ############################################################################################################
    
    patience=0
    data_iter = iter(train_loader) 
    for epochs in range(0, max_epochs):
        run_loss=0 # total loss  
        #run_cls_loss=0 #  BCE loss(mixup) with GT (both Encoder and ODE)
        run_cnstncy_loss=0 # MSE loss of ODE feature pred
        run_rnk_loss=0  # Concordance Index loss
        
        tic=time.time()
        for i in range(0, nupdates): # batch updates in each round of training (epoch) 
            try:
                sample = next(data_iter) 
            except StopIteration:
                # StopIteration is thrown if dataset ends
                # reinitialize data loader 
                data_iter = iter(train_loader)
                sample = next(data_iter)
                
            
            
            loss, cnstncy_loss, rank_loss, optimizer, scheduler=train_one_batch(sample, optimizer, scheduler)
            del sample
            
                
            run_loss=run_loss+loss
            #run_cls_loss=run_cls_loss+cls_loss
            run_cnstncy_loss=run_cnstncy_loss+cnstncy_loss
            run_rnk_loss=run_rnk_loss+rank_loss
            
            del loss,  cnstncy_loss, rank_loss
                
            if (i+1) % 10== 0: # displays after every 10 batch updates
                print ("Epoch [{}/{}], Batch [{}/{}], Train Loss: {:.4f}, CONSISTENCY: {:.4f}, RANKING: {:.4f}"
                       .format(epochs+1, max_epochs, i+1, nupdates, (run_loss/i), (run_cnstncy_loss/i), (run_rnk_loss/i)), end ="\r")
        
            
        ### End of an epoch. Check validation loss
        metric=complete_inference(val_loader) 
        toc=time.time()
        print('\n Val Metric: '+str(metric)+'  Last Epoch took '+str(toc-tic)+' seconds')
        
        
        run_loss=0 # total loss  
        #run_cls_loss=0 #  BCE loss(mixup) with GT (both Encoder and ODE)
        run_cnstncy_loss=0  # BCE loss of consistency in pred for ODE
        run_rnk_loss=0  # Concordance Index loss
        
        #### Early stopping
        if metric>max_metric:
            max_metric=metric
            patience=0
            print('Validation metric improved !')
            torch.save({
                        'encoder_state_dict': encoder_model.state_dict(),
                        'classifier_state_dict': classifier_model.state_dict()
            },'best_weight_fld'+str(fld)+'_metric'+str(metric)+'.pt')
        else:
            patience=patience+1
            print('\n Validation metric has not improved in last '+str(patience)+' epochs')
            if patience>max_patience:
                print('Early Stopping !')
                break
    
    

In [5]:
def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)
    
def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [6]:
bce_loss_logits=nn.BCEWithLogitsLoss(pos_weight=torch.FloatTensor([3.0]).to(device))

In [None]:
fld=5
############### Data Loader #######################################

val_data=validation_dataset()
val_loader=DataLoader(dataset=val_data, batch_size=16, shuffle=False, num_workers=2, 
                             pin_memory=False, drop_last=False, worker_init_fn=worker_init_fn)


train_data=train_dataset(fold=fld, prcnt=25, discard_converted=False)
train_loader=DataLoader(dataset=train_data, batch_size=1, shuffle=True, num_workers=2, 
                             pin_memory=False, drop_last=False, worker_init_fn=worker_init_fn)

################## Prepare the model  ############################### 
encoder_model=Encoder_Network2D()
classifier_model=Classification_Network()


encoder_model.to(device)
classifier_model.to(device)

############ Load source domain weights ##################
# Load weights
wt_nm='src_weight_fld5_metric1.4942585858396749.pt'

checkpoint=torch.load(wt_nm)
encoder_model.load_state_dict(checkpoint['encoder_state_dict'])
classifier_model.load_state_dict(checkpoint['classifier_state_dict'])
del checkpoint

#######################################################################
train_complete()

903

 CI: 0.6558512329643176  AUC: 0.6601325265482484
Epoch [1/200], Batch [200/200], Train Loss: 0.7310, CONSISTENCY: 0.0417, RANKING: 0.6894
 CI: 0.7287608495785871  AUC: 0.7342547838134881

 Val Metric: 1.4630156333920752  Last Epoch took 180.71966552734375 seconds
Validation metric improved !
Epoch [2/200], Batch [200/200], Train Loss: 0.6412, CONSISTENCY: 0.0072, RANKING: 0.6340
 CI: 0.7301478960558329  AUC: 0.7377587326447499

 Val Metric: 1.4679066287005829  Last Epoch took 176.84829139709473 seconds
Validation metric improved !
Epoch [3/200], Batch [200/200], Train Loss: 0.6181, CONSISTENCY: 0.0030, RANKING: 0.6152
 CI: 0.7321734441306823  AUC: 0.7425580236948551

 Val Metric: 1.4747314678255374  Last Epoch took 179.7371346950531 seconds
Validation metric improved !
Epoch [4/200], Batch [200/200], Train Loss: 0.6151, CONSISTENCY: 0.0088, RANKING: 0.6063
 CI: 0.7236163458120576  AUC: 0.7218453475260985

 Val Metric: 1.445461693338156  Last Epoch took 178.49827075004578 seconds

