In [1]:
import os
import sys
import torch
import wandb
import GPUtil
import torch.optim as optim
from torchinfo import summary
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from ESTFormer import ESTFormer

sys.path.append('../../')
from utils.epoch_data_reader import EpochDataReader

In [2]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # Force CUDA to use the GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Use first GPU
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # Enable memory optimization settings for PyTorch

In [3]:
# Check if CUDA is available
try:
    gpus = GPUtil.getGPUs()
    if gpus:
        print(f"GPUtil detected {len(gpus)} GPUs:")
        for i, gpu in enumerate(gpus):
            print(f"  GPU {i}: {gpu.name} (Memory: {gpu.memoryTotal}MB)")
        
        # Set default GPU
        os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in range(len(gpus))])
        print(f"Set CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']}")
    else:
        print("GPUtil found no available GPUs")
except Exception as e:
    print(f"Error checking GPUs with GPUtil: {e}")

GPUtil detected 1 GPUs:
  GPU 0: NVIDIA GeForce RTX 3070 Laptop GPU (Memory: 8192.0MB)
Set CUDA_VISIBLE_DEVICES=0


In [4]:
# Check for CUDA availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Print available GPU memory
if torch.cuda.is_available():
    print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"Available GPU memory: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB")

Using device: cuda
Total GPU memory: 8.59 GB
Available GPU memory: 0.00 GB


In [5]:
all_channels = ['Fp1', 'AF7', 'AF3', 'F1', 'F3', 'F5', 'F7', 'FT7', 'FC5', 'FC3', 'FC1', 'C1', 'C3', 'C5', 'T7', 'TP7', 'CP5', 'CP3', 'CP1', 'P1', 'P3', 'P5', 'P7', 'P9', 'PO7', 'PO3', 'O1', 'Iz', 'Oz', 'POz', 'Pz', 'CPz', 'Fpz', 'Fp2', 'AF8', 'AF4', 'AFz', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT8', 'FC6', 'FC4', 'FC2', 'FCz', 'Cz', 'C2', 'C4', 'C6', 'T8', 'TP8', 'CP6', 'CP4', 'CP2', 'P2', 'P4', 'P6', 'P8', 'P10', 'PO8', 'PO4', 'O2']

hr_channel_names = [
    'Fp1', 'AF7', 'AF3', 'F1', 'F3', 'F5', 'F7', 'FT7', 'FC5', 'FC3', 'FC1', 
    'C1', 'C3', 'C5', 'T7', 'TP7', 'CP5', 'CP3', 'CP1', 'P1', 'P3', 'P5', 'P7', 
    'P9', 'PO7', 'PO3', 'O1', 'Iz', 'Oz', 'POz', 'Pz', 'CPz', 'Fpz', 'Fp2', 
    'AF8', 'AF4', 'AFz', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT8', 'FC6', 'FC4', 
    'FC2', 'FCz', 'Cz', 'C2', 'C4', 'C6', 'T8', 'TP8', 'CP6', 'CP4', 'CP2', 
    'P2', 'P4', 'P6', 'P8', 'P10', 'PO8', 'PO4', 'O2'
]

# Select 32 channels for the downsampled (low-resolution) set
# This selection preserves the overall spatial coverage while reducing density
lr_channel_names = [
    'Fp1', 'AF3', 'F3', 'F7', 'FC3', 'C1', 'C5', 'T7', 
    'CP3', 'P1', 'P7', 'PO7', 'O1', 'Oz', 'Pz', 'CPz',
    'Fp2', 'AF4', 'F4', 'F8', 'FC4', 'C2', 'C6', 'T8',
    'CP4', 'P2', 'P8', 'PO8', 'O2', 'POz', 'Fz', 'Cz'
]

# Model parameters
# hr_channel_names = all_channels # High-resolution setup (all channels)
# lr_channel_names = ['AF3', 'F7', 'F3', 'FC5', 'T7', 'P7', 'O1', 'O2', 'P8', 'T8', 'FC6', 'F4', 'F8', 'AF4'] # Low-resolution setup (fewer channels)
builtin_montage = 'standard_1020'
alpha_t = 0.60
alpha_s = 0.75
r_mlp = 4 # amplification factor for MLP layers
dropout_rate = 0.5
L_s = 1  # Number of spatial layers
L_t = 1  # Number of temporal layers

# Training parameters
epochs = 1

# Optimizer parameters
lr = 5e-5
weight_decay = 0.5
beta_1 = 0.9
beta_2 = 0.95

# Dataset parameters
# split = "70/25/5"
# epoch_type = "around_evoked"
# before = 0.05
# after = 0.6
# random_state = 97

# Data Loader parameter
batch_size = 30

In [6]:
# Create datasets
lo_res_dataset = EpochDataReader(
    channel_names=lr_channel_names
)

hi_res_dataset = EpochDataReader(
    channel_names=hr_channel_names
)

Creating new group: cross/ground-truth/AF3-AF4-C1-C2-C5-C6-CP3-CP4-CPz-Cz-F3-F4-F7-F8-FC3-FC4-Fp1-Fp2-Fz-O1-O2-Oz-P1-P2-P7-P8-PO7-PO8-POz-Pz-T7-T8/512/around_evoked/0.65/70_25_5/97
Opening raw data file s:\PolySecLabProjects\eeg-image-decode\code\utils\..\..\data\all-joined-1\eeg\preprocessed\ground-truth\subj01_session1_eeg.fif...
    Range : 1121 ... 1777926 =      2.189 ...  3472.512 secs
Ready.
Opening raw data file s:\PolySecLabProjects\eeg-image-decode\code\utils\..\..\data\all-joined-1\eeg\preprocessed\ground-truth\subj01_session1_eeg.fif...
    Range : 1121 ... 1777926 =      2.189 ...  3472.512 secs
Ready.
3839 events found on stim channel Status
Event IDs: [  1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17  18
  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35  36
  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53  54
  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71  72
  73  74  75  76  77  78  79  80 

KeyboardInterrupt: 

In [None]:
lo_res_loader = DataLoader(
    lo_res_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

hi_res_loader = DataLoader(
    hi_res_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

len(lo_res_loader), len(hi_res_loader)

(1533, 1533)

In [None]:
# Get sample data to determine time_steps
sample_item = lo_res_dataset[0][0] if lo_res_dataset.epoch_type == 'around_evoked' else lo_res_dataset[0]
time_steps = sample_item.shape[1]
sfreq = lo_res_dataset.resample_freq

config = {
    "total_epochs_trained_on": epochs,
    "scale_factor": len(hr_channel_names) / len(lr_channel_names),
    "time_steps_in_seconds": time_steps / sfreq,
    "is_parieto_occipital_exclusive": all(ch.startswith('P') or ch.startswith('O') or ch.startswith('PO') or ch.startswith('CP') for ch in lr_channel_names) and all(ch.startswith('P') or ch.startswith('O') or ch.startswith('PO') or ch.startswith('CP') for ch in hr_channel_names),
    "model_params": {
        "model": "ESTformer",
        "num_lr_channels": len(lr_channel_names),
        "num_hr_channels": len(hr_channel_names),
        "builtin_montage": builtin_montage,
        "alpha_s": alpha_s,
        "alpha_t": alpha_t,
        "r_mlp": r_mlp,
        "dropout_rate": dropout_rate,
        "L_s": L_s,
        "L_t": L_t,
    },
    "dataset_params": {
        "subject_session_id": lo_res_dataset.subject_session_id,
        "epoch_type": lo_res_dataset.epoch_type,
        "split": lo_res_dataset.split,
        "duration": str((lo_res_dataset.before + lo_res_dataset.after) * 1000) + 'ms' if lo_res_dataset.epoch_type == 'around_evoked' else lo_res_dataset.fixed_length_duration,
        "batch_size": batch_size,
        "random_state": lo_res_dataset.random_state
    },
    "optimizer_params": {
        "optimizer": "Adam",
        "learning_rate": lr,
        "weight_decay": weight_decay,
        "betas": (beta_1, beta_2)
    }
}

[34m[1mwandb[0m: Currently logged in as: [33mdubs2310[0m ([33mdubs2310-cal-poly-pomona[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




In [None]:
model = ESTFormer(
    device=device, 
    lr_channel_names=lr_channel_names,
    hr_channel_names=hr_channel_names,
    builtin_montage=builtin_montage,
    time_steps=time_steps,
    alpha_t=alpha_t,
    alpha_s=alpha_s,
    r_mlp=r_mlp,
    dropout_rate=dropout_rate,
    L_s=L_s,
    L_t=L_t
)

summary(model)

  return t.to(


Layer (type:depth-idx)                                                                Param #
ESTFormer                                                                             204
├─SigmaParameters: 1-1                                                                2
├─SIM: 1-2                                                                            --
│    └─Linear: 2-1                                                                    69,156
│    └─LayerNorm: 2-2                                                                 408
│    └─CAB: 2-3                                                                       --
│    │    └─ModuleList: 3-1                                                           507,112
│    └─MaskTokensInsert: 2-4                                                          --
│    │    └─MaskTokenExpander: 3-2                                                    204
│    └─Linear: 2-5                                                                    41,820
│

In [None]:
# Create optimizer with both model and sigma parameters
optimizer = optim.Adam(
    params=[{'params': model.parameters()}], 
    lr=lr,
    weight_decay=weight_decay,
    betas=(beta_1, beta_2)
)

with wandb.init(project="eeg-estformer", config=config) as run:
    history = model.fit(
        epochs=epochs,
        lo_res_loader=lo_res_loader,
        hi_res_loader=hi_res_loader,
        optimizer=optimizer,
        checkpoint_dir='checkpoints',
        identifier='test'
    )


Epoch 1/1: 100%|██████████| 1073/1073 [01:27<00:00, 12.32it/s, loss=233.5678]
                                                                                     

Epoch 1/1, train_loss: 290.8506val_loss: 246.7821sigma1: 1.0507, sigma2: 0.9466
Saved best model checkpoint to checkpoints\estformer_test_best.pt


In [None]:
# average_test_results = model.predict(test_loader)
# print("Average Results on Test Set: ", average_test_results)

NameError: name 'test_loader' is not defined

In [None]:
# def monitor_sigma_values_and_loss(history):
#     """
#     Monitor the values of sigma1 and sigma2 during training.
    
#     Args:
#         history: Training history dictionary
#     """
#     # Get the values of sigma1 and sigma2
#     sigma1_values = history['sigma1']
#     sigma2_values = history['sigma2']
    
#     print(f"Final sigma1 value: {sigma1_values[-1]}")
#     print(f"Final sigma2 value: {sigma2_values[-1]}")
    
#     # Plot the loss history
#     plt.figure(figsize=(12, 8))
    
#     # Plot loss
#     plt.subplot(2, 2, 1)
#     plt.plot(history['train_loss'], label='Training Loss')
#     plt.plot(history['val_loss'], label='Validation Loss')
#     plt.title('Model Loss')
#     plt.xlabel('Epoch')
#     plt.ylabel('Loss')
#     plt.legend()
    
#     # Plot MAE
#     plt.subplot(2, 2, 2)
#     plt.plot(history['train_mae'], label='Training MAE')
#     plt.plot(history['val_mae'], label='Validation MAE')
#     plt.title('Model MAE')
#     plt.xlabel('Epoch')
#     plt.ylabel('MAE')
#     plt.legend()

#     # Plot NMSE
#     plt.subplot(2, 2, 3)
#     plt.plot(history['train_nmse'], label='Training NMSE')
#     plt.plot(history['val_nmse'], label='Validation NMSE')
#     plt.title('Model NMSE')
#     plt.xlabel('Epoch')
#     plt.ylabel('NMSE')
#     plt.legend()

#     # Plot SNR
#     plt.subplot(2, 2, 4)
#     plt.plot(history['train_snr'], label='Training SNR')
#     plt.plot(history['val_snr'], label='Validation SNR')
#     plt.title('Model SNR')
#     plt.xlabel('Epoch')
#     plt.ylabel('SNR')
#     plt.legend()
    
#     # Plot PCC
#     plt.subplot(2, 2, 5)
#     plt.plot(history['train_pcc'], label='Training PCC')
#     plt.plot(history['val_pcc'], label='Validation PCC')
#     plt.title('Model PCC')
#     plt.xlabel('Epoch')
#     plt.ylabel('PCC')
#     plt.legend()
    
#     # Plot sigma values
#     plt.subplot(2, 2, 3)
#     plt.plot(sigma1_values, label='Sigma1')
#     plt.title('Sigma1 Value')
#     plt.xlabel('Epoch')
#     plt.ylabel('Value')
    
#     plt.subplot(2, 2, 4)
#     plt.plot(sigma2_values, label='Sigma2')
#     plt.title('Sigma2 Value')
#     plt.xlabel('Epoch')
#     plt.ylabel('Value')
    
#     plt.tight_layout()
    
#     # Save figure to wandb
#     if wandb.run is not None:
#         wandb.log({"training_history": wandb.Image(plt)})
    
#     plt.show()

# monitor_sigma_values_and_loss(history)

In [None]:
# def visualize_results(model, val_dataset, device, subject_idx=0, channel_idx=0):
#     """
#     Visualize the results of the model on a validation sample.
    
#     Args:
#         model: Trained ESTformer model
#         val_dataset: Validation dataset
#         device: Device to run inference on
#         subject_idx: Index of the subject to visualize
#         channel_idx: Index of the channel to visualize
#     """
#     # Set model to eval mode
#     model.eval()
    
#     # Get a validation sample
#     sample = val_dataset[subject_idx]
    
#     # Convert to tensors and add batch dimension
#     lo_res = torch.tensor(sample['lo_res'], dtype=torch.float32).unsqueeze(0).to(device)
#     hi_res = torch.tensor(sample['hi_res'], dtype=torch.float32)
    
#     # Get predictions
#     with torch.no_grad():
#         pred = model(lo_res).cpu().numpy()[0]
    
#     # Convert back to numpy for visualization
#     lo_res = lo_res.cpu().numpy()[0]
#     hi_res = hi_res.numpy()
    
#     # Plot the results
#     plt.figure(figsize=(12, 8))
    
#     # Plot low-res input
#     plt.subplot(3, 1, 1)
#     plt.plot(lo_res[channel_idx])
#     plt.title(f'Low-Res (Downsampled) Input (Channel {channel_idx})')
    
#     # Plot high-res ground truth
#     plt.subplot(3, 1, 2)
#     plt.plot(hi_res[channel_idx])
#     plt.title(f'High-Res (Ground Truth) (Channel {channel_idx})')
    
#     # Plot prediction
#     plt.subplot(3, 1, 3)
#     plt.plot(pred[channel_idx])
#     plt.title(f'Super-Res (Prediction) (Channel {channel_idx})')
    
#     plt.tight_layout()
    
#     # Save figure to wandb
#     if wandb.run is not None:
#         wandb.log({"prediction_visualization": wandb.Image(plt)})
    
#     plt.show()

# visualize_results(model, val_loader.dataset, device)