In [6]:
import mne
import torch
import numpy as np
from torcheeg.models import LaBraM

# Load a sample EEG file
file_path = "data/301A_FG_preprocessed-epo.fif"  # Update with your actual file
epochs = mne.read_epochs(file_path, preload=True)

# Extract data and labels
eeg_data = epochs.get_data()  # Shape: (n_epochs, n_channels, n_times)
labels = epochs.events[:, -1]  # Assuming the last column contains event labels

# Normalize EEG data
eeg_data = (eeg_data - eeg_data.mean()) / eeg_data.std()

# Convert to PyTorch tensors
eeg_tensor = torch.tensor(eeg_data, dtype=torch.float32)
labels_tensor = torch.tensor(labels, dtype=torch.long)

# Extract the actual electrode names
electrode_names = [ch.upper() for ch in epochs.ch_names ] # List of channel names


In [7]:
electrode_names

['FP1',
 'AF7',
 'AF3',
 'F1',
 'F3',
 'F5',
 'F7',
 'FT7',
 'FC5',
 'FC3',
 'FC1',
 'C1',
 'C3',
 'C5',
 'T7',
 'TP7',
 'CP5',
 'CP3',
 'CP1',
 'P1',
 'P3',
 'P5',
 'P7',
 'P9',
 'PO7',
 'PO3',
 'O1',
 'IZ',
 'OZ',
 'POZ',
 'PZ',
 'CPZ',
 'FPZ',
 'FP2',
 'AF8',
 'AF4',
 'AFZ',
 'FZ',
 'F2',
 'F4',
 'F6',
 'F8',
 'FT8',
 'FC6',
 'FC4',
 'FC2',
 'FCZ',
 'CZ',
 'C2',
 'C4',
 'C6',
 'T8',
 'TP8',
 'CP6',
 'CP4',
 'CP2',
 'P2',
 'P4',
 'P6',
 'P8',
 'P10',
 'PO8',
 'PO4',
 'O2']

In [8]:
model = LaBraM(num_electrodes=len(electrode_names), electrodes=electrode_names)

# Load pre-trained weights (if available)
# model.load_state_dict(torch.load("path_to_pretrained_labram.pth"))

model.train()


LaBraM(
  (patch_embed): TemporalConv(
    (conv1): Conv2d(1, 8, kernel_size=(1, 15), stride=(1, 8), padding=(0, 7))
    (gelu1): GELU(approximate='none')
    (norm1): GroupNorm(4, 8, eps=1e-05, affine=True)
    (conv2): Conv2d(8, 8, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1))
    (gelu2): GELU(approximate='none')
    (norm2): GroupNorm(4, 8, eps=1e-05, affine=True)
    (conv3): Conv2d(8, 8, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1))
    (norm3): GroupNorm(4, 8, eps=1e-05, affine=True)
    (gelu3): GELU(approximate='none')
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=200, out_features=600, bias=False)
        (q_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
        (k_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
    

In [9]:
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Create a DataLoader
dataset = TensorDataset(eeg_tensor, labels_tensor)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    for batch_X, batch_y in train_loader:
        optimizer.zero_grad()
        outputs = model(batch_X, electrodes=electrode_names)  # Pass electrodes explicitly
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

ValueError: not enough values to unpack (expected 4, got 3)