In [5]:
import mne
import pickle
import numpy as np
from torch.utils.data import DataLoader

from sklearn.model_selection import train_test_split
from opm_thesis.classifiers.classifier import DeepConvNet, MyDataset

EPOCHS_DIR = (
    r"/Users/martin.iniguez/Desktop/master_thesis/opm-thesis/data/data_nottingham_preprocessed/all_epochs.pkl"
)

decimate=False
with open(EPOCHS_DIR, "rb") as f:
  epochs = pickle.load(f)

picks = mne.pick_types(epochs.info, meg=True, exclude="bads")

# Extract the epoch data for the selected channels
x = epochs.get_data()[:, picks]
num_channels = x.shape[1]
num_samples = x.shape[2]

y = (np.log2(epochs.events[:, 2]) - 2).astype(int)

if decimate:
    x = x[:, :, ::10]

In [6]:
# Split the data into training and testing sets
x_train, x_test, y_train, y_test = train_test_split(
    x, y, test_size=0.2, random_state=50
)

dataset_train = MyDataset(x_train, y_train)
dataset_test = MyDataset(x_test, y_test)

# Define batch size for training
batch_size = 16  # You can adjust this based on your available memory

# Create a DataLoader for your dataset
train_loader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False)

In [7]:
classifier = DeepConvNet(num_channels=num_channels, num_samples=num_samples, num_classes=5)

# Train the classifier using your training data
# train_loader should be a DataLoader containing your training data
classifier.train(train_loader, num_epochs=50, learning_rate=0.001)

Epoch [1/50], Loss: 1.6085
Epoch [2/50], Loss: 1.6093
Epoch [3/50], Loss: 1.6095
Epoch [4/50], Loss: 1.6100
Epoch [5/50], Loss: 1.6098
Epoch [6/50], Loss: 1.6099
Epoch [7/50], Loss: 1.6102
Epoch [8/50], Loss: 1.6092
Epoch [9/50], Loss: 1.6117
Epoch [10/50], Loss: 1.6092
Epoch [11/50], Loss: 1.6090
Epoch [12/50], Loss: 1.6048
Epoch [13/50], Loss: 1.6089
Epoch [14/50], Loss: 1.6128
Epoch [15/50], Loss: 1.6087
Epoch [16/50], Loss: 1.6107
Epoch [17/50], Loss: 1.6091
Epoch [18/50], Loss: 1.6091
Epoch [19/50], Loss: 1.6123
Epoch [20/50], Loss: 1.6052
Epoch [21/50], Loss: 1.6207
Epoch [22/50], Loss: 1.6095
Epoch [23/50], Loss: 1.6096
Epoch [24/50], Loss: 1.6064
Epoch [25/50], Loss: 1.6186
Epoch [26/50], Loss: 1.6041
Epoch [27/50], Loss: 1.6110
Epoch [28/50], Loss: 1.6063
Epoch [29/50], Loss: 1.6128
Epoch [30/50], Loss: 1.6121
Epoch [31/50], Loss: 1.6071
Epoch [32/50], Loss: 1.6130
Epoch [33/50], Loss: 1.6158
Epoch [34/50], Loss: 1.6100
Epoch [35/50], Loss: 1.6086
Epoch [36/50], Loss: 1.6165
E

In [8]:
classifier.evaluate(test_loader) 

Test Accuracy: 0.1795
