In [13]:
import meld_graph.experiment
import os
import numpy as np
import h5py
import matplotlib_surface_plotting as msp
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import nibabel as nb
from meld_classifier.paths import BASE_PATH
from meld_classifier.meld_cohort import MeldCohort,MeldSubject
def load_prediction(subject,hdf5,dset='prediction'):
    results={}
    with h5py.File(hdf5, "r") as f:
        for hemi in ['lh','rh']:
            results[hemi] = f[subject][hemi][dset][:]
    return results


In [2]:
model_path = '/rds/project/kw350/rds-kw350-meld/experiments_graph/kw350/23-01-13_QXFB_kernel_spiral_fold/s_2/'

In [5]:
cohort = MeldCohort(hdf5_file_root='{site_code}_{group}_featurematrix_combat_6.hdf5',
               dataset='MELD_dataset_V6.csv')


In [7]:
folds = np.arange(0,9)

save_dirs = {
    'spiral': [os.path.join(model_path,f'fold_0{fold}', 'results') for fold in folds] 
}

In [15]:
n_vert = len(cohort.cortex_label)*2

In [10]:
with h5py.File(os.path.join(save_dirs['spiral'][0], 'predictions.hdf5'), "r") as f:
    subjects = list(f.keys())


In [77]:
def roc_curves(subject_dictionary,roc_dictionary):
    """calculate performance at multiple thresholds"""
    roc_curves_thresholds=np.linspace(0,1,21)
    for t_i,threshold in enumerate(roc_curves_thresholds):
        predicted = subject_dictionary['result']>= threshold
        # if we want tpr vs fpr curve too
        # tp,fp,fn, tn = tp_fp_fn_tn(predicted, subject_dictionary['input_labels'])
        #store sensitivity and sensitivity_plus for each patient (has a label)
        if subject_dictionary['input_labels'].sum()>0:
            roc_dictionary['sensitivity'][t_i] += np.logical_and(predicted, subject_dictionary['input_labels']).any()
            roc_dictionary['sensitivity_plus'][t_i] += np.logical_and(predicted, subject_dictionary['borderzone']).any()
        #store specificity for controls (no label)
        else:
            roc_dictionary['specificity'][t_i] += ~predicted.any()

In [82]:
roc_dictionary={'sensitivity':np.zeros(21),
'sensitivity_plus':np.zeros(21),
'specificity':np.zeros(21)}
for si,subj in enumerate(subjects):
    if si%100==0:
        print(si)
    s = MeldSubject(subj,cohort=cohort)
    labels_hemis = {}
    dists={}
    subject_results = np.zeros(n_vert)
    labels = np.zeros(n_vert)
    for hemi in ['lh','rh']:
        dists[hemi], labels_hemis[hemi] = s.load_feature_lesion_data(
                    features=['.on_lh.boundary_zone.mgh'], hemi=hemi, features_to_ignore=[]
                )
        if np.sum(dists[hemi])==0:
            dists[hemi] +=200
    labels = np.hstack([labels_hemis['lh'][cohort.cortex_mask],labels_hemis['rh'][cohort.cortex_mask]])
    borderzones = np.vstack([dists['lh'][cohort.cortex_mask,:],dists['rh'][cohort.cortex_mask,:]]).ravel()<20
    for fold in folds:
        save_dir = save_dirs['spiral'][fold]
        pred_file = os.path.join(save_dir, 'predictions.hdf5')
        result_hemis = load_prediction(subj,pred_file, dset='prediction')
        subject_results += np.hstack([result_hemis['lh']/10,result_hemis['rh']/10])
    subject_dictionary={'input_labels':labels,'borderzone':borderzones,'result':subject_results}
    roc_curves(subject_dictionary,roc_dictionary)
    

0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230
240
250
260
270
280
290
300
310
320
330
340
350
360
370
380
390
400
410
420
430
440
450


In [83]:
def optimal_threshold(b):
    thresholds=np.linspace(0,1,21)
    youden = b['sensitivity_plus']/max(b['sensitivity_plus'])+b['specificity']/max(b['specificity'])
    optimal_thresh =thresholds[np.max(np.where(youden==np.max(youden)))]
    print(optimal_thresh)

In [1]:
class EnsembleModel:
    def __init__(self, model_path):
        """model path doesn't contain the fold bit. this is then looped over"""
        #load models
        self.models=[]
        for fold in np.arange(10):
            fold_path = os.path.join(model_path,f'fold_0{fold}')
            exp = meld_graph.experiment.Experiment.from_folder(fold_path)
            exp.load_model(
                        checkpoint_path=os.path.join(fold_path, "best_model.pt"),
                        force=True,
                    )
            self.models.append(exp.model)
        self.network_parameters = exp.network_parameters

    def predict(self,data):
        """function to predict and average"""
        predictions=[]
        distance_maps=[]
        for model in self.models:
            estimates = model(data)
            predictions.append(torch.exp(estimates['log_softmax'])[:,1].numpy)
            #get distance map if exist in loss, otherwise return array of NaN
            if 'distance_regression' in self.network_parameters['training_parameters']['loss_dictionary'].keys():
                distance_map = estimates['non_lesion_logits'][:,0]
            else: 
                distance_map = torch.full((len(prediction),1), torch.nan)[:,0]
            distance_maps = 

In [7]:
models=[]
for fold in np.arange(10):
    fold_path = os.path.join(model_path,f'fold_0{fold}')
    exp = meld_graph.experiment.Experiment.from_folder(fold_path)
    exp.load_model(
                checkpoint_path=os.path.join(fold_path, "best_model.pt"),
                force=True,
            )
    models.append(exp.model)

Initialised Experiment 23-01-13_QXFB_kernel_GMM_fold/s_2
Creating model
Loading model weights from checkpoint /rds/project/kw350/rds-kw350-meld/experiments_graph/kw350/23-01-13_QXFB_kernel_GMM_fold/s_2/fold_00/best_model.pt
Initialised Experiment 23-01-13_QXFB_kernel_GMM_fold/s_2
Creating model
Loading model weights from checkpoint /rds/project/kw350/rds-kw350-meld/experiments_graph/kw350/23-01-13_QXFB_kernel_GMM_fold/s_2/fold_01/best_model.pt
Initialised Experiment 23-01-13_QXFB_kernel_GMM_fold/s_2
Creating model
Loading model weights from checkpoint /rds/project/kw350/rds-kw350-meld/experiments_graph/kw350/23-01-13_QXFB_kernel_GMM_fold/s_2/fold_02/best_model.pt
Initialised Experiment 23-01-13_QXFB_kernel_GMM_fold/s_2
Creating model
Loading model weights from checkpoint /rds/project/kw350/rds-kw350-meld/experiments_graph/kw350/23-01-13_QXFB_kernel_GMM_fold/s_2/fold_03/best_model.pt
Initialised Experiment 23-01-13_QXFB_kernel_GMM_fold/s_2
Creating model
Loading model weights from check

NameError: name 'data' is not defined