In [2]:
import sys
import torch
import wandb
sys.path.append('../../')
import pytorch_lightning as pl
from utils.data_modules.super_resolution import EEGSuperResolutionDataModule
from models.trainers.super_resolution import SuperResolutionTrainerModel, SaveSuperResEpochsCallbackMinimal

/workspace/eeg-image-decoding/data/all-joined-1/eeg/epochs


In [3]:
dm = EEGSuperResolutionDataModule(
    input_channels=['Fp1', 'Fp2', 'AF3', 'AF4', 'F7', 'F3', 'Fz', 'F4', 'F8', 'FT7', 'FC3', 'FCz', 'FC4', 'FT8', 'T7', 'C3', 'Cz', 'C4', 'T8', 'TP7', 'CP3', 'CPz', 'CP4', 'TP8', 'P7', 'P3', 'Pz', 'P4', 'P8', 'O1', 'Oz', 'O2'],
    output_channels=['Fp1', 'AF7', 'AF3', 'F1', 'F3', 'F5', 'F7', 'FT7', 'FC5', 'FC3', 'FC1', 'C1', 'C3', 'C5', 'T7', 'TP7', 'CP5', 'CP3', 'CP1', 'P1', 'P3', 'P5', 'P7', 'P9', 'PO7', 'PO3', 'O1', 'Iz', 'Oz', 'POz', 'Pz', 'CPz', 'Fpz', 'Fp2', 'AF8', 'AF4', 'AFz', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT8', 'FC6', 'FC4', 'FC2', 'FCz', 'Cz', 'C2', 'C4', 'C6', 'T8', 'TP8', 'CP6', 'CP4', 'CP2', 'P2', 'P4', 'P6', 'P8', 'P10', 'PO8', 'PO4', 'O2'],
    sfreq=250,
    montage='standard_1020',
    window_before_event_ms=50,
    window_after_event_ms=600,
    subject=1, 
    session=1, 
    batch_size=64, 
    num_workers=4,
    test='All'
)

open_clip_model.safetensors:   0%|          | 0.00/3.94G [00:00<?, ?B/s]

In [4]:
sample_data = dm.get_sample_info()
epochs = 100
subject = 1
session = 1
input_channels = sample_data['input']['channel_names']
output_channels = sample_data['output']['channel_names']
timesteps = sample_data['input']['num_timesteps']

Using ALL data for subject=1, session=1 as TEST SET
Found 3839 total samples for test set
Data split - Train: 2879, Val: 320, Test: 3839
Creating Datasets...
Original dataframe size: 2879
Original dataframe size: 320
Original dataframe size: 3839


In [5]:
checkpoint_path = "/workspace/eeg-image-decoding/code/models/check_points/super-resolution/subj1_session1_epoch=99.ckpt"

In [6]:
lightning_model = SuperResolutionTrainerModel.load_from_checkpoint(checkpoint_path)

save_callback = SaveSuperResEpochsCallbackMinimal(
    save_dir="/workspace/eeg-image-decoding/data/all-joined-1/eeg/super-res-epochs/650ms-250Hz",
    filename=f"subj{subject}_session{session}_epochs"
)

trainer = pl.Trainer(
    max_epochs=epochs,
    callbacks=[save_callback],
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1
)

trainer.test(lightning_model, dm)

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/venv/main/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
You are using a CUDA device ('NVIDIA H100 NVL') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium

Datasets already initialized, skipping setup
Collecting super-resolution predictions to save as subj1_session1_epochs.npy...


Testing: |          | 0/? [00:00<?, ?it/s]


=== NMSE ANALYSIS ===
NMSE (variance): 0.527355
NMSE (mean square): 0.527331
Average MSE: 0.256146
Average Target Variance: 0.485718
Average Target Mean Square: 0.485741


=== SUPER RESOLUTION TEST RESULTS ===
Test Loss: 17.708557
Test SNR: 2.7417 dB
Test MAE: 0.357683
Test MSE: 0.256158
Test NMSE: 0.527355
Test Pearson Correlation: 0.7655

Saved 3839 super-resolution epochs to: /workspace/eeg-image-decoding/data/all-joined-1/eeg/super-res-epochs/650ms-250Hz/subj1_session1_epochs.npy
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss            17.70855712890625
        test_mae            0.3576827943325043
        test_mse            0.2561577260494232
        test_nmse           0.5273549556732178
      test_pearson          0.765539

[{'test_loss': 17.70855712890625,
  'test_snr': 2.7416789531707764,
  'test_mae': 0.3576827943325043,
  'test_mse': 0.2561577260494232,
  'test_nmse': 0.5273549556732178,
  'test_pearson': 0.7655397653579712}]