In [None]:
import sys
import importlib
import subprocess

def install_if_colab():
    if "google.colab" in sys.modules:
        print("Running in Google Colab. Checking required libraries...")

        packages = ["moabb", "braindecode","torch_geometric"]  # Add required libraries
        missing_packages = [pkg for pkg in packages if importlib.util.find_spec(pkg) is None]

        if missing_packages:
            print(f"Installing missing libraries: {', '.join(missing_packages)}")
            !pip install {" ".join(missing_packages)}
        else:
            print("All required libraries are already installed.")
    else:
        print("Not running in Google Colab. No installation needed.")

install_if_colab()


Running in Google Colab. Checking required libraries...
Installing missing libraries: moabb, braindecode, torch_geometric
Collecting moabb
  Downloading moabb-1.2.0-py3-none-any.whl.metadata (14 kB)
Collecting braindecode
  Downloading braindecode-0.8.1-py3-none-any.whl.metadata (8.1 kB)
Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
Collecting coverage<8.0.0,>=7.0.1 (from moabb)
  Downloading coverage-7.7.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.5 kB)
Collecting edfio<0.5.0,>=0.4.2 (from moabb)
  Downloading edfio-0.4.7-py3-none-any.whl.metadata (3.9 kB)
Collecting edflib-python<2.0.0,>=1.0.6 (from moabb)
  Downloading EDFlib_Python-1.0.8-py3-none-any.whl.metadata (1.3 kB)
Collecting memory-profiler<0.62.0,>=0.61.0 (from moabb)
  Downloading memory

In [None]:
from google.colab import drive
drive.mount('/content/drive')
path = "/content/drive/MyDrive/"

In [None]:
import pickle
with open(path+'EEG_data/test_set.pkl', 'rb') as f_test:
    test_set = pickle.load(f_test)

import pickle
with open(path+'EEG_data/train_set.pkl', 'rb') as f_train:
    train_set = pickle.load(f_train)

In [None]:
print(train_set[0])

In [None]:
print(test_set[0])

In [None]:
import torch
#from shallow_fbcsp import ShallowFBCSPNet
from braindecode.util import set_random_seeds


cuda = torch.cuda.is_available()  # check if GPU is available, if True chooses to use it
device = "cuda" if cuda else "cpu"
if cuda:
    torch.backends.cudnn.benchmark = True
seed = 11999
set_random_seeds(seed=seed, cuda=cuda)

n_classes = 4
classes = list(range(n_classes))
# Extract number of chans and time steps from dataset
n_channels = 22
input_window_samples = 1125

print("n_classes: ", n_classes)
print("n_channels:", n_channels)
print("input_window_samples size:", input_window_samples)

In [None]:
#from models_fbscp import CollapsedShallowNet
# The ShallowFBCSPNet is a `nn.Sequential` model
import importlib
from weight_init import init_weights
import GIN_model
importlib.reload(GIN_model)
from GIN_model import ShallowGNNNet
model = ShallowGNNNet(
    n_chans=22,
    n_outputs=n_classes,
    n_times=input_window_samples,
    device=device,
)

#model.apply(init_weights)

# Display torchinfo table describing the model
print(model)

# Send model to GPU
if cuda:
    model.cuda()



In [None]:
# splitted = windows_dataset.split("session")
# train_set = splitted['0train']  # Session train
# test_set = splitted['1test']  # Session evaluation

from torch.nn import Module
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader

#lr = 1e-4
#weight_decay = 1e-4
#batch_size = 64
#n_epochs = 200


In [None]:

from tqdm import tqdm
# Define a method for training one epoch


def train_one_epoch(
        dataloader: DataLoader, model: Module, loss_fn, optimizer,
        scheduler: LRScheduler, epoch: int, device, print_batch_stats=True
):
    model.train()  # Set the model to training mode
    train_loss, correct = 0, 0

    progress_bar = tqdm(enumerate(dataloader), total=len(dataloader),
                        disable=not print_batch_stats)

    for batch_idx, (X, y, _) in progress_bar:
        X, y = X.to(device), y.to(device)
        #print(X.shape)
        optimizer.zero_grad()
        pred = model(X)
        loss = loss_fn(pred, y)
        loss.backward()
        print(f"Adjacency Matrix Gradient: {model.gnn.edge_weights.grad}")
        optimizer.step()  # update the model weights
        optimizer.zero_grad()

        train_loss += loss.item()
        correct += (pred.argmax(1) == y).sum().item()

        #if print_batch_stats:
        #    progress_bar.set_description(
        #        f"Epoch {epoch}/{n_epochs}, "
        #        f"Batch {batch_idx + 1}/{len(dataloader)}, "
        #        f"Loss: {loss.item():.6f}"
        #    )

    # Update the learning rate
    scheduler.step()

    correct /= len(dataloader.dataset)
    return train_loss / len(dataloader), correct


In [None]:
from collections import defaultdict
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix
import numpy as np

@torch.no_grad()
def test_model(dataloader: DataLoader, model: torch.nn.Module, loss_fn, print_batch_stats=True):
    device = next(model.parameters()).device  # Get model device
    size = len(dataloader.dataset)
    n_batches = len(dataloader)
    model.eval()  # Switch to evaluation mode
    test_loss, correct = 0, 0

    # Initialize dictionaries for per-class tracking
    class_correct = defaultdict(int)
    class_total = defaultdict(int)

    # Lists to store true and predicted labels for confusion matrix
    all_preds = []
    all_targets = []

    if print_batch_stats:
        progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))
    else:
        progress_bar = enumerate(dataloader)

    for batch_idx, (X, y, _) in progress_bar:
        X, y = X.to(device), y.to(device)
        pred = model(X)
        batch_loss = loss_fn(pred, y).item()

        test_loss += batch_loss
        correct += (pred.argmax(1) == y).sum().item()

        # Store predictions and true labels for confusion matrix
        all_preds.append(pred.argmax(1).cpu())
        all_targets.append(y.cpu())

        # Compute per-class accuracy
        preds_labels = pred.argmax(1)
        for label, pred_label in zip(y, preds_labels):
            class_total[label.item()] += 1
            class_correct[label.item()] += (label == pred_label).item()

        if print_batch_stats:
            progress_bar.set_description(
                f"Batch {batch_idx + 1}/{len(dataloader)}, Loss: {batch_loss:.6f}"
            )

    # Convert lists to tensors
    all_preds = torch.cat(all_preds)
    all_targets = torch.cat(all_targets)

    # Compute per-class accuracy
    class_accuracies = {
        cls: (class_correct[cls] / class_total[cls]) * 100 if class_total[cls] > 0 else 0
        for cls in class_total
    }

    # Compute overall accuracy
    test_loss /= n_batches
    overall_accuracy = (correct / size) * 100

    # Print per-class accuracy
    print("\nClass-wise Accuracy:")
    for cls, acc in class_accuracies.items():
        print(f"  Class {cls}: {acc:.2f}%")

    print(f"Test Accuracy: {overall_accuracy:.1f}%, Test Loss: {test_loss:.6f}\n")

    return test_loss, overall_accuracy, class_accuracies, all_preds, all_targets


In [None]:
import wandb
wandb.login()

In [None]:
import torch
import wandb
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn import CrossEntropyLoss
import numpy as np
import matplotlib.pyplot as plt

# Initialize Weights & Biases
wandb.init(project="Master Thesis", name=f"GIN {seed}")

# Define hyperparameters
lr = 1e-3
weight_decay = 1e-4
batch_size = 64  # Start with 124
n_epochs = 100
patience = 30  # Patience for early stopping

final_acc = 0.0

# Log hyperparameters to wandb
wandb.config.update({
    "learning_rate": lr,
    "weight_decay": weight_decay,
    "batch_size": batch_size,
    "epochs": n_epochs
})

# Define optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = CosineAnnealingLR(optimizer, T_max=n_epochs - 1)

# Define loss function
loss_fn = CrossEntropyLoss()

# Create DataLoaders
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size)

# Initialize lists to store all predictions & targets
all_preds, all_targets = [], []

# Early Stopping setup
best_val_acc = 0.0
epochs_without_improvement = 0

# Training loop
for epoch in range(1, n_epochs + 1):
    print(f"Epoch {epoch}/{n_epochs}: ", end="")

    train_loss, train_accuracy = train_one_epoch(
        train_loader, model, loss_fn, optimizer, scheduler, epoch, device
    )

    test_loss, test_accuracy, class_accuracies, batch_preds, batch_targets = test_model(test_loader, model, loss_fn)
    final_acc = test_accuracy

    # Store predictions & labels for confusion matrix
    all_preds.extend(batch_preds)
    all_targets.extend(batch_targets)

    # Print class-wise accuracy
    print("\nClass-wise Accuracy:")
    for class_idx, acc in class_accuracies.items():
        print(f"  Class {class_idx}: {acc:.2f}%")

    adj_matrix = model.edge_weights.detach().cpu().numpy().reshape(22,22)

    # Create a plot
    fig, ax = plt.subplots(figsize=(8, 8))
    cax = ax.matshow(adj_matrix, cmap='viridis', interpolation='nearest')  # Adjust color map
    fig.colorbar(cax)

    # Save the adjacency matrix as an image (not as a graph)
    plt.tight_layout()
    img_path = f"adj_matrix_epoch_{epoch}.png"
    plt.savefig(img_path)
    plt.close()

    # Log results to wandb
    wandb.log({
        "epoch": epoch,
        "train_loss": train_loss,
        "train_accuracy": train_accuracy * 100,
        "test_loss": test_loss,
        "test_accuracy": test_accuracy,
        "learning_rate": scheduler.get_last_lr()[0],
        **{f"class_{class_idx}_accuracy": acc for class_idx, acc in class_accuracies.items()},
        "adj_matrix": wandb.Image(img_path),
    })

    print(
        f"Train Accuracy: {100 * train_accuracy:.2f}%, "
        f"Average Train Loss: {train_loss:.6f}, "
        f"Test Accuracy: {test_accuracy:.2f}%, "
        f"Average Test Loss: {test_loss:.6f}\n"
    )

    # Early stopping check
    if test_accuracy > best_val_acc:
        best_val_acc = test_accuracy
        epochs_without_improvement = 0  # Reset counter if we have improvement
    else:
        epochs_without_improvement += 1

    # If no improvement for 'patience' epochs, stop training early
    if epochs_without_improvement >= patience:
        print(f"Early stopping triggered after {epoch+1} epochs.")
        break

# Convert lists to NumPy arrays
all_preds = np.array(all_preds)
all_targets = np.array(all_targets)

# Save predictions & true labels for later use (confusion matrix)
wandb.log({"all_preds": all_preds.tolist(), "all_targets": all_targets.tolist()})


In [None]:
wandb.finish()

In [None]:
import math
print(seed)
# Assuming 'model' is your trained Braindecode model
torch.save(model, f"GIN_{math.ceil(final_acc)}_{seed}.pth")
torch.save(model.state_dict(), f"GIN_{math.ceil(final_acc)}_{seed}_state.pth")
