In [1]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset
import copy

In [7]:
num_classes = 4
batchsize = 16
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
learning_rate = 0.001
num_epochs = 200

In [8]:
class EEGDataset(Dataset):
    def __init__(self, data_path, labels_path):
        self.data = np.load(data_path)
        self.labels = np.load(labels_path)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        # Get the EEG data and corresponding label
        eeg = torch.tensor(self.data[idx], dtype=torch.float32)
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return eeg, label

# Example usage
train_dataset = EEGDataset("eeg_dataset/train_epochs.npy", "eeg_dataset/train_labels.npy")
val_dataset = EEGDataset("eeg_dataset/val_epochs.npy", "eeg_dataset/val_labels.npy")
test_dataset = EEGDataset("eeg_dataset/test_epochs.npy", "eeg_dataset/test_labels.npy")

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batchsize, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batchsize, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batchsize, shuffle=False)

print(len(train_loader))
# Check one batch
for eeg_batch, label_batch in train_loader:
    print("EEG Batch Shape:", eeg_batch.shape)  # (batch_size, 14, 640)
    print("Label Batch Shape:", label_batch.shape)  # (batch_size,)
    break


27
EEG Batch Shape: torch.Size([16, 1076])
Label Batch Shape: torch.Size([16])


In [9]:
class MLP(nn.Module):
    def __init__(self, input_size, num_classes):
        super(MLP, self).__init__()
        
        self.relu = nn.ReLU()
        self.dropout1 = nn.Dropout(p=0.1)
        self.dropout2 = nn.Dropout(p=0.1)
        self.dropout3 = nn.Dropout(p=0.1)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.bn3 = nn.BatchNorm1d(64)
        
        self.l1 = nn.Linear(input_size, 512)
        self.l2 = nn.Linear(512, 256)
        self.l3 = nn.Linear(256, 64)
        self.l4 = nn.Linear(64, num_classes)
        
    def forward(self, x):
        x = self.l1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.dropout1(x)
        
        x = self.l2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.dropout2(x)

        x = self.l3(x)
        x = self.bn3(x)
        x = self.relu(x)
        x = self.dropout3(x)
        
        x = self.l4(x)
        return x
        

In [10]:
model = MLP(1076, num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

n_total_steps = len(train_loader)

best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for eeg, labels in train_loader:
        
        eeg = eeg.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        output = model(eeg)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    # Validation loop
    model.eval()
    val_accuracy = 0
    with torch.no_grad():
        for eeg, labels in val_loader:
            eeg = eeg.to(device)
            labels = labels.to(device)
            
            output = model(eeg)
            pred = output.argmax(dim=1, keepdim=True)
            val_accuracy += pred.eq(labels.view_as(pred)).sum().item()

    val_accuracy /= len(val_loader.dataset)
    print(f'Epoch {epoch+1}, Training Loss: {running_loss/len(train_loader)}, Validation Accuracy: {val_accuracy * 100}%')

    if val_accuracy > best_acc:
        best_acc = val_accuracy
        best_model_wts = copy.deepcopy(model.state_dict())


Epoch 1, Training Loss: 1.400624151583071, Validation Accuracy: 37.362637362637365%
Epoch 2, Training Loss: 1.2909097406599257, Validation Accuracy: 50.54945054945055%
Epoch 3, Training Loss: 1.1912864755701136, Validation Accuracy: 58.24175824175825%
Epoch 4, Training Loss: 1.100425159489667, Validation Accuracy: 70.32967032967034%
Epoch 5, Training Loss: 1.0383639909602977, Validation Accuracy: 73.62637362637363%
Epoch 6, Training Loss: 0.9827218011573509, Validation Accuracy: 74.72527472527473%
Epoch 7, Training Loss: 0.9305962213763485, Validation Accuracy: 79.12087912087912%
Epoch 8, Training Loss: 0.8465247485372756, Validation Accuracy: 83.51648351648352%
Epoch 9, Training Loss: 0.7862480900905751, Validation Accuracy: 81.31868131868131%
Epoch 10, Training Loss: 0.7918329857013844, Validation Accuracy: 81.31868131868131%
Epoch 11, Training Loss: 0.7476351879261158, Validation Accuracy: 86.81318681318682%
Epoch 12, Training Loss: 0.7324011789427863, Validation Accuracy: 89.010989

In [11]:

model.load_state_dict(best_model_wts)
# test
with torch.no_grad():
    n_correct = 0
    n_samples = 0
    for eeg, labels in test_loader:
        eeg = eeg.to(device)
        labels = labels.to(device)
        outputs = model(eeg)
        
        # # print(outputs)
        # print(torch.max(outputs, 1))
        # print("ACC")
        # print(labels)
        # break
        
        _, predictions = torch.max(outputs, 1)
        n_samples += labels.shape[0]
        n_correct += (predictions==labels).sum().item()
        
    acc = 100 * (n_correct/n_samples)
    print(f'accuracy = {acc}')

accuracy = 92.22222222222223
