In [1]:
from eeg2vec.train.train import train
from eeg2vec.data_loader import get_dataloader
from eeg2vec.models.eeg2vec import EEG2Vec
from eeg2vec.contrastive_loss import ContrastiveLoss

import numpy as np
import torch
from sklearn.metrics import f1_score, accuracy_score
from sklearn.model_selection import train_test_split
import xgboost as xgb
from sklearn.multioutput import MultiOutputClassifier

In [2]:
## First let's load the training data
from pathlib import Path
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.signal import butter, lfilter
import pandas as pd
import pickle

ROOT_PATH = Path("train/")
training_data = [(np.load(ROOT_PATH / f"data_{i}.npy"),np.load(ROOT_PATH / f"target_{i}.npy")) for i in range(4)]


In [3]:
def butter_bandpass(lowcut, highcut, fs, order=5):
    return butter(order, [lowcut, highcut], fs=fs, btype='band')

def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = lfilter(b, a, data)
    return y

In [4]:
# Apply bandpass filter to the data
training_data_filtered = [(butter_bandpass_filter(data, 0.1, 18, 100), target) for data, target in training_data]

In [5]:
# First we need to get the point that maps to a label

def reshape_array_into_windows(x, sample_rate, window_duration_in_seconds):
    """
    Reshape the data into an array of shape (C, T, window) where 'window' contains
    the points corresponding to 'window_duration' seconds of data.

    Parameters:
    x (numpy array): The input data array.
    sample_rate (int): The number of samples per second.
    window_duration_in_seconds (float): The duration of each window in seconds.

    Returns:
    reshaped_x (numpy array): The reshaped array with shape (C, T, window).
    """
    # Calculate the number of samples in one window
    window_size = int(window_duration_in_seconds * sample_rate)
    
    # Ensure the total length of x is a multiple of window_size
    total_samples = x.shape[-1]
    if total_samples % window_size != 0:
        # Truncate or pad x to make it divisible by window_size
        x = x[..., :total_samples - (total_samples % window_size)]
    # Reshape x into (C, T, window)
    reshaped_x = x.reshape(x.shape[0], -1, window_size)

    return reshaped_x

In [6]:
# We first load and reshape all the data
all_data = []
all_targets = []
# We need to have
# data of Shape: [num_samples, num_channels (5), sequence_length]
# labels of Shape: [num_samples, 5]

for data, target in training_data:
    reshaped_data = reshape_array_into_windows(data, 250, 2)
    reshaped_data = reshaped_data.transpose(1, 0, 2)
    target = target.reshape(-1, 5)
    all_data.append(reshaped_data)
    all_targets.append(target)

all_data = np.concatenate(all_data, axis=0)
all_targets = np.concatenate(all_targets, axis=0)

In [7]:
print(all_data.shape)

(52351, 5, 500)


In [8]:
# data, labels = all_data[:10000], all_targets[:10000] # We only use 10k samples for now
data, labels = all_data, all_targets # Use the rest of the data

In [9]:
X_train, Y_train = data, labels

In [None]:
torch.isnan(torch.tensor(X_train)).any(), torch.isnan(torch.tensor(Y_train)).any()

(tensor(False), tensor(False))

## Train EEG2VEC model

In [13]:
# model_path = "eeg2vec/data/saved_models/eeg2vec_8_2_5_2_11dec_10000+pretrainedtotal.pth"
d_model = 8
n_heads = 2
n_layers = 2
dim_feedforward = 1
conv_size = 8
out_size = 8
kernel_size = 2
kernel_size_2 = 2
eeg2vec_model = EEG2Vec(d_model, n_heads, n_layers, dim_feedforward, conv_size, out_size, kernel_size, kernel_size_2)
# eeg2vec_model.load_state_dict(torch.load(model_path))

In [14]:
data_loader = get_dataloader(X_train, Y_train, batch_size=100)

In [15]:
# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [None]:
eeg2vec_model = eeg2vec_model.to(device)
train(eeg2vec_model, data_loader, 100, device)

Epoch 1/100 completed.
Epoch 2/100 completed.
Epoch 3/100 completed.
Epoch 4/100 completed.
Epoch 5/100 completed.
Epoch 6/100 completed.
Epoch 7/100 completed.
Epoch 8/100 completed.
Epoch 9/100 completed.
Epoch 10/100 completed.
Epoch 11/100 completed.
Epoch 12/100 completed.
Epoch 13/100 completed.
Epoch 14/100 completed.
Epoch 15/100 completed.


In [None]:
# Save the model
torch.save(eeg2vec_model.state_dict(), 'eeg2vec/data/saved_models/eeg2vec_8_2_2_1_27nov_allpoints.pth')