In [1]:
import sys
import wandb
import torch
sys.path.append('../../')
import pytorch_lightning as pl
from utils.data_modules.contrastive import EEGContrastiveDataModule
from models.trainers.contrastive import ContrastiveTrainerModel, PlottingCallback

/workspace/eeg-image-decoding/data/all-joined-1/eeg/epochs


In [6]:
dm = EEGContrastiveDataModule(
    input_channels=['Fp1', 'AF7', 'AF3', 'F1', 'F3', 'F5', 'F7', 'FT7', 'FC5', 'FC3', 'FC1', 'C1', 'C3', 'C5', 'T7', 'TP7', 'CP5', 'CP3', 'CP1', 'P1', 'P3', 'P5', 'P7', 'P9', 'PO7', 'PO3', 'O1', 'Iz', 'Oz', 'POz', 'Pz', 'CPz', 'Fpz', 'Fp2', 'AF8', 'AF4', 'AFz', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT8', 'FC6', 'FC4', 'FC2', 'FCz', 'Cz', 'C2', 'C4', 'C6', 'T8', 'TP8', 'CP6', 'CP4', 'CP2', 'P2', 'P4', 'P6', 'P8', 'P10', 'PO8', 'PO4', 'O2'],
    sfreq=250,
    montage='standard_1020',
    window_before_event_ms=50,
    window_after_event_ms=600,
    subject=1, 
    session=1, 
    batch_size=64, 
    num_workers=4,
    epochs_dir="/workspace/eeg-image-decoding/data/all-joined-1/eeg/super-res-epochs",
    test='All'
)

In [7]:
sample_data = dm.get_sample_info()
epochs = 200
subject = 1
session = 1
num_channels = sample_data['input']['num_channels']
timesteps = sample_data['input']['num_timesteps']
num_fine_labels = sample_data['output']['fine_labels_shape']

Using ALL data for subject=1, session=1 as TEST SET
Found 3839 total samples for test set
Data split - Train: 2879, Val: 320, Test: 3839
Creating Datasets...
Original dataframe size: 2879
Original dataframe size: 320
Original dataframe size: 3839


In [8]:
checkpoint_path = "/workspace/eeg-image-decoding/code/models/check_points/contrastive_encoder/subj1_session1_epoch=199.ckpt"

In [9]:
lightning_model = ContrastiveTrainerModel.load_from_checkpoint(checkpoint_path, strict=False)

results_callback = PlottingCallback()

logger = pl.loggers.WandbLogger(project="super_res_contrastive_eeg_net")

trainer = pl.Trainer(
    max_epochs=epochs,
    callbacks=[results_callback],
    logger=logger,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1
)

trainer.test(lightning_model, dm)

if logger and isinstance(logger, pl.loggers.WandbLogger):
    logger.finalize('success')
    wandb.finish()

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Skip loading parameter: train_image_features, required shape: torch.Size([4002, 1024]), loaded shape: torch.Size([4014, 1024])
Skip loading parameter: test_image_features, required shape: torch.Size([802, 1024]), loaded shape: torch.Size([790, 1024])


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Datasets already initialized, skipping setup


Testing: |          | 0/? [00:00<?, ?it/s]


=== TEST RESULTS ===
Test Precision: 0.0339
Test Recall: 0.7720
Test F1-Score: 0.0601
Test AUROC: 0.5016

Confusion matrix visualization completed!

DETAILED CLASSIFICATION REPORT (Top 20 Classes by Precision)
Class 49 - fork                     : P=0.510, R=0.800, F1=0.623, Acc=0.506, Support=1956
Class 18 - potted plant             : P=0.104, R=0.848, F1=0.185, Acc=0.227, Support=396
Class 27 - bicycle                  : P=0.096, R=0.789, F1=0.172, Acc=0.333, Support=336
Class 26 - bed                      : P=0.070, R=0.742, F1=0.127, Acc=0.331, Support=252
Class 22 - donut                    : P=0.069, R=0.779, F1=0.126, Acc=0.237, Support=272
Class 14 - orange                   : P=0.062, R=0.805, F1=0.115, Acc=0.235, Support=236
Class 74 - toaster                  : P=0.057, R=0.708, F1=0.106, Acc=0.326, Support=216
Class 13 - toilet                   : P=0.053, R=0.755, F1=0.099, Acc=0.226, Support=216
Class 36 - giraffe                  : P=0.047, R=0.784, F1=0.089, Acc=0.261,

0,1
epoch,▁
test_auroc,▁
test_f1,▁
test_precision,▁
test_recall,▁
trainer/global_step,▁▁▁▁▁

0,1
epoch,0.0
test_auroc,0.50161
test_f1,0.06014
test_precision,0.03387
test_recall,0.772
trainer/global_step,0.0
