In [1]:
import os
import sys
import wandb
import torch
import GPUtil
from EEGNet import EEGNet
from torchinfo import summary
from torch.utils.data import DataLoader

sys.path.append('../../')
from utils.coco_data_handler import COCODataHandler
from utils.epoch_data_reader import EpochDataReader

In [2]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # Force CUDA to use the GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Use first GPU
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # Enable memory optimization settings for PyTorch

In [3]:
# Check if CUDA is available
try:
    gpus = GPUtil.getGPUs()
    if gpus:
        print(f"GPUtil detected {len(gpus)} GPUs:")
        for i, gpu in enumerate(gpus):
            print(f"  GPU {i}: {gpu.name} (Memory: {gpu.memoryTotal}MB)")
        
        # Set default GPU
        os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in range(len(gpus))])
        print(f"Set CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']}")
    else:
        print("GPUtil found no available GPUs")
except Exception as e:
    print(f"Error checking GPUs with GPUtil: {e}")

GPUtil found no available GPUs


In [4]:
# Check for CUDA availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Print available GPU memory
if torch.cuda.is_available():
    print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"Available GPU memory: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB")

Using device: cpu


In [5]:
# if torch.backends.mps.is_available():
#     device = torch.device("mps")
# else:
# Use CPU on Mac, there is a know bug with PyTorch
device = torch.device("cpu")

In [6]:
coco_data = COCODataHandler.get_instance()
coco_data.category_index

{'accessory': 0,
 'animal': 1,
 'appliance': 2,
 'electronic': 3,
 'food': 4,
 'furniture': 5,
 'indoor': 6,
 'kitchen': 7,
 'outdoor': 8,
 'person': 9,
 'sports': 10,
 'vehicle': 11}

In [7]:
channels = ['Fp1', 'AF7', 'AF3', 'F1', 'F3', 'F5', 'F7', 'FT7', 'FC5', 'FC3', 'FC1', 'C1', 'C3', 'C5', 'T7', 'TP7', 'CP5', 'CP3', 'CP1', 'P1', 'P3', 'P5', 'P7', 'P9', 'PO7', 'PO3', 'O1', 'Iz', 'Oz', 'POz', 'Pz', 'CPz', 'Fpz', 'Fp2', 'AF8', 'AF4', 'AFz', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT8', 'FC6', 'FC4', 'FC2', 'FCz', 'Cz', 'C2', 'C4', 'C6', 'T8', 'TP8', 'CP6', 'CP4', 'CP2', 'P2', 'P4', 'P6', 'P8', 'P10', 'PO8', 'PO4', 'O2']

num_classes = len(coco_data.category_index.keys())
model_type = "low-resolution (downsampled)" # "super-resolution (upsampled)" | "high-resolution (ground-truth)"

# Training parameters
epochs = 100

# Optimizer parameters
lr = 5e-5
# weight_decay = 0.5
# beta_1 = 0.9
# beta_2 = 0.95

batch_size = 64

# Dataset parameters
split = "70/25/5"
epoch_type = "around_evoked_event"
before = 0.05
after = 0.6
random_state = 97

In [8]:
dataset = EpochDataReader(
    channel_names=channels
)

In [9]:
loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

len(loader)

1437

In [10]:
sample_item = dataset[0][0]
num_channels = sample_item.shape[0]
time_steps = sample_item.shape[1]
sfreq = dataset.resample_freq

config = {
    "total_epochs_trained_on": epochs,
    "model_type": model_type,
    "time_steps_in_seconds": time_steps / sfreq,
    "model_params": {
        "model": "EEGNet",
        "num_channels": num_channels,
        "num_classes": num_classes,
        "time_steps": time_steps,
        "builtin_montage": "standard_1020",
    },
    "dataset_params": {
        "subject_session_id": dataset.subject_session_id,
        "epoch_type": dataset.epoch_type,
        "split": dataset.split,
        "duration": str((dataset.before + dataset.after) * 1000) + 'ms' if dataset.epoch_type == 'around_evoked' else dataset.fixed_length_duration,
        "batch_size": batch_size,
        "random_state": dataset.random_state
    },
    "optimizer_params": {
        "optimizer": "Adam",
        "learning_rate": lr,
        # "weight_decay": weight_decay,
        # "betas": (beta_1, beta_2)
    }
}

In [11]:
model = EEGNet(device, num_channels, time_steps, num_classes)
summary(model)

Layer (type:depth-idx)                   Param #
EEGNet                                   --
├─Ensure4d: 1-1                          --
├─Expression: 1-2                        --
├─Conv2d: 1-3                            512
├─BatchNorm2d: 1-4                       16
├─Conv2dWithConstraint: 1-5              1,024
├─BatchNorm2d: 1-6                       32
├─Expression: 1-7                        --
├─AvgPool2d: 1-8                         --
├─Dropout: 1-9                           --
├─Conv2d: 1-10                           256
├─Conv2d: 1-11                           256
├─BatchNorm2d: 1-12                      32
├─Expression: 1-13                       --
├─AvgPool2d: 1-14                        --
├─Dropout: 1-15                          --
├─Flatten: 1-16                          --
├─Linear: 1-17                           82,432
├─Linear: 1-18                           6,156
Total params: 90,716
Trainable params: 90,716
Non-trainable params: 0

In [12]:
optimizer = torch.optim.Adam(
    params=[{'params': model.parameters()}], 
    lr=lr
)

with wandb.init(project="eeg-eegnet", config=config) as run:
    history = model.fit(loader, 1, optimizer, 'cpoints', 'classification', use_checkpoint=True)

[34m[1mwandb[0m: Currently logged in as: [33mdubs2310[0m ([33mdubs2310-cal-poly-pomona[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Loading checkpoint from cpoints/eegnet_classification_best.pt


In [None]:
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix
loader.dataset.set_split_type('all')

all_preds = []
all_targets = []

for batch in loader:
    epochs = batch[0]
    one_hot_encoding = batch[1]
    y_pred = model.predict(epochs)

    all_preds.append(y_pred)
    all_targets.append(one_hot_encoding)

all_preds = np.concatenate(all_preds, axis=0)  
all_targets = np.concatenate(all_targets, axis=0)

In [26]:
labels = list(coco_data.category_index.keys())
print(classification_report(all_targets, all_preds, target_names=labels))

              precision    recall  f1-score   support

   accessory       0.00      0.00      0.00      5599
      animal       0.00      0.00      0.00     10825
   appliance       0.00      0.00      0.00      3160
  electronic       0.00      0.00      0.00      3878
        food       0.00      0.00      0.00      6086
   furniture       0.00      0.00      0.00      9770
      indoor       0.00      0.00      0.00      5651
     kitchen       0.00      0.00      0.00      6558
     outdoor       0.00      0.00      0.00      4358
      person       0.51      0.84      0.64     23413
      sports       0.00      0.00      0.00      9479
     vehicle       0.00      0.00      0.00     10721

   micro avg       0.51      0.20      0.29     99498
   macro avg       0.04      0.07      0.05     99498
weighted avg       0.12      0.20      0.15     99498
 samples avg       0.43      0.18      0.25     99498

