In [None]:
import sys
import importlib
import subprocess

def install_if_colab():
    if "google.colab" in sys.modules:
        print("Running in Google Colab. Checking required libraries...")

        packages = ["moabb", "braindecode","torch_geometric"]  # Add required libraries
        missing_packages = [pkg for pkg in packages if importlib.util.find_spec(pkg) is None]

        if missing_packages:
            print(f"Installing missing libraries: {', '.join(missing_packages)}")
            !pip install {" ".join(missing_packages)}
        else:
            print("All required libraries are already installed.")
    else:
        print("Not running in Google Colab. No installation needed.")

install_if_colab()


In [None]:
from google.colab import drive
drive.mount('/content/drive')
path = "/content/drive/MyDrive/"

In [None]:
import pickle
with open(path+'EEG_data/test_set.pkl', 'rb') as f:
    test_set = pickle.load(f)

import pickle
with open(path+'EEG_data/train_set.pkl', 'rb') as f:
    train_set = pickle.load(f)

In [None]:
import torch
#from shallow_fbcsp import ShallowFBCSPNet
from braindecode.util import set_random_seeds
import torch.nn.functional as F


cuda = torch.cuda.is_available()  # check if GPU is available, if True chooses to use it
device = "cuda" if cuda else "cpu"
if cuda:
    torch.backends.cudnn.benchmark = True
seed = 282828
set_random_seeds(seed=seed, cuda=cuda)

n_classes = 4
classes = list(range(n_classes))
# Extract number of chans and time steps from dataset
n_channels = 22
input_window_samples = 1125

print("n_classes: ", n_classes)
print("n_channels:", n_channels)
print("input_window_samples size:", input_window_samples)

In [None]:
#from models_fbscp import CollapsedShallowNet
# The ShallowFBCSPNet is a `nn.Sequential` model
import importlib
import RGNN
importlib.reload(RGNN)
from RGNN import ShallowRGNNNet
from weight_init import init_weights
from CKA_functions import adjacency_matrix_motion,adjacency_matrix_distance_motion
adj_m,pos = adjacency_matrix_motion()
#print(adj_m)
adj_dis_m, dm = adjacency_matrix_distance_motion(pos,delta=5)
dm
torch_tensor = torch.from_numpy(dm)
edge_weight = torch_tensor.reshape(-1).to(device)
print(edge_weight.shape)

model = ShallowSGCNNet(22,4,1125,edge_weight)

# Display torchinfo table describing the model
print(model)

# Send model to GPU
if cuda:
    model.cuda()



In [None]:
# splitted = windows_dataset.split("session")
# train_set = splitted['0train']  # Session train
# test_set = splitted['1test']  # Session evaluation

from torch.nn import Module
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader

#lr = 1e-4
#weight_decay = 1e-4
#batch_size = 64
#n_epochs = 200


In [None]:
#train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
#progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))


#from collections import defaultdict

#counting_dict = defaultdict(int)  # Initialize class counter

#for batch_idx, (X, y, _) in progress_bar:
#   X, y = X.to(device), y.to(device)  # Move to device if needed

    # Count occurrences of each class in y
#    for value in y:
#        counting_dict[int(value.item())] += 1  # Convert tensor to int and update count

# Print class frequencies
#print("Class counts:", dict(counting_dict))


In [None]:

from tqdm import tqdm
# Define a method for training one epoch


def train_one_epoch(
        dataloader: DataLoader, model: Module,edge_index, loss_fn, optimizer,
        scheduler: LRScheduler, epoch: int, device, print_batch_stats=True
):
    model.train()  # Set the model to training mode
    train_loss, correct = 0, 0

    progress_bar = tqdm(enumerate(dataloader), total=len(dataloader),
                        disable=not print_batch_stats)

    for batch_idx, (X, y, _) in progress_bar:
        X, y = X.to(device), y.to(device)
        #print(y)
        #print(X.shape)
        optimizer.zero_grad()
        pred = model(X,edge_index)


        #print(y.shape)
        #print(pred.shape)
        loss = loss_fn(pred, y)
        loss.backward()
        optimizer.step()  # update the model weights
        optimizer.zero_grad()
        #print(loss.item())
        #print(loss)
        #print(train_loss)
        train_loss += loss.item()
        correct += (pred.argmax(1) == y).sum().item()

        #if print_batch_stats:
        #    progress_bar.set_description(
        #        f"Epoch {epoch}/{n_epochs}, "
        #        f"Batch {batch_idx + 1}/{len(dataloader)}, "
        #        f"Loss: {loss.item():.6f}"
        #    )

    # Update the learning rate
    scheduler.step()

    correct /= len(dataloader.dataset)
    return train_loss / len(dataloader), correct


In [None]:
from collections import defaultdict
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix
import numpy as np

@torch.no_grad()
def test_model(dataloader: DataLoader, model: torch.nn.Module,edge_index, loss_fn, print_batch_stats=True):
    device = next(model.parameters()).device  # Get model device
    size = len(dataloader.dataset)
    n_batches = len(dataloader)
    model.eval()  # Switch to evaluation mode
    test_loss, correct = 0, 0

    # Initialize dictionaries for per-class tracking
    class_correct = defaultdict(int)
    class_total = defaultdict(int)

    # Lists to store true and predicted labels for confusion matrix
    all_preds = []
    all_targets = []

    if print_batch_stats:
        progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))
    else:
        progress_bar = enumerate(dataloader)

    for batch_idx, (X, y, _) in progress_bar:
        X, y = X.to(device), y.to(device)
        pred = model(X,edge_index)
        batch_loss = loss_fn(pred, y).item()

        test_loss += batch_loss
        correct += (pred.argmax(1) == y).sum().item()

        # Store predictions and true labels for confusion matrix
        all_preds.append(pred.argmax(1).cpu())
        all_targets.append(y.cpu())

        # Compute per-class accuracy
        preds_labels = pred.argmax(1)
        for label, pred_label in zip(y, preds_labels):
            class_total[label.item()] += 1
            class_correct[label.item()] += (label == pred_label).item()

        if print_batch_stats:
            progress_bar.set_description(
                f"Batch {batch_idx + 1}/{len(dataloader)}, Loss: {batch_loss:.6f}"
            )

    # Convert lists to tensors
    all_preds = torch.cat(all_preds)
    all_targets = torch.cat(all_targets)

    # Compute per-class accuracy
    class_accuracies = {
        cls: (class_correct[cls] / class_total[cls]) * 100 if class_total[cls] > 0 else 0
        for cls in class_total
    }

    # Compute overall accuracy
    test_loss /= n_batches
    overall_accuracy = (correct / size) * 100

    # Print per-class accuracy
    print("\nClass-wise Accuracy:")
    for cls, acc in class_accuracies.items():
        print(f"  Class {cls}: {acc:.2f}%")

    print(f"Test Accuracy: {overall_accuracy:.1f}%, Test Loss: {test_loss:.6f}\n")



    return test_loss, overall_accuracy, class_accuracies, all_preds, all_targets


In [None]:
import wandb
wandb.login()

In [None]:


def get_e_index(dm):
  threshold = 0  # Adjust as needed

  source_nodes = []
  target_nodes = []

  # Iterate over all elements in the distance matrix, including self-loops and duplicates
  for i in range(dm.shape[0]):
      for j in range(dm.shape[1]):  # Iterate over all pairs, including (i, i)
          if dm[i, j] >= threshold:  # If the distance meets the condition
              source_nodes.append(i)  # Source node
              target_nodes.append(j)  # Target node

  # Create the edge_index tensor
  edge_index = torch.tensor([source_nodes, target_nodes], dtype=torch.long).to(device)

In [None]:
import torch
import wandb
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn import CrossEntropyLoss
import numpy as np
import matplotlib.pyplot as plt
import math

seeds = [99999,182726,91111222,44552222,12223111,100300,47456655,4788347,77766666,809890]
for _seed in seeds:
  seed = _seed
  set_random_seeds(seed=seed, cuda=cuda)
  adj_m,pos = adjacency_matrix_motion()
  #print(adj_m)
  adj_dis_m, dm = adjacency_matrix_distance_motion(pos,delta=5)
  dm
  torch_tensor = torch.from_numpy(dm)
  edge_weight = torch_tensor.reshape(-1).to(device)
  model = ShallowSGCNNet(
      n_chans=22,
      n_outputs=n_classes,
      n_times=input_window_samples,
  )

  if cuda:
      model.cuda()

  edge_index = get_e_index(dm)
  # Initialize Weights & Biases
  wandb.init(project="Master Thesis", name=f"{model.__class__.__name__} {seed}")
  model.apply(init_weights)
  # Define hyperparameters
  lr = 1e-3
  weight_decay = 1e-4
  batch_size = 64  # Start with 124
  n_epochs = 100

  final_acc = 0.0

  # Log hyperparameters to wandb
  wandb.config.update({
      "learning_rate": lr,
      "weight_decay": weight_decay,
      "batch_size": batch_size,
      "epochs": n_epochs
  })

  # Define optimizer and scheduler
  optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
  scheduler = CosineAnnealingLR(optimizer, T_max=n_epochs - 1)

  # Define loss function
  loss_fn = CrossEntropyLoss()

  # Create DataLoaders
  train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
  test_loader = DataLoader(test_set, batch_size=batch_size)

  # Initialize lists to store all predictions & targets
  all_preds, all_targets = [], []

  # Training loop
  for epoch in range(1, n_epochs + 1):
      print(f"Epoch {epoch}/{n_epochs}: ", end="")

      train_loss, train_accuracy = train_one_epoch(
          train_loader, model,edge_index, loss_fn, optimizer, scheduler, epoch, device
      )

      test_loss, test_accuracy, class_accuracies, batch_preds, batch_targets = test_model(test_loader, model,edge_index, loss_fn)
      final_acc = test_accuracy

      # Store predictions & labels for confusion matrix
      all_preds.extend(batch_preds)
      all_targets.extend(batch_targets)

      # Print class-wise accuracy
      print("\nClass-wise Accuracy:")
      for class_idx, acc in class_accuracies.items():
          print(f"  Class {class_idx}: {acc:.2f}%")

      adj_matrix = model.rgnn.edge_weights.detach().cpu().numpy().reshape(22,22)

      # Create a plot
      fig, ax = plt.subplots(figsize=(8, 8))
      cax = ax.matshow(adj_matrix, cmap='viridis', interpolation='nearest',vmin=0, vmax=1)  # Adjust color map
      fig.colorbar(cax)

      # Save the adjacency matrix as an image (not as a graph)
      plt.tight_layout()
      img_path = f"adj_matrix_epoch_{epoch}.png"
      plt.savefig(img_path)
      plt.close()

      # Log results to wandb
      wandb.log({
          "epoch": epoch,
          "train_loss": train_loss,
          "train_accuracy": train_accuracy * 100,
          "test_loss": test_loss,
          "test_accuracy": test_accuracy,
          "learning_rate": scheduler.get_last_lr()[0],
          **{f"class_{class_idx}_accuracy": acc for class_idx, acc in class_accuracies.items()},
          "adj_matrix": wandb.Image(img_path),
      })

      print(
          f"Train Accuracy: {100 * train_accuracy:.2f}%, "
          f"Average Train Loss: {train_loss:.6f}, "
          f"Test Accuracy: {test_accuracy:.2f}%, "
          f"Average Test Loss: {test_loss:.6f}\n"
      )

  # Convert lists to NumPy arrays
  all_preds = np.array(all_preds)
  all_targets = np.array(all_targets)

  # Save predictions & true labels for later use (confusion matrix)
  wandb.log({"all_preds": all_preds.tolist(), "all_targets": all_targets.tolist()})
  wandb.finish()

  # Assuming 'model' is your trained Braindecode model
  print(seed)
  torch.save(model, f"{model.__class__.__name__}_{math.ceil(final_acc)}_{seed}.pth")
  torch.save(model.state_dict(), f"{model.__class__.__name__}_{math.ceil(final_acc)}_{seed}_state.pth")


