In [None]:
import torch
import importlib
import os
import sys
import time
from sklearn.metrics import roc_auc_score

sys.path.append(os.getcwd()+'/model/')
sys.path.append(os.getcwd()+'/losses/')
sys.path.append(os.getcwd()+'/Dataloader/')


from Encoder_model import *
from Classifier_Model import *
from Decoder_model import *
from Comparator_model import *
from Dense_Spatial_Transformation import * 
from MorphSSL_dataloader import *
from Test_dataloader_TTC import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def complete_test(val_loader):
    classifier_model.eval()
    
    gt_lst=[]
    pred_lst=[]
    
    cnt=0
    for i, sample in enumerate(val_loader):     
        img=sample['I']
        gt=sample['gt'] # It has 2 points
        t=sample['t']
        
        img=img.to(device)
        gt=gt.to(device)
        t=t.to(device)
        
        with torch.no_grad():
            ftr=feature_model(img)
            a,b,cmap=classifier_model(ftr)
            pred=torch.sigmoid((t-b)/(a + 0.05))
        
        
        pred=pred.detach().cpu().numpy()
        gt=gt.detach().cpu().numpy()
        
        pred_lst.append(pred)
        gt_lst.append(gt)
        del pred, gt
        
    
    pred_lst=np.concatenate(pred_lst, axis=0)
    gt_lst=np.concatenate(gt_lst, axis=0) # Samples,3
    return pred_lst, gt_lst
    

###############################################################################################################
    

In [None]:
def compute_metrics_tst(pred_lst, gt_lst):
    
    ####### ROC_AUC for 0 month, 6 month, 12 month, 18 month
    tmp_gt=np.squeeze(gt_lst[:,0])
    tmp_pred=np.squeeze(pred_lst[:,0])
    idx=np.where(tmp_gt!=-1)
    tmp_gt=tmp_gt[idx]
    tmp_pred=tmp_pred[idx]
    roc_auc0=roc_auc_score(tmp_gt, tmp_pred)
    
    
    tmp_gt=np.squeeze(gt_lst[:,1])
    tmp_pred=np.squeeze(pred_lst[:,1])
    idx=np.where(tmp_gt!=-1)
    tmp_gt=tmp_gt[idx]
    tmp_pred=tmp_pred[idx]
    roc_auc1=roc_auc_score(tmp_gt, tmp_pred)
    
    
    tmp_gt=np.squeeze(gt_lst[:,2])
    tmp_pred=np.squeeze(pred_lst[:,2])
    idx=np.where(tmp_gt!=-1)
    tmp_gt=tmp_gt[idx]
    tmp_pred=tmp_pred[idx]
    roc_auc2=roc_auc_score(tmp_gt, tmp_pred)
    
    
    tmp_gt=np.squeeze(gt_lst[:,3])
    tmp_pred=np.squeeze(pred_lst[:,3])
    idx=np.where(tmp_gt!=-1)
    tmp_gt=tmp_gt[idx]
    tmp_pred=tmp_pred[idx]
    roc_auc3=roc_auc_score(tmp_gt, tmp_pred)
    
    del tmp_gt, tmp_pred
    
    
    roc_lst=np.array([roc_auc0, roc_auc1, roc_auc2, roc_auc3])
    return roc_lst

In [None]:
def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)
    
def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
fld=3
eps=10**(-8)

################## Prepare Dataloaders ###############################
tst_data=test_dataset(fold='/msc/home/achakr83/PINNACLE/SSL_training/June30/cross_validation_splits_new/fold'+str(fld)+'.npz',
                      img_pth='/msc/home/achakr83/PINNACLE/preprocessed_Downstream_images2/')
tst_loader=DataLoader(dataset=tst_data, batch_size=16, shuffle=False, num_workers=2, 
                             pin_memory=False, drop_last=False, worker_init_fn=worker_init_fn)


################## Prepare the model  ############################### 
feature_model=Encoder_Architecture(base_chnls=16, out_ftr_dim=(64*2))
feature_model.to(device)

classifier_model=Classification_Network()
classifier_model.to(device)


# load the best performing model
wt_nm=str(fld)+'_best_weight_frozen.pt'
checkpoint = torch.load(wt_nm)
feature_model.load_state_dict(checkpoint['feature_model_state_dict_model'])
classifier_model.load_state_dict(checkpoint['classifier_model_state_dict_model'])
del checkpoint
feature_model.eval()
classifier_model.eval()    


pred_lst, gt_lst=complete_test(tst_loader)
roc=compute_metrics_tst(pred_lst, gt_lst)

In [None]:
print("################################################")
print("AUROC")
print("0 month:"+"{:.4f}".format(roc[0]))
print("6 month:"+"{:.4f}".format(roc[1]))
print("12 month:"+"{:.4f}".format(roc[2]))
print("18 month:"+"{:.4f}".format(roc[3]))
print("################################################")