In [1]:
import meld_graph
import meld_graph.models
import meld_graph.experiment
import meld_graph.dataset
import meld_graph.data_preprocessing
import meld_graph.icospheres

import importlib
importlib.reload(meld_graph)
importlib.reload(meld_graph.models)
importlib.reload(meld_graph.dataset)
importlib.reload(meld_graph.experiment)
importlib.reload(meld_graph.data_preprocessing)

from meld_graph.paths import EXPERIMENT_PATH
import numpy as np

import logging
import argparse
from copy import deepcopy
import os
from functools import reduce
import operator

Setting MELD_DATA_PATH to /home/co-spit1/meld_data
Setting BASE_PATH to /home/co-spit1/meld_data
Setting EXPERIMENT_PATH to /home/co-spit1/meld_experiments/co-spit1
Setting FS_SUBJECTS_PATH to /home/co-spit1/meld_data/output/fs_outputs
Setting EXPERIMENT_PATH to /rds/project/kw350/rds-kw350-meld/experiments_graph/co-spit1


In [2]:
def load_config(config_file):
    """load config.py file and return config object"""
    import importlib.machinery, importlib.util

    loader = importlib.machinery.SourceFileLoader("config", config_file)
    spec = importlib.util.spec_from_loader(loader.name, loader)
    config = importlib.util.module_from_spec(spec)
    loader.exec_module(config)
    return config

In [4]:
config = load_config("../scripts/config_files/experiment_config_hannah_synth.py")

In [5]:
print(config.data_parameters)

{'hdf5_file_root': '{site_code}_{group}_featurematrix_combat_6.hdf5', 'site_codes': ['H4'], 'scanners': ['15T', '3T'], 'dataset': 'MELD_dataset_V6.csv', 'group': 'both', 'features_to_exclude': [], 'subject_features_to_exclude': [], 'features': ['.on_lh.lesion.mgh', '.combat.on_lh.pial.K_filtered.sm20.mgh'], 'features_to_replace_with_0': [], 'number_of_folds': 10, 'fold_n': 0, 'preprocessing_parameters': {'scaling': None, 'zscore': False}, 'icosphere_parameters': {'distance_type': 'exact'}, 'augment_data': {}, 'combine_hemis': None, 'lobes': False, 'lesion_bias': 10, 'synthetic_data': {'n_subs': 100, 'bias': 0.5, 'radius': 0.5}}


## train baseline model

In [6]:
print(config.data_parameters)
print(config.network_parameters)

{'hdf5_file_root': '{site_code}_{group}_featurematrix_combat_6.hdf5', 'site_codes': ['H4'], 'scanners': ['15T', '3T'], 'dataset': 'MELD_dataset_V6.csv', 'group': 'both', 'features_to_exclude': [], 'subject_features_to_exclude': [], 'features': ['.on_lh.lesion.mgh', '.combat.on_lh.pial.K_filtered.sm20.mgh'], 'features_to_replace_with_0': [], 'number_of_folds': 10, 'fold_n': 0, 'preprocessing_parameters': {'scaling': None, 'zscore': False}, 'icosphere_parameters': {'distance_type': 'exact'}, 'augment_data': {}, 'combine_hemis': None, 'lobes': False, 'lesion_bias': 10, 'synthetic_data': {'n_subs': 100, 'bias': 0.5, 'radius': 0.5}}
{'network_type': 'MoNet', 'model_parameters': {'layer_sizes': [16, 16, 16], 'activation_fn': 'leaky_relu', 'conv_type': 'SpiralConv', 'dim': 2, 'kernel_size': 3, 'spiral_len': 7}, 'training_parameters': {'max_patience': 400, 'num_epochs': 20, 'optimiser': 'sgd', 'optimiser_parameters': {'lr': 0.0001, 'momentum': 0.99, 'nesterov': True}, 'lr_decay': 0, 'loss_dict

In [9]:
config.network_parameters['name'] += '/pretrain_radius0.5'
print(config.network_parameters['name'])

22-07-29_synth_3layer/pretrain_radius0.5


In [10]:
exp = meld_graph.experiment.Experiment(config.network_parameters, config.data_parameters, verbose=logging.INFO)

saving parameter files to /rds/project/kw350/rds-kw350-meld/experiments_graph/co-spit1/22-07-29_synth_3layer/pretrain_radius0.5/fold_00


In [11]:
trainer = meld_graph.training.Trainer(exp)

In [12]:
trainer.train()

Using coord type exact


conv 2 16
conv 16 16
conv 16 16


getting train val test split
total number of subjects: 86
total number of subjects after restricting to subjects from MELD_dataset_V6.csv: 86
total number of subjects: 950
total number of subjects after restricting to subjects from MELD_dataset_V6.csv: 942
total number after filtering by scanner ['3T', '15T'], features, lesional_only True: 911
full_feature_list: ['.combat.on_lh.curv.sm5.mgh', '.combat.on_lh.gm_FLAIR_0.25.sm10.mgh', '.combat.on_lh.gm_FLAIR_0.5.sm10.mgh', '.combat.on_lh.gm_FLAIR_0.75.sm10.mgh', '.combat.on_lh.gm_FLAIR_0.sm10.mgh', '.combat.on_lh.pial.K_filtered.sm20.mgh', '.combat.on_lh.sulc.sm5.mgh', '.combat.on_lh.thickness.sm10.mgh', '.combat.on_lh.w-g.pct.sm10.mgh', '.combat.on_lh.wm_FLAIR_0.5.sm10.mgh', '.combat.on_lh.wm_FLAIR_1.sm10.mgh', '.inter_z.asym.intra_z.combat.on_lh.curv.sm5.mgh', '.inter_z.asym.intra_z.combat.on_lh.gm_FLAIR_0.25.sm10.mgh', '.inter_z.asym.intra_z.combat.on_lh.gm_FLAIR_0.5.sm10.mgh', '.inter_z.asym.intra_z.combat.on_lh.gm_FLAIR_0.75.sm10.mgh

KeyError: 'n_subtypes'

In [9]:
for p in exp.model.parameters():
    print(p)


Parameter containing:
tensor([[-0.1594],
        [-0.0392]], device='cuda:0', requires_grad=True)
Parameter containing:
tensor([ 0.3608, -0.4987], device='cuda:0', requires_grad=True)


In [9]:
idxs = trainer.val_data_loader.dataset.lesional_idxs

lesions = list(trainer.val_data_loader.dataset[idxs])

print(lesions[0].x[lesions[0].y > 0].mean(axis=0))
print(lesions[0].x[lesions[0].y == 0].mean(axis=0))

tensor([10.9818, 11.1576])
tensor([0.0000, 0.4085])


In [10]:
idxs = trainer.train_data_loader.dataset.lesional_idxs

lesions = list(trainer.train_data_loader.dataset[idxs])

print(lesions[0].x[lesions[0].y > 0].mean(axis=0))
print(lesions[0].x[lesions[0].y == 0].mean(axis=0))

tensor([10.9865, 10.2297])
tensor([0.0000, 0.2891])


In [46]:
data_true = lesions[0].y
sum(data_true)

tensor(650)

In [47]:
data_pred = exp.model(lesions[0].x.to(exp.model.device))[0]
data_pred = np.argmax(data_pred.cpu().detach(), axis=1)

In [48]:
sum(data_pred)

tensor(733)

In [49]:
sum(~(data_true == data_pred))

tensor(83)

In [50]:
lesions[0].x

tensor([[-0.0469, -0.9361],
        [-0.0469,  2.6119],
        [-0.0469, -0.1069],
        ...,
        [-0.0469, -0.2335],
        [-0.0469, -0.1417],
        [-0.0469, -0.0278]])

In [51]:
lesions[1].x

tensor([[-0.0126, -1.1752],
        [-0.0126, -1.1540],
        [-0.0126, -0.2428],
        ...,
        [-0.0126,  0.8933],
        [-0.0126,  0.9001],
        [-0.0126,  1.0584]])