In [1]:
import sys
import torch
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from ESTFormer import ESTFormer

sys.path.append('../../')
from utils.epoch_data_module import SuperResEpochDataModule

pl.seed_everything(42, workers=True)

Seed set to 42


42

In [2]:
# data module configuration
subject_session_id = 'subj01_session1'
resample_freq = 256

hi_res_channel_names = [
    'TP7', 'CP5', 'CP3', 'CP1', 'P1', 'P3', 'P5', 'P7', 'P9', 'PO7', 'PO3', 'O1', 
    'Iz', 'Oz', 'POz', 'Pz', 'CPz', 
    'TP8', 'CP6', 'CP4', 'CP2', 'P2', 'P4', 'P6', 'P8', 'P10', 'PO8', 'PO4', 'O2'
]

lo_res_channel_count = 14
hi_res_channel_count = len(hi_res_channel_names)

if lo_res_channel_count == 4:
    lo_res_channel_names = [
        'PO7',
        'POz', 'Oz',
        'PO8',
    ]

if lo_res_channel_count == 8:
    lo_res_channel_names = [
        'P7', 'PO7', 'O1',
        'POz', 'Oz',
        'P8', 'PO8', 'O2'
    ]

if lo_res_channel_count == 14:
    lo_res_channel_names = [
        'CP3', 'P3', 'P7', 'PO7', 'O1',
        'CPz', 'Pz', 'POz', 'Oz',
        'CP4', 'P4', 'P8', 'PO8', 'O2'
    ]


data_module = SuperResEpochDataModule(lo_res_channel_names=lo_res_channel_names, hi_res_channel_names=hi_res_channel_names, subject_session_id='subj01_session1', resample_freq=resample_freq, shuffle=True, num_workers=4, pin_memory=True)

Opening raw data file s:\PolySecLabProjects\eeg-image-decoding\code\utils\..\..\data\all-joined-1\eeg\preprocessed\ground-truth\subj01_session1_eeg.fif...
    Range : 1121 ... 1777926 =      2.189 ...  3472.512 secs
Ready.
3839 events found on stim channel Status
Event IDs: [  1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17  18
  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35  36
  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53  54
  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71  72
  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88  89  90
  91  92  93  94  95  96  97  98  99 100 101 102 103 104 105 106 107 108
 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179

In [3]:
# model configuration
time_steps = data_module.lo_res_data_reader[0][0].shape[1]
model = ESTFormer(lo_res_channel_names=lo_res_channel_names, hi_res_channel_names=hi_res_channel_names, time_steps=time_steps)

In [4]:
model_checkpoint = ModelCheckpoint(
    monitor='val_loss',
    dirpath='check_points/',
    mode='min',
    filename=f'{subject_session_id}_PO_650ms_{lo_res_channel_count}_{hi_res_channel_count}_{{epoch:02d}}_{{val_loss:.2f}}',
    auto_insert_metric_name=False,
)

lr_rate_monitor = LearningRateMonitor(logging_interval="epoch")

trainer = pl.Trainer(
    accelerator="auto", # Auto select the best hardware accelerator available
    devices="auto", # Auto select available devices for the accelerator (For eg. mutiple GPUs)
    strategy="auto", # Auto select the distributed training strategy.
    max_epochs=20, # Maximum number of epoch to train for.
    deterministic=True, # For deteministic and reproducible training.
    enable_model_summary=True,
    callbacks=[model_checkpoint, lr_rate_monitor],  # Declaring callbacks to use.
    precision="16", # Using Mixed Precision training.
    logger=True, # Auto generate TensorBoard logs.
)

s:\PolySecLabProjects\eeg-image-decoding\env\Lib\site-packages\lightning_fabric\connector.py:571: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
s:\PolySecLabProjects\eeg-image-decoding\env\Lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [5]:
trainer.fit(model, data_module)

You are using a CUDA device ('NVIDIA GeForce RTX 3070 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name               | Type                           | Params | Mode 
-------------------------------------------------------------------------------
0  | sigmas             | SigmaParameters                | 2      | train
1  | sim                | SIM                            | 324 K  | train
2  | trm                | TRM                            | 14.7 K | train
3  | norm               | LayerNorm                      | 9.8 K  | train
4  | mean_train_loss    | MeanMetric                     | 0      | train
5  | mean_val_loss      | MeanMetric                     |

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

ValueError: Expected both predictions and target to be either 1- or 2-dimensional tensors, but got 3 and 3.

In [None]:
model_checkpoint.best_model_path

'S:\\PolySecLabProjects\\eeg-image-decoding\\code\\lightning\\eegnet\\check_points\\subj01_session1_PO_650ms_29_14_0.43.ckpt'

In [None]:
pretrained = ESTFormer.load_from_checkpoint(model_checkpoint.best_model_path)

In [None]:
import numpy as np

loader = data_module.test_dataloader()

all_preds = []
all_targets = []

for batch in loader:
    lo_res, hi_res = batch
    
    lo_res = lo_res.cuda()
    super_res = pretrained.predict(lo_res)

    all_preds.append(super_res)
    all_targets.append(hi_res)

all_preds = np.concatenate(all_preds, axis=0)  
all_targets = np.concatenate(all_targets, axis=0)