## Initial notebook for project 

In [1]:
# Imports and set up
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader,TensorDataset
import matplotlib.pyplot as plt
from torcheeg.io.eeg_signal import EEGSignalIO
from sklearn.model_selection import train_test_split

## Path to dir with data (remember the last '/')
path = "../data/"

## Establish connection to datafile
IO = EEGSignalIO(io_path=str(path), io_mode='lmdb')
## Read metadata dataframeimports
metadata = pd.read_csv(path + 'sample_metadata.tsv', sep='\t')

In [2]:
# Verifying connextion to data
idxs = np.arange(len(metadata))

X = torch.FloatTensor(np.array([IO.read_eeg(str(i)) for i in idxs]))
print(f"nsamples: {X.shape[0]}  -  nchannels: {X.shape[1]}  -  t: {X.shape[2]}")

y = torch.tensor(metadata["value"].values, dtype=torch.long)

nsamples: 5184  -  nchannels: 22  -  t: 800


In [3]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=42, stratify=y)
X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size=0.5, random_state=42)

print(f"shape of X train: {X_train.shape}")
print(f"shape of X val: {X_val.shape}")
print(f"shape of X test: {X_test.shape}")

print(f"shape of y train: {y_train.shape}")
print(f"shape of y val: {y_val.shape}")
print(f"shape of y test: {y_test.shape}")

shape of X train: torch.Size([3110, 22, 800])
shape of X val: torch.Size([1037, 22, 800])
shape of X test: torch.Size([1037, 22, 800])
shape of y train: torch.Size([3110])
shape of y val: torch.Size([1037])
shape of y test: torch.Size([1037])


In [4]:
#set up of matrixies 
#number of samples, channels, and timesteps
nsamples_train, nchannels_train, t = X_train.shape
nsamples_val, nchannels_val, t = X_val.shape

train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True)

adj_matrix_train = torch.eye(nchannels_train)
adj_matrix_val = torch.eye(nchannels_val)

In [5]:
class GraphConv(nn.Module):
    def __init__(self, in_features, out_features):
        super(GraphConv, self).__init__()
        self.weight = nn.Parameter(torch.randn(in_features, out_features))
        self.bias = nn.Parameter(torch.zeros(out_features))

    def forward(self, x, adj):
        x = torch.matmul(adj, x)  
        x = torch.matmul(x, self.weight) + self.bias  
        return torch.relu(x)

class EEG_GNN(nn.Module):
    def __init__(self, in_features, hidden_dim, nclasses, nchannels):
        super(EEG_GNN, self).__init__()
        self.conv1 = GraphConv(in_features, hidden_dim)
        self.batch_norm1 = nn.BatchNorm1d(nchannels) 
        self.conv2 = GraphConv(hidden_dim, hidden_dim)
        self.batch_norm2 = nn.BatchNorm1d(nchannels) 
        self.conv3 = GraphConv(hidden_dim, nclasses)  

    def forward(self, x, adj):
        x = self.conv1(x, adj)
        x = self.batch_norm1(x)  
        x = F.relu(x)
        x = self.conv2(x, adj)
        x = self.batch_norm2(x)  
        x = F.relu(x)
        x = self.conv3(x, adj)  
        x = x.mean(dim=1)  
        return x

In [6]:
print(adj_matrix_train.shape)
print(type(adj_matrix_train))
print(nsamples_train)

print(adj_matrix_val.shape)
print(type(adj_matrix_val))
print(nsamples_val)

torch.Size([22, 22])
<class 'torch.Tensor'>
3110
torch.Size([22, 22])
<class 'torch.Tensor'>
1037


In [7]:
def get_adj_by_corr(data, nsamples, threshold=0.5):

    # Compute Pearson correlation between channels
    tmp_lst=[]
    
    for i in range(nsamples): 
        correlation_matrix = np.corrcoef(data[i])
        tmp_lst.append((correlation_matrix > threshold).astype(np.float32))
        
    # Convert correlation to adjacency matrix
    adj_mat = torch.from_numpy(np.array(tmp_lst))
    
    return adj_mat

In [19]:
class TrainGNN():
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def train_model(self, model, train_loader, val_loader, adjacency_matrix, learning_rate=0.0005, epochs=500, prints=True):
        model = model.to(self.device)
        adjacency_matrix = adjacency_matrix.to(self.device)  # Move adjacency to GPU if available

        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)

        highest_train_accuracy = 0.0
        
        losses_train = []; losses_val = []

        for epoch in range(epochs):
            model.train()
            running_loss = 0.0; running_loss_val = 0.0
            correct = 0; correct_val = 0
            total = 0; total_val = 0
            
            for inputs, labels in train_loader:
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)

                optimizer.zero_grad()
                outputs = model(inputs, adjacency_matrix)  # Forward pass
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

            epoch_loss = running_loss / len(train_loader.dataset)
            epoch_accuracy = correct/total
            losses_train.append(epoch_loss)
            
            for inputs, labels in val_loader:
                outputs = model(inputs, adjacency_matrix)
                loss = criterion(outputs, labels)
                
                running_loss_val += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()
                
            epoch_loss_val = running_loss_val / len(val_loader.dataset)
            epoch_accuracy_val = correct_val/total_val
            losses_val.append(epoch_loss_val)
                
            if epoch_accuracy > highest_train_accuracy:
                highest_train_accuracy = epoch_accuracy

            if prints:
                print(f"Epoch {epoch+1}/{epochs}, Train loss: {epoch_loss:.4f}, Train acc: {(epoch_accuracy*100):.2f}%" +
                     f"| Val loss: {epoch_loss_val:.4f}, Val acc: {(epoch_accuracy_val*100):.2f}%")

        print(f"Highest Train Accuracy {(highest_train_accuracy*100):.2f}")
        torch.save(model.state_dict(), 'eeg_gnn.pth')
        
        losses = [losses_train, losses_val]

        return model, losses

In [20]:
nclasses = y.max().item() + 1 
model = EEG_GNN(in_features=t, hidden_dim=32, nclasses=nclasses, nchannels=nchannels_train)

trainer = TrainGNN()
trained_model, losses = trainer.train_model(model, train_loader, val_loader, adj_matrix_train, epochs=20, prints=True)

Epoch 1/20, Train loss: 3.0868, Train acc: 16.72%| Val loss: 2.8919, Val acc: 16.39%
Epoch 2/20, Train loss: 2.6281, Train acc: 19.20%| Val loss: 2.5702, Val acc: 17.36%
Epoch 3/20, Train loss: 2.3599, Train acc: 21.41%| Val loss: 2.3630, Val acc: 18.42%
Epoch 4/20, Train loss: 2.1719, Train acc: 22.64%| Val loss: 2.2250, Val acc: 19.19%
Epoch 5/20, Train loss: 2.0324, Train acc: 23.38%| Val loss: 2.1025, Val acc: 19.48%
Epoch 6/20, Train loss: 1.9274, Train acc: 24.95%| Val loss: 2.0167, Val acc: 19.38%
Epoch 7/20, Train loss: 1.8470, Train acc: 25.69%| Val loss: 1.9496, Val acc: 19.58%
Epoch 8/20, Train loss: 1.7823, Train acc: 25.98%| Val loss: 1.8932, Val acc: 20.25%
Epoch 9/20, Train loss: 1.7348, Train acc: 26.56%| Val loss: 1.8464, Val acc: 19.86%
Epoch 10/20, Train loss: 1.6958, Train acc: 26.43%| Val loss: 1.8130, Val acc: 19.67%
Epoch 11/20, Train loss: 1.6678, Train acc: 27.04%| Val loss: 1.7913, Val acc: 19.96%
Epoch 12/20, Train loss: 1.6454, Train acc: 27.49%| Val loss: 1