In [1]:
import meld_graph.experiment
import os
import numpy as np
import h5py
import matplotlib_surface_plotting as msp
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import nibabel as nb
from meld_classifier.paths import BASE_PATH
from meld_classifier.meld_cohort import MeldCohort,MeldSubject

from meld_graph.experiment import Experiment
from meld_graph.dataset import GraphDataset
import torch
import torch_geometric


Setting EXPERIMENT_PATH to /rds/project/kw350/rds-kw350-meld/experiments_graph/co-spit1
Setting MELD_DATA_PATH to /home/co-spit1/meld_data
Setting BASE_PATH to /home/co-spit1/meld_data
Setting EXPERIMENT_PATH to /home/co-spit1/meld_experiments/co-spit1
Setting FS_SUBJECTS_PATH to /home/co-spit1/meld_data/output/fs_outputs


In [2]:
model_path1 = '/rds/project/kw350/rds-kw350-meld/experiments_graph/kw350/23-02-09_MYCZ_baseline/s_2/'

cohort = MeldCohort(hdf5_file_root='{site_code}_{group}_featurematrix_combat_6.hdf5',
               dataset='MELD_dataset_V6.csv')

In [3]:
folds = np.arange(5)

exp_dirs = [os.path.join(model_path1,f'fold_0{fold}') for fold in folds]

In [4]:
exps = [Experiment.from_folder(exp_dir) for exp_dir in exp_dirs]

Initialised Experiment 23-02-09_MYCZ_baseline/s_2
Initialised Experiment 23-02-09_MYCZ_baseline/s_2
Initialised Experiment 23-02-09_MYCZ_baseline/s_2
Initialised Experiment 23-02-09_MYCZ_baseline/s_2
Initialised Experiment 23-02-09_MYCZ_baseline/s_2


In [5]:
for exp in exps:
    exp.load_model(os.path.join(exp.experiment_path, exp.experiment_name, f'fold_0{exp.fold}', 'best_model.pt'))

Creating model
Creating model
Creating model
Creating model
Creating model


In [6]:
train_ids, val_ids, test_ids = exp.get_train_val_test_ids()
dataset = GraphDataset(val_ids[:10], cohort, exp.data_parameters, mode='test')

Loading and preprocessing test data
Z-scoring data for MELD_H17_3T_FCD_0035
Z-scoring data for MELD_H21_15T_FCD_0039
Z-scoring data for MELD_H2_3T_C_0024
Z-scoring data for MELD_H19_3T_C_027
Z-scoring data for MELD_H15_3T_C_0028
Z-scoring data for MELD_H23_15T_FCD_0007
Z-scoring data for MELD_H2_3T_C_0015
Z-scoring data for MELD_H10_3T_C_0010
Z-scoring data for MELD_H19_3T_C_022
Z-scoring data for MELD_H14_3T_FCD_0019


In [7]:
data_loader = torch_geometric.loader.DataLoader(
        dataset,
        shuffle=False,
        batch_size=1)

In [8]:
from torch import nn

In [9]:
class Ensemble(nn.Module):
    """
    Ensemble models
    """
    def __init__(self, models):
        super(Ensemble, self).__init__()
        self.models = nn.ModuleList(models)
        
    def forward(self, x):
        """
        Forward pass
        """
        estimates = [model(x) for model in self.models]
        ensembled_estimates = {}
        for key in ['log_softmax', 'hemi_log_softmax', 'non_lesion_logits']:
            if key not in estimates[0].keys():
                continue
            if 'log_softmax' in key:
                # there are the logged outputs -> before mean, need to do exp 
                vals = torch.stack([torch.exp(est[key]) for est in estimates], dim=2)
                mean_val = torch.log(torch.mean(vals, dim=2))
            else:
                mean_val = torch.mean(torch.stack([est[key] for est in estimates], dim=2), dim=2)
            ensembled_estimates[key] = mean_val
        return ensembled_estimates

ensemble = Ensemble([exp.model for exp in exps])

In [10]:
# predict
for i, data in enumerate(data_loader):
    estimates = ensemble(data.x)
    
    # single folds for comparison
    #estimates_folds = [exp.model(data.x) for exp in exps]
    break

In [13]:
print(estimates['log_softmax'])

estimates_exp = [torch.exp(f['log_softmax']) for f in estimates_folds]
print(torch.log(torch.mean(torch.stack(estimates_exp, dim=2), dim=2)))

tensor([[-0.6839, -0.7025],
        [-0.7104, -0.6762],
        [-0.7167, -0.6701],
        ...,
        [-0.6948, -0.6915],
        [-0.7212, -0.6659],
        [-0.7076, -0.6789]], grad_fn=<LogBackward0>)
tensor([[-0.6839, -0.7025],
        [-0.7104, -0.6762],
        [-0.7167, -0.6701],
        ...,
        [-0.6948, -0.6915],
        [-0.7212, -0.6659],
        [-0.7076, -0.6789]], grad_fn=<LogBackward0>)


In [14]:
print(estimates['non_lesion_logits'])
print(estimates_folds[1]['non_lesion_logits'])

tensor([[0.0771],
        [0.1123],
        [0.6275],
        ...,
        [0.2871],
        [0.3094],
        [0.3253]], grad_fn=<MeanBackward1>)
tensor([[-0.0008],
        [-0.0036],
        [ 0.0858],
        ...,
        [ 0.5935],
        [ 0.3564],
        [ 0.5201]], grad_fn=<ViewBackward0>)


In [11]:
# try saving 
fname = 'ensemble_model.pt'
torch.save(ensemble.state_dict(), fname)

In [12]:
# load model - function of experiment
import copy
def load_ensemble_model(self, checkpoint_path=None, force=False):
    if self.model is not None and not force:
        self.log.info("Model already exists. Specify force=True to force reloading and initialisation")
    # create model without checkpoint
    self.load_model()
    self.log.info('Creating ensemble model')
    models = [copy.deepcopy(self.model) for _ in range(5)]  # TODO this assumes that we are always ensembling 5 models
    ensemble_model = Ensemble(models)
    self.model = ensemble_model
    # load weights from checkpoint    
    if checkpoint_path is not None and os.path.isfile(checkpoint_path):
        # checkpoint contains both model architecture + weights
        device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.log.info(f"Loading ensemble model weights from checkpoint {checkpoint_path}")
        self.model.load_state_dict(torch.load(checkpoint_path, map_location=device), strict=False)
        self.model.eval()
    


In [13]:
exp = Experiment.from_folder(exp_dirs[0])
load_ensemble_model(exp, checkpoint_path='ensemble_model.pt')

Initialised Experiment 23-02-09_MYCZ_baseline/s_2
Creating model
Creating ensemble model
Loading ensemble model weights from checkpoint ensemble_model.pt


In [14]:
# check that have different weights in different models
for name, param in exp.model.named_parameters():
    if 'encoder_conv_layers.0.0.layer.weight' in name:
        print(name, param)

models.0.encoder_conv_layers.0.0.layer.weight Parameter containing:
tensor([[ 0.0970, -0.1243, -0.0667,  ...,  0.0817, -0.0645, -0.0773],
        [ 0.1273,  0.0729,  0.1507,  ...,  0.1059,  0.0259,  0.1334],
        [ 0.0711, -0.1097,  0.0649,  ..., -0.0778, -0.0976,  0.0843],
        ...,
        [-0.1498, -0.0035,  0.0912,  ...,  0.0214,  0.0946,  0.0228],
        [-0.0873, -0.0023, -0.0966,  ...,  0.0932, -0.0044, -0.0792],
        [-0.0366,  0.0820,  0.0872,  ..., -0.1319, -0.1302,  0.0937]],
       requires_grad=True)
models.1.encoder_conv_layers.0.0.layer.weight Parameter containing:
tensor([[ 0.1487, -0.1123, -0.0906,  ..., -0.0787,  0.0127,  0.1028],
        [-0.0862,  0.1353, -0.1198,  ...,  0.0290, -0.0596, -0.0157],
        [-0.0119, -0.0816,  0.0578,  ...,  0.0009, -0.0595, -0.1305],
        ...,
        [-0.0145, -0.1298,  0.1177,  ..., -0.0879,  0.0247, -0.0711],
        [ 0.1373,  0.0732, -0.0994,  ..., -0.1396, -0.0774,  0.0269],
        [-0.0647, -0.0460,  0.0085,  ...

In [15]:
# predict
for i, data in enumerate(data_loader):
    estimates2 = ensemble(data.x)
    
    break

In [19]:
(estimates['log_softmax'] == estimates2['log_softmax']).all()

tensor(True)