<a href="https://colab.research.google.com/github/Div12345/SleepStaging-TransferLearning/blob/main/RP_SSL_Model_Sleep_TL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Self-Supervised Learning on EEG with Relative Positioning


Uncovering the structure of clinical EEG signals with self-supervised learning - 
[Paper](https://arxiv.org/pdf/2007.16104.pdf)

Options for SSL - 
1. Autoencoder
2. Relative Positioning
3. Temporal Shuffling
4. Contrastive Predictive Coding

From everything, the paper shows better performance with Relative Positioning I think

In [None]:
#@title Dependencies Install
!pip install wandb
!pip install git+https://github.com/sylvchev/beetl-competition
!pip install moabb
!pip install braindecode
!pip install git+https://github.com/pyRiemann/pyRiemann
!pip install matplotlib
!pip install https://github.com/ufoym/imbalanced-dataset-sampler/archive/master.zip

In [None]:
#@title Imports
from mne import get_config, set_config
import os.path as osp
import os
from beetl.task_datasets import BeetlSleepLeaderboard, BeetlSleepSource
import numpy as np
import pandas as pd
import mne
import logging

import braindecode
from braindecode import EEGClassifier
from braindecode.util import np_to_var, set_random_seeds
from braindecode.models import SleepStagerChambon2018
from braindecode.datautil.preprocess import preprocess, Preprocessor, zscore
from braindecode.samplers.ssl import RelativePositioningSampler
#from braindecode.datautil import create_from_X_y
from braindecode.datasets import BaseDataset

from braindecode.datasets import BaseDataset, BaseConcatDataset
from braindecode.datautil import create_fixed_length_windows

import sklearn
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.metrics import confusion_matrix, classification_report, balanced_accuracy_score, cohen_kappa_score

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

import skorch
import skorch.dataset
from skorch.callbacks import EarlyStopping, Checkpoint, EpochScoring, WandbLogger, TrainEndCheckpoint
from skorch.dataset import Dataset
from skorch.helper import predefined_split

import time

import pickle
import torch
from torch import nn
from torch.utils.data import DataLoader

from torchsampler import ImbalancedDatasetSampler as IDS

import wandb

import joblib

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


Wandb setup

Documentation for Skorch Adaption - [Link](https://gitbook-docs.wandb.ai/guides/integrations/other/skorch), [Simple Colab](https://colab.research.google.com/drive/1Bo8SqN1wNPMKv5Bn9NjwGecBxzFlaNZn?usp=sharing#scrollTo=9AoMDvXXpaUT), [Step by Step Tut](https://wandb.ai/cayush/uncategorized/reports/Automate-Kaggle-model-training-with-Skorch-and-W-B--Vmlldzo4NTQ1NQ)

In [None]:
!wandb login
# Specific to user - Put it here for convinience
# 7018335e3eae8802cd03eeeb536cd1bc36ccbc7b # should remove it later when publishing ig

[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter: 
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [None]:
mne.set_log_level(False) # Equivalent to WARNING
path = "/content/drive/MyDrive/mne_data"
set_config("MNE_DATA", path)
set_config("MNE_DATASETS_BEETLSLEEPLEADERBOARD_PATH",path)
set_config("MNE_DATASETS_BEETLSLEEPSOURCE_PATH",path)
get_config()

  set_config("MNE_DATASETS_BEETLSLEEPLEADERBOARD_PATH",path)
  set_config("MNE_DATASETS_BEETLSLEEPSOURCE_PATH",path)


{'MNE_DATA': '/content/drive/MyDrive/mne_data',
 'MNE_DATASETS_BEETLSLEEPLEADERBOARD_PATH': '/content/drive/MyDrive/mne_data',
 'MNE_DATASETS_BEETLSLEEPSOURCE_PATH': '/content/drive/MyDrive/mne_data'}

In [None]:
cuda = torch.cuda.is_available()  # check if GPU is available
device = 'cuda' if cuda else 'cpu'
print(device)
if cuda:
    torch.backends.cudnn.benchmark = True
# Set random seed to be able to reproduce results
set_random_seeds(seed=87, cuda=cuda)
random_state = 87
# print num_workers available
# 
!nvidia-smi -L
# !nvidia-smi -q

# gpu = cuda.get_current_device()
# print("maxThreadsPerBlock = %s" % str(gpu.MAX_THREADS_PER_BLOCK))

cuda
GPU 0: Tesla P100-PCIE-16GB (UUID: GPU-e69cecc5-d4a3-e7f5-0384-2ff76d35e84a)


In [None]:
#@title Helper Functions

def label_count(y_train):
  labels= np.unique(y_train)
  labelsize=labels.shape[0]
  #print('labelsize:',labelsize)
  label_count = np.zeros(labelsize).astype(int)
  for i in range(labelsize):
      # tempy = ys1[ys1==labels[i]]
      label_count[i]=y_train[y_train==labels[i]].shape[0]
  maxsize = label_count.max()
  print(label_count)
  return np_to_var(label_count)

def label_viz(y):
  # Another Nice func with viz
  classes_mapping = {0: 'W', 1: 'S1', 2: 'S2', 3: 'S3', 4: 'S4', 5:'REM'}
  # This might be a time consuming method though
  y_train = pd.Series([y for _, y in train_ds]).map(classes_mapping)
  ax = y_train.value_counts().plot(kind='barh')
  ax.set_xlabel('Number of training examples');
  ax.set_ylabel('Sleep stage');

# For trained Skorch model
def training_viz(clf):
  # For Trained Skorch Classifier
  df = pd.DataFrame(clf.history.to_list())
  df[['train_mis_clf', 'valid_mis_clf']] = 100 - df[
      ['train_bacc', 'valid_bacc']] * 100

  # get percent of misclass for better visual comparison to loss
  plt.style.use('seaborn-talk')
  fig, ax1 = plt.subplots(figsize=(20, 7))
  df.loc[:, ['train_loss', 'valid_loss']].plot(
      ax=ax1, style=['-', ':'], marker='o', color='tab:blue', legend=False,
      fontsize=14)

  ax1.tick_params(axis='y', labelcolor='tab:blue', labelsize=14)
  ax1.set_ylabel("Loss", color='tab:blue', fontsize=14)

  ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

  df.loc[:, ['train_mis_clf', 'valid_mis_clf']].plot(
      ax=ax2, style=['-', ':'], marker='o', color='tab:red', legend=False)
  ax2.tick_params(axis='y', labelcolor='tab:red', labelsize=14)
  ax2.set_ylabel('Balanced misclassification rate [%]', color='tab:red',
                fontsize=14)
  ax2.set_ylim(ax2.get_ylim()[0], 85)  # make some room for legend
  ax1.set_xlabel('Epoch', fontsize=14)

  # where some data has already been plotted to ax
  handles = []
  handles.append(
      Line2D([0], [0], color='black', linewidth=1, linestyle='-', label='Train'))
  handles.append(
      Line2D([0], [0], color='black', linewidth=1, linestyle=':', label='Valid'))
  plt.legend(handles, [h.get_label() for h in handles], fontsize=14)
  plt.tight_layout()

# Scaling the data
class TrainObject(object):
    # Scaling the data
    def __init__(self, X, y, nps= False):
        assert len(X) == len(y)
        # mean = np.mean(X, axis=2, keepdims=True)
        # # Here normalise across the window, when channel size is not large enough
        # # In motor imagery kit, we put axis = 1, across channel as an example
        # std = np.std(X, axis=2, keepdims=True)
        # X = (X - mean) / std
        X = zscore(X)
        # we scale it to 1000 as a better training scale of the shallow CNN
        # according to the orignal work of the paper referenced above
        if (nps == False):
          self.X = np_to_var(X.astype(np.float32)*1e3).to(device)
          self.y = np_to_var(y.astype(np.int8)).to(device)
        else:
          self.X = X.astype(np.float32)*1e3
          self.y = y.astype(np.int8)
          # Keeping y in float here will be better for the direct skorch case

# To update for the cases data is already available
def predict_leaderboard_unlabelled(clf,save_fname,x_test_data = None,emb = False,nps=False):
  # Test Data - 6 to 17 - 12 subjects 
  if x_test_data is None:
    _, _, X_test, _ = dsl.get_data(subjects=range(6, 18)) 
    print("Sleep leaderboard - Test Data : There are {} trials with {} electrodes and {} time samples".format(*X_test.shape))

    # print(X_test.shape[0])

    x_test_mean = TrainObject(X_test, y = np.zeros(X_test.shape[0]),nps=nps)
    if (nps==False):
      # means torch must've been the input to the classifier
      x_test_data = Dataset(x_test_mean.X,x_test_mean.y)
      # Maybe put a tqdm bar?
      y_pred = clf.predict(x_test_data)
    else:
      x_test_data = x_test_mean 
      y_pred = clf.predict(x_test_data.X)
  else:
    y_pred = clf.predict(x_test_data)
  print(X_test.shape)

  print("Checking if all classes have been predicted")
  print(np.unique(y_pred))

  np.savetxt("/content/drive/MyDrive/mne_data/predict/"+save_fname+".txt",y_pred,delimiter=',',fmt="%d")

def custom_create_from_X_y(
  
        X, y, sfreq , drop_last_window=False, ch_names=None, window_size_samples=None,
        window_stride_samples=None,preload=False,n_jobs=1):
    """Create a BaseConcatDataset of WindowsDatasets from X and y to be used for
    decoding with skorch and braindecode, where X is a list of pre-cut trials
    and y are corresponding targets.
    Parameters
    ----------
    X: array-like
        list of pre-cut trials as n_trials x n_channels x n_times
    y: array-like
        targets corresponding to the trials
    drop_last_window: bool
        whether or not have a last overlapping window, when
        windows/windows do not equally divide the continuous signal
    sfreq: float
        Sampling frequency of signals.
    ch_names: array-like
        Names of the channels.
    window_size_samples: int
        window size
    window_stride_samples: int
        stride between windows
    Returns
    -------
    windows_datasets: BaseConcatDataset
        X and y transformed to a dataset format that is compatible with skorch
        and braindecode
    """
    # Prevent circular import
    # from braindecode.preprocessing.windowers import (
    #     create_fixed_length_windows, )
    n_samples_per_x = []
    base_datasets = []
    if ch_names is None:
        ch_names = [str(i) for i in range(X.shape[1])]
        log.info(f"No channel names given, set to 0-{X.shape[1]}).")

    for x, target in zip(X, y):
        n_samples_per_x.append(x.shape[1])
        info = mne.create_info(ch_names=ch_names, sfreq=sfreq)
        raw = mne.io.RawArray(x, info)
        base_dataset = BaseDataset(raw, pd.Series({"target": target}),
                                   target_name="target")
        # Also add option to give further description in line above for subject,run,session
        base_datasets.append(base_dataset)
    base_datasets = BaseConcatDataset(base_datasets)

    if window_size_samples is None and window_stride_samples is None:
        if not len(np.unique(n_samples_per_x)) == 1:
            raise ValueError("if 'window_size_samples' and "
                             "'window_stride_samples' are None, "
                             "all trials have to have the same length")
        window_size_samples = n_samples_per_x[0]
        window_stride_samples = n_samples_per_x[0]
    windows_datasets = create_fixed_length_windows(
        base_datasets,
        start_offset_samples=0,
        stop_offset_samples=None, # this makes it consider it the end of the recording
        window_size_samples=window_size_samples,
        window_stride_samples=window_stride_samples,
        drop_last_window=drop_last_window,
        preload = preload,
        n_jobs = n_jobs
    )
    return windows_datasets

def data_save(BrainDecode_data,fname):
  # convert all trainB,valB,testB into dict and save
  # BrainDecode_data = dict()
  # BrainDecode_data["train"] = trainB
  # BrainDecode_data["valid"] = valB
  # BrainDecode_data["test"] = testB

  save_path = "/content/drive/MyDrive/mne_data/" 
  with open(save_path+fname+'.pkl', 'wb') as f:
    pickle.dump(BrainDecode_data, f, pickle.HIGHEST_PROTOCOL)

from time import time as t
def load_obj(path,name):
  target = path + name + '.pkl'
  with open(target, 'rb') as f:
    unpickler = pickle.Unpickler(f)
    a = unpickler.load()
    return a


In [None]:
#@title Helper functions for loading data
# B - phase 1, C - phase 2
# get_testC will be same simple similar to testB
# But get_valC has to be the similar long-winded route like trainB, valB


def get_trainB():
  start = time.time()
  dss = BeetlSleepSource()
  X_train_list, y_train_list, trsubj = [], [], []
  for i in range(39):
    a, b, _ = dss.get_data(subjects = [i])
    X_train_list.append(a)
    y_train_list.append(b)
    trsubj.append([i]*y_train_list[i].shape[0])
    # print(i)
    # print(y_train_list[i].shape[0])
  end = time.time()
  print(f"training data load time = {(end-start)/60} min")

  print(len(trsubj))
  print(len(X_train_list))
  print(y_train_list)
  print(X_train_list[0].shape)
  print(X_train_list[1].shape)
  X_train = np.concatenate(X_train_list,0)
  y_train = np.concatenate(y_train_list,0)

  print(X_train.shape,y_train.shape)
  del X_train_list,y_train_list
  # For the Sleep Physionet in Braindecode, this is the desc used
  # desc = pd.Series({'subject': subj_nb, 'recording': sess_nb}, name='')

  sfreq = 100
  # z-scoring or normalizing

  # start = time.time()
  trainX = TrainObject(X_train, y = y_train, nps = True)
  end = time.time()
  print(f"Z-scoring time {end-start}")

  # train = Dataset(trainX.X,trainX.y)
  # For braindecode - returns BaseConcatDatset 
  # - Seems this is also compatible with skorch

  start = time.time()
  trainB = custom_create_from_X_y(trainX.X,trainX.y, sfreq = 100, preload = True, n_jobs = -1)
  # Nah this mne type structure seems way to wasteful.. Can rather just write own implementation of the sampler from pytorch Dataset I think
  # It took a long time, but RAM sort of cleared out after running, not bad
  end = time.time()
  print(f"Braindecode dataset time {end-start} sec = {(end-start)/60} min")
  # 10 sec + 7 min
  len(trainB.datasets)
  # trainB.datasets
  print(dir(trainB))
  print(type(trainB.description)) # pd.Series
  train_subj = np.concatenate(trsubj)
  print(train_subj.shape)

  y = 0
  for i,ds in enumerate(trainB.datasets):
    ds.description["subject"] = train_subj[i]
    if (i==0):
      y = 0
    elif (train_subj[i]!=train_subj[i-1]):
      y = 0
    ds.windows.metadata['i_window_in_trial'] = y
    ds.windows.metadata['i_start_in_trial'] = y*3000
    ds.windows.metadata['i_stop_in_trial'] = (y+1)*3000
    y+=1
  # _compute_window_inds from braindecode.preprocessing.windowers
  # Raise issue to include this in create_from_X_y if the param set that subject data is cont?
  # trainB.description["subject"] = train_subj
  print(trainB.description)

  # 3 sec + 1.5 min
  return trainB

def get_valB():
  # valB
  # Labelled Leaderboard Data
  start = time.time()
  dsl = BeetlSleepLeaderboard()
  # dsl.get_data()

  # Validation Data - 5? 6? subjects from the test group - Competition says 5, looks like 6

  X_target_list, y_target_list, tssubj = [], [], []
  for i in range(6):
    a, b, _,_ = dsl.get_data(subjects = [i])
    X_target_list.append(a)
    y_target_list.append(b)
    tssubj.append([i]*y_target_list[i].shape[0])
    # print(y_target_list[i].shape[0])
  end = time.time()
  print(f"leaderboard labelled load time = {(end-start)/60} min")

  print(len(tssubj))
  print(len(X_target_list))
  print(y_target_list)

  # X_target, y_target, _, _ = dsl.get_data(subjects=range(0,6))
  # label_count(y_target_list[5]) # Seems like there are 6 valid subjects in the labelled leaderboard
  # print(X_target_list[0].shape)
  # print(X_target_list[1].shape)

  X_target = np.concatenate(X_target_list,0)
  y_target = np.concatenate(y_target_list,0)
  sfreq = 100
  print(X_target.shape,y_target.shape)
  del X_target_list,y_target_list
  label_count(y_target)
  # z-scoring or normalizing
  start = time.time()
  valX = TrainObject(X_target, y = y_target, nps = True)
  end = time.time()
  print(f"Z-scoring time {end-start} sec")
  # val = Dataset(valX.X,valX.y)
  # for braindecode
  start = time.time()
  valB = custom_create_from_X_y(valX.X,valX.y, sfreq = 100, preload = True, n_jobs = -1)
  end = time.time()
  print(f"Braindecode dataset time {end-start} sec = {(end-start)/60} min")

  # Since I need to use this for RP task training purposes too, 
  # Need to add the time modification stuff here too
  # 10 sec + 7 min
  len(valB.datasets)
  # trainB.datasets
  print(dir(valB))
  print(type(valB.description)) # pd.Series
  ts_subj = np.concatenate(tssubj)
  print(ts_subj.shape)

  y = 0
  for i,ds in enumerate(valB.datasets):
    ds.description["subject"] = ts_subj[i]
    if (i==0):
      y = 0
    elif (ts_subj[i]!=ts_subj[i-1]):
      y = 0
    ds.windows.metadata['i_window_in_trial'] = y
    ds.windows.metadata['i_start_in_trial'] = y*3000
    ds.windows.metadata['i_stop_in_trial'] = (y+1)*3000
    y+=1
  # _compute_window_inds from braindecode.preprocessing.windowers
  # Raise issue to include this in create_from_X_y if the param set that subject data is cont?
  # valB.description["subject"] = ts_subj
  print(valB.description)

  return valB

def get_testB():
  dsl = BeetlSleepLeaderboard()
  _, _, X_test, _ = dsl.get_data(subjects=range(6, 18)) 
  print("Sleep leaderboard - Test Data : There are {} trials with {} electrodes and {} time samples".format(*X_test.shape))
  # print(X_test.shape[0])
  # x_test_mean = TrainObject(X_test, y = np.zeros(X_test.shape[0]),nps=nps)

  # z-scoring or normalizing
  start = time.time()
  testX = TrainObject(X_test, y = np.zeros(X_test.shape[0]), nps = True)
  end = time.time()
  print(f"Z-scoring time {end-start} sec")
  # val = Dataset(valX.X,valX.y)
  # for braindecode
  start = time.time()
  testB = custom_create_from_X_y(testX.X,testX.y, sfreq = 100, preload = True, n_jobs = -1)
  end = time.time()
  print(f"Braindecode dataset time {end-start} sec = {(end-start)/60} min")
  return testB


def get_valC():
  # Phase 2 Target Data
  # Need to download by myself first
  target_savebase = '/content/drive/MyDrive/mne_data/MNE-beetlsleeptest-data/sleep_target/'
  X_sleep_target = []
  y_sleep_target = []
  tssubj = []
  #from s0-s4 in final set
  start = time.time()
  for subj in range(0, 5):
    for session in range(1, 3):
      # "testing_s{}r{}X.npy", replacing "leaderboard_s{}r{}X.npy" before
      with open(target_savebase + "testing_s{}r{}X.npy".format(subj, session), 'rb') as f:
        X_sleep_target.append(pickle.load(f))
      with open(target_savebase + "testing_s{}r{}y.npy".format(subj, session), 'rb') as g:
        y_sleep_target.append(pickle.load(g))
      
      tssubj.append([subj]*y_sleep_target[2*subj+session-1].shape[0])
  
  end = time.time()
  print(f"phase 2 labelled load time = {(end-start)/60} min")

  print(len(tssubj))
  print(len(X_sleep_target))
  print(y_sleep_target)
  # print(X_sleep_target[0].shape)
  # print(X_sleep_target[1].shape)

  X_sleep_target = np.concatenate(X_sleep_target)
  y_sleep_target = np.concatenate(y_sleep_target)
  sfreq = 100

  print("There are {} trials with {} electrodes and {} time samples".format(*X_sleep_target.shape))
  print(X_sleep_target.shape, y_sleep_target.shape)
  label_count(y_sleep_target)

  # package in torch dataset with mean normalizing
  # z-scoring or normalizing
  start = time.time()
  valX = TrainObject(X_sleep_target, y = y_sleep_target, nps = True)
  end = time.time()
  print(f"Z-scoring time {end-start} sec")

  
  # for braindecode
  start = time.time()
  # valC = Dataset(valX.X,valX.y)
  valC = custom_create_from_X_y(valX.X,valX.y, sfreq = 100, preload = True, n_jobs = -1)
  end = time.time()
  print(f"Braindecode dataset time {end-start} sec = {(end-start)/60} min")
  # 

  # Since I need to use this for RP task training purposes too, 
  # Need to add the time modification stuff here too
  # 10 sec + 7 min
  len(valC.datasets)
  # trainB.datasets
  print(dir(valC))
  print(type(valC.description)) # pd.Series
  ts_subj = np.concatenate(tssubj)
  print(ts_subj.shape)

  y = 0
  for i,ds in enumerate(valC.datasets):
    ds.description["subject"] = ts_subj[i]
    if (i==0):
      y = 0
    elif (ts_subj[i]!=ts_subj[i-1]):
      y = 0
    ds.windows.metadata['i_window_in_trial'] = y
    ds.windows.metadata['i_start_in_trial'] = y*3000
    ds.windows.metadata['i_stop_in_trial'] = (y+1)*3000
    y+=1
  # _compute_window_inds from braindecode.preprocessing.windowers
  # Raise issue to include this in create_from_X_y if the param set that subject data is cont?
  # valC.description["subject"] = ts_subj
  print(valC.description)

  return valC

def get_testC():
  # Phase 2 Test Data - Need to download by myself first
  test_savebase = '/content/drive/MyDrive/mne_data/MNE-beetlsleeptest-data/testing/'
  X_sleep_test = []
  start = time.time()

  #starts from s5 in final set
  for subj in range(5, 14):
      for session in range(1, 3):
          # "testing_s{}r{}X.npy", replacing "leaderboard_s{}r{}X.npy" before
          with open(test_savebase + "testing_s{}r{}X.npy".format(subj, session), 'rb') as f:
              X_sleep_test.append(pickle.load(f))
  X_sleep_test = np.concatenate(X_sleep_test)
  end = time.time()
  print(f"phase 2 final test set load time = {(end-start)/60} min")
  print ("There are {} trials with {} electrodes and {} time samples".format(*X_sleep_test.shape))

  # package in torch dataset with mean normalizing
  # z-scoring or normalizing
  start = time.time()
  testX = TrainObject(X_sleep_test, y = np.zeros(X_sleep_test.shape[0]), nps = True)
  end = time.time()
  print(f"Z-scoring time {end-start} sec")
  sfreq = 100
  # for braindecode
  start = time.time()
  # testC = Dataset(testX.X,testX.y) # Torch
  testC = custom_create_from_X_y(testX.X,testX.y, sfreq = 100, preload = True, n_jobs = -1)
  end = time.time()
  print(f"Torch dataset time {end-start} sec = {(end-start)/60} min")

  return testC

In [None]:
log = logging.getLogger(__name__)

In [None]:
#@title RPD

# Train-Test Split - Not necessary here - We have our train and validation
# train and val - skorch dataset class

# TODO - Mods required - I think done
# New Dataset class which can receive a pair of indices and return the corresponding epochs
# RPD - RelativePositioningDataset
from braindecode.datasets import BaseConcatDataset
class RPD(BaseConcatDataset):
    """BaseConcatDataset with __getitem__ that expects 2 indices and a target.
    """
    def __init__(self, list_of_ds):
        super().__init__(list_of_ds)
        self.return_pair = True

    def __getitem__(self, index):
        if self.return_pair:
            ind1, ind2, y = index
            # TODO - Check if the 0 indexing is required
            return (super().__getitem__(ind1)[0],
                    super().__getitem__(ind2)[0]), y
            # return(super())
        else:
            return super().__getitem__(index)

    @property
    def return_pair(self):
        return self._return_pair

    @return_pair.setter
    def return_pair(self, value):
        self._return_pair = value

In [None]:
#@title SleepStagerChambon2018 model and Contrastive Net
# Extract number of channels and time steps from dataset
# n_channels = X_train[0].shape[0] #2
n_channels = 2 #@param
#print(n_channels)
# input_size_samples = X_train[0].shape[1]
input_size_samples = 3000 #@param
#print(input_size_samples)
emb_size = 100 #@param
n_conv_chs = 16 #@param
# dropout = 0.5
sfreq = 100 #@param

emb = SleepStagerChambon2018(
    n_channels,
    sfreq,
    n_classes=emb_size,
    n_conv_chs=n_conv_chs,
    input_size_s=input_size_samples / sfreq,
    dropout=0,
    apply_batch_norm=True
)


class ContrastiveNet(nn.Module):
    """Contrastive module with linear layer on top of siamese embedder.

    Parameters
    ----------
    emb : nn.Module
        Embedder architecture.
    emb_size : int
        Output size of the embedder.
    dropout : float
        Dropout rate applied to the linear layer of the contrastive module.
    """
    def __init__(self, emb, emb_size, dropout=0.5):
        super().__init__()
        self.emb = emb
        self.clf = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(emb_size, 1)
        )

    def forward(self, x):
        x1, x2 = x
        z1, z2 = self.emb(x1), self.emb(x2)
        return self.clf(torch.abs(z1 - z2)).flatten()



  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


I don't think it's right to train the SSL on the data from the test, considering that the pre-text task is temporal.. Even though it is given as subjects in the data, I'm not sure I can be sure that it is temporally un-shuffled data. 

# Function based generation of trainB, valB

In [None]:
trainB = get_trainB()
valB = get_valB()
# 19min

Repeat for Test Dataset

In [None]:
# convert all trainB,valB,testB into dict and save
BrainDecode_data = dict()
BrainDecode_data["train"] = trainB
BrainDecode_data["valid"] = valB
# BrainDecode_data["test"] = testB

save_path = "/content/drive/MyDrive/mne_data/" 
with open(save_path+'BraindecodeData_Dict2.pkl', 'wb') as f:
    pickle.dump(BrainDecode_data, f, pickle.HIGHEST_PROTOCOL)

In [None]:
# BrainDecode_data.to_pickle(save_path+'BraindecodeData_Dict3.pkl')
del BrainDecode_data

In [None]:
# Should run this still sometime later
# z_dat = dict()
# z_dat["train"] = trainX
# z_dat["valid"] = valX
# z_dat["test"] = testX

# save_path = "/content/drive/MyDrive/mne_data/" 
# with open(save_path+'zScored_np_Dict.pkl', 'wb') as f:
#     pickle.dump(z_dat, f, pickle.HIGHEST_PROTOCOL)

##
# Just call this file and do the follow to get tensors 
# self.X = np_to_var(X.astype(np.float32)*1e3).to(device)
# self.y = np_to_var(y.astype(np.int8)).to(device)

In [None]:
# trainB = get_trainB

valB = get_valB()

In [None]:
valC = get_valC()

testB = get_testB()
testC = get_testC()

phase 2 labelled load time = 0.3801357587178548 min
10
10
[array([2, 2, 2, ..., 0, 2, 1], dtype=int32), array([0, 2, 0, ..., 1, 1, 2], dtype=int32), array([0, 5, 0, ..., 5, 2, 5], dtype=int32), array([2, 5, 0, ..., 2, 2, 5], dtype=int32), array([0, 5, 5, ..., 0, 5, 0], dtype=int32), array([0, 0, 0, ..., 0, 3, 5], dtype=int32), array([0, 1, 2, ..., 0, 0, 0], dtype=int32), array([0, 0, 0, ..., 1, 2, 2], dtype=int32), array([0, 0, 2, ..., 3, 2, 0], dtype=int32), array([4, 0, 2, ..., 4, 1, 4], dtype=int32)]
There are 16568 trials with 2 electrodes and 3000 time samples
(16568, 2, 3000) (16568,)
[7809 1639 4704  689  145 1582]
Z-scoring time 0.9398367404937744 sec
Braindecode dataset time 89.99127793312073 sec = 1.4998546322186788 min
['__add__', '__annotations__', '__class__', '__class_getitem__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__len__', '

In [None]:
print(valB)
print(valC)
print(testB)
print(testC)

<braindecode.datasets.base.BaseConcatDataset object at 0x7f26459a5650>
<braindecode.datasets.base.BaseConcatDataset object at 0x7f2649750b50>
<braindecode.datasets.base.BaseConcatDataset object at 0x7f2628d787d0>
<braindecode.datasets.base.BaseConcatDataset object at 0x7f264bc8e690>


# Relative Positioning Dataset and Sampler

In [None]:
# a = trainB.datasets[0]
# dir(a)
# a.windows.metadata

In [None]:
# split_ids = {'train': subj_train, 'valid': subj_valid, 'test': subj_test}
splitted = dict()
# for name, values in split_ids.items():
#     splitted[name] = RelativePositioningDataset(
#         [ds for ds in windows_dataset.datasets
#          if ds.description['subject'] in values])
from time import time as t
s = t()
splitted['train'] = RPD([ds for ds in trainB.datasets])
e = t()
print(e-s)
s = t()
splitted['valid'] = RPD([ds for ds in valB.datasets])
e = t()
print(e-s)
s = t()
splitted['test'] = RPD([ds for ds in valC.datasets])
e = t()
print(e-s)
# get data from phase 2 validation as test sampler here

# should do del(all prev data related var.s)


2.0512423515319824
2.1328721046447754


In [None]:
a = splitted['test'].get_metadata()[0:100]
print(a)

    i_window_in_trial  i_start_in_trial  i_stop_in_trial  target  subject
0                   0                 0             3000       2        0
0                   1              3000             6000       2        0
0                   2              6000             9000       2        0
0                   3              9000            12000       5        0
0                   4             12000            15000       2        0
..                ...               ...              ...     ...      ...
0                  95            285000           288000       0        0
0                  96            288000           291000       0        0
0                  97            291000           294000       1        0
0                  98            294000           297000       0        0
0                  99            297000           300000       2        0

[100 rows x 5 columns]


In [None]:
del a

In [None]:
# RP samplers - randomly sample pairs of examples to train and validate model with SSL
# 2 main hyperparams - 
# Pairs of windows that are separated by less than tau_pos samples will be given a label of 1, 
# while pairs of windows that are separated by more than tau_neg samples will be given a label of 0. 
# another param - n_examples - # pairs to sample
# Higher number - better regularization 2000 pairs per rec in paper

# Log these params in wandb
n_examples = 2000 #250
sfreq = 100
tau_pos, tau_neg = int(sfreq * 60), int(sfreq * 15 * 60)
# Should be n_examples per recording
# print(len(splitted['train'].datasets))

# n_examples_train = n_examples * len(splitted['train'].datasets)
# n_examples_valid = n_examples * len(splitted['valid'].datasets)

# n_examples_train = n_examples * 39
n_examples_valid = n_examples * 6
n_examples_test = n_examples * 5

# print(n_examples_train,n_examples_valid)
random_state = 87 # remove in next run - added line in cuda cell
# train_sampler = RelativePositioningSampler(
#     splitted['train'].get_metadata(), tau_pos=tau_pos, tau_neg=tau_neg,
#     n_examples=n_examples_train, same_rec_neg=True, random_state=random_state)
valid_sampler = RelativePositioningSampler(
    splitted['valid'].get_metadata(), tau_pos=tau_pos, tau_neg=tau_neg,
    n_examples=n_examples_valid, same_rec_neg=True, random_state=random_state).presample()
test_sampler = RelativePositioningSampler(
    splitted['test'].get_metadata(), tau_pos=tau_pos, tau_neg=tau_neg,
    n_examples=n_examples_test, same_rec_neg=True, random_state=random_state).presample()

In [None]:
print(test_sampler.info)

                                                     index                                   i_start_in_trial
subject                                                                                                      
0        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...  [0, 3000, 6000, 9000, 12000, 15000, 18000, 210...
1        [2570, 2571, 2572, 2573, 2574, 2575, 2576, 257...  [0, 3000, 6000, 9000, 12000, 15000, 18000, 210...
2        [5757, 5758, 5759, 5760, 5761, 5762, 5763, 576...  [0, 3000, 6000, 9000, 12000, 15000, 18000, 210...
3        [9062, 9063, 9064, 9065, 9066, 9067, 9068, 906...  [0, 3000, 6000, 9000, 12000, 15000, 18000, 210...
4        [12086, 12087, 12088, 12089, 12090, 12091, 120...  [0, 3000, 6000, 9000, 12000, 15000, 18000, 210...


In [None]:
print(train_sampler.info)

                                                     index                                   i_start_in_trial
subject                                                                                                      
0        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...  [0, 3000, 6000, 9000, 12000, 15000, 18000, 210...
1        [1931, 1932, 1933, 1934, 1935, 1936, 1937, 193...  [0, 3000, 6000, 9000, 12000, 15000, 18000, 210...
2        [4398, 4399, 4400, 4401, 4402, 4403, 4404, 440...  [0, 3000, 6000, 9000, 12000, 15000, 18000, 210...
3        [6830, 6831, 6832, 6833, 6834, 6835, 6836, 683...  [0, 3000, 6000, 9000, 12000, 15000, 18000, 210...
4        [8781, 8782, 8783, 8784, 8785, 8786, 8787, 878...  [0, 3000, 6000, 9000, 12000, 15000, 18000, 210...
5        [10988, 10989, 10990, 10991, 10992, 10993, 109...  [0, 3000, 6000, 9000, 12000, 15000, 18000, 210...
6        [13104, 13105, 13106, 13107, 13108, 13109, 131...  [0, 3000, 6000, 9000, 12000, 15000, 18000, 210...
7        [

# SleepStagerChambon Training for RP task - Based on [this tutorial](https://braindecode.org/auto_examples/plot_relative_positioning.html#creating-the-model)

In [None]:
model = ContrastiveNet(emb, emb_size).to(device)

In [None]:
# wandb

wandb_run = wandb.init(name = "RP-SSL SSC2018 train", project='RP-fulltrain', entity='sleep_hacking')

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

In [None]:

# Train the network
lr = 5e-2 # 5e-4
batch_size = 1024 # increase this next time
n_epochs = 150
num_workers = 0 # if n_jobs <= 1 else n_jobs

# Callback WandbLogger logs the "best trained model", etc. after every epoch
# Not sure of definition of best trained

# TODO - 
# Add early stopping with 10 on valid_bacc I guess - 
# This fails if big oscillations happens, but well the lr is low

# should make this function of parameters for easiness
save_path = "/content/drive/MyDrive/mne_data/rp_ssc_checkpoints3" 

# Log hyperparameters
wandb_run.config.update({"Embedder_size":emb_size, "learning rate": lr, 
                         "batch size": batch_size, "n_conv_chs_SSC": n_conv_chs,
                         "tau_pos":tau_pos, "tau_neg":tau_neg, "n_examples_RPsampler":n_examples,
                         "n_examples_train":n_examples_train,"n_examples_val":n_examples_valid,
                         "save_path":save_path})

# org paper used a weight decay of 1e-3 on all trainable params
cp = Checkpoint(monitor = 'valid_acc_best',
                f_params = None, f_optimizer = None, f_criterion = None,
                f_pickle = "model_{last_epoch[epoch]}.pkl",
                dirname = save_path)
train_acc = EpochScoring(
    scoring='accuracy', on_train=True, name='train_acc',
    lower_is_better=False)
valid_acc = EpochScoring(
    scoring='accuracy', on_train=False, name='valid_acc',
    lower_is_better=False)
callbacks = [('train_acc', train_acc),
             ('valid_acc', valid_acc),
             ("checkpoint",cp),
             ("wandb",WandbLogger(wandb_run))
             ]


clf = EEGClassifier(
    model,
    criterion=torch.nn.BCEWithLogitsLoss,
    optimizer=torch.optim.Adam,
    max_epochs=n_epochs,
    iterator_train__shuffle=False,
    iterator_train__sampler=train_sampler,
    iterator_valid__sampler=valid_sampler,
    iterator_train__num_workers=num_workers,
    iterator_valid__num_workers=num_workers,
    train_split=predefined_split(splitted['valid']),
    optimizer__lr=lr,
    batch_size=batch_size,
    callbacks=callbacks,
    device=device
)

print(clf)

<class 'braindecode.classifier.EEGClassifier'>[uninitialized](
  module=ContrastiveNet(
    (emb): SleepStagerChambon2018(
      (spatial_conv): Conv2d(1, 2, kernel_size=(2, 1), stride=(1, 1))
      (feature_extractor): Sequential(
        (0): Conv2d(1, 16, kernel_size=(1, 50), stride=(1, 1), padding=(0, 25))
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): MaxPool2d(kernel_size=(1, 13), stride=(1, 13), padding=0, dilation=1, ceil_mode=False)
        (4): Conv2d(16, 16, kernel_size=(1, 50), stride=(1, 1), padding=(0, 25))
        (5): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): ReLU()
        (7): MaxPool2d(kernel_size=(1, 13), stride=(1, 13), padding=0, dilation=1, ceil_mode=False)
      )
      (fc): Sequential(
        (0): Dropout(p=0, inplace=False)
        (1): Linear(in_features=544, out_features=100, bias=True)
      )
    )
    (clf): Sequential(
     

Classifier Fit for 2k pairs, batch_size 1024 with increased learning rate 5e-2

In [None]:
# del clf
mne.set_log_level(True)
clf.fit(splitted["train"], y=None)

  epoch    train_acc    train_loss    valid_acc    valid_loss    cp       dur
-------  -----------  ------------  -----------  ------------  ----  --------
      1       [36m0.7347[0m        [32m0.7485[0m       [35m0.5259[0m        [31m1.0389[0m     +  442.1357
      2       [36m0.8034[0m        [32m0.4372[0m       0.5228        1.5883        447.0001
      3       [36m0.8240[0m        [32m0.4050[0m       [35m0.5264[0m        1.3782     +  449.5075
      4       [36m0.8276[0m        [32m0.3979[0m       0.5230        1.5310        447.5935
      5       [36m0.8350[0m        [32m0.3818[0m       [35m0.5302[0m        1.4573     +  444.6362
      6       [36m0.8394[0m        [32m0.3742[0m       0.5268        1.4676        446.8178
      7       [36m0.8441[0m        [32m0.3635[0m       0.5222        1.9785        443.0689
      8       [36m0.8563[0m        [32m0.3435[0m       [35m0.5331[0m        1.4493     +  443.6451
      9       [36m0.8589[0m  

<class 'braindecode.classifier.EEGClassifier'>[initialized](
  module_=ContrastiveNet(
    (emb): SleepStagerChambon2018(
      (spatial_conv): Conv2d(1, 2, kernel_size=(2, 1), stride=(1, 1))
      (feature_extractor): Sequential(
        (0): Conv2d(1, 16, kernel_size=(1, 50), stride=(1, 1), padding=(0, 25))
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): MaxPool2d(kernel_size=(1, 13), stride=(1, 13), padding=0, dilation=1, ceil_mode=False)
        (4): Conv2d(16, 16, kernel_size=(1, 50), stride=(1, 1), padding=(0, 25))
        (5): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): ReLU()
        (7): MaxPool2d(kernel_size=(1, 13), stride=(1, 13), padding=0, dilation=1, ceil_mode=False)
      )
      (fc): Sequential(
        (0): Dropout(p=0, inplace=False)
        (1): Linear(in_features=544, out_features=100, bias=True)
      )
    )
    (clf): Sequential(
      

In [None]:
with open(save_path+'/final_model.pkl', 'wb') as f:
    pickle.dump(clf, f)

In [None]:
# Use the val itself to see the matrix
y_pred = clf.forward(splitted['valid'], training=False) > 0

y_true = [y for _, _, y in valid_sampler]

print(confusion_matrix(y_true, y_pred))
print(classification_report(y_true, y_pred))

[[5156  765]
 [4997 1082]]
              precision    recall  f1-score   support

         0.0       0.51      0.87      0.64      5921
         1.0       0.59      0.18      0.27      6079

    accuracy                           0.52     12000
   macro avg       0.55      0.52      0.46     12000
weighted avg       0.55      0.52      0.45     12000



In [None]:
wandb.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
dur,450.05605
train_loss,0.30546
valid_loss,2.40528
train_acc,0.87683
valid_acc,0.52158
_runtime,11278.0
_timestamp,1631228205.0
_step,24.0


0,1
dur,▂▆█▆▄▆▃▃▅▅▄▃▂▃▃▄▄▃▃▂▁▃▄▆█
train_loss,█▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
valid_loss,▁▃▂▃▃▃▅▃▃▅▃▅▅▆█▃▆▇▆▅▅▃▇▅▇
train_acc,▁▄▅▅▆▆▆▇▇▇▇▇▇████████████
valid_acc,▄▃▄▃▆▄▃▇▇▃▆▃▅▃▁█▅▂▅▄▄▆▃▆▂
_runtime,▁▁▂▂▂▂▃▃▃▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
_timestamp,▁▁▂▂▂▂▃▃▃▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
_step,▁▁▂▂▂▂▃▃▃▄▄▄▅▅▅▅▆▆▆▇▇▇▇██


# Other SSL-CNN training tries

Classifier Fit for 2k pairs, batch_size 1024

In [None]:
# del clf
mne.set_log_level(True)
clf.fit(splitted["train"], y=None)

  epoch    train_acc    train_loss    valid_acc    valid_loss    cp       dur
-------  -----------  ------------  -----------  ------------  ----  --------
      1       [36m0.8249[0m        [32m0.4070[0m       [35m0.5135[0m        [31m1.5999[0m     +  451.3815
      2       [36m0.8299[0m        [32m0.4006[0m       0.5111        1.6242        445.3153
      3       [36m0.8359[0m        [32m0.3893[0m       0.5129        1.6680        443.0258
      4       [36m0.8391[0m        [32m0.3824[0m       0.5108        1.8031        442.6547
      5       [36m0.8429[0m        [32m0.3756[0m       0.5122        1.7254        443.6592


<class 'braindecode.classifier.EEGClassifier'>[initialized](
  module_=ContrastiveNet(
    (emb): SleepStagerChambon2018(
      (spatial_conv): Conv2d(1, 2, kernel_size=(2, 1), stride=(1, 1))
      (feature_extractor): Sequential(
        (0): Conv2d(1, 16, kernel_size=(1, 50), stride=(1, 1), padding=(0, 25))
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): MaxPool2d(kernel_size=(1, 13), stride=(1, 13), padding=0, dilation=1, ceil_mode=False)
        (4): Conv2d(16, 16, kernel_size=(1, 50), stride=(1, 1), padding=(0, 25))
        (5): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): ReLU()
        (7): MaxPool2d(kernel_size=(1, 13), stride=(1, 13), padding=0, dilation=1, ceil_mode=False)
      )
      (fc): Sequential(
        (0): Dropout(p=0, inplace=False)
        (1): Linear(in_features=544, out_features=100, bias=True)
      )
    )
    (clf): Sequential(
      

In [None]:
with open('/content/drive/MyDrive/mne_data/rp_ssc_checkpoints2/final_model.pkl', 'wb') as f:
    pickle.dump(clf, f)

In [None]:
# Use the val itself to see the matrix
y_pred = clf.forward(splitted['valid'], training=False) > 0

y_true = [y for _, _, y in valid_sampler]

print(confusion_matrix(y_true, y_pred))
print(classification_report(y_true, y_pred))

[[4626 1295]
 [4540 1539]]
              precision    recall  f1-score   support

         0.0       0.50      0.78      0.61      5921
         1.0       0.54      0.25      0.35      6079

    accuracy                           0.51     12000
   macro avg       0.52      0.52      0.48     12000
weighted avg       0.52      0.51      0.48     12000



Probably should try increasing the learning rate

Classifier Fit for 2000 pairs in RP Sampler and batch_size 256




In [None]:
mne.set_log_level(True)
clf.fit(splitted["train"], y=None)

  epoch    train_acc    train_loss    valid_acc    valid_loss    cp       dur
-------  -----------  ------------  -----------  ------------  ----  --------
      1       [36m0.7621[0m        [32m0.4982[0m       [35m0.5065[0m        [31m1.6180[0m     +  448.4228
      2       [36m0.8071[0m        [32m0.4340[0m       [35m0.5132[0m        1.7131     +  453.6100


<class 'braindecode.classifier.EEGClassifier'>[initialized](
  module_=ContrastiveNet(
    (emb): SleepStagerChambon2018(
      (spatial_conv): Conv2d(1, 2, kernel_size=(2, 1), stride=(1, 1))
      (feature_extractor): Sequential(
        (0): Conv2d(1, 16, kernel_size=(1, 50), stride=(1, 1), padding=(0, 25))
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): MaxPool2d(kernel_size=(1, 13), stride=(1, 13), padding=0, dilation=1, ceil_mode=False)
        (4): Conv2d(16, 16, kernel_size=(1, 50), stride=(1, 1), padding=(0, 25))
        (5): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): ReLU()
        (7): MaxPool2d(kernel_size=(1, 13), stride=(1, 13), padding=0, dilation=1, ceil_mode=False)
      )
      (fc): Sequential(
        (0): Dropout(p=0, inplace=False)
        (1): Linear(in_features=544, out_features=100, bias=True)
      )
    )
    (clf): Sequential(
      

In [None]:
# Use the val itself to see the matrix
y_pred = clf.forward(splitted['valid'], training=False) > 0

y_true = [y for _, _, y in valid_sampler]

print(confusion_matrix(y_true, y_pred))
print(classification_report(y_true, y_pred))

[[4385 1536]
 [4299 1780]]
              precision    recall  f1-score   support

         0.0       0.50      0.74      0.60      5921
         1.0       0.54      0.29      0.38      6079

    accuracy                           0.51     12000
   macro avg       0.52      0.52      0.49     12000
weighted avg       0.52      0.51      0.49     12000



Classifier Fitting Table for 250 pairs in RP sampler

In [None]:
# Model training for a specified number of epochs. `y` is None as it is already
# supplied in the dataset.
# The idotic thing doesn't even seem to get converted to tensor or gpu.. idk
# It's loading everything at the sampler, even though everything should be in memory
# See if the data instead of me
# Memory access is within couple of seconds when preload is done 
# still not sure if the tensor conversion can be activated at a point before the dataloader
# Follow up with an issue citing this - https://github.com/braindecode/braindecode/issues/63#issuecomment-584035555 
# Checked the follow up PRs, etc. Didn't see any discussion regarding this tensor cast clearly
mne.set_log_level(True)
clf.fit(splitted["train"], y=None)
# clf.load_params(checkpoint=cp)  # Load the model with the lowest valid_loss

Re-initializing optimizer because the following parameters were re-set: lr.
  epoch    train_acc    train_loss    valid_acc    valid_loss    cp      dur
-------  -----------  ------------  -----------  ------------  ----  -------
      1       [36m0.7188[0m        [32m0.5638[0m       [35m0.4967[0m        [31m1.1182[0m     +  52.8966
      2       [36m0.7631[0m        [32m0.5085[0m       [35m0.5080[0m        1.2368     +  55.3495
      3       [36m0.7839[0m        [32m0.4825[0m       0.5053        1.2735        56.1307
      4       [36m0.7876[0m        [32m0.4672[0m       0.5000        1.4313        52.9659
      5       [36m0.8000[0m        [32m0.4534[0m       0.5013        1.4262        53.1660
      6       0.7941        0.4584       0.4980        1.3995        52.8803
      7       0.7947        0.4596       0.4987        1.5176        53.0050
      8       0.7988        [32m0.4489[0m       0.4993        1.4433        53.6070
      9       [36m0.8022[0

<class 'braindecode.classifier.EEGClassifier'>[initialized](
  module_=ContrastiveNet(
    (emb): SleepStagerChambon2018(
      (spatial_conv): Conv2d(1, 2, kernel_size=(2, 1), stride=(1, 1))
      (feature_extractor): Sequential(
        (0): Conv2d(1, 16, kernel_size=(1, 50), stride=(1, 1), padding=(0, 25))
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): MaxPool2d(kernel_size=(1, 13), stride=(1, 13), padding=0, dilation=1, ceil_mode=False)
        (4): Conv2d(16, 16, kernel_size=(1, 50), stride=(1, 1), padding=(0, 25))
        (5): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): ReLU()
        (7): MaxPool2d(kernel_size=(1, 13), stride=(1, 13), padding=0, dilation=1, ceil_mode=False)
      )
      (fc): Sequential(
        (0): Dropout(p=0, inplace=False)
        (1): Linear(in_features=544, out_features=100, bias=True)
      )
    )
    (clf): Sequential(
      

In [None]:
# Save last model - Likely that if it isn't already checkpointed, it's probably overfitted
# Still save it I guess
# Also check if wandb is auto logging the checkpointed, final
with open('/content/drive/MyDrive/mne_data/rp_ssc_checkpoints2/final_model.pkl', 'wb') as f:
    pickle.dump(clf, f) # I think I've saved the wrong one previously

In [None]:
# training_viz(clf)
# wandb.finish()

In [None]:
# Switch to the test sampler
# clf.iterator_valid__sampler = test_sampler

# wandb got closed, so disabling wandb callback
# clf.set_params(callbacks="disable")

# Use the val itself to see the matrix
y_pred = clf.forward(splitted['valid'], training=False) > 0

y_true = [y for _, _, y in valid_sampler]

print(confusion_matrix(y_true, y_pred))
print(classification_report(y_true, y_pred))

# wandb.finish()

[[575 166]
 [556 203]]
              precision    recall  f1-score   support

         0.0       0.51      0.78      0.61       741
         1.0       0.55      0.27      0.36       759

    accuracy                           0.52      1500
   macro avg       0.53      0.52      0.49      1500
weighted avg       0.53      0.52      0.49      1500



In [None]:
wandb.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

# Use the CNN as Feature Extractor and train the actual Sleep Stage Classifier


In [None]:
print(emb)
print(clf.module.emb)

SleepStagerChambon2018(
  (spatial_conv): Conv2d(1, 2, kernel_size=(2, 1), stride=(1, 1))
  (feature_extractor): Sequential(
    (0): Conv2d(1, 16, kernel_size=(1, 50), stride=(1, 1), padding=(0, 25))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=(1, 13), stride=(1, 13), padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(16, 16, kernel_size=(1, 50), stride=(1, 1), padding=(0, 25))
    (5): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): MaxPool2d(kernel_size=(1, 13), stride=(1, 13), padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Sequential(
    (0): Dropout(p=0, inplace=False)
    (1): Linear(in_features=544, out_features=100, bias=True)
  )
)
SleepStagerChambon2018(
  (spatial_conv): Conv2d(1, 2, kernel_size=(2, 1), stride=(1, 1))
  (feature_extractor): Sequential(
    (0): Conv2d(1, 16, kernel_size=(1, 50), stride=(1, 1), padding

In [None]:
# Load the CNN
# rp_ssc_checkpoints3 (2k, 1024, inc LR)
# rp_ssc_checkpoints2 (2k, 1024)
# rp_ssc_checkpoints2k_256 (2k, 256)
# rp_ssc_checkpoints (250, 256)
def load_obj(path,name):
    with open(path + name + '.pkl', 'rb') as f:
        return pickle.load(f)

# Need the trained embedder not the classifier with the contrastive model
emb_c = clf.module.emb

In [None]:
from torch.utils.data import DataLoader
from sklearn.metrics import balanced_accuracy_score
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline

# Extract features with the trained embedder
data = dict()
# splitted has the RPDataset 
# For unlabelled a simple dataset should be equivalent since single win
for name, split in splitted.items():
    # run the till the RP sampler above
    split.return_pair = False  # Return single windows
    loader = DataLoader(split, batch_size=batch_size, num_workers=num_workers)
    with torch.no_grad():
        feats = [emb_c(batch_x.to(device)).cpu().numpy()
                 for batch_x, _, _ in loader]
    data[name] = (np.concatenate(feats), split.get_metadata()['target'].values)


In [None]:
print(data.keys())
print(len(data["train"]))

print(data["train"][0].shape)
print(data["train"][1].shape)

dict_keys(['train', 'valid'])
2
(90545, 100)
(90545,)


Saving the data from the model at the end of the last embedder trained (2k pairs, 1024 batch_size for 25 epochs)

In [None]:
# Save data - Then simple load for next time
save_path = "/content/drive/MyDrive/mne_data/RPClassifier/" 
with open(save_path+'Data_Dict.pkl', 'wb') as f:
    pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)

# Too lazy to convert to csv and save
# Just load data and train the classifier
def load_obj(path,name):
    with open(path + name + '.pkl', 'rb') as f:
        return pickle.load(f)

name = "Data_Dict"
data2 = load_obj(save_path,name)
# For unlabelled dataset alone you need to reload the embedder

# Might also have to try SSL training the model itself on the validation data
# either new one or using the same net and just continuing the training

In [None]:
print(data2.keys())
print(len(data2["train"]))

print(data2["train"][0].shape)
print(data2["train"][1].shape)

dict_keys(['train', 'valid'])
2
(90545, 100)
(90545,)


In [None]:
# For loading from next time
data = data2

Have to try different Sklearn Classifiers (Could run Auto-Sklearn also I guess)

Also should probably try a wandb search space atleast now

newton-cg with max-iter 1.5k

In [None]:
# wandb for the classifier - same project diff name should be fine
wandb_run = wandb.init(name = "Classifier train", project='RP-fulltrain', entity='sleep_hacking')
max_iter = 1500
solver = "newton-cg"
wandb.config.update({'classifier type':'log_reg','solver':solver,"max_iter":max_iter})

[34m[1mwandb[0m: [32m[41mERROR[0m Problem finishing run
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/wandb/sdk/wandb_run.py", line 1579, in _atexit_cleanup
    self._on_finish()
  File "/usr/local/lib/python3.7/dist-packages/wandb/sdk/wandb_run.py", line 1715, in _on_finish
    self.history._flush()
  File "/usr/local/lib/python3.7/dist-packages/wandb/sdk/wandb_history.py", line 59, in _flush
    self._callback(row=self._data, step=self._step)
  File "/usr/local/lib/python3.7/dist-packages/wandb/sdk/wandb_run.py", line 903, in _history_callback
    row, step, publish_step=not_using_tensorboard
  File "/usr/local/lib/python3.7/dist-packages/wandb/sdk/interface/interface.py", line 223, in publish_history
    item.value_json = json_dumps_safer_history(v)  # type: ignore
  File "/usr/local/lib/python3.7/dist-packages/wandb/util.py", line 749, in json_dumps_safer_history
    return json.dumps(obj, cls=WandBHistoryJSONEncoder, **kwargs)
  File "/usr

In [None]:
# max_iter = 1500
# solver = "newton-cg"
log_reg = LogisticRegression(
    penalty='l2', C=1.0, class_weight='balanced', solver=solver,
    multi_class='multinomial', random_state=random_state,verbose=1,
    n_jobs=-1,max_iter=max_iter)
clf_pipe = make_pipeline(StandardScaler(), log_reg)

In [None]:
s = t()
clf_pipe.fit(*data['train'])
e = t()
print(f"Fit time {e-s}")

train_y_pred = clf_pipe.predict(data['train'][0])
train_bal_acc = balanced_accuracy_score(data['train'][1], train_y_pred)
print(f'Train bal acc: {train_bal_acc:0.4f}')
wandb.log({"Train bal Acc": train_bal_acc})

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=-1)]: Done   1 out of   1 | elapsed:  2.2min finished


Fit time 131.59850764274597
Train bal acc: 0.6739


In [None]:
valid_y_pred = clf_pipe.predict(data['valid'][0])
valid_bal_acc = balanced_accuracy_score(data['valid'][1], valid_y_pred)
wandb.log({"Val bal Acc":valid_bal_acc})
# print('Sleep staging performance with logistic regression:')
print(f'Valid bal acc: {valid_bal_acc:0.4f}')

print('Results on test set:')
print(confusion_matrix(data['valid'][1], valid_y_pred))
print(classification_report(data['valid'][1], valid_y_pred))
wandb.sklearn.plot_confusion_matrix(data['valid'][1], valid_y_pred)
# wandb run end 
# wandb.finish()

Valid bal acc: 0.5592
Results on test set:
[[4558  921   26   25   25  455]
 [ 171  956  251   25    1  268]
 [  20 1011 2907  576   37  484]
 [   0   15  166  429   86    8]
 [   0    0   15  243  154    2]
 [  77  649  132    0    0  749]]
              precision    recall  f1-score   support

           0       0.94      0.76      0.84      6010
           1       0.27      0.57      0.37      1672
           2       0.83      0.58      0.68      5035
           3       0.33      0.61      0.43       704
           4       0.51      0.37      0.43       414
           5       0.38      0.47      0.42      1607

    accuracy                           0.63     15442
   macro avg       0.54      0.56      0.53     15442
weighted avg       0.74      0.63      0.66     15442



In [None]:
# Save the model, also log to wandb I guess
save_path = "/content/drive/MyDrive/mne_data/RPClassifier/" 
with open(save_path+'LogReg_newton-cg_1500.pkl', 'wb') as f:
    pickle.dump(clf_pipe, f)

In [None]:
wandb.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Train bal Acc,0.67393
_runtime,190.0
_timestamp,1631232121.0
_step,2.0
Val bal Acc,0.55916


0,1
Train bal Acc,▁
_runtime,▁██
_timestamp,▁██
_step,▁▅█
Val bal Acc,▁


In [None]:
# Predict on unlabelled

newton-cg solver - failed to converge

In [None]:
# wandb for the classifier - same project diff name should be fine
wandb_run = wandb.init(name = "Classifier train", project='RP-fulltrain', entity='sleep_hacking')
max_iter = 100
wandb.config.update({'classifier type':'log_reg','solver':"newton-cg","max_iter":max_iter})


In [None]:
# wandb.sklearn.plot_class_proportions()
# Initialize the logistic regression model
# penalty = 'l2'
# C = 1.0
# class_weight = 'balanced'
# solver = 'saga'
log_reg = LogisticRegression(
    penalty='l2', C=1.0, class_weight='balanced', solver='newton-cg',
    multi_class='multinomial', random_state=random_state)
clf_pipe = make_pipeline(StandardScaler(), log_reg)

In [None]:
# Fit and score the logistic regression
s = t()
clf_pipe.fit(*data['train'])
e = t()
print(f"Fit time {e-s}")

train_y_pred = clf_pipe.predict(data['train'][0])
train_bal_acc = balanced_accuracy_score(data['train'][1], train_y_pred)
print(f'Train bal acc: {train_bal_acc:0.4f}')
wandb.log({"Train bal Acc": train_bal_acc})



Fit time 131.3616509437561
Train bal acc: 0.6739


In [None]:
valid_y_pred = clf_pipe.predict(data['valid'][0])
valid_bal_acc = balanced_accuracy_score(data['valid'][1], valid_y_pred)
wandb.log({"Val bal Acc":valid_bal_acc})
# print('Sleep staging performance with logistic regression:')
print(f'Valid bal acc: {valid_bal_acc:0.4f}')

print('Results on test set:')
print(confusion_matrix(data['valid'][1], valid_y_pred))
print(classification_report(data['valid'][1], valid_y_pred))
wandb.sklearn.plot_confusion_matrix(data['valid'][1], valid_y_pred)
# wandb run end 
# wandb.finish()

Valid bal acc: 0.5592
Results on test set:
[[4558  921   26   25   25  455]
 [ 171  956  251   25    1  268]
 [  20 1011 2907  576   37  484]
 [   0   15  166  429   86    8]
 [   0    0   15  243  154    2]
 [  77  649  132    0    0  749]]
              precision    recall  f1-score   support

           0       0.94      0.76      0.84      6010
           1       0.27      0.57      0.37      1672
           2       0.83      0.58      0.68      5035
           3       0.33      0.61      0.43       704
           4       0.51      0.37      0.43       414
           5       0.38      0.47      0.42      1607

    accuracy                           0.63     15442
   macro avg       0.54      0.56      0.53     15442
weighted avg       0.74      0.63      0.66     15442



In [None]:
# Save the model, also log to wandb I guess
save_path = "/content/drive/MyDrive/mne_data/RPClassifier" 
with open(save_path+'/LogReg_newton-cg.pkl', 'wb') as f:
    pickle.dump(clf_pipe, f)

In [None]:
# wandb.log({"fitted classifier":clf_pipe})

saga solver - failed to converge

In [None]:
# wandb for the classifier - same project diff name should be fine
wandb_run = wandb.init(name = "Classifier train", project='RP-fulltrain', entity='sleep_hacking')

wandb.config.update({'classifier type':'log_reg','solver':"saga"})


In [None]:
log_reg = LogisticRegression(
    penalty='l2', C=1.0, class_weight='balanced', solver='saga',
    multi_class='multinomial', random_state=random_state,verbose=1,n_jobs=-1)
clf_pipe = make_pipeline(StandardScaler(), log_reg)

In [None]:
# Fit and score the logistic regression
s = t()
clf_pipe.fit(*data['train'])
e = t()
print(f"Fit time {e-s}")

train_y_pred = clf_pipe.predict(data['train'][0])
train_bal_acc = balanced_accuracy_score(data['train'][1], train_y_pred)
print(f'Train bal acc: {train_bal_acc:0.4f}')
wandb.log({"Train bal Acc": train_bal_acc})

Fit time 39.08369469642639
Train bal acc: 0.6729




In [None]:
valid_y_pred = clf_pipe.predict(data['valid'][0])
valid_bal_acc = balanced_accuracy_score(data['valid'][1], valid_y_pred)
wandb.log({"Val bal Acc":valid_bal_acc})
# print('Sleep staging performance with logistic regression:')
print(f'Valid bal acc: {valid_bal_acc:0.4f}')

print('Results on test set:')
print(confusion_matrix(data['valid'][1], valid_y_pred))
print(classification_report(data['valid'][1], valid_y_pred))
wandb.sklearn.plot_confusion_matrix(data['valid'][1], valid_y_pred)
# wandb run end 
wandb.finish()

Valid bal acc: 0.5611
Results on test set:
[[4559  922   27   24   25  453]
 [ 171  962  247   27    1  264]
 [  20 1004 2909  571   36  495]
 [   0   14  168  432   83    7]
 [   0    0   15  242  155    2]
 [  77  648  132    0    0  750]]
              precision    recall  f1-score   support

           0       0.94      0.76      0.84      6010
           1       0.27      0.58      0.37      1672
           2       0.83      0.58      0.68      5035
           3       0.33      0.61      0.43       704
           4       0.52      0.37      0.43       414
           5       0.38      0.47      0.42      1607

    accuracy                           0.63     15442
   macro avg       0.55      0.56      0.53     15442
weighted avg       0.74      0.63      0.66     15442



VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Train bal Acc,0.67295
_runtime,779.0
_timestamp,1631230685.0
_step,2.0
Val bal Acc,0.56107


0,1
Train bal Acc,▁
_runtime,▁██
_timestamp,▁██
_step,▁▅█
Val bal Acc,▁


In [None]:
# forgot to save the classifer - Takes very less time, might as well run again
with open(save_path+'/final_model.pkl', 'wb') as f:
    pickle.dump(clf, f)

In [None]:
# run for test - unlabelled data - Have to do the embedding for it also

# split must be of RP dataset class
# loader = DataLoader(split, batch_size=batch_size, num_workers=num_workers)
#     with torch.no_grad():
#         feats = [emb(batch_x.to(device)).cpu().numpy()
#                  for batch_x, _, _ in loader]
#     data[name] = (np.concatenate(feats), split.get_metadata()['target'].values)


# test_y_pred = clf_pipe.predict(data['test'][0])
# Save to txt

# test_bal_acc = balanced_accuracy_score(data['test'][1], test_y_pred)
# print(f'Test bal acc: {test_bal_acc:0.4f}')

# Loading and predicting on the Unlabelled dataset

In [None]:
testB = get_testB()
split = RPD([ds for ds in testB.datasets])

In [None]:
# Loading BD data trainB and valB from dict
# I need to reload only if I want to either train the Contrastive model, CNN or
# if I want to create a dict will all BD data for better save. - Too much RAM 

save_path = "/content/drive/MyDrive/mne_data/NeuroIPS_Hack_Sleep/" 
# Load BrainDecode_data for trainB and valB
bd_dat_file = "BraindecodeData_Dict2"
BrainDecode_data = {}
s = t()
BrainDecode_data = load_obj(save_path,bd_dat_file)
e = t()
print(f"Loading BD data {e-s}")

print(BrainDecode_data)

trainB = BrainDecode_data["train"]
# # valB = BrainDecode_data["valid"]
# del BrainDecode_data

Loading BD data 107.36659359931946
{'train': <braindecode.datasets.base.BaseConcatDataset object at 0x7f25fcd7a290>, 'valid': <braindecode.datasets.base.BaseConcatDataset object at 0x7f2477790550>}


In [None]:
del BrainDecode_data

In [None]:
splitted['train'] = RPD([ds for ds in trainB.datasets])

In [None]:
# load embedder
# Load the CNN
# rp_ssc_checkpoints3 (2k, 1024, inc LR)
# rp_ssc_checkpoints2 (2k, 1024)
# rp_ssc_checkpoints2k_256 (2k, 256)
# rp_ssc_checkpoints (250, 256)

# Need the trained embedder not the classifier with the contrastive model
SSLnet_path = save_path+"rp_ssc_checkpoints3/"
clf = load_obj(SSLnet_path,"final_model")
emb_c = clf.module.emb
# del clf

# load the other train and val data passed through embedder
# SSL_classifier_path = save_path+"RPClassifier/" 
# name = "Data_Dict"
# data = load_obj(SSL_classifier_path,name)

#pass through embedder
# For unlabelled a simple dataset should be equivalent since single win

split.return_pair = False  # Return single windows
batch_size = 1024
num_workers = 0
loader = DataLoader(split, batch_size=batch_size, num_workers=num_workers)
with torch.no_grad():
    feats = [emb_c(batch_x.to(device)).cpu().numpy()
              for batch_x, _, _ in loader]
# data['test'] = (np.concatenate(feats), split.get_metadata()['target'].values)
emb_test_dat = (np.concatenate(feats), split.get_metadata()['target'].values)
# full embedded dict
# save_path = "/content/drive/MyDrive/mne_data/RPClassifier/" 
# with open(SSL_classifier_path+'Data_Dict_withTest.pkl', 'wb') as f:
#     pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
# data_save(data,'RPClassifier/Data_Dict_withTest')

# emb_test_dat = data['test']
# del data


In [None]:
# max_iter = 1500
# solver = "newton-cg"
# log_reg = LogisticRegression(
#     penalty='l2', C=1.0, class_weight='balanced', solver=solver,
#     multi_class='multinomial', random_state=random_state,verbose=1,
#     n_jobs=-1,max_iter=max_iter)
# clf_pipe = make_pipeline(StandardScaler(), log_reg)

# load clf_pipe
SSL_classifier_path = save_path+"RPClassifier/" 
clf_pipe = load_obj(SSL_classifier_path,"LogReg_newton-cg_1500")

# X_test = emb_c()
# give as input to classifier.predict
y_pred = clf_pipe.predict(emb_test_dat[0])
print("Checking if all classes have been predicted")
print(np.unique(y_pred))
save_fname = "LogReg_newton-cg_1500"
np.savetxt("/content/drive/MyDrive/mne_data/predict/"+save_fname+".txt",y_pred,delimiter=',',fmt="%d")

Checking if all classes have been predicted
[0 1 2 3 4 5]


Other loading and ways of processing. Replaced by more simpler methods since these data need not be loaded

In [None]:
# BrainDecode_data["test"] = get_testB()

# Last time I had to interrupt and stop the saving because the RAM consumption was 
# bordering overflow on the High-RAM itself

# data_save(BrainDecode_data,'BraindecodeData_Dict_withTest') 

# split = RPD([ds for ds in BrainDecode_data["test"].datasets])

In [None]:
# get splitted by running train and val till RPD creation

# Create RP Dataset
# Only RP sampler requires subject split, dataset alone doesn't
# splitted = dict()
# splitted['train'] = RPD([ds for ds in BrainDecode_data["train"].datasets])
# splitted['valid'] = RPD([ds for ds in BrainDecode_data["valid"].datasets])
# splitted['test'] = RPD([ds for ds in BrainDecode_data["test"].datasets])


# Save splitted as a dict
# save_path = "/content/drive/MyDrive/mne_data/" 
# with open(save_path+'RPDData_Dict.pkl', 'wb') as f:
#     pickle.dump(splitted, f, pickle.HIGHEST_PROTOCOL)
# del BrainDecode_data
# data_save(splitted,'RPDData_Dict')
# split = splitted['test']
# del splitted

In [None]:
# clf train on 2nd group data with RP sampler


Training the SSL on the group 1 data has pretty poor performance at 47% on the unlabelled phase 1 leaderboard data - expected I guess considering that validation on the target dataset was pretty poor

One more step for betterment of this - push more into functions so that local var.s are nicely destroyed automatically

Total time for all data loading = 4+1.5+7+1.5 = ~15min

# Load the model and do TL on the valB or target data from Phase 1

In [None]:
# load embedder
# Load the CNN
# rp_ssc_checkpoints3 (2k, 1024, inc LR)
# rp_ssc_checkpoints2 (2k, 1024)
# rp_ssc_checkpoints2k_256 (2k, 256)
# rp_ssc_checkpoints (250, 256)

# Need the trained embedder not the classifier with the contrastive model
save_path = "/content/drive/MyDrive/mne_data/NeuroIPS_Hack_Sleep"
SSLnet_path = save_path+"/rp_ssc_checkpoints3/"

clf = load_obj(SSLnet_path,"final_model")
# emb_c = clf.module.emb
# del clf

# Need to modify the classifier for further training on the target valB
# data without any validation dataset

print(clf.get_params())
# Activate warm_start or can also use partial_fit, but better warm_start
# set train_split = None
# changing the iterator_train_sampler

{'cropped': False, 'module': ContrastiveNet(
  (emb): SleepStagerChambon2018(
    (spatial_conv): Conv2d(1, 2, kernel_size=(2, 1), stride=(1, 1))
    (feature_extractor): Sequential(
      (0): Conv2d(1, 16, kernel_size=(1, 50), stride=(1, 1), padding=(0, 25))
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=(1, 13), stride=(1, 13), padding=0, dilation=1, ceil_mode=False)
      (4): Conv2d(16, 16, kernel_size=(1, 50), stride=(1, 1), padding=(0, 25))
      (5): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU()
      (7): MaxPool2d(kernel_size=(1, 13), stride=(1, 13), padding=0, dilation=1, ceil_mode=False)
    )
    (fc): Sequential(
      (0): Dropout(p=0, inplace=False)
      (1): Linear(in_features=544, out_features=100, bias=True)
    )
  )
  (clf): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=100, out_features

In [None]:
trained_model = clf.module
print(trained_model)
print(type(trained_model))

ContrastiveNet(
  (emb): SleepStagerChambon2018(
    (spatial_conv): Conv2d(1, 2, kernel_size=(2, 1), stride=(1, 1))
    (feature_extractor): Sequential(
      (0): Conv2d(1, 16, kernel_size=(1, 50), stride=(1, 1), padding=(0, 25))
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=(1, 13), stride=(1, 13), padding=0, dilation=1, ceil_mode=False)
      (4): Conv2d(16, 16, kernel_size=(1, 50), stride=(1, 1), padding=(0, 25))
      (5): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU()
      (7): MaxPool2d(kernel_size=(1, 13), stride=(1, 13), padding=0, dilation=1, ceil_mode=False)
    )
    (fc): Sequential(
      (0): Dropout(p=0, inplace=False)
      (1): Linear(in_features=544, out_features=100, bias=True)
    )
  )
  (clf): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=100, out_features=1, bias=True)
  )
)
<class '

In [None]:
# Currently prediciton on the valB is pretty much chance level
# That's not expected. But let's go on

In [None]:
# start wandb
# wandb
# should I set reinit = true?
wandb_run = wandb.init(name = "RP-SSL SSC2018 TL1 train", project='RP-fulltrain', entity='sleep_hacking')


[34m[1mwandb[0m: Currently logged in as: [33mdiv12345[0m (use `wandb login --relogin` to force relogin)


In [None]:
# I can define a new BD classifier simply and just give this module as input and it will work fine
# or I can modify the parameters in this model itself

# Train the network
lr = 5e-2 # 5e-4
batch_size = 1024 # increase this next time
n_epochs = 150
num_workers = 0 # if n_jobs <= 1 else n_jobs

# Callback WandbLogger logs the "best trained model", etc. after every epoch
# Not sure of definition of best trained

# TODO - 
# Add early stopping with 10 on valid_bacc I guess - 
# This fails if big oscillations happens, but well the lr is low

# should make this function of parameters for easiness
save_path = "/content/drive/MyDrive/mne_data/NeuroIPS_Hack_Sleep/rp_ssc_TL1_checkpoints1" 

# Log hyperparameters
wandb_run.config.update({"Embedder_size":emb_size, "learning rate": lr, 
                         "batch size": batch_size, "n_conv_chs_SSC": n_conv_chs,
                         "tau_pos":tau_pos, "tau_neg":tau_neg, "n_examples_RPsampler":n_examples,
                         "n_examples_val":n_examples_valid,
                         "save_path":save_path})

# org paper used a weight decay of 1e-3 on all trainable params
cp = Checkpoint(monitor = 'val_acc_best',
                f_params = "params_{last_epoch[epoch]}.pt", 
                f_optimizer = "optimizer_{last_epoch[epoch]}.pt", 
                f_criterion = "criterion_{last_epoch[epoch]}.pt",
                # f_pickle = "model_{last_epoch[epoch]}.pkl",
                dirname = save_path)
train_end_cp = TrainEndCheckpoint(
                f_params = "params_{last_epoch[epoch]}.pt", 
                f_optimizer = "optimizer_{last_epoch[epoch]}.pt", 
                f_criterion = "criterion_{last_epoch[epoch]}.pt",
                dirname=save_path)
# giving this name 'val_acc' since I want this logged in val acc in Wandb
train_acc = EpochScoring(
    scoring='accuracy', on_train=True, name='val_acc', 
    lower_is_better=False)
# valid_acc = EpochScoring(
#     scoring='accuracy', on_train=False, name='valid_acc',
#     lower_is_better=False)
callbacks = [('val_acc', train_acc),
            #  ('valid_acc', valid_acc),
             ("train_end_cp",train_end_cp),
             ("checkpoint",cp),
             ("wandb",WandbLogger(wandb_run, save_model=False))
             ]

clf2 = EEGClassifier(
    trained_model,
    criterion=torch.nn.BCEWithLogitsLoss,
    optimizer=torch.optim.Adam,
    max_epochs=n_epochs,
    iterator_train__shuffle=False,
    iterator_train__sampler=valid_sampler,
    # iterator_valid__sampler=valid_sampler,
    iterator_train__num_workers=num_workers,
    # iterator_valid__num_workers=num_workers,
    train_split = None,
    # train_split=predefined_split(splitted['valid']),
    optimizer__lr=lr,
    batch_size=batch_size,
    callbacks=callbacks,
    device=device,
    # warm_start = True
    # **arg_dict
)

print(clf2)

<class 'braindecode.classifier.EEGClassifier'>[uninitialized](
  module=ContrastiveNet(
    (emb): SleepStagerChambon2018(
      (spatial_conv): Conv2d(1, 2, kernel_size=(2, 1), stride=(1, 1))
      (feature_extractor): Sequential(
        (0): Conv2d(1, 16, kernel_size=(1, 50), stride=(1, 1), padding=(0, 25))
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): MaxPool2d(kernel_size=(1, 13), stride=(1, 13), padding=0, dilation=1, ceil_mode=False)
        (4): Conv2d(16, 16, kernel_size=(1, 50), stride=(1, 1), padding=(0, 25))
        (5): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): ReLU()
        (7): MaxPool2d(kernel_size=(1, 13), stride=(1, 13), padding=0, dilation=1, ceil_mode=False)
      )
      (fc): Sequential(
        (0): Dropout(p=0, inplace=False)
        (1): Linear(in_features=544, out_features=100, bias=True)
      )
    )
    (clf): Sequential(
     

In [None]:
# del clf
mne.set_log_level(True)
clf2.fit(splitted['valid'], y=None)

  epoch    train_loss    val_acc    cp     dur
-------  ------------  ---------  ----  ------
      1        [36m1.9929[0m     [32m0.5124[0m     +  2.8557
      2        [36m1.0810[0m     0.5112        2.8396
      3        [36m0.8254[0m     [32m0.5144[0m     +  2.8365
      4        [36m0.7781[0m     [32m0.5198[0m     +  2.8430
      5        [36m0.7431[0m     [32m0.5309[0m     +  2.8216
      6        [36m0.7257[0m     [32m0.5397[0m     +  2.8428
      7        [36m0.7090[0m     [32m0.5504[0m     +  2.8109
      8        [36m0.7041[0m     [32m0.5507[0m     +  2.9017
      9        [36m0.6913[0m     [32m0.5653[0m     +  2.8378
     10        [36m0.6825[0m     [32m0.5774[0m     +  2.8198
     11        [36m0.6710[0m     [32m0.5889[0m     +  2.8064
     12        [36m0.6653[0m     [32m0.5996[0m     +  2.8146
     13        [36m0.6630[0m     0.5961        2.7997
     14        [36m0.6562[0m     [32m0.6008[0m     +  2.8133
     15       

[34m[1mwandb[0m: Network error resolved after 0:00:38.734775, resuming normal operation.


    143        0.4225     0.8038        2.8019
    144        0.4237     0.8043        2.8088
    145        0.4202     0.8016        2.8036
    146        [36m0.4125[0m     [32m0.8111[0m     +  2.8197
    147        0.4150     0.8078        2.8452
    148        [36m0.4105[0m     0.8093        2.8154
    149        [36m0.4091[0m     0.8096        2.8376
    150        0.4099     0.8106        2.8442


<class 'braindecode.classifier.EEGClassifier'>[initialized](
  module_=ContrastiveNet(
    (emb): SleepStagerChambon2018(
      (spatial_conv): Conv2d(1, 2, kernel_size=(2, 1), stride=(1, 1))
      (feature_extractor): Sequential(
        (0): Conv2d(1, 16, kernel_size=(1, 50), stride=(1, 1), padding=(0, 25))
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): MaxPool2d(kernel_size=(1, 13), stride=(1, 13), padding=0, dilation=1, ceil_mode=False)
        (4): Conv2d(16, 16, kernel_size=(1, 50), stride=(1, 1), padding=(0, 25))
        (5): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): ReLU()
        (7): MaxPool2d(kernel_size=(1, 13), stride=(1, 13), padding=0, dilation=1, ceil_mode=False)
      )
      (fc): Sequential(
        (0): Dropout(p=0, inplace=False)
        (1): Linear(in_features=544, out_features=100, bias=True)
      )
    )
    (clf): Sequential(
      

In [None]:
# Confusion matrix of ValB
# Use the val itself to see the matrix
clf2.iterator_valid__sampler = valid_sampler
y_pred = clf2.forward(splitted['valid'], training=False) > 0

y_true = [y for _, _, y in valid_sampler]

print(confusion_matrix(y_true, y_pred))
print(classification_report(y_true, y_pred))

[[5136  785]
 [1572 4507]]
              precision    recall  f1-score   support

         0.0       0.77      0.87      0.81      5921
         1.0       0.85      0.74      0.79      6079

    accuracy                           0.80     12000
   macro avg       0.81      0.80      0.80     12000
weighted avg       0.81      0.80      0.80     12000



In [None]:
wandb.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
dur,2.84418
train_loss,0.40989
val_acc,0.81058


0,1
dur,▇▆▄▄▅▆▅▄▆█▂▄▁▄▃▁▂▄▂▃▄▂▆▂▄▃▅▂▅▄▆▆▆▅▃▅▄▇▃▆
train_loss,█▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_acc,▁▁▂▃▃▄▄▄▄▅▅▅▅▅▆▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇████


Before overfitting predict on the leaderboard and see the performance

In [None]:
# Get the features from the val and test now and see the classifier performance


In [None]:
print(splitted['valid'].__dict__)

In [None]:
print(len(splitted['valid'].datasets))

15442


In [None]:
# Need the trained embedder not the classifier with the contrastive model
emb_c = clf2.module.emb

In [None]:
# Extract features with the trained embedder
data = dict()
# splitted has the RPDataset 
# For unlabelled a simple dataset should be equivalent since single win
for name, split in splitted.items():
    # run the till the RP sampler above
    split.return_pair = False  # Return single windows
    loader = DataLoader(split, batch_size=batch_size, num_workers=num_workers)
    with torch.no_grad():
        feats = [emb_c(batch_x.to(device)).cpu().numpy()
                 for batch_x, _, _ in loader]
    data[name] = (np.concatenate(feats), split.get_metadata()['target'].values)


In [None]:
print(data.keys())

dict_keys(['valid', 'test', 'train'])


In [None]:
# wandb for the classifier - same project diff name should be fine
wandb_run = wandb.init(name = "Classifier train", project='RP-fulltrain', entity='sleep_hacking')
max_iter = 1500
solver = "newton-cg"
save_path = "/content/drive/MyDrive/mne_data/NeuroIPS_Hack_Sleep/RPClassifier/" 
clf_path = save_path+'TL1_1_LogReg_newton-cg_1500.pkl'
wandb.config.update({'classifier type':'log_reg','solver':solver,"max_iter":max_iter,"TL on Phase":1,"clf_path":clf_path})

In [None]:
# max_iter = 1500
# solver = "newton-cg"
log_reg = LogisticRegression(
    penalty='l2', C=1.0, class_weight='balanced', solver=solver,
    multi_class='multinomial', random_state=random_state,verbose=1,
    n_jobs=-1,max_iter=max_iter)
clf_pipe = make_pipeline(StandardScaler(), log_reg)

In [None]:
s = t()
clf_pipe.fit(*data['valid'])
e = t()
print(f"Fit time {e-s}")

train_y_pred = clf_pipe.predict(data['valid'][0])
train_bal_acc = balanced_accuracy_score(data['valid'][1], train_y_pred)
print(f'Train bal acc: {train_bal_acc:0.4f}')
wandb.log({"Train bal Acc": train_bal_acc})

print('Results on valid set(trained on valid set here):')
print(confusion_matrix(data['valid'][1], train_y_pred))
print(classification_report(data['valid'][1], train_y_pred))
wandb.sklearn.plot_confusion_matrix(data['valid'][1], train_y_pred)

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=-1)]: Done   1 out of   1 | elapsed:   14.0s finished


Fit time 14.112605571746826
Train bal acc: 0.4241
Results on valid set(trained on valid set here):
[[3711  750  169  255  384  741]
 [ 400  424  171  122  131  424]
 [ 573  677 1281  901  631  972]
 [  42   37  104  258  225   38]
 [  11    6   36   85  272    4]
 [ 299  288  176  113   95  636]]
              precision    recall  f1-score   support

           0       0.74      0.62      0.67      6010
           1       0.19      0.25      0.22      1672
           2       0.66      0.25      0.37      5035
           3       0.15      0.37      0.21       704
           4       0.16      0.66      0.25       414
           5       0.23      0.40      0.29      1607

    accuracy                           0.43     15442
   macro avg       0.35      0.42      0.34     15442
weighted avg       0.56      0.43      0.45     15442



In [None]:
# Save the model, also log to wandb I guess
with open(clf_path, 'wb') as f:
    pickle.dump(clf_pipe, f)

In [None]:
wandb.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Train bal Acc,0.42412


0,1
Train bal Acc,▁


If the representation learnt on the second stage had been good, I could've tried things like training the clf on the first group and phase 1 val data combined and sought to improve overall clf performance, but IDK what to do in this case

Maybe try loading the trainB anyways, try using this embedder on that, then try teaching this clf on that embedded rep? OR combine that embedded rep with the one from

In [None]:
# Train clf on only Train first
# wandb for the classifier - same project diff name should be fine
wandb_run = wandb.init(name = "Classifier train", project='RP-fulltrain', entity='sleep_hacking')
max_iter = 1500
solver = "newton-cg"
save_path = "/content/drive/MyDrive/mne_data/NeuroIPS_Hack_Sleep/RPClassifier/" 
clf_path = save_path+'TL1_2_LogReg_newton-cg_1500.pkl'
wandb.config.update({'classifier type':'log_reg','solver':solver,"max_iter":max_iter,"TL on Phase":1,"clf_path":clf_path})

In [None]:
# max_iter = 1500
# solver = "newton-cg"
log_reg = LogisticRegression(
    penalty='l2', C=1.0, class_weight='balanced', solver=solver,
    multi_class='multinomial', random_state=random_state,verbose=1,
    n_jobs=-1,max_iter=max_iter)
clf_pipe = make_pipeline(StandardScaler(), log_reg)

In [None]:
s = t()
clf_pipe.fit(*data['train'])
e = t()
print(f"Fit time {e-s}")

train_y_pred = clf_pipe.predict(data['train'][0])
train_bal_acc = balanced_accuracy_score(data['train'][1], train_y_pred)
print(f'Train bal acc: {train_bal_acc:0.4f}')
wandb.log({"Train bal Acc": train_bal_acc})

print('Results on valid set(trained on valid set here):')
print(confusion_matrix(data['train'][1], train_y_pred))
print(classification_report(data['train'][1], train_y_pred))
wandb.sklearn.plot_confusion_matrix(data['train'][1], train_y_pred)
###


[Parallel(n_jobs=-1)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=-1)]: Done   1 out of   1 | elapsed:  2.3min finished


Fit time 140.67068433761597
Train bal acc: 0.4180
Results on valid set(trained on valid set here):
[[15366  4121   668   753   714  2421]
 [ 1421  2959   879   460   447  1775]
 [ 2641  6865  9602  5628  4392  6855]
 [  109   261   797  1643  1956   481]
 [   16    97   220   688  1872   164]
 [ 2222  3844  1895  1037   941  4335]]
              precision    recall  f1-score   support

           0       0.71      0.64      0.67     24043
           1       0.16      0.37      0.23      7941
           2       0.68      0.27      0.38     35983
           3       0.16      0.31      0.21      5247
           4       0.18      0.61      0.28      3057
           5       0.27      0.30      0.29     14274

    accuracy                           0.40     90545
   macro avg       0.36      0.42      0.34     90545
weighted avg       0.53      0.40      0.42     90545



In [None]:

valid_y_pred = clf_pipe.predict(data['valid'][0])
valid_bal_acc = balanced_accuracy_score(data['valid'][1], valid_y_pred)
print(f'Train bal acc: {valid_bal_acc:0.4f}')
wandb.log({"Train bal Acc": valid_bal_acc})

print('Results on valid set(trained on valid set here):')
print(confusion_matrix(data['valid'][1], valid_y_pred))
print(classification_report(data['valid'][1], valid_y_pred))
wandb.sklearn.plot_confusion_matrix(data['valid'][1], valid_y_pred)

Train bal acc: 0.3426
Results on valid set(trained on valid set here):
[[3121 1376  182  127  351  853]
 [ 388  510  190  101   92  391]
 [ 464  933 1411  623  535 1069]
 [  25   44  129  189  187  130]
 [   4    7   50  118  171   64]
 [ 356  479  210   65   64  433]]
              precision    recall  f1-score   support

           0       0.72      0.52      0.60      6010
           1       0.15      0.31      0.20      1672
           2       0.65      0.28      0.39      5035
           3       0.15      0.27      0.20       704
           4       0.12      0.41      0.19       414
           5       0.15      0.27      0.19      1607

    accuracy                           0.38     15442
   macro avg       0.32      0.34      0.30     15442
weighted avg       0.53      0.38      0.42     15442



In [None]:
wandb.finish()

[34m[1mwandb[0m: Network error resolved after 0:03:25.123336, resuming normal operation.


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
Train bal Acc,0.34259


0,1
Train bal Acc,█▁


In [None]:
# Train clf on concatenated train and val embed