In [1]:
from torch.utils.data import DataLoader
from sksurv.metrics import concordance_index_censored
from sklearn.metrics import roc_auc_score
import random
import matplotlib.pyplot as plt



###################################################################
from model_architecture import *
from losses import *
from Dataloader.Test_Dataloader import *


######################## configure device ###############
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# verify that gpu is recognized
print(device)



################## Set random seem for reproducibility ##########
manualSeed = 9432
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)


########## interactive mode for plots ###################
plt.ion()   
%matplotlib inline

print(torch.__version__)

cuda
Random Seed:  9432
2.0.0


In [2]:
######## Computes Balanced Accuracy ###########
def compute_bacc(gt, pred, thresh):
    
    lst=[]
    for t in thresh:
        tmp_pred=np.zeros_like(pred)
        idx=np.where(pred>t)
        tmp_pred[idx]=1
        del idx
        
        tp=np.where((gt==1) & (tmp_pred==1))[0].shape[0]
        tn=np.where((gt==0) & (tmp_pred==0))[0].shape[0]
        fp=np.where((gt==0) & (tmp_pred==1))[0].shape[0]
        fn=np.where((gt==1) & (tmp_pred==0))[0].shape[0]
    
        sens=tp/(tp+fn)
        spec=tn/(tn+fp)
        acc=(sens+spec)/2
        lst.append(acc)
        
    lst=np.array(lst)
    lst=np.expand_dims(lst, axis=1)
    return lst

In [3]:
############# Compute the AUROC,Concordance Index and Balanced Accuracy ################
def compute_performance(indctr_lst, tcnv_lst, risk_scr_lst, pred_lst, gt_lst, thresh):
    risk_scr_lst=np.squeeze(risk_scr_lst, axis=1)
    
    c_index=concordance_index_censored(indctr_lst.astype(bool), tcnv_lst, risk_scr_lst)
    c_index=c_index[0]
    
    auc_lst=[]
    b_acc_lst=[]
    for cls in range(0,6):
        gt=gt_lst[:, cls]
        pred=pred_lst[:, cls]
        
        idx2=np.where(gt !=-1) # -1 implies GT is unavailable (eg. querying for a time-point after censoring has occured)
        gt=gt[idx2]
        pred=pred[idx2]
        
        auc_lst.append(roc_auc_score(gt, pred))
        b_acc_lst.append(compute_bacc(gt, pred, thresh)) # N,1
        del gt, pred, idx2
    
    auc_lst=np.array(auc_lst)
    auc_lst=np.expand_dims(auc_lst, axis=0)    
    
    b_acc_lst=np.concatenate(b_acc_lst, axis=1) # N,6    
    return auc_lst, c_index, b_acc_lst

In [4]:
######### The inference Code ################
def complete_inference(tst_loader):
    encoder_model.eval()
    classifier_model.eval()
    
    # time-points at which conversion performance is measured (AUROC and Bal. Accuracy)
    t_inp=torch.from_numpy(np.array([6.0, 12.0, 18.0, 24.0, 30.0, 36.0])).to(device) 
    t_inp=(t_inp/36.0)  # Normalization s.t. 3 year corresponds to 1
    
    gt_lst=[] # 2D array:  sample, 6 time-points
    pred_lst=[]
    risk_scr_lst=[]
    tcnv_lst=[]
    indctr_lst=[]
    nm_lst=[]
    
    for i, sample in enumerate(tst_loader):
        img1=sample['img1'].to(device)
        img2=sample['img2'].to(device)
        img3=sample['img3'].to(device)  # average of 3 images, each containing 3 consecutive B-scans from 3D vol
        
        gt=sample['gt'] # 6-D  conversion within 6/12/18/24/30/36 month time-points
        tcnv=sample['tcnv']
        indctr=sample['indctr']
        nm=sample['nm']
        nm=np.array(nm)
        
        img1=rearrange(img1, 'b 1 c h w -> b c h w')
        img2=rearrange(img2, 'b 1 c h w -> b c h w')
        img3=rearrange(img3, 'b 1 c h w -> b c h w')
        
        
        ### Forward Pass
        with torch.no_grad():
            
            ftr1=encoder_model(img1)
            ftr2=encoder_model(img2)
            ftr3=encoder_model(img3)
            
            tmp=t_inp.unsqueeze(dim=0).repeat(img1.shape[0],1).to(dtype=torch.float32)
            rsk_curr1, pred_logits1=classifier_model(ftr1, tmp) # 3B, 1  and 3B,1
            rsk_curr2, pred_logits2=classifier_model(ftr2, tmp) # 3B, 1  and 3B,1
            rsk_curr3, pred_logits3=classifier_model(ftr3, tmp) # 3B, 1  and 3B,1
            
            p1=F.sigmoid(pred_logits1)
            p2=F.sigmoid(pred_logits2)
            p3=F.sigmoid(pred_logits3)
            del tmp 
                
        p1=p1.detach().cpu().numpy()
        p2=p2.detach().cpu().numpy()
        p3=p3.detach().cpu().numpy()
        
        rsk_curr1=rsk_curr1.detach().cpu().numpy()
        rsk_curr2=rsk_curr2.detach().cpu().numpy()
        rsk_curr3=rsk_curr3.detach().cpu().numpy()
        
        p=(p1+p2+p3)/3.0
        rsk_curr=(rsk_curr1+rsk_curr2+rsk_curr3)/3.0
        ################################################
        
        pred_lst.append(p)
        risk_scr_lst.append(rsk_curr) # risk predicted for the current input scan
        nm_lst.append(nm)
        gt_lst.append(gt)
        tcnv_lst.append(tcnv)
        indctr_lst.append(indctr)
        
        del img1, img2, img3, gt, tcnv, indctr, nm, ftr1, ftr2, ftr3, 
        del rsk_curr1, rsk_curr2, rsk_curr3
        del p1, p2, p3, p, rsk_curr
    
    
    
    pred_lst=np.concatenate(pred_lst, axis=0) # or stack?  # B,6
    risk_scr_lst=np.concatenate(risk_scr_lst, axis=0)      # B,1
    nm_lst=np.concatenate(nm_lst, axis=0)                  # B,
    gt_lst=np.concatenate(gt_lst, axis=0)                  # B,6
    tcnv_lst=np.concatenate(tcnv_lst, axis=0)              # B,
    indctr_lst=np.concatenate(indctr_lst, axis=0)          # B,
    
    
    ### Sort the list which is used to define the index for the samples of each bootstrap re-sampling.
    idx=np.argsort(nm_lst, axis=0)
    pred_lst=pred_lst[idx,:]           # B,6
    risk_scr_lst=risk_scr_lst[idx,:]   # B,1
    nm_lst=nm_lst[idx]                 # B,
    gt_lst=gt_lst[idx,:]               # B,6
    tcnv_lst=tcnv_lst[idx]             # B,
    indctr_lst=indctr_lst[idx]         # B,
    del idx
    
    ### Check that the sorted nm_lst is same as the validation dataloader
    ## This ordering is required to compute the boot-strap performance at the eye-level. 
    flag=np.array_equal(nm_lst, tst_data.nm_lst) # this is the order used to define bootstrap samplings
    if flag==False:
        print('the nm_lst doesnot match !')
        
    ####### save the predictions   ##############
    np.savez('save_predictions.npz', pred_lst=pred_lst, risk_scr_lst=risk_scr_lst, nm_lst=nm_lst,
             gt_lst=gt_lst, tcnv_lst=tcnv_lst, indctr_lst=indctr_lst, indices=tst_data.sampling_index)
    
    
    ##### Compute the scan-level performance ####
    thresh=np.sort(np.unique(pred_lst.flatten()))
    scn_lvl_auc, scn_lvl_ci, scn_lvl_bacc=compute_performance(indctr_lst, tcnv_lst, risk_scr_lst, pred_lst, gt_lst, thresh)
    scn_lvl_bacc=np.max(scn_lvl_bacc, axis=0)   
    
    ###### Now everything is sorted by name. So now, we can use the pre-saved indices ###
    # In Each bootstrap only scan from 1 visit per eye is randomly selected. Performance is averaged across 1000 bootstrap re-samplings
    # The indices of the scans used in each bootstrap are pre-saved inside tst_data.sampling_index to ensure same bootstrap samplings are used for different methods.
    indices=tst_data.sampling_index
    c_lst=[]
    auc_lst=[]
    bacc_lst=[]
    ############ The following part would take some time as it computes performance for 1000 different samplings
    for k in range(0, len(indices)): # No. of bootstrap re-samplings 
        if (k % 100)==0:
            print(k)
        
        idx=indices[k]  # contains indices of which scans are in the current bootstrap sampling
        tmp_rsk=risk_scr_lst[idx,:] # B,1
        tmp_tcnv=tcnv_lst[idx]
        tmp_indctr=indctr_lst[idx]
        tmp_gt=gt_lst[idx,:]
        tmp_pred=pred_lst[idx,:]
        
        tmp_auc, tmp_ci, tmp_bacc=compute_performance(tmp_indctr, tmp_tcnv, tmp_rsk, tmp_pred, tmp_gt, thresh)
        # tmp_auc, tmp_ci are scalar values.     tmp_bacc: N,6
        
        auc_lst.append(tmp_auc)
        c_lst.append(tmp_ci)
        bacc_lst.append(tmp_bacc)
        del idx, tmp_rsk, tmp_tcnv, tmp_indctr, tmp_gt, tmp_pred, tmp_auc, tmp_ci, tmp_bacc
        
    
    auc_lst=np.concatenate(auc_lst, axis=0)   # 1000,6
    c_lst=np.array(c_lst)                     # 1000,6
    bacc_lst=np.stack(bacc_lst, axis=0)       # 1000,N,6
    
    tmp=np.mean(bacc_lst, axis=0) # N,6
    print('thresholds Scan Level: ')
    idx0=np.argmax(tmp[:,0])
    idx1=np.argmax(tmp[:,1])
    idx2=np.argmax(tmp[:,2])
    idx3=np.argmax(tmp[:,3])
    idx4=np.argmax(tmp[:,4])
    idx5=np.argmax(tmp[:,5])
    
    bacc_lst=np.concatenate([bacc_lst[:,idx0,0:1], bacc_lst[:,idx1,1:2], bacc_lst[:,idx2,2:3], bacc_lst[:,idx3,3:4],
              bacc_lst[:,idx4,4:5], bacc_lst[:,idx5,5:6]], axis=1)
    print('bacc_lst: '+str(bacc_lst.shape))
    del tmp, idx0, idx1, idx2, idx3, idx4, idx5
    
    ######## Save the metrics #######
    np.savez('performance.npz', scn_lvl_auc=scn_lvl_auc, scn_lvl_ci=scn_lvl_ci, scn_lvl_bacc=scn_lvl_bacc,
             bootstrp_auc=auc_lst, bootstrp_ci=c_lst, bootstrp_bacc=bacc_lst)
    return

In [5]:
def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)
    
def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [6]:
############### Data Loader #######################################

tst_data=test_dataset()
tst_loader=DataLoader(dataset=tst_data, batch_size=16, shuffle=False, num_workers=2, 
                             pin_memory=False, drop_last=False, worker_init_fn=worker_init_fn)



################## Prepare the model  ############################### 
encoder_model=Encoder_Network2D()
classifier_model=Classification_Network()

encoder_model.to(device)
classifier_model.to(device)

#######################################################################

# Load weights
checkpoint=torch.load('src_weight_fld5_metric1.4942585858396749.pt')
encoder_model.load_state_dict(checkpoint['encoder_state_dict'])
classifier_model.load_state_dict(checkpoint['classifier_state_dict'])
del checkpoint

# test mode
encoder_model.eval()
classifier_model.eval()


complete_inference(tst_loader)

1887
0
100
200
300
400
500
600
700
800
900
thresholds Scan Level: 
bacc_lst: (1000, 6)
