In [None]:
import os
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor

from datasets.SHHS_dataset_timeonly import SHHS_dataset_1, EEGdataModule

from models.logistic_regression import LogisticRegression
from models.simclr_model import SimCLR
from models.supervised_model import SupervisedModel
import torch
import torch.utils.data as data
import matplotlib.pyplot as plt
import numpy as np
from sklearn.manifold import TSNE
import seaborn as sns
import pandas as pd
from argparse import Namespace
from copy import deepcopy
from utils.helper_functions import load_model, SimCLRdataModule

## Define dataset: 10 patients for training, 5 for validation and 30 for testing


In [None]:
data_args = {
  "DATA_PATH": "../../thesis01/data/",
  "data_split": [2, 1],
  "first_patient": 1,
  "num_patients_train": 15,
  "num_patients_test": 30,
  "batch_size": 64,
  "num_workers": 12
}

device = torch.device('cpu')
dm = EEGdataModule(**data_args)  # Load datamodule
dm.setup()

## Run data through the pretrained SimCLR encoder to get the representations
- The SimCLR model was pretrained on 100 patients (+50 for validation)

In [None]:
encoder_path = "../trained_models/simclr05.ckpt"
pretrained_model = load_model(SimCLR, encoder_path)  # Load pretrained simclr model
simclr_dm = SimCLRdataModule(pretrained_model, dm, data_args['batch_size'], data_args['num_workers'], device)

In [None]:
## Analyse the features: histogram and t-SNE plot
feature = next(iter(simclr_dm.train_dataloader()))[0]
plt.hist(np.asarray(feature), bins=50)
plt.show()

In [None]:
tsne = TSNE(n_components=2, n_iter=300, verbose=1, perplexity=125)
x = torch.cat(list(dm.train_dataloader())[0])
x, y = train_ds.tensors
tsne_results = tsne.fit_transform(x)

In [None]:
df = pd.DataFrame()
df['comp-1'] = tsne_results[:,0]
df['comp-2'] = tsne_results[:,1]
df["y"] = y
plt.figure(figsize=(12, 8))
sns.scatterplot(
    x="comp-1", y="comp-2",
    hue=df.y.tolist(),
    palette=sns.color_palette("hls", 5),
    data=df,
    legend="full",
    alpha=0.3
)

## Train a logistic classifier on top


In [None]:
logistic_args = {
  "MODEL_TYPE": "SupervisedModel",
  "save_name": "logistic_on_simclr",
  "DATA_PATH": data_args['DATA_PATH'],
  "CHECKPOINT_PATH": "checkpoints",

  "encoder": "None",
  "encoder_hparams": {},

  "classifier": "logistic",
  "data_hparams": data_args,

  "trainer_hparams":{
    "max_epochs": 15
  },
  "optim_hparams": {
    "lr": 1e-3,
    "weight_decay": 1e-4
  }
}
from trainers.train_supervised import train_supervised
model, res = train_supervised(Namespace(**logistic_args), device=device, dm=simclr_dm)
print(res)