In [None]:
import sys
import wandb
import torch
sys.path.append('../../')
import pytorch_lightning as pl
from utils.data_modules.contrastive import EEGContrastiveDataModule
from models.trainers.contrastive import ContrastiveTrainerModel, PlottingCallback

In [None]:
dm = EEGContrastiveDataModule(
    input_channels=['Fp1', 'AF7', 'AF3', 'F1', 'F3', 'F5', 'F7', 'FT7', 'FC5', 'FC3', 'FC1', 'C1', 'C3', 'C5', 'T7', 'TP7', 'CP5', 'CP3', 'CP1', 'P1', 'P3', 'P5', 'P7', 'P9', 'PO7', 'PO3', 'O1', 'Iz', 'Oz', 'POz', 'Pz', 'CPz', 'Fpz', 'Fp2', 'AF8', 'AF4', 'AFz', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT8', 'FC6', 'FC4', 'FC2', 'FCz', 'Cz', 'C2', 'C4', 'C6', 'T8', 'TP8', 'CP6', 'CP4', 'CP2', 'P2', 'P4', 'P6', 'P8', 'P10', 'PO8', 'PO4', 'O2'],
    sfreq=250,
    montage='standard_1020',
    window_before_event_ms=50,
    window_after_event_ms=600,
    subject=1, 
    session=1, 
    batch_size=64, 
    num_workers=4,
    epochs_dir="/workspace/eeg-image-decoding/data/all-joined-1/eeg/lo-res-epochs",
    test='All'
)

In [None]:
sample_data = dm.get_sample_info()
epochs = 200
subject = 1
session = 1
num_channels = sample_data['input']['num_channels']
timesteps = sample_data['input']['num_timesteps']
num_fine_labels = sample_data['output']['fine_labels_shape']

In [None]:
checkpoint_path = "/workspace/eeg-image-decoding/code/models/check_points/contrastive_encoder/subj1_session1_epoch=199.ckpt"

In [None]:
lightning_model = ContrastiveTrainerModel.load_from_checkpoint(checkpoint_path, strict=False)

results_callback = PlottingCallback()

logger = pl.loggers.WandbLogger(project="lo_res_contrastive_eeg_net")

trainer = pl.Trainer(
    max_epochs=epochs,
    callbacks=[results_callback],
    logger=logger,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1
)

trainer.test(lightning_model, dm)

if logger and isinstance(logger, pl.loggers.WandbLogger):
    logger.finalize('success')
    wandb.finish()