In [None]:
import os, datetime
from pathlib import Path
import yaml
import numpy as np
import pandas as pd
import torch
import scipy
import sklearn
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import scikit_posthocs as sp
import torchmetrics
import pytorch_lightning as L
from hnc_foundation_dm_prediction import data_prep as dp
from hnc_foundation_dm_prediction.pytorch.run_model_lightning import RunModel
from hnc_foundation_dm_prediction.pytorch.user_metrics import MMetric
from MLstatkit.stats import Delong_test

In [None]:
# all of the metrics used in the evaluation, can add or remove as desired
auc_fn = torchmetrics.classification.BinaryAUROC()
ap_fn = torchmetrics.classification.BinaryAveragePrecision()
spe_fn = torchmetrics.classification.BinarySpecificity()
sen_fn = torchmetrics.classification.BinaryRecall()
mul_auc_fn = torchmetrics.classification.MultilabelAUROC(num_labels=4, average=None)
roc_fn = torchmetrics.classification.BinaryROC()
pr_fn = torchmetrics.classification.BinaryPrecisionRecallCurve()

In [None]:
# resets metric functions, should be run for each successive evaluation of probabilities as values entered into functions are persistent until reset
auc_fn.reset()
ap_fn.reset()
spe_fn.reset()
sen_fn.reset()
roc_fn.reset()
pr_fn.reset()
mul_auc_fn.reset()

In [None]:
#no_graph_log_dir = './logs/lightning_no_graph_feat64_true_weight3_balance_real_dp3_l21e6_11180_minmax_rot3_balance_newclinical_nolrfinder_nocensor_v54'
#no_graph_log_dir = './logs/lightning_gtvp_gcn_foundation_rad_nopool_weight3_32_dp3_v86'
#no_graph_log_dir = './logs/lightning_gtvp_gcn_rad_nopool_weight3_32_dp3_v92'
#no_graph_log_dir = './logs/lightning_gtvp_gcn_rad_foundation_image_nopool_weight3_32_dp0_v99'
no_graph_log_dir = './logs/lightning_gtvp_image_vit_nomask_nopool_weight7_22_dp2_v110'
#no_graph_log_dir = './logs/lightning_gtvp_spottune18_weight7_dp3_v77'
top_graph_dir = './logs/lightning_graph_feat64_true_weight3_balance_real_dp3_l21e6_11180_minmax_rot3_balance_newclinical_nolrfinder_nocensor_v55'
tmp_dir = './logs/'
#graph_log_dir = './logs/lightning_graph_feat64_weight1_balance_real_dp3_l21e6_11180_minmax_rot3_balance_newclinical_nolrfinder_nocensor_v49'
spottune18_dir = './logs/lightning_spottune18_graph_weight1_v69'
#graph_log_dir = './logs/lightning_spottune18_graph_weight1_v69'
#graph_log_dir = './logs/lightning_gtv_gan_foundation_nopool_weight3_dp3_v83'
#graph_log_dir = './logs/lightning_gtv_gan_foundation_rad_nopool_weight3_dp3_v84'
#graph_log_dir = './logs/lightning_gtv_gcn_foundation_rad_nopool_weight3_dp3_v85'
#graph_log_dir = './logs/lightning_gtv_gcn_image_foundation_rad_nopool_weight3_dp3_v85'
#graph_log_dir = './logs/lightning_gtvp_gcn_foundation_nopool_weight3_32_dp3_v87'
#graph_log_dir = './logs/lightning_gtvp_gcn_rad_nopool_weight3_32_dp3_v88'
#graph_log_dir = './logs/lightning_gtvp_gcn_rad_foundation_nopool_weight3_32_dp0_v96'
#graph_log_dir = './logs/lightning_gtvp_gcn_rad_nopool_weight3_32_dp0_v98'
graph_log_dir = './logs/lightning_gtvp_foundation_image_vit_nomask_nopool_weight7_22_dp2_v111'
#graph_log_dir = './logs/lightning_gtv_plusradiomics_gat_vit_weight3_dp3_v78'
#graph_log_dir = './logs/lightning_undirected_edge_graph_feat64_true_weight3_balance_real_dp3_l21e6_11180_minmax_rot3_balance_newclinical_nolrfinder_nocensor_v56'
#graph_log_dir = './logs/lightning_star_graph_v59'
#graph_log_dir = './logs/lightning_star_graph_gat_vit_v61'
#graph_log_dir = './logs/lightning_star_graph_gat_v60'
#graph_log_dir = './logs/lightning_reverse_edge_graph_feat64_true_weight3_balance_real_dp3_l21e6_11180_minmax_rot3_balance_newclinical_nolrfinder_nocensor_v56'
#graph_log_dir = './logs/lightning_undirected_edge_graph_feat64_true_weight3_balance_real_dp4_l21e6_11180_minmax_rot3_balance_newclinical_nolrfinder_nocensor_v57'
#graph_log_dir = './logs/lightning_undirected_edge_graph_multi_label_v58'

foundation_log_dir = 'logs/lightning_gtvp_foundation_linear_mask_nopool_weight7_22_dp2_v124/metric_dfs/test_predictions.pkl'
radiomics_log_dir = 'logs/lightning_gtvp_rad_linear_mask_nopool_weight7_22_dp2_v123/metric_dfs/test_predictions.pkl'
image_log_dir = 'logs/lightning_gtvp_image_vit_linear_mask_nopool_weight7_22_dp2_v120/metric_dfs/test_predictions.pkl'
foundation_image_log_dir
rad_image_log_dir
foundation_rad_log_dir
test_pred_file = 'metric_dfs/test_predictions.pkl'
val_pred_file = 'metric_dfs/val_predictions.pkl'
m_fn = MMetric(0.6, 0.4)

In [None]:
test_graph_pred = pd.read_pickle(os.path.join(graph_log_dir, test_pred_file))
test_no_graph_pred = pd.read_pickle(os.path.join(no_graph_log_dir, test_pred_file))

foundation_pred =
radiomics_pred =
image_pred =

In [None]:
test_graph_avg = []
test_no_graph_avg = []
test_graph_m = []
test_no_graph_m = []
#targets = test_graph_pred['targets'][0].reshape((test_graph_pred['val_auc'][0].size()[0], 4))
targets = test_graph_pred['targets'][0]
for idx in range(5):
    test_graph_avg.append(list(test_graph_pred['val_auc'][idx]))
    test_no_graph_avg.append(list(test_no_graph_pred['val_auc'][idx]))
    test_graph_m.append(m_fn(sen_fn(test_graph_pred['val_auc'][idx], targets), spe_fn(test_graph_pred['val_auc'][idx], targets)))
    test_no_graph_m.append(m_fn(sen_fn(test_no_graph_pred['val_auc'][idx], test_no_graph_pred['targets'][idx]), spe_fn(test_no_graph_pred['val_auc'][idx], test_no_graph_pred['targets'][idx])))

test_graph_m = torch.tensor(test_graph_m)
test_no_graph_m = torch.tensor(test_no_graph_m)

test_graph_weight = test_graph_m / test_graph_m.sum()
test_no_graph_weight = test_no_graph_m / test_no_graph_m.sum()
test_graph_weight_update = []
test_no_graph_weight_update = []
for idx in range(5):
    test_graph_weight_update.append([test_graph_weight[idx]] * len(test_graph_pred['targets'][0]))
    test_no_graph_weight_update.append([test_no_graph_weight[idx]] * len(test_no_graph_pred['targets'][0]))


In [None]:
test_graph_avg = np.average(test_graph_avg, axis=0, weights=test_graph_weight_update)    
test_no_graph_avg = np.average(test_no_graph_avg, axis=0, weights=test_no_graph_weight_update)    

In [None]:
for pred_file in sorted(log_path.rglob('metric_dfs/*predictions.pkl'), key=lambda x: x.as_posix().split('/')[1].split('_')[-1]):
    print(pred_file)

In [None]:
pd.set_option('display.max_rows', None)

In [None]:
log_path = Path('./logs')
metric_dict = {}
targets = None
target_folds = []
pred_dict = {}
pred_fold_dict = {}
#for idx, pred_file in enumerate([
#    'logs/lightning_gtvp_foundation_linear_mask_nopool_weight7_22_dp2_v124/metric_dfs/test_predictions.pkl',
#    'logs/lightning_gtvp_rad_linear_mask_nopool_weight7_22_dp2_v123/metric_dfs/test_predictions.pkl',
#    'logs/lightning_gtvp_image_vit_linear_mask_nopool_weight7_22_dp2_v120/metric_dfs/test_predictions.pkl',
#    'logs/lightning_gtvp_rad_image_vit_linear_mask_nopool_weight7_22_dp2_v121/metric_dfs/test_predictions.pkl',
#    'logs/lightning_gtvp_foundation_image_vit_linear_mask_nopool_weight7_22_dp2_v122/metric_dfs/test_predictions.pkl',
#    'logs/lightning_gtvp_rad_foundation_image_vit_linear_mask_nopool_weight7_22_dp2_v123/metric_dfs/test_predictions.pkl', 
#]) :
for pred_file in sorted(log_path.rglob('metric_dfs/*predictions.pkl'), key=lambda x: x.as_posix().split('/')[1].split('_')[-1]):
    pred_file = Path(pred_file)
    if 'test_predictions' not in pred_file.as_posix(): continue
    pred_avg = []
    model = pred_file.as_posix().split('/')[1]
    if int(model.split('_')[-1].split('v')[-1]) < 90: continue
    #print(model)
    predictions = pd.read_pickle(pred_file)
    if 'val_auc' not in predictions.keys(): continue
    targets = predictions['targets'][0]
    for idx in range(5):
        pred_avg.append(list(predictions['val_auc'][idx]))
    pred_fold_dict[model] = pred_avg
    pred_avg = np.average(pred_avg, axis=0)
    pred_dict[model] = pred_avg
    #print(len(pred_avg), len(targets))
    auc = auc_fn(torch.tensor(pred_avg), targets)
    ap = ap_fn(torch.tensor(pred_avg), targets.to(torch.long))
    sen = sen_fn(torch.tensor(pred_avg), targets)
    spe = spe_fn(torch.tensor(pred_avg), targets)
    metric_dict[model] = [float(auc), float(ap), float(sen), float(spe)]
metric_df = pd.DataFrame(metric_dict, columns=['AUC', 'AP', 'SEN', 'SPE'], index=metric_dict.keys())
for key in metric_dict.keys():
    metric_df.loc[key] = metric_dict[key]
metric_df



In [None]:
pred_dict['foundation_rad_avg'] = np.average([
    pred_dict['lightning_gtvp_foundation_linear_mask_nopool_weight7_22_dp2_v124'],
    pred_dict['lightning_gtvp_rad_linear_mask_nopool_weight7_22_dp2_v123']
], axis=0)
pred_dict['foundation_image_avg'] = np.average([
    pred_dict['lightning_gtvp_foundation_linear_mask_nopool_weight7_22_dp2_v124'],
    pred_dict['lightning_gtvp_image_vit_linear_mask_nopool_weight7_22_dp2_v120']
], axis=0)

pred_dict['rad_image_avg'] = np.average([
    pred_dict['lightning_gtvp_rad_linear_mask_nopool_weight7_22_dp2_v123'],
    pred_dict['lightning_gtvp_image_vit_linear_mask_nopool_weight7_22_dp2_v120']
], axis=0)
pred_dict['foundation_rad_image_avg'] = np.average([
    pred_dict['lightning_gtvp_foundation_linear_mask_nopool_weight7_22_dp2_v124'],
    pred_dict['lightning_gtvp_rad_linear_mask_nopool_weight7_22_dp2_v123'],
    pred_dict['lightning_gtvp_image_vit_linear_mask_nopool_weight7_22_dp2_v120']
], axis=0)


In [None]:
model_1 = 'lightning_gtvp_rad_image_vit_linear_nomask_nopool_weight7_22_dp2_v114'
model_2 = 'lightning_gtvp_image_vit_linear_nomask_nopool_weight7_22_dp2_v114'
model_3 = 'lightning_gtvp_foundation_image_vit_linear_nomask_nopool_weight7_22_dp2_v115' 
model_4 = 'lightning_gtvp_rad_foundation_image_vit_linear_nomask_nopool_weight7_22_dp2_v116'
model_5 = 'lightning_gtvp_image_vit_linear_mask_nopool_weight7_22_dp2_v120'
model_6 = 'lightning_gtvp_rad_image_vit_linear_mask_nopool_weight7_22_dp2_v121'
model_7 = 'lightning_gtvp_foundation_image_vit_linear_mask_nopool_weight7_22_dp2_v122'
model_8 = 'lightning_gtvp_rad_foundation_image_vit_linear_mask_nopool_weight7_22_dp2_v123'
model_9 = 'lightning_gtvp_rad_linear_mask_nopool_weight7_22_dp2_v123'
model_10 = 'lightning_gtvp_foundation_linear_mask_nopool_weight7_22_dp2_v124'
for idx in range(1,10+1):
    for jdx in range(1,10+1):
        if idx == jdx: continue
        print(globals().get(f'model_{idx}'), globals().get(f'model_{jdx}'))    
        print(Delong_test(targets, pred_dict[globals().get(f'model_{idx}')],
                         pred_dict[globals().get(f'model_{jdx}')]))

In [None]:
t, p = scipy.stats.ttest_rel(test_graph_avg, test_no_graph_avg)
print(p)

In [None]:
for idx in range(5):
    #t1, p1 = Delong_test(val_graph_pred['targets'][idx], val_graph_pred['val_auc'][idx], val_no_graph_pred['val_auc'][idx])
    t2, p2 = Delong_test(test_graph_pred['targets'][idx], test_graph_pred['val_auc'][idx], test_no_graph_pred['val_auc'][idx])
    print(p2)

In [None]:
Delong_test(test_graph_pred['targets'][1], test_graph_avg, test_no_graph_avg)

In [None]:
#graph_mul_auc = mul_auc_fn(torch.tensor(test_graph_avg), targets.to(torch.long))
graph_auc = auc_fn(torch.tensor(test_graph_avg), targets)
no_graph_auc = auc_fn(torch.tensor(test_no_graph_avg), test_no_graph_pred['targets'][0])
graph_ap = ap_fn(torch.tensor(test_graph_avg), targets.to(torch.long))
no_graph_ap = ap_fn(torch.tensor(test_no_graph_avg), test_no_graph_pred['targets'][0].to(torch.long))
graph_sen = sen_fn(torch.tensor(test_graph_avg), targets)
no_graph_sen = sen_fn(torch.tensor(test_no_graph_avg), test_no_graph_pred['targets'][0])
graph_spe = spe_fn(torch.tensor(test_graph_avg), targets)
no_graph_spe = spe_fn(torch.tensor(test_no_graph_avg), test_no_graph_pred['targets'][0])
print('auc:', graph_auc, no_graph_auc)
print('ap:', graph_ap, no_graph_ap)
print('sen:', graph_sen, no_graph_sen)
print('spe:', graph_spe, no_graph_spe)

In [None]:
auc_fn.reset()
auc_fn.update(torch.tensor(test_graph_avg), test_graph_pred['targets'][0])
graph_auc = auc_fn.compute()
auc_fn.reset()
auc_fn.update(torch.tensor(test_no_graph_avg), test_no_graph_pred['targets'][0])
no_graph_auc = auc_fn.compute()

ap_fn.reset()
ap_fn.update(torch.tensor(test_graph_avg), test_graph_pred['targets'][0].to(torch.long))
graph_ap = ap_fn.compute()
ap_fn.reset()
ap_fn.update(torch.tensor(test_no_graph_avg), test_no_graph_pred['targets'][0].to(torch.long))
no_graph_ap = ap_fn.compute()

sen_fn.reset()
sen_fn.update(torch.tensor(test_graph_avg), test_graph_pred['targets'][0])
graph_sen = sen_fn.compute()
sen_fn.reset()
sen_fn.update(torch.tensor(test_no_graph_avg), test_no_graph_pred['targets'][0])
no_graph_sen = sen_fn.compute()

spe_fn.reset()
spe_fn.update(torch.tensor(test_graph_avg), test_graph_pred['targets'][0])
graph_spe = spe_fn.compute()
spe_fn.reset()
spe_fn.update(torch.tensor(test_no_graph_avg), test_no_graph_pred['targets'][0])
no_graph_spe = spe_fn.compute()

print('auc:', graph_auc, no_graph_auc)
print('ap:', graph_ap, no_graph_ap)
print('sen:', graph_sen, no_graph_sen)
print('spe:', graph_spe, no_graph_spe)

In [None]:
pred_dict

In [None]:
fold_auc = {}
fold_ap = {}
fold_sen = {}
fold_spe = {}

for model in pred_fold_dict.keys():
    fold_auc[model] = []
    fold_ap[model] = []
    fold_sen[model] = []
    fold_spe[model] = []
    for idx in range(5):
        fold_auc[model].append(auc_fn(torch.tensor(pred_fold_dict[model][idx]), targets))
        fold_ap[model].append(ap_fn(torch.tensor(pred_fold_dict[model][idx]), targets.to(torch.long)))
        fold_sen[model].append(sen_fn(torch.tensor(pred_fold_dict[model][idx]), targets))
        fold_spe[model].append(spe_fn(torch.tensor(pred_fold_dict[model][idx]), targets))

fold_auc

In [None]:
fold_spe

In [None]:
model_auc = {}
model_ap = {}
model_sen = {}
model_spe = {}
for model in pred_dict.keys():
    
    
    model_auc[model] = auc_fn(torch.tensor(pred_dict[model]), targets)
    model_ap[model] = ap_fn(torch.tensor(pred_dict[model]), targets.to(torch.long))
    model_sen[model] = sen_fn(torch.tensor(pred_dict[model]), targets)
    model_spe[model] = spe_fn(torch.tensor(pred_dict[model]), targets)
model_auc

In [None]:
for model in pred_dict.keys():
    print(model)
    print(f'    auc: {model_auc[model]}')
    print(f'    ap: {model_ap[model]}')
    print(f'    sen: {model_sen[model]}')
    print(f'    spe: {model_spe[model]}')

In [None]:
from sklearn.metrics import RocCurveDisplay
target = test_graph_pred['targets'][0]
#figure = plt.figure()
#axes = figure.add_subplot(111)
display = RocCurveDisplay.from_predictions(target, test_graph_avg, name='CNN+GNN', plot_chance_level=True)
RocCurveDisplay.from_predictions(target, test_no_graph_avg, name='CNN-only', plot_chance_level=False, ax=display.ax_)
plt.grid(visible=True, which='both')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
handles, labels = plt.gca().get_legend_handles_labels()
line_styles = ['-', '--', '-.', ':']
for line, ls in zip(display.ax_.get_lines(), line_styles):
    line.set_linestyle(ls)
plt.legend([handles[idx] for idx in [0, 2, 1]], [labels[idx] for idx in [0, 2, 1]])
display.figure_.savefig(f"roc_curve_GNN.pdf", dpi=600)
plt.show()

In [None]:
metrics = [0.791, 0.5, 0.772, 0.753, 0.794, 0.788] 

In [None]:
from sklearn.metrics import RocCurveDisplay
#figure = plt.figure()
#axes = figure.add_subplot(111)
display = RocCurveDisplay.from_predictions(targets, pred_dict['lightning_gtvp_foundation_linear_mask_nopool_weight7_22_dp2_v124'], name='Foundation Emb.', plot_chance_level=True)
RocCurveDisplay.from_predictions(targets, pred_dict['lightning_gtvp_rad_linear_mask_nopool_weight7_22_dp2_v123'], name='Radiomics', plot_chance_level=False, ax=display.ax_)
RocCurveDisplay.from_predictions(targets, pred_dict['lightning_gtvp_image_vit_linear_mask_nopool_weight7_22_dp2_v120'].ravel(), name='Image w/mask', plot_chance_level=False, ax=display.ax_)
RocCurveDisplay.from_predictions(targets, pred_dict['lightning_gtvp_rad_image_vit_linear_mask_nopool_weight7_22_dp2_v121'].ravel(), name='Rad. + Image w/mask', plot_chance_level=False, ax=display.ax_)
RocCurveDisplay.from_predictions(targets, pred_dict['lightning_gtvp_image_spottune_nograph_nopool_weight7_22_dp2_v143'].ravel(), name='Spottune: Image w/o mask', plot_chance_level=False, ax=display.ax_)
plt.grid(visible=True, which='both')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
handles, labels = plt.gca().get_legend_handles_labels()
line_styles = ['-', '--', ':', '-.', 'dashed', '-.']
for line, ls in zip(display.ax_.get_lines(), line_styles):
    line.set_linestyle(ls)
plt.legend([handles[idx] for idx in [0, 2, 3, 4, 5, 1]], 
           #[labels[idx] for idx in [0, 2, 3, 4, 5, 1]])
           [labels[idx].replace(''.join(labels[idx][-5:]), f'{metrics[idx]:.3f})')  for idx in [0, 2, 3, 4, 5, 1]])
display.figure_.savefig(f"roc_curve_foundation_v3.png", dpi=600)
plt.show()

In [None]:
display.figure_.savefig(f"roc_curve_GNN.png", dpi=600)

In [None]:
fold_auc.keys()

In [None]:
torch.tensor(fold_auc['lightning_gtvp_rad_image_vit_linear_mask_nopool_weight7_22_dp2_v121']).std()

### Testing models

In [None]:
graph_log_dir = './logs/lightning_graph_feat64_weight9_balance_real_dp3_l21e6_11180_minmax_rot3_balance_newclinical_nolrfinder_nocensor_v50'
no_graph_log_dir = './logs/lightning_no_graph_feat64_weight7p5_balance_real_dp3_l21e6_11180_minmax_rot3_balance_newclinical_nolrfinder_nocensor_v52'
test_pred_file = 'metric_dfs/test_predictions.pkl'
gnn_config_file = os.path.join(graph_log_dir, 'csvlog_fold_0','lightning_logs','version_1','hparams.yaml')
cnn_config_file = os.path.join(no_graph_log_dir, 'csvlog_fold_0','lightning_logs','version_0','hparams.yaml')
graph_m_models = [
    'model_m_epoch=47_val_loss=1.33_val_auc=0.74_val_m=0.75.ckpt',
    'model_m_epoch=47_val_loss=1.24_val_auc=0.74_val_m=0.75.ckpt',
    'model_m_epoch=56_val_loss=1.41_val_auc=0.76_val_m=0.75.ckpt',
    'model_m_epoch=14_val_loss=1.23_val_auc=0.74_val_m=0.77.ckpt',
    'model_m_epoch=73_val_loss=1.46_val_auc=0.77_val_m=0.81.ckpt',
]
graph_loss_models = [
    'model_loss_epoch=04_val_loss=1.28_val_auc=0.70_val_m=0.62.ckpt',
    'model_loss_epoch=50_val_loss=1.12_val_auc=0.75_val_m=0.69.ckpt',
    'model_loss_epoch=58_val_loss=1.20_val_auc=0.78_val_m=0.60.ckpt',
    'model_loss_epoch=14_val_loss=1.23_val_auc=0.74_val_m=0.77.ckpt',
    'model_loss_epoch=58_val_loss=1.41_val_auc=0.75_val_m=0.72.ckpt',
]

In [None]:
gnn_config = yaml.safe_load(open(gnn_config_file, 'r'))
cnn_config = yaml.safe_load(open(cnn_config_file, 'r'))



In [None]:
gnn_model = RunModel(config=gnn_config['config'])
cnn_model = RunModel(config=cnn_config['config'])

#gnn_model.config['remove_censored'] = False
#cnn_model.config['remove_censored'] = False

In [None]:
gnn_model.config['remove_censored'] = False
gnn_model.config['remove_censored']

In [None]:
gnn_model.set_model()
gnn_model.set_data()
gnn_model.set_train_test_split_challenge()
gnn_model.set_data_module()
gnn_model.set_callbacks(5)
idx = 5
gnn_model.trainer = L.Trainer(
                max_epochs=gnn_model.config['n_epochs'],
                accelerator="auto",
                devices=gnn_model.config['gpu_device'] if torch.cuda.is_available() else None,
                logger=[L.loggers.CSVLogger(save_dir=os.path.join(gnn_model.log_dir, f"csvlog_fold_{idx}")), L.loggers.TensorBoardLogger(save_dir=os.path.join(gnn_model.log_dir, f"tb_fold_{idx}"))],
                callbacks=gnn_model.callbacks,
                #check_val_every_n_epoch = 1,
                #auto_lr_find=True
)

In [None]:
cnn_model.set_model()
cnn_model.set_data()
cnn_model.set_train_test_split_challenge()
cnn_model.set_data_module()
cnn_model.set_callbacks(5)
idx = 5
cnn_model.trainer = L.Trainer(
                max_epochs=cnn_model.config['n_epochs'],
                accelerator="auto",
                devices=cnn_model.config['gpu_device'] if torch.cuda.is_available() else None,
                logger=[L.loggers.CSVLogger(save_dir=os.path.join(cnn_model.log_dir, f"csvlog_fold_{idx}")), L.loggers.TensorBoardLogger(save_dir=os.path.join(cnn_model.log_dir, f"tb_fold_{idx}"))],
                callbacks=cnn_model.callbacks,
                #check_val_every_n_epoch = 1,
                #auto_lr_find=True
)

In [None]:
model_idx = 0
gnn_model.trainer.test(gnn_model.model,
                       datamodule=gnn_model.data_module_cross_val[model_idx], 
                       ckpt_path=os.path.join(graph_log_dir, 
                                              f"top_models_fold_{model_idx}", 
                                              graph_loss_models[model_idx])
                      )

In [None]:
model_idx = 0
cnn_model.trainer.test(cnn_model.model,
                       datamodule=cnn_model.data_module_cross_val[model_idx], 
                       ckpt_path=os.path.join(graph_log_dir, 
                                              f"top_models_fold_{model_idx}", 
                                              graph_loss_models[model_idx])
                      )

In [None]:
model_idx = 0
test_pred = gnn_model.trainer.predict(gnn_model.model,
                       gnn_model.data_module_cross_val[model_idx].test_dataloader(), 
                       ckpt_path=os.path.join(graph_log_dir, 
                                              f"top_models_fold_{model_idx}", 
                                              graph_loss_models[model_idx])
                      )

In [None]:
model_idx = 0
test_pred = cnn_model.trainer.predict(cnn_model.model,
                       cnn_model.data_module_cross_val[model_idx].test_dataloader(), 
                       ckpt_path=os.path.join(graph_log_dir, 
                                              f"top_models_fold_{model_idx}", 
                                              graph_loss_models[model_idx])
                      )

In [None]:
tmp_gnn_test_targets = []
for batch in gnn_model.data_module_cross_val[model_idx].test_dataloader():
    tmp_test_targets.append(batch.y)

In [None]:
auc_fn.update(torch.cat(gnn_model.model.test_preds).to('cpu'), torch.cat(gnn_model.model.test_targets).to('cpu'))
auc_fn.compute()

In [None]:
auc_fn.reset()
for idx in range(len(tmp_test_targets)):
    auc_fn.update(test_pred[idx], tmp_test_targets[idx])

auc_fn.compute()

In [None]:
auc_fn.reset()
auc_fn.update(torch.cat(test_pred), torch.cat(tmp_test_targets))
auc_fn.compute()

In [None]:
auc_fn.reset()
auc_fn(torch.cat(test_pred), torch.cat(tmp_test_targets))

In [None]:
gnn_model.model.test_targets

In [None]:
torch.cat(test_pred) == torch.cat(gnn_model.model.test_preds).to('cpu')

In [None]:
sklearn.metrics.roc_auc_score(torch.cat(tmp_test_targets), torch.cat(test_pred))

### R testing

In [None]:
model_1 = 'lightning_gtvp_rad_image_vit_linear_nomask_nopool_weight7_22_dp2_v114'
model_2 = 'lightning_gtvp_image_vit_linear_nomask_nopool_weight7_22_dp2_v114'
model_3 = 'lightning_gtvp_foundation_image_vit_linear_nomask_nopool_weight7_22_dp2_v115' 
model_4 = 'lightning_gtvp_rad_foundation_image_vit_linear_nomask_nopool_weight7_22_dp2_v116'
model_5 = 'lightning_gtvp_image_vit_linear_mask_nopool_weight7_22_dp2_v120'
model_6 = 'lightning_gtvp_rad_image_vit_linear_mask_nopool_weight7_22_dp2_v121'
model_7 = 'lightning_gtvp_foundation_image_vit_linear_mask_nopool_weight7_22_dp2_v122'
model_8 = 'lightning_gtvp_rad_foundation_image_vit_linear_mask_nopool_weight7_22_dp2_v123'
model_9 = 'lightning_gtvp_rad_linear_mask_nopool_weight7_22_dp2_v123'
model_10 = 'lightning_gtvp_foundation_linear_mask_nopool_weight7_22_dp2_v124'
model_11 = 'foundation_rad_avg'
model_12 = 'foundation_image_avg'
model_13 = 'rad_image_avg'
model_14 = 'foundation_rad_image_avg'
model_15 = 'lightning_gtvp_image_spottune_nograph_nopool_weight7_22_dp2_v143'

In [None]:
from rpy2.robjects.packages import importr
import rpy2.robjects.lib.ggplot2 as gp
import rpy2.robjects as ro
from rpy2.robjects import numpy2ri, default_converter, pandas2ri, r

utils = importr('utils')
base = importr('base')
np_cv_rules = default_converter + numpy2ri.converter
numpy2ri.activate()
pandas2ri.activate()
utils.chooseCRANmirror(ind=1)

utils.install_packages('stats')
utils.install_packages('pROC')

proc = importr('pROC')
stats = importr('stats')

In [None]:
roc_foundation = proc.roc(targets.numpy(),pred_dict[model_10])
roc_rad = proc.roc(targets.numpy(), pred_dict[model_9])
roc_image = proc.roc(targets.numpy(), pred_dict[model_5])
roc_rad_image = proc.roc(targets.numpy(), pred_dict[model_6])
roc_spottune = proc.roc(targets.numpy(), pred_dict[model_15])
test1 = proc.roc_test(roc_foundation, roc_rad, method='delong', alternative='greater')
test2 = proc.roc_test(roc_foundation, roc_image, method='delong', alternative='greater')
test3 = proc.roc_test(roc_rad, roc_image, method='delong', alternative='greater')
test4 = proc.roc_test(roc_foundation, roc_rad_image, method='delong', alternative='greater')
test5 = proc.roc_test(roc_rad_image, roc_image, method='delong', alternative='greater')
test6 = proc.roc_test(roc_rad_image, roc_rad, method='delong', alternative='greater')
#test7 = proc.roc_test(roc_foundation, roc_spottune, method='delong', alternative='greater')
test7 = proc.roc_test(roc_image, roc_spottune, method='delong', alternative='greater')
print(test1[7])
print(test2[7])
print(test3[7])
print(test4[7])
print(test5[7])
print(test7[7])
p_list = []
p_list.extend(test1[7])
p_list.extend(test2[7])
p_list.extend(test3[7])
p_list.extend(test4[7])
#p_list.extend(test5[7])
#p_list.extend(test6[7])

In [None]:
stats.p_adjust(p_list, method='fdr')

In [None]:
test7

In [None]:
proc.auc(roc2)

In [None]:
print(proc.ci(roc_foundation))
print(proc.ci(roc_rad))
print(proc.ci(roc_image))
print(proc.ci(roc_rad_image))

In [None]:
p_values = []
for idx in range(1,10+1):
    for jdx in range(1,10+1):
        if idx == jdx: continue
        print(globals().get(f'model_{idx}'), globals().get(f'model_{jdx}'))   
        roc1 = proc.roc(targets.numpy(),pred_dict[globals().get(f'model_{idx}')])
        roc2 = proc.roc(targets.numpy(), pred_dict[globals().get(f'model_{jdx}')])
        auc1 = proc.auc(roc1)
        auc2 = proc.auc(roc2)
        if auc1 > auc2:
            test = proc.roc_test(roc1, roc2, method='delong', alternative='greater')
        elif auc2 > auc1:
            test = proc.roc_test(roc2, roc1, method='delong', alternative='greater')
        else:
            test = proc.roc_test(roc1, roc2, method='delong', alternative='greater')
            print('auc is the same')
        print(test[7])
        p_values.append((globals().get(f'model_{idx}'), globals().get(f'model_{jdx}'), test[7], auc1, auc2))

In [None]:
p_values_df = pd.DataFrame(p_values)

In [None]:
p_values_unique_df = p_values_df[~p_values_df.duplicated(subset=[2])]

In [None]:
p_values_unique_df

In [None]:
test_str = p_values_unique_df[0][0].split('_')

In [None]:
map = [
    'rad',
    'foundation',
    'image',
    'mask',
    'nomask',
]

'_'.join([sub_str for sub_str in test_str if sub_str in map])

In [None]:
for idx in [0,1]:
    p_values_unique_df[idx] = ['_'.join([sub_string for sub_string in string.split('_') if sub_string in map]) for string in p_values_unique_df[idx]]

In [None]:
p_values_unique_df.rename(columns={0: 'roc1', 1: 'roc2', 2: 'p-value', 3: 'auc1', 4: 'auc2'}, inplace=True)

In [None]:
p_values_unique_df[p_values_unique_df['roc1'] == 'rad_image_mask']