In [7]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset

In [8]:
num_classes = 5
batchsize = 16
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
learning_rate = 0.001
num_epochs = 50

In [9]:
class EEGDataset(Dataset):
    def __init__(self, data_path, labels_path):
        self.data = np.load(data_path)
        self.labels = np.load(labels_path)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        # Get the EEG data and corresponding label
        eeg = torch.tensor(self.data[idx], dtype=torch.float32)
        label = torch.tensor(self.labels[idx], dtype=torch.int64)
        return eeg, label

# Example usage
train_dataset = EEGDataset("eeg_dataset/train_epochs.npy", "eeg_dataset/train_labels.npy")
val_dataset = EEGDataset("eeg_dataset/val_epochs.npy", "eeg_dataset/val_labels.npy")
test_dataset = EEGDataset("eeg_dataset/test_epochs.npy", "eeg_dataset/test_labels.npy")

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batchsize, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batchsize, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batchsize, shuffle=False)

print(len(train_loader))
# Check one batch
for eeg_batch, label_batch in train_loader:
    print("EEG Batch Shape:", eeg_batch.shape)  # (batch_size, 1, 640)
    print("Label Batch Shape:", label_batch.shape)  # (batch_size,)
    break


65
EEG Batch Shape: torch.Size([16, 1, 640])
Label Batch Shape: torch.Size([16])


In [10]:
print(len(train_dataset))

1039


In [15]:
class MLP(nn.Module):
    def __init__(self, input_size, num_classes):
        super(MLP, self).__init__()
        
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=8, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv1d(in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
        self.dropout = nn.Dropout(p=0.5)
        self.bnfc1 = nn.BatchNorm1d(128)
        self.bnfc2 = nn.BatchNorm1d(64)
        self.bn1 = nn.BatchNorm1d(8)  # After conv1
        self.bn2 = nn.BatchNorm1d(16)  # After conv2
        self.bn3 = nn.BatchNorm1d(32)  # After conv3
        
        self.fc1 = nn.Linear(32 * (input_size // 8), 128)  # Adjust based on input length
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, num_classes)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.pool(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.pool(x)
        
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu(x)
        x = self.pool(x)
        
        x = x.view(x.size(0), -1)  # Flatten for FC layers
        x = self.fc1(x)
        x = self.bnfc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        
        x = self.fc2(x)
        x = self.bnfc2(x)
        x = self.relu(x)
        x = self.dropout(x)
        
        x = self.fc3(x)
        
        return x
        

In [16]:
model = MLP(640, num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
step_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.1)

n_total_steps = len(train_loader)

for epoch in range(num_epochs):
    # print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    # print('-' * 10)
    
    # for phase in ['train', 'val']:
    #     if phase == 'train':
    #         model.train()  # Set model to training mode
    #     else:
    #         model.eval()   # Set model to evaluate mode

    #     running_loss = 0.0
    #     running_corrects = 0

    #     loader = train_loader if phase == 'train' else val_loader
        
    #     for i, (eeg, labels) in enumerate(loader):
    #         eeg = eeg.to(device)
    #         labels = labels.to(device)
            
    #         with torch.set_grad_enabled(phase == 'train'):
    #             # forward pass
    #             outputs = model(eeg)
    #             _, preds = torch.max(outputs, 1)
    #             loss = criterion(outputs, labels)
                
    #             if phase == 'train':
    #                 optimizer.zero_grad()
    #                 loss.backward()
    #                 optimizer.step()
                    
    #         # statistics
    #         running_loss += loss.item() * eeg.size(0)
    #         running_corrects += torch.sum(preds == labels.data)
            
    #     if phase == 'train':
    #         step_lr_scheduler.step()
            
    #     dataset_sizes = len(train_dataset) if phase == 'train' else len(val_dataset)
            
    #     epoch_loss = running_loss / dataset_sizes
    #     epoch_acc = running_corrects.double() / dataset_sizes
                    
    #     print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

    model.train()
    running_loss = 0.0
    for eeg, labels in train_loader:
        
        eeg = eeg.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        output = model(eeg)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    # Validation loop
    model.eval()
    val_accuracy = 0
    with torch.no_grad():
        for eeg, labels in val_loader:
            eeg = eeg.to(device)
            labels = labels.to(device)
            
            output = model(eeg)
            pred = output.argmax(dim=1, keepdim=True)
            val_accuracy += pred.eq(labels.view_as(pred)).sum().item()

    val_accuracy /= len(val_loader.dataset)
    print(f'Epoch {epoch+1}, Training Loss: {running_loss/len(train_loader)}, Validation Accuracy: {val_accuracy * 100}%')
    
    # Step scheduler
    step_lr_scheduler.step(running_loss)
            


Epoch 1, Training Loss: 1.743458812053387, Validation Accuracy: 23.766816143497756%
Epoch 2, Training Loss: 1.5617719155091505, Validation Accuracy: 23.318385650224215%
Epoch 3, Training Loss: 1.481669099514301, Validation Accuracy: 21.524663677130047%
Epoch 4, Training Loss: 1.3086539011735183, Validation Accuracy: 20.62780269058296%
Epoch 5, Training Loss: 1.1666739051158612, Validation Accuracy: 19.730941704035875%
Epoch 6, Training Loss: 1.0111059794059167, Validation Accuracy: 23.766816143497756%
Epoch 7, Training Loss: 0.8370797148117652, Validation Accuracy: 19.730941704035875%
Epoch 8, Training Loss: 0.713405814079138, Validation Accuracy: 24.2152466367713%
Epoch 9, Training Loss: 0.5956890729757456, Validation Accuracy: 21.076233183856502%
Epoch 10, Training Loss: 0.49052752691965834, Validation Accuracy: 23.318385650224215%
Epoch 11, Training Loss: 0.502615805543386, Validation Accuracy: 23.766816143497756%
Epoch 12, Training Loss: 0.4121878912815681, Validation Accuracy: 21.

In [None]:
# test
with torch.no_grad():
    n_correct = 0
    n_samples = 0
    for eeg, labels in test_loader:
        eeg = eeg.view(-1, 640).to(device)
        labels = labels.to(device)
        outputs = model(eeg)
        
        # # print(outputs)
        # print(torch.max(outputs, 1))
        # print("ACC")
        # print(labels)
        # break
        
        _, predictions = torch.max(outputs, 1)
        n_samples += labels.shape[0]
        
        n_correct += (predictions==labels).sum().item()
        
    acc = 100 * (n_correct/n_samples)
    print(f'accuracy = {acc}')

tensor([[ 0.4128,  0.8285, -0.1117,  1.0184, -0.4285],
        [ 0.7556, -0.1825, -0.1907,  0.7619, -0.5506],
        [-1.4447,  0.1858, -0.6263, -0.3076, -0.5893],
        [-0.3405, -0.0728,  0.2519, -0.3739,  0.6909],
        [ 0.1231, -0.3118,  0.5803, -0.5326, -0.4991],
        [-0.1691,  0.1540, -0.2083,  0.4978,  0.9697],
        [ 0.4857,  0.7541, -0.0427, -0.2517,  1.1587],
        [ 1.2802, -1.5254,  0.9511, -0.0391, -0.1489],
        [-1.4681,  0.0307, -0.5298, -0.2597,  0.4749],
        [ 0.3191, -0.8715,  0.3086, -0.5079,  0.4744],
        [ 0.0897, -1.8945,  0.5087, -0.2311, -0.6535],
        [ 1.5925, -0.3468, -0.1320, -0.3495,  0.4048],
        [ 0.5215, -0.4503,  0.3164,  1.2285,  0.2300],
        [ 0.8337, -0.5260, -0.3096, -0.4274, -0.0776],
        [-0.1658,  0.6187,  0.4250,  0.8674, -0.1612],
        [ 0.8629, -0.5693,  1.0497, -0.4761, -0.2668]], device='cuda:0')
tensor([1.0184, 0.7619, 0.1858, 0.6909, 0.5803, 0.9697, 1.1587, 1.2802, 0.4749,
        0.4744, 0.5087