## Initial notebook for project 

In [1]:
# Imports and set up
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader,TensorDataset

from torcheeg.io.eeg_signal import EEGSignalIO

## Path to dir with data (remember the last '/')
path = "../data/"

## Establish connection to datafile
IO = EEGSignalIO(io_path=str(path), io_mode='lmdb')
## Read metadata dataframeimports
metadata = pd.read_csv(path + 'sample_metadata.tsv', sep='\t')

In [2]:
# Verifying connextion to data
idxs = np.arange(len(metadata))

eeg = torch.FloatTensor(np.array([IO.read_eeg(str(i)) for i in idxs]))
print(f"nsamples: {eeg.shape[0]}  -  nchannels: {eeg.shape[1]}  -  t: {eeg.shape[2]}")

nsamples: 5184  -  nchannels: 22  -  t: 800


Imports we might need cuz eeg ref code.

In [3]:
#set up of matrixies 
#number of samples, channels, and timesteps
nsamples, nchannels, t = eeg.shape  

labels = torch.tensor(metadata["value"].values, dtype=torch.long)

train_dataset = TensorDataset(eeg, labels)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

adj_matrix = torch.eye(eeg.shape[1])  #place holder change later FIX

In [4]:
class GraphConv(nn.Module):
    def __init__(self, in_features, out_features):
        super(GraphConv, self).__init__()
        self.weight = nn.Parameter(torch.randn(in_features, out_features))
        self.bias = nn.Parameter(torch.zeros(out_features))

    def forward(self, x, adj):
        x = torch.matmul(adj, x)  # Apply adjacency matrix
        x = torch.matmul(x, self.weight) + self.bias  # Linear transformation
        return torch.relu(x)

class EEG_GNN(nn.Module):
    def __init__(self, in_features, hidden_dim, nclasses):
        super(EEG_GNN, self).__init__()
        self.conv1 = GraphConv(in_features, hidden_dim)
        self.conv2 = GraphConv(hidden_dim, nclasses) 

    def forward(self, x, adj):
        x = self.conv1(x, adj)
        x = self.conv2(x, adj)  
        x = x.mean(dim=1) 
        return x 


In [5]:
class TrainGNN():
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def train_model(self, model, train_loader, adjacency_matrix, learning_rate=0.001, epochs=500):
        model = model.to(self.device)
        adjacency_matrix = adjacency_matrix.to(self.device)  # Move adjacency to GPU if available

        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)

        highest_train_accuracy = 0.0

        for epoch in range(epochs):
            model.train()
            running_loss = 0.0
            correct = 0
            total = 0
            for inputs, labels in train_loader:
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)

                optimizer.zero_grad()
                outputs = model(inputs, adjacency_matrix)  # Forward pass
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

            epoch_loss = running_loss / len(train_loader.dataset)
            epoch_accuracy = correct / total
            if epoch_accuracy > highest_train_accuracy:
                highest_train_accuracy = epoch_accuracy

            print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}, Accuracy: {(epoch_accuracy*100):.2f}%")

        print("Highest Train Accuracy:", highest_train_accuracy)
        torch.save(model.state_dict(), 'eeg_gnn.pth')

        return model


In [6]:
nclasses = labels.max().item() + 1 
model = EEG_GNN(in_features=t, hidden_dim=32, nclasses=nclasses)

trainer = TrainGNN()
trained_model = trainer.train_model(model, train_loader, adj_matrix,epochs=50)

Epoch 1/50, Loss: 39.8273, Accuracy: 21.82%
Epoch 2/50, Loss: 28.0527, Accuracy: 22.40%
Epoch 3/50, Loss: 19.4051, Accuracy: 22.74%
Epoch 4/50, Loss: 13.1676, Accuracy: 21.99%
Epoch 5/50, Loss: 8.9479, Accuracy: 20.83%
Epoch 6/50, Loss: 6.2899, Accuracy: 19.73%
Epoch 7/50, Loss: 4.6496, Accuracy: 18.15%
Epoch 8/50, Loss: 3.6531, Accuracy: 16.86%
Epoch 9/50, Loss: 3.0286, Accuracy: 15.39%
Epoch 10/50, Loss: 2.6226, Accuracy: 14.35%
Epoch 11/50, Loss: 2.3469, Accuracy: 13.27%
Epoch 12/50, Loss: 2.1549, Accuracy: 12.40%
Epoch 13/50, Loss: 2.0176, Accuracy: 11.71%
Epoch 14/50, Loss: 1.9174, Accuracy: 11.00%
Epoch 15/50, Loss: 1.8437, Accuracy: 10.55%
Epoch 16/50, Loss: 1.7867, Accuracy: 10.19%
Epoch 17/50, Loss: 1.7424, Accuracy: 9.84%
Epoch 18/50, Loss: 1.7077, Accuracy: 9.70%
Epoch 19/50, Loss: 1.6811, Accuracy: 9.38%
Epoch 20/50, Loss: 1.6592, Accuracy: 9.20%
Epoch 21/50, Loss: 1.6419, Accuracy: 9.18%
Epoch 22/50, Loss: 1.6275, Accuracy: 9.05%
Epoch 23/50, Loss: 1.6155, Accuracy: 9.01%
