In [9]:
import meld_graph
import meld_graph.models
import meld_graph.experiment
import meld_graph.dataset
import meld_graph.data_preprocessing
import meld_graph.evaluation


import importlib
import logging
import os
import json

from meld_graph.dataset import GraphDataset
from meld_classifier.meld_cohort import MeldCohort, MeldSubject
import numpy as np
from meld_graph.paths import EXPERIMENT_PATH

from meld_graph.evaluation import Evaluator



### Generate dataset

In [10]:
#initialise dataset

cohort= MeldCohort(hdf5_file_root='{site_code}_{group}_featurematrix_combat_6.hdf5', dataset='MELD_dataset_V6.csv')

# initialise subjects manually in list or get from a csv dataset
subject_ids, trainval_ids, test_ids  = cohort.read_subject_ids_from_dataset()

subjects = test_ids
subjects.remove('MELD_H10_3T_FCD_0008')
subjects.remove('MELD_H23_15T_FCD_0002')
subjects = [
        # 'MELD_H15_3T_FCD_0008',
        #     'MELD_H16_3T_FCD_004',
        #     'MELD_H16_3T_FCD_005',
        #     'MELD_H18_3T_FCD_0006',
            'MELD_H18_3T_FCD_0109',
            'MELD_H23_15T_FCD_0027',
            'MELD_H23_15T_FCD_0031',
            'MELD_H4_3T_FCD_0010',
            'MELD_H14_3T_FCD_0035',          
]


features= ['.combat.on_lh.pial.K_filtered.sm20.mgh',
          '.combat.on_lh.thickness.sm10.mgh',
          '.combat.on_lh.w-g.pct.sm10.mgh',
          '.combat.on_lh.sulc.sm5.mgh',
          '.combat.on_lh.curv.sm5.mgh',
          '.combat.on_lh.gm_FLAIR_0.75.sm10.mgh',
          '.combat.on_lh.gm_FLAIR_0.5.sm10.mgh',
          '.combat.on_lh.gm_FLAIR_0.25.sm10.mgh',
          '.combat.on_lh.gm_FLAIR_0.sm10.mgh',
          '.combat.on_lh.wm_FLAIR_0.5.sm10.mgh',
          '.combat.on_lh.wm_FLAIR_1.sm10.mgh',
          '.inter_z.intra_z.combat.on_lh.pial.K_filtered.sm20.mgh',
          '.inter_z.intra_z.combat.on_lh.thickness.sm10.mgh',
          '.inter_z.intra_z.combat.on_lh.w-g.pct.sm10.mgh',
          '.inter_z.intra_z.combat.on_lh.sulc.sm5.mgh',
          '.inter_z.intra_z.combat.on_lh.curv.sm5.mgh',
          '.inter_z.intra_z.combat.on_lh.gm_FLAIR_0.75.sm10.mgh',
          '.inter_z.intra_z.combat.on_lh.gm_FLAIR_0.5.sm10.mgh',
          '.inter_z.intra_z.combat.on_lh.gm_FLAIR_0.25.sm10.mgh',
          '.inter_z.intra_z.combat.on_lh.gm_FLAIR_0.sm10.mgh',
          '.inter_z.intra_z.combat.on_lh.wm_FLAIR_0.5.sm10.mgh',
          '.inter_z.intra_z.combat.on_lh.wm_FLAIR_1.sm10.mgh',
          '.inter_z.asym.intra_z.combat.on_lh.pial.K_filtered.sm20.mgh',
          '.inter_z.asym.intra_z.combat.on_lh.thickness.sm10.mgh',
          '.inter_z.asym.intra_z.combat.on_lh.w-g.pct.sm10.mgh',
          '.inter_z.asym.intra_z.combat.on_lh.sulc.sm5.mgh',
          '.inter_z.asym.intra_z.combat.on_lh.curv.sm5.mgh',
          '.inter_z.asym.intra_z.combat.on_lh.gm_FLAIR_0.75.sm10.mgh',
          '.inter_z.asym.intra_z.combat.on_lh.gm_FLAIR_0.5.sm10.mgh',
          '.inter_z.asym.intra_z.combat.on_lh.gm_FLAIR_0.25.sm10.mgh',
          '.inter_z.asym.intra_z.combat.on_lh.gm_FLAIR_0.sm10.mgh',
          '.inter_z.asym.intra_z.combat.on_lh.wm_FLAIR_0.5.sm10.mgh',
          '.inter_z.asym.intra_z.combat.on_lh.wm_FLAIR_1.sm10.mgh',]

# initiate params
params = {                
          "features":features,
          "augment_data": {}, 
          "lesion_bias":0,
          "lobes": False,
          "smooth_labels": False,
          "synthetic_data": {
              'run_synthetic': False,
              # 'n_subs': 200,
              # 'use_controls': True,
              # 'radius': 0.5,
              # 'n_subtypes': 25,
              # 'jitter_factor': 2,
              # 'bias': 1,
              # 'proportion_features_abnormal': 0.2,
              # 'proportion_hemispheres_lesional': 0.9,
            },
          "number_of_folds": 5,
          "preprocessing_parameters": {
              "scaling": None, 
              "zscore": '../data/feature_means.json',
            },
          "combine_hemis": None,
          }

#load dataset
dataset = GraphDataset(subjects, cohort, params, mode='test')


Loading and preprocessing test data


dataset using distance_maps


Z-scoring data for MELD_H18_3T_FCD_0109
Z-scoring data for MELD_H23_15T_FCD_0027
Z-scoring data for MELD_H23_15T_FCD_0031
Z-scoring data for MELD_H4_3T_FCD_0010
Z-scoring data for MELD_H14_3T_FCD_0035


In [42]:
len(dataset.subject_ids)


5

### Load models and predict

In [11]:
# initialise models you want to run
EXPERIMENT_PATH='/rds/project/kw350/rds-kw350-meld/experiments_graph/co-spit1'

model_base_paths = {
    # unet deepsupervision trained on big lesions (radius 2)
#    'augment_finetune_small': '/rds/project/kw350/rds-kw350-meld/experiments_graph/kw350/22-10-24_synth/unet/augment_finetune_real/fold_00/',
#     '3layer_augment_finetune_small': '/rds/project/kw350/rds-kw350-meld/experiments_graph/kw350/22-10-24_synth/3layers/finetune_real/fold_00/',
    'fold0': '23-02-17_SCAY_mask_augmentation/s_2/fold_00/',
    'fold1': '23-02-17_SCAY_mask_augmentation/s_2/fold_01/',
    'fold2': '23-02-17_SCAY_mask_augmentation/s_2/fold_02/',
    'fold3': '23-02-17_SCAY_mask_augmentation/s_2/fold_03/',
    'fold4': '23-02-17_SCAY_mask_augmentation/s_2/fold_04/'
}

In [13]:
from pyexpat import model

for model_name in model_base_paths.keys(): 
    
    #load experiment already trained using checkpoint path
    model_base_path=model_base_paths[model_name]
    checkpoint_path=os.path.join(EXPERIMENT_PATH, model_base_path)
    exp = meld_graph.experiment.Experiment.from_folder(checkpoint_path)

    # Run the evaluation on the test data and save into directory provided
    save_dir=os.path.join(EXPERIMENT_PATH,f'23-02-20_TEST_SCAY_mask_augmentation/{model_name}')

    eva = Evaluator(experiment = exp,
                    checkpoint_path = checkpoint_path,
                    save_dir = save_dir ,
                    make_images = True,
                    dataset=dataset,
                    cohort=cohort,
                )

    # evaluate (predict , stats and plot) or just run individually each step
    # eva.evaluate()
    
    # # load data and predict
    eva.load_predict_data()
    # # calculate stats 
    # eva.stat_subjects()
    #  # make images 
    eva.plot_subjects_prediction()

    

Initialised Experiment 23-02-17_SCAY_mask_augmentation/s_2
Creating model
Loading model weights from checkpoint /rds/project/kw350/rds-kw350-meld/experiments_graph/co-spit1/23-02-17_SCAY_mask_augmentation/s_2/fold_00/best_model.pt
loading data and predicting model
saving prediction for MELD_H18_3T_FCD_0109
saving distance_map for MELD_H18_3T_FCD_0109
saving prediction for MELD_H23_15T_FCD_0027
saving distance_map for MELD_H23_15T_FCD_0027
saving prediction for MELD_H23_15T_FCD_0031
saving distance_map for MELD_H23_15T_FCD_0031
saving prediction for MELD_H4_3T_FCD_0010
saving distance_map for MELD_H4_3T_FCD_0010
saving prediction for MELD_H14_3T_FCD_0035
saving distance_map for MELD_H14_3T_FCD_0035
Initialised Experiment 23-02-17_SCAY_mask_augmentation/s_2
Creating model
Loading model weights from checkpoint /rds/project/kw350/rds-kw350-meld/experiments_graph/co-spit1/23-02-17_SCAY_mask_augmentation/s_2/fold_01/best_model.pt
loading data and predicting model
saving prediction for MELD_H

### Compare models

In [26]:
# load results 
import pandas as pd

df={}
predicted={}
for model_name in model_base_paths.keys():  
    
    results_file = os.path.join(EXPERIMENT_PATH,f'22-10-21_synth_evaluation/{model_name}', 'results', 'test_results.csv')
    df[model_name] = pd.read_csv(results_file) 
    
    

In [27]:
df[model_name]

Unnamed: 0,ID,group,detected,dice_tp,dice_fp,dice_fn,dice_tn
0,MELD_H5_3T_C_0008,True,True,5923,43,43,287795
1,MELD_H3_3T_C_0008,True,True,3831,30,16,289927
2,MELD_H19_3T_C_016,True,True,1387,13,5720,286684
3,MELD_H4_15T_C_0017,True,True,380,54,5906,287464
4,MELD_H5_3T_C_0022,True,True,2183,60,0,291561
...,...,...,...,...,...,...,...
157,MELD_H4_3T_C_0019,True,True,1439,109,81,292175
158,MELD_H3_3T_C_0002,True,True,4814,2691,551,285748
159,MELD_H14_3T_C_0019,True,True,13269,87,22,280426
160,MELD_H5_3T_C_0028,True,True,29340,88,27,264349


In [28]:
# stats
import pandas as pd

for model_name in model_base_paths.keys():  
   
    dfsub = df[model_name]    
    predicted = (dfsub['dice_tp']>0)
    print(f'Model {model_name}: \n {predicted.sum()/len(predicted)} sensitivity')

Model unet_deepsuper: 
 0.9382716049382716 sensitivity
Model unet_small_scratch: 
 0.9259259259259259 sensitivity
Model unet_small_finetune: 
 0.9691358024691358 sensitivity


Unnamed: 0,ID,group,detected,dice_tp,dice_fp,dice_fn,dice_tn
0,MELD_H5_3T_C_0008,True,True,5923,43,43,287795
1,MELD_H3_3T_C_0008,True,True,3831,30,16,289927
2,MELD_H19_3T_C_016,True,True,1387,13,5720,286684
3,MELD_H4_15T_C_0017,True,True,380,54,5906,287464
4,MELD_H5_3T_C_0022,True,True,2183,60,0,291561
...,...,...,...,...,...,...,...
157,MELD_H4_3T_C_0019,True,True,1439,109,81,292175
158,MELD_H3_3T_C_0002,True,True,4814,2691,551,285748
159,MELD_H14_3T_C_0019,True,True,13269,87,22,280426
160,MELD_H5_3T_C_0028,True,True,29340,88,27,264349


### Plot prediction and lesion 


In [14]:
import h5py
import matplotlib_surface_plotting as msp
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import nibabel as nb
from meld_classifier.paths import BASE_PATH

def load_prediction(subject,hdf5,dset='prediction'):
    results={}
    with h5py.File(hdf5, "r") as f:
        for hemi in ['lh','rh']:
            results[hemi] = f[subject][hemi][dset][:]
    return results

def create_surface_plots(coords,faces,overlay,flat_map=True):
    """plot and reload surface images"""
    from meld_classifier.meld_plotting import trim
    import matplotlib_surface_plotting.matplotlib_surface_plotting as msp
    from PIL import Image

    msp.plot_surf(coords,faces, 
                overlay,
                flat_map=flat_map,
                rotate=[90, 270],
                filename='tmp.png',
                vmin=0.4,
                vmax=0.6,
             )
    im = Image.open('tmp.png')
    im = trim(im)
    im = im.convert("RGBA")
    im1 = np.array(im)
    return im1

In [19]:
cohort = MeldCohort(hdf5_file_root='{site_code}_{group}_featurematrix_combat_6.hdf5',
               dataset='MELD_dataset_V6.csv')


# ensemble if needed
predictions_file= ['/rds/project/kw350/rds-kw350-meld/experiments_graph/co-spit1/23-02-20_TEST_SCAY_mask_augmentation/fold0/results/predictions.hdf5',
                    '/rds/project/kw350/rds-kw350-meld/experiments_graph/co-spit1/23-02-20_TEST_SCAY_mask_augmentation/fold1/results/predictions.hdf5',
                    '/rds/project/kw350/rds-kw350-meld/experiments_graph/co-spit1/23-02-20_TEST_SCAY_mask_augmentation/fold2/results/predictions.hdf5',
                    '/rds/project/kw350/rds-kw350-meld/experiments_graph/co-spit1/23-02-20_TEST_SCAY_mask_augmentation/fold3/results/predictions.hdf5',
                    '/rds/project/kw350/rds-kw350-meld/experiments_graph/co-spit1/23-02-20_TEST_SCAY_mask_augmentation/fold4/results/predictions.hdf5',
                    ]

# predictions_file = os.path.join('/rds/project/kw350/rds-kw350-meld/experiments/co-ripa1/iteration_21-09-15/ensemble_21-09-15/fold_all/results/predictions_ensemble_iteration.hdf5')

with h5py.File(predictions_file, "r") as f:
    subjects= list(f.keys())
subjects

subjects = [
            'MELD_H18_3T_FCD_0109',
            'MELD_H23_15T_FCD_0031',
            'MELD_H4_3T_FCD_0010',
            'MELD2_H7_3T_FCD_009',
            'MELD2_H7_3T_FCD_004'          
]

In [27]:

from numpy import False_

threshold=0.09

features=['none']
for subject in subjects:
    print(subject)
    if isinstance(predictions_file, list):
        result_hemis= {}
        result_hemis['lh']=np.zeros(cohort.cortex_mask.sum())
        result_hemis['rh']=np.zeros(cohort.cortex_mask.sum())
        #combined fold
        for pred_file in predictions_file:
            result_hemis_temp = load_prediction(subject,pred_file, dset='prediction')
            result_hemis['lh'] += result_hemis_temp['lh']/len(predictions_file)
            result_hemis['rh'] += result_hemis_temp['rh']/len(predictions_file)   
        #threshold
        result_hemis['lh'] = result_hemis_temp['lh']>0.09
        result_hemis['rh'] += result_hemis_temp['rh']>0.09   
    else: 
        result_hemis = load_prediction(subject,predictions_file)
    subj = MeldSubject(subject,cohort=cohort)
    labels_hemis = {}
    for hemi in ['lh','rh']:
        _, labels_hemis[hemi] = subj.load_feature_lesion_data(
            features, hemi=hemi, features_to_ignore=[]
        )

    #plot lesion
    flat = nb.load(os.path.join(BASE_PATH, "fsaverage_sym", "surf", "lh.full.patch.flat.gii"))
    coords, faces = flat.darrays[0].data, flat.darrays[1].data

    # round up to get the square grid size
    fig= plt.figure(figsize=(8,8), constrained_layout=True, facecolor='white')
    gs1 = GridSpec(2, 2, width_ratios=[1, 1],  wspace=0.1, hspace=0.1)
    data_to_plot= [result_hemis['lh'], result_hemis['rh'], labels_hemis['lh'], labels_hemis['rh']]
    titles=['predictions left hemi', 'predictions right hemi', 'labels left hemi', 'labels right hemi']
    for i,overlay in enumerate(data_to_plot):
        if len(overlay) < len(cohort.cortex_mask):
                overlay_tmp=np.zeros(len(cohort.cortex_mask))
                overlay_tmp[cohort.cortex_mask]= overlay
                overlay=overlay_tmp
        ax = fig.add_subplot(gs1[i])
        im = create_surface_plots(coords,faces,overlay,flat_map=True)
        ax.imshow(im)
        ax.axis('off')
        ax.set_title(titles[i], loc='left', fontsize=20)  
    fig.savefig(f'/rds/project/kw350/rds-kw350-meld/experiments_graph/co-spit1/23-02-20_TEST_SCAY_mask_augmentation/results_pervertex/images/{subject}', bbox_inches='tight', transparent=False, facecolor=fig.get_facecolor())
    plt.close("all")
    fig.show()

MELD_H18_3T_FCD_0109


  fig.savefig(f'/rds/project/kw350/rds-kw350-meld/experiments_graph/co-spit1/23-02-20_TEST_SCAY_mask_augmentation/results_pervertex/images/{subject}', bbox_inches='tight', transparent=False, facecolor=fig.get_facecolor())


MELD_H23_15T_FCD_0031
MELD_H4_3T_FCD_0010
MELD2_H7_3T_FCD_009
MELD2_H7_3T_FCD_004
