In [1]:
import torch
import mne
import numpy as np
from torcheeg.models import LaBraM
import glob
import numpy as np
import pandas as pd


# Load metadata DataFrame
df_info = pd.read_pickle("data/FG_overview_df_v2.pkl")  # Update with actual path

# Extract unique experiment IDs
valid_experiments = set(df_info["Exp_id"].unique())

file_paths = glob.glob("data/*_FG_preprocessed-epo.fif")  # Update with actual data path
print(f"Found {len(file_paths)} EEG files.")
all_eeg_data, all_labels = [], []

for file_path in file_paths:
    # Extract filename (e.g., "301A")
    file_name = file_path.split("/")[-1].split("_")[0]  # Extract "301A" from "301A_FG_preprocessed-epo.fif"
    print(f"Processing {file_name}...")
    # Extract Experiment ID (e.g., "301")
    exp_id = file_name[:4]
    
    # Extract Subject ID (e.g., "301A")
    subject_id = file_name  # Full subject identifier

    # Check if this experiment is valid
    if exp_id not in valid_experiments:
        print(f"Skipping {file_name}: Experiment {exp_id} not in metadata.")
        continue  # Ignore invalid experiments

    # Check if this specific subject is a part of the experiment
    if subject_id not in df_info["Exp_id"].values:
        print(f"Skipping {file_name}: Subject {subject_id} was not part of {exp_id}.")
        continue

    # Load EEG file
    epochs = mne.read_epochs(file_path, preload=True)
    eeg_data = epochs.get_data()  # (n_epochs, n_channels, n_times)
    labels = epochs.events[:, -1]

    # Normalize per file
    eeg_data = (eeg_data - eeg_data.mean()) / eeg_data.std()

    # Convert labels to binary (feedback vs. no feedback)
    binary_labels = np.array([1 if label in {301, 303, 305, 307, 309} else 0 for label in labels])

    all_eeg_data.append(eeg_data)
    all_labels.append(binary_labels)
    print("eeg_data_shape = ", eeg_data.shape)

# Convert to PyTorch tensors
eeg_tensor = torch.tensor(np.concatenate(all_eeg_data, axis=0), dtype=torch.float32)
labels_tensor = torch.tensor(np.concatenate(all_labels, axis=0), dtype=torch.long)

print(f"Final EEG Tensor Shape: {eeg_tensor.shape}")  


KeyboardInterrupt: 

In [1]:
import torch
import mne
import numpy as np
from torcheeg.models import LaBraM
import glob
import pandas as pd

# Load metadata DataFrame
df_info = pd.read_pickle("data/FG_overview_df_v2.pkl")  # Update with actual path

# Define event IDs
event_labels = {'T1P': 301, 'T1Pn': 302, 'T3P': 303, 'T3Pn': 304,
                'T12P': 305, 'T12Pn': 306, 'T13P': 307, 'T13Pn': 308,
                'T23P': 309, 'T23Pn': 310}

file_paths = glob.glob("data/*_FG_preprocessed-epo.fif")  # Update with actual data path
print(f"Found {len(file_paths)} EEG files.")

all_eeg_data, all_labels = [], []

# Create a mapping: Exp_id → (Set of valid subject IDs)
experiment_participants = df_info.groupby("Exp_id")["Subject_id"].apply(set).to_dict()

for file_path in file_paths:
    # Extract filename (e.g., "301A")
    file_name = file_path.split("/")[-1].split("_")[0]  # Extract "301A"

    # Extract Experiment ID (e.g., "301")
    exp_id = file_name[:3]

    # Extract Subject ID (e.g., "301A" → 1049 from metadata)
    subject_entry = df_info[df_info["Exp_id"] == file_name]  # Match full "301A"
    if subject_entry.empty:
        print(f"Skipping {file_name}: Subject {file_name} not found in metadata.")
        continue

    subject_id = subject_entry["Subject_id"].values[0]

    # Check if subject is part of this experiment
    if subject_id not in experiment_participants.get(exp_id, set()):
        print(f"Skipping {file_name}: Subject {subject_id} was not part of {exp_id}.")
        continue

    # Get only the event types this subject participated in within THIS experiment
    subject_events = df_info[(df_info["Subject_id"] == subject_id) & (df_info["Exp_id"] == exp_id)]
    
    # Extract the relevant event labels for this subject
    valid_events = {event_labels[e] for e in event_labels if e in subject_events["Class_friends"].values}

    # Load EEG file
    epochs = mne.read_epochs(file_path, preload=True)
    eeg_data = epochs.get_data()  # Shape: (n_epochs, n_channels, n_times)
    labels = epochs.events[:, -1]  # Extract event labels

    # Filter out trials that are not relevant for this subject
    valid_trials = [i for i, label in enumerate(labels) if label in valid_events]

    if len(valid_trials) == 0:
        print(f"Skipping {file_name}: No relevant trials for subject {subject_id}.")
        continue  # Skip files where all trials are invalid

    # Keep only the valid trials
    eeg_data = eeg_data[valid_trials]  
    labels = labels[valid_trials]  

    # Normalize per file
    eeg_data = (eeg_data - eeg_data.mean()) / eeg_data.std()

    # Convert labels to binary classification (feedback vs. no feedback)
    binary_labels = np.array([1 if label in {301, 303, 305, 307, 309} else 0 for label in labels])

    all_eeg_data.append(eeg_data)
    all_labels.append(binary_labels)

# Convert to PyTorch tensors
eeg_tensor = torch.tensor(np.concatenate(all_eeg_data, axis=0), dtype=torch.float32)
labels_tensor = torch.tensor(np.concatenate(all_labels, axis=0), dtype=torch.long)

print(f"Final EEG Tensor Shape: {eeg_tensor.shape}")  


Found 15 EEG files.
Skipping 303A: Subject 1045 was not part of 303.
Skipping 301C: Subject 1028 was not part of 301.
Skipping 305C: Subject 2034 was not part of 305.
Skipping 302B: Subject 1024 was not part of 302.
Skipping 302A: Subject 1064 was not part of 302.
Skipping 303B: Subject 1041 was not part of 303.
Skipping 304C: Subject 1046 was not part of 304.
Skipping 302C: Subject 1036 was not part of 302.
Skipping 305B: Subject 2012 was not part of 305.
Skipping 301B: Subject 1029 was not part of 301.
Skipping 304A: Subject 1034 was not part of 304.
Skipping 304B: Subject 1024 was not part of 304.
Skipping 303C: Subject 1009 was not part of 303.
Skipping 305A: Subject 2041 was not part of 305.
Skipping 301A: Subject 1049 was not part of 301.


ValueError: need at least one array to concatenate

In [None]:
df_info["Exp_id"].values

array(['301A', '301B', '301C', '302A', '302B', '302C', '303A', '303B',
       '303C', '304A', '304B', '304C', '305A', '305B', '305C', '306A',
       '306B', '306C', '307A', '307B', '307C', '308A', '308B', '308C',
       '309A', '309B', '309C', '310A', '310B', '310C', '311A', '311B',
       '311C', '312A', '312B', '312C', '313A', '313B', '313C', '314A',
       '314B', '314C', '315A', '315B', '315C', '316A', '316B', '316C',
       '317A', '317B', '317C', '318A', '318B', '318C', '319A', '319B',
       '320A', '320B', '320C', '321A', '321B', '321C', '322A', '322B',
       '322C', '323A', '323B', '323C', '324A', '324B', '324C', '325A',
       '325B', '325C', '326A', '326B', '326C', '327A', '327B', '327C',
       '328A', '328B', '328C', '329A', '329B', '329C', '330A', '330B',
       '330C', '331A', '331B', '331C'], dtype=object)

In [None]:
4443 / 300


14.81

In [None]:
print(f"EEG Data Shape: {eeg_tensor.shape}")  
print(f"Number of Unique Labels: {len(set(labels_tensor.numpy()))}")  
print(f"Total Files Processed: {len(file_path)}")

EEG Data Shape: torch.Size([300, 64, 3000])
Number of Unique Labels: 2
Total Files Processed: 33


In [None]:
epochs

Unnamed: 0,General,General.1
,Filename(s),301A_FG_preprocessed-epo.fif
,MNE object type,EpochsFIF
,Measurement date,2023-10-11 at 14:54:47 UTC
,Participant,
,Experimenter,Unknown
,Acquisition,Acquisition
,Total number of events,300
,Events counts,T12P: 30  T12Pn: 30  T13P: 30  T13Pn: 30  T1P: 30  T1Pn: 30  T23P: 30  T23Pn: 30  T3P: 30  T3Pn: 30
,Time range,-0.500 – 5.498 s
,Baseline,off


In [None]:
import torch
import mne
import numpy as np
from torcheeg.models import LaBraM

# Load EEG data
file_path = "data/301A_FG_preprocessed-epo.fif"  # Update with your actual file
epochs = mne.read_epochs(file_path, preload=True)

# Extract EEG data and labels
eeg_data = epochs.get_data()  # Shape: (n_epochs, n_channels, n_times)
labels = epochs.events[:, -1]  # Assuming the last column contains event labels

# Normalize EEG data
eeg_data = (eeg_data - eeg_data.mean()) / eeg_data.std()

# Convert to PyTorch tensors
eeg_tensor = torch.tensor(eeg_data, dtype=torch.float32)

no_feedback_labels = {301, 303, 305, 307, 309}  # Feedback events
binary_labels = np.array([0 if label in no_feedback_labels else 1 for label in labels])

labels_tensor = torch.tensor(binary_labels, dtype=torch.long)

# Extract actual electrode names
electrode_names = [ch.upper() for ch in epochs.ch_names]  # Convert to uppercase

PATCH_SIZE = 500  

# Compute number of patches
num_time_steps = eeg_tensor.shape[2]
num_patches = num_time_steps // PATCH_SIZE  # Truncate remainder

# Reshape data into patches
eeg_tensor = eeg_tensor[:, :, :num_patches * PATCH_SIZE]  # Trim to fit patches
eeg_tensor = eeg_tensor.view(eeg_tensor.shape[0], eeg_tensor.shape[1], num_patches, PATCH_SIZE)

print(f"New EEG Tensor Shape: {eeg_tensor.shape}")  # Should be (batch_size, num_channels, num_patches, patch_size)

New EEG Tensor Shape: torch.Size([300, 64, 6, 500])


In [None]:
eeg_tensor.shape

torch.Size([300, 64, 6, 500])

In [None]:
epochs

Unnamed: 0,General,General.1
,Filename(s),301A_FG_preprocessed-epo.fif
,MNE object type,EpochsFIF
,Measurement date,2023-10-11 at 14:54:47 UTC
,Participant,
,Experimenter,Unknown
,Acquisition,Acquisition
,Total number of events,300
,Events counts,T12P: 30  T12Pn: 30  T13P: 30  T13Pn: 30  T1P: 30  T1Pn: 30  T23P: 30  T23Pn: 30  T3P: 30  T3Pn: 30
,Time range,-0.500 – 5.498 s
,Baseline,off


In [None]:
print(epochs.event_id)


{'T1P': 301, 'T1Pn': 302, 'T3P': 303, 'T3Pn': 304, 'T12P': 305, 'T12Pn': 306, 'T13P': 307, 'T13Pn': 308, 'T23P': 309, 'T23Pn': 310}


In [None]:
model = LaBraM(num_electrodes=len(electrode_names), electrodes=electrode_names)

# Load pre-trained weights (if available)
# model.load_state_dict(torch.load("path_to_pretrained_labram.pth"))

model.train()


LaBraM(
  (patch_embed): TemporalConv(
    (conv1): Conv2d(1, 8, kernel_size=(1, 15), stride=(1, 8), padding=(0, 7))
    (gelu1): GELU(approximate='none')
    (norm1): GroupNorm(4, 8, eps=1e-05, affine=True)
    (conv2): Conv2d(8, 8, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1))
    (gelu2): GELU(approximate='none')
    (norm2): GroupNorm(4, 8, eps=1e-05, affine=True)
    (conv3): Conv2d(8, 8, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1))
    (norm3): GroupNorm(4, 8, eps=1e-05, affine=True)
    (gelu3): GELU(approximate='none')
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((200,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=200, out_features=600, bias=False)
        (q_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
        (k_norm): LayerNorm((20,), eps=1e-06, elementwise_affine=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
    

In [None]:
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Create DataLoader
dataset = TensorDataset(eeg_tensor, labels_tensor)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    for batch_X, batch_y in train_loader:
        optimizer.zero_grad()
        outputs = model(batch_X, electrodes=electrode_names)  # Pass electrodes explicitly
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

Epoch [1/10], Loss: 4.2277
Epoch [2/10], Loss: 2.5846
Epoch [3/10], Loss: 2.3767


KeyboardInterrupt: 