# Pretraining EEG2VEC on sleep data

In [1]:
from eeg2vec.train.train import train
from eeg2vec.data_loader import get_dataloader
from eeg2vec.models.eeg2vec import EEG2Vec
from eeg2vec.contrastive_loss import ContrastiveLoss

import numpy as np
import torch
from sklearn.metrics import f1_score, accuracy_score
from sklearn.model_selection import train_test_split

In [2]:
from pathlib import Path
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.signal import butter, lfilter
import pandas as pd
import pickle
import mne

In [3]:
# Import data
data_path = Path('eeg2vec/data/sleep-edf-database-1.0.0/sc4002e0.edf')
raw = mne.io.read_raw_edf(data_path)

Extracting EDF parameters from c:\Users\Emile\Documents\Polytechnique\4A\ParisSaclay\Cours\ML\Sleep_EEG_Kaggle\eeg2vec\data\sleep-edf-database-1.0.0\sc4002e0.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


In [4]:
raw.info

0,1
Measurement date,"April 25, 1989 14:50:00 GMT"
Experimenter,Unknown
Participant,X

0,1
Digitized points,Not available
Good channels,7 EEG
Bad channels,
EOG channels,Not available
ECG channels,Not available

0,1
Sampling frequency,100.00 Hz
Highpass,0.00 Hz
Lowpass,50.00 Hz


In [5]:
raw.ch_names

['EEG Fpz-Cz',
 'EEG Pz-Oz',
 'EOG horizontal',
 'Resp oro-nasal',
 'EMG Submental',
 'Temp body',
 'Event marker']

In [6]:
ts_fpz = raw.copy().pick_channels(['EEG Fpz-Cz']).get_data()[0]
ts_pz = raw.copy().pick_channels(['EEG Pz-Oz']).get_data()[0]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


In [7]:
# duplicate channels to make 5 channels
ts_fpz = np.tile(ts_fpz, (3, 1))
ts_pz = np.tile(ts_pz, (2, 1))

In [8]:
ts_all = np.concatenate([ts_fpz, ts_pz], axis=0)
# ts = ts_all[:, :ts_all.shape[1]//2] # take first half of the signal
ts = ts_all[:,ts_all.shape[1]//2:] # take second half of the signal

In [9]:
ts.shape # Sampling frequency is 100 Hz

(5, 4245000)

In [10]:
def reshape_array_into_windows(x, sample_rate, window_duration_in_seconds):
    """
    Reshape the data into an array of shape (C, T, window) where 'window' contains
    the points corresponding to 'window_duration' seconds of data.

    Parameters:
    x (numpy array): The input data array.
    sample_rate (int): The number of samples per second.
    window_duration_in_seconds (float): The duration of each window in seconds.

    Returns:
    reshaped_x (numpy array): The reshaped array with shape (C, T, window).
    """
    # Calculate the number of samples in one window
    window_size = int(window_duration_in_seconds * sample_rate)
    
    # Ensure the total length of x is a multiple of window_size
    total_samples = x.shape[-1]
    if total_samples % window_size != 0:
        # Truncate or pad x to make it divisible by window_size
        x = x[..., :total_samples - (total_samples % window_size)]
    # Reshape x into (C, T, window)
    reshaped_x = x.reshape(x.shape[0], -1, window_size)

    return reshaped_x

In [11]:
# We first load and reshape all the data
# We need to have
# data of Shape: [num_samples, num_channels (5), sequence_length]


reshaped_data = reshape_array_into_windows(ts, 100, 2)
reshaped_data = reshaped_data.transpose(1, 0, 2)

ts_reshaped = reshaped_data

In [12]:
ts_reshaped.shape

(21225, 5, 200)

## Model

In [13]:
# model_eeg2vec = EEG2Vec(8,2,5,2)
model_path = "eeg2vec/data/saved_models/eeg2vec_8_2_5_2_11dec_10000points.pth"
eeg2vec_model = EEG2Vec(8,2,5,2)
eeg2vec_model.load_state_dict(torch.load(model_path))

  eeg2vec_model.load_state_dict(torch.load(model_path))


<All keys matched successfully>

## Training

In [14]:
X_train, X_test, y_train, y_test = train_test_split(ts_reshaped, ts_reshaped, test_size=0.2, random_state=42)

In [15]:
# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [16]:
data_loader = get_dataloader(X_train, y_train, batch_size=32)

In [17]:
eeg2vec_model = eeg2vec_model.to(device)
train(eeg2vec_model, data_loader, 100, device)

Epoch 1/100 completed.
Epoch 2/100 completed.
Epoch 3/100 completed.
Epoch 4/100 completed.
Epoch 5/100 completed.
Epoch 6/100 completed.
Epoch 7/100 completed.
Epoch 8/100 completed.
Epoch 9/100 completed.
Epoch 10/100 completed.
Epoch 11/100 completed.
Epoch 12/100 completed.
Epoch 13/100 completed.
Epoch 14/100 completed.
Epoch 15/100 completed.
Epoch 16/100 completed.
Epoch 17/100 completed.
Epoch 18/100 completed.
Epoch 19/100 completed.
Epoch 20/100 completed.
Epoch 21/100 completed.
Epoch 22/100 completed.
Epoch 23/100 completed.
Epoch 24/100 completed.
Epoch 25/100 completed.
Epoch 26/100 completed.
Epoch 27/100 completed.
Epoch 28/100 completed.
Epoch 29/100 completed.
Epoch 30/100 completed.
Epoch 31/100 completed.
Epoch 32/100 completed.
Epoch 33/100 completed.
Epoch 34/100 completed.
Epoch 35/100 completed.
Epoch 36/100 completed.
Epoch 37/100 completed.
Epoch 38/100 completed.
Epoch 39/100 completed.
Epoch 40/100 completed.
Epoch 41/100 completed.
Epoch 42/100 completed.
E

In [20]:
eeg2vec_model.eval()

EEG2Vec(
  (cnn_encoder): CNNEncoder(
    (conv_layers): Sequential(
      (0): Conv1d(5, 8, kernel_size=(1,), stride=(1,))
      (1): ReLU()
      (2): MaxPool1d(kernel_size=5, stride=5, padding=0, dilation=1, ceil_mode=False)
      (3): Conv1d(8, 8, kernel_size=(2,), stride=(1,))
      (4): ReLU()
    )
  )
  (transformer_encoder): TransformerEncoder(
    (transformer_encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-4): 5 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=8, out_features=8, bias=True)
          )
          (linear1): Linear(in_features=8, out_features=2, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2, out_features=8, bias=True)
          (norm1): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, in

In [21]:
# Save the model
torch.save(eeg2vec_model.state_dict(), 'eeg2vec/data/saved_models/eeg2vec_8_2_5_2_11dec_100000+pretrainedtotal.pth')