In [1]:
import sys
sys.path.extend(['/users/students/r0749898/thesis/'])

In [2]:
from datasets.SHHS_dataset_timeonly import SHHS_dataset_1, EEGdataModule

from models.simclr_model import SimCLR
from models.supervised_model import SupervisedModel
import torch
import matplotlib.pyplot as plt
import numpy as np
from argparse import Namespace
from copy import deepcopy
from utils.helper_functions import load_model, SimCLRdataModule
from trainers.train_supervised import train_supervised


## Define dataset: 10 patients for training, 5 for validation and 30 for testing


In [5]:
data_args = {
  "DATA_PATH": "/esat/biomeddata/SHHS_Dataset/no_backup/",
  "data_split": [2, 1],
  "first_patient": 15,
  "num_patients_train": 50,
  "num_patients_test": 30,
  "batch_size": 64,
  "num_workers": 12
}

dm = EEGdataModule(**data_args)  # Load datamodule
dm.setup()

Couldn't find file at path:  /esat/biomeddata/SHHS_Dataset/no_backup/n0068_eeg.mat
Couldn't find file at path:  /esat/biomeddata/SHHS_Dataset/no_backup/n0086_eeg.mat
Couldn't find file at path:  /esat/biomeddata/SHHS_Dataset/no_backup/n0094_eeg.mat


In [6]:
device = torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda:0")
print(device)

cuda:0


## Run data through the pretrained SimCLR encoder to get the representations
- The SimCLR model was pretrained on 100 patients (+50 for validation)

In [None]:
encoder_path = "../trained_models/cnn_simclr01.ckpt"
pretrained_model = load_model(SimCLR, encoder_path)  # Load pretrained simclr model
simclr_dm = SimCLRdataModule(pretrained_model, dm, data_args['batch_size'], data_args['num_workers'], device)

In [None]:
## Analyse the features: histogram and t-SNE plot
feature = next(iter(simclr_dm.train_dataloader()))[0][0]
plt.hist(np.asarray(feature), bins=50)
plt.show()

## Train a logistic classifier on top


In [None]:
logistic_args = {
  "MODEL_TYPE": "SupervisedModel",
  "save_name": "logistic_on_simclr",
  "DATA_PATH": data_args['DATA_PATH'],
  "CHECKPOINT_PATH": "checkpoints",

  "encoder": "None",
  "encoder_hparams": {},

  "classifier": "logistic",
  "classifier_hparams":{
      "input_dim": 100
  },
  "data_hparams": data_args,

  "trainer_hparams":{
    "max_epochs": 15
  },
  "optim_hparams": {
    "lr": 1e-3,
    "weight_decay": 1e-4
  }
}
logistic_model, logistic_res = train_supervised(Namespace(**logistic_args), device=device, dm=simclr_dm)
print(logistic_res)

## Train a supervised model with the same dataset for comparison


In [7]:
supervised_args = {
  "MODEL_TYPE": "SupervisedModel",
  "save_name": "supervised_simclr",
  "DATA_PATH": data_args['DATA_PATH'],
  "CHECKPOINT_PATH": "checkpoints",

  "encoder": "CNN_head",
  "encoder_hparams": {
    "conv_filters": [32, 64, 64],
    "representation_dim": 100
  },

  "classifier": "logistic",
  "classifier_hparams":{
      "input_dim": 100
  },
  "data_hparams": data_args,

  "trainer_hparams":{
    "max_epochs": 40
  },
  "optim_hparams": {
    "lr": 1e-5,
    "weight_decay": 5e-4,
    "lr_hparams": None
  }
}
sup_model, sup_res = train_supervised(Namespace(**supervised_args), device, dm=dm)
print(sup_res)

Global seed set to 42


Couldn't find file at path:  /esat/biomeddata/SHHS_Dataset/no_backup/n0068_eeg.mat
Couldn't find file at path:  /esat/biomeddata/SHHS_Dataset/no_backup/n0086_eeg.mat
Couldn't find file at path:  /esat/biomeddata/SHHS_Dataset/no_backup/n0094_eeg.mat


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
  rank_zero_warn(
Missing logger folder: checkpoints/supervised_simclr/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [5]


TypeError: AdamW.__init__() got an unexpected keyword argument 'lr_hparams'

## Finetuned supervised model

In [None]:
finetune_args = {
  "MODEL_TYPE": "SupervisedModel",
  "save_name": "finetuned_simclr",
  "DATA_PATH": data_args['DATA_PATH'],
  "CHECKPOINT_PATH": "checkpoints",

  "encoder": "CNN_head",
  "encoder_hparams": {
    "conv_filters": [32, 64, 64],
    "representation_dim": 100
  },

  "classifier": "logistic",
  "classifier_hparams":{
      "input_dim": 100
  },
  "data_hparams": data_args,

  "trainer_hparams":{
    "max_epochs": 30
  },
  "optim_hparams": {
    "lr": 1e-5,
    "weight_decay": 1e-5
  }
}
pretrained_encoder = type(pretrained_model.f)(**finetune_args['encoder_hparams'])
pretrained_encoder.load_state_dict(pretrained_model.f.state_dict())
fine_tuned_model, fine_tuned_res = train_supervised(Namespace(**finetune_args), device, dm=dm, pretrained_encoder=pretrained_encoder)
print(fine_tuned_res)

In [None]:
finetune_logistic_args = {
  "MODEL_TYPE": "SupervisedModel",
  "save_name": "finetuned_simclr",
  "DATA_PATH": data_args['DATA_PATH'],
  "CHECKPOINT_PATH": "checkpoints",

  "encoder": "CNN_head",
  "encoder_hparams": {
    "conv_filters": [32, 64, 64],
    "representation_dim": 100
  },

  "classifier": "logistic",
  "classifier_hparams":{
      "input_dim": 100
  },
  "data_hparams": data_args,

  "trainer_hparams":{
    "max_epochs": 50
  },
  "optim_hparams": {
    "lr": 1e-6,
    "weight_decay": 0
  }
}
pretrained_encoder = type(pretrained_model.f)(**finetune_logistic_args['encoder_hparams'])
pretrained_classifier = type(logistic_model.classifier)(finetune_logistic_args['classifier_hparams']['input_dim'], 5)
pretrained_classifier.load_state_dict(logistic_model.classifier.state_dict())
pretrained_encoder.load_state_dict(pretrained_model.f.state_dict())
fully_tuned_model, fully_tuned_res = train_supervised(Namespace(**finetune_logistic_args), device, dm=dm, pretrained_encoder=pretrained_encoder, pretrained_classifier=pretrained_classifier)
print(fully_tuned_res)

In [None]:
print(fully_tuned_res)