In [None]:
import os
import sys
import torch
import wandb
import GPUtil
import torch.optim as optim
from models.estformer.ESTFormer import ESTFormer
from torchinfo import summary
from torch.utils.data import DataLoader

sys.path.append('../../')
from utils.hdf5_data_split_generator import HDF5DataSplitGenerator

In [3]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # Force CUDA to use the GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Use first GPU
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # Enable memory optimization settings for PyTorch

In [4]:
# Check if CUDA is available
try:
    gpus = GPUtil.getGPUs()
    if gpus:
        print(f"GPUtil detected {len(gpus)} GPUs:")
        for i, gpu in enumerate(gpus):
            print(f"  GPU {i}: {gpu.name} (Memory: {gpu.memoryTotal}MB)")
        
        # Set default GPU
        os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in range(len(gpus))])
        print(f"Set CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']}")
    else:
        print("GPUtil found no available GPUs")
except Exception as e:
    print(f"Error checking GPUs with GPUtil: {e}")

GPUtil detected 1 GPUs:
  GPU 0: NVIDIA GeForce RTX 3070 Laptop GPU (Memory: 8192.0MB)
Set CUDA_VISIBLE_DEVICES=0


In [5]:
# Check for CUDA availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Print available GPU memory
if torch.cuda.is_available():
    print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"Available GPU memory: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB")

Using device: cuda
Total GPU memory: 8.59 GB
Available GPU memory: 0.00 GB


In [6]:
all_channels = ['Fp1', 'AF7', 'AF3', 'F1', 'F3', 'F5', 'F7', 'FT7', 'FC5', 'FC3', 'FC1', 'C1', 'C3', 'C5', 'T7', 'TP7', 'CP5', 'CP3', 'CP1', 'P1', 'P3', 'P5', 'P7', 'P9', 'PO7', 'PO3', 'O1', 'Iz', 'Oz', 'POz', 'Pz', 'CPz', 'Fpz', 'Fp2', 'AF8', 'AF4', 'AFz', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT8', 'FC6', 'FC4', 'FC2', 'FCz', 'Cz', 'C2', 'C4', 'C6', 'T8', 'TP8', 'CP6', 'CP4', 'CP2', 'P2', 'P4', 'P6', 'P8', 'P10', 'PO8', 'PO4', 'O2']

# Model parameters
hr_channel_names = all_channels # High-resolution setup (all channels)
lr_channel_names = ['AF3', 'F7', 'F3', 'FC5', 'T7', 'P7', 'O1', 'O2', 'P8', 'T8', 'FC6', 'F4', 'F8', 'AF4'] # Low-resolution setup (fewer channels)
builtin_montage = 'standard_1020'
alpha_t = 0.60
alpha_s = 0.75
r_mlp = 4 # amplification factor for MLP layers
dropout_rate = 0.5
L_s = 1  # Number of spatial layers
L_t = 1  # Number of temporal layers

# Training parameters
epochs = 30

# Optimizer parameters
lr = 5e-5
weight_decay = 0.5
beta_1 = 0.9
beta_2 = 0.95

# Dataset parameters
batch_size = 30
dataset_split = "70/25/5"
eeg_epoch_mode = "fixed_length"
fixed_length_duration = 6
duration_before_onset = 0.05
duration_after_onset = 0.6
random_state = 97

In [7]:
# Create datasets
train_dataset = HDF5DataSplitGenerator(
    dataset_type="train",
    dataset_split=dataset_split,
    eeg_epoch_mode=eeg_epoch_mode,
    random_state=random_state,
    fixed_length_duration=fixed_length_duration,
    duration_before_onset=duration_before_onset,
    duration_after_onset=duration_after_onset,
    lr_channel_names=lr_channel_names,
    hr_channel_names=hr_channel_names
)

val_dataset = HDF5DataSplitGenerator(
    dataset_type="val",
    dataset_split=dataset_split,
    eeg_epoch_mode=eeg_epoch_mode,
    random_state=random_state,
    fixed_length_duration=fixed_length_duration,
    duration_before_onset=duration_before_onset,
    duration_after_onset=duration_after_onset,
    lr_channel_names=lr_channel_names,
    hr_channel_names=hr_channel_names
)

test_dataset = HDF5DataSplitGenerator(
    dataset_type="test",
    dataset_split=dataset_split,
    eeg_epoch_mode=eeg_epoch_mode,
    random_state=random_state,
    fixed_length_duration=fixed_length_duration,
    duration_before_onset=duration_before_onset,
    duration_after_onset=duration_after_onset,
    lr_channel_names=lr_channel_names,
    hr_channel_names=hr_channel_names
)

In [8]:
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

len(train_loader), len(val_loader), len(test_loader)

(162, 58, 12)

In [9]:
# Get sample data to determine time_steps
sample_item = train_loader.dataset[0]
time_steps = sample_item["lo_res"].shape[1]
sfreq = sample_item["sfreq"]

config = {
    "total_epochs_trained_on": epochs,
    "subject": "all",
    "scale_factor": len(hr_channel_names) / len(lr_channel_names),
    "time_steps_in_seconds": time_steps / sfreq,
    "is_parieto_occipital_exclusive": all(ch.startswith(('P', 'O')) for ch in lr_channel_names) and all(ch.startswith(('P', 'O')) for ch in hr_channel_names),
    "model_params": {
        "model": "ESTformer",
        "num_lr_channels": len(lr_channel_names),
        "num_hr_channels": len(hr_channel_names),
        "builtin_montage": builtin_montage,
        "alpha_s": alpha_s,
        "alpha_t": alpha_t,
        "r_mlp": r_mlp,
        "dropout_rate": dropout_rate,
        "L_s": L_s,
        "L_t": L_t,
    },
    "dataset_params": {
        "eeg_epoch_mode": eeg_epoch_mode,
        "dataset_split": dataset_split,
        "fixed_length_duration": fixed_length_duration,
        "duration_before_onset": duration_before_onset,
        "duration_after_onset": duration_after_onset,
        "batch_size": batch_size,
        "random_state": random_state
    },
    "optimizer_params": {
        "optimizer": "Adam",
        "learning_rate": lr,
        "weight_decay": weight_decay,
        "betas": (beta_1, beta_2)
    }
}

wandb.init(project="eeg-estformer", config=config)

[34m[1mwandb[0m: Currently logged in as: [33mdubs2310[0m ([33mdubs2310-cal-poly-pomona[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
model = ESTFormer(
    device=device, 
    lr_channel_names=lr_channel_names,
    hr_channel_names=hr_channel_names,
    builtin_montage=builtin_montage,
    time_steps=time_steps,
    alpha_t=alpha_t,
    alpha_s=alpha_s,
    r_mlp=r_mlp,
    dropout_rate=dropout_rate,
    L_s=L_s,
    L_t=L_t
)

summary(model)

  return t.to(


Layer (type:depth-idx)                                                                Param #
ESTFormer                                                                             1,845
├─SigmaParameters: 1-1                                                                2
├─SIM: 1-2                                                                            --
│    └─Linear: 2-1                                                                    5,669,685
│    └─LayerNorm: 2-2                                                                 3,690
│    └─CAB: 2-3                                                                       --
│    │    └─ModuleList: 3-1                                                           40,877,353
│    └─MaskTokensInsert: 2-4                                                          --
│    │    └─MaskTokenExpander: 3-2                                                    1,845
│    └─Linear: 2-5                                                                

In [11]:
# Create optimizer with both model and sigma parameters
optimizer = optim.Adam(
    params=[{'params': model.parameters()}], 
    lr=lr,
    weight_decay=weight_decay,
    betas=(beta_1, beta_2)
)

history = model.fit(
    epochs=epochs,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    checkpoint_dir='checkpoints'
)

Epoch 1/30:   0%|          | 0/162 [00:28<?, ?it/s]


KeyboardInterrupt: 

In [None]:
monitor_sigma_values_and_loss(history)
visualize_results(model, val_loader.dataset, device)