# Training GREEN with `lightning`

The notebook is a simple example of how to train the GREEN model. It uses dummy data, `mne.Epochs`.

In [1]:
import mne
from green.data_utils import EpochsDataset
from green.wavelet_layers import RealCovariance
import torch

from research_code.pl_utils import get_green, GreenClassifierLM
from research_code.crossval_utils import pl_crossval
from tests.conftest import make_one_dummy_epoch

mne.set_log_level('ERROR')

In [2]:
n = 3
dataset = EpochsDataset(
    epochs=[make_one_dummy_epoch() for i in range(n)],
    targets=torch.Tensor([[0, 0]] * n).to(torch.float64),
    subjects=[f'subject_{i}' for i in range(n)],
    n_epochs=2,
)
dataset

EpochsDataset
len: 3
n_epochs/sample: 2
num_channels/sample: 3
sampling frequency: 100.0
epoch duration (s): 1.99
padding: repeat
shuffle: False
random_state: Generator(PCG64)
use age: None

In [3]:
model = get_green(
    n_freqs=2,
    kernel_width_s=.5,
    n_ch=3,
    sfreq=100,
    orth_weights=True,
    dropout=.5,
    hidden_dim=[8],
    logref='logeuclid',
    pool_layer=RealCovariance(),
    bi_out=[2],
    dtype=torch.float32,
    out_dim=2
)
model_pl = GreenClassifierLM(model=model,)
model_pl

GreenClassifierLM(
  (model): Green(
    (conv_layers): Sequential(
      (0): ComplexWavelet(kernel_width_s=0.5, sfreq=100, n_wavelets=2, stride=5, padding=0, scaling=oct)
    )
    (pooling_layers): RealCovariance()
    (spd_layers): Sequential(
      (0): LedoitWold(n_freqs=2, init_shrinkage=-3.0, learnable=True)
      (1): BiMap(d_in=3, d_out=2, n_freqs=2
    )
    (proj): LogEig(ref=logeuclid, reg=0.0001, n_freqs=2, size=2
    (head): Sequential(
      (0): BatchNorm1d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): Dropout(p=0.5, inplace=False)
      (2): Linear(in_features=6, out_features=8, bias=True)
      (3): GELU(approximate='none')
      (4): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): Dropout(p=0.5, inplace=False)
      (6): Linear(in_features=8, out_features=2, bias=True)
    )
  )
)

In [4]:
pl_crossval_output, _ = pl_crossval(
    model, 
    dataset=dataset,
    n_epochs=2,
    save_preds=True,
    ckpt_prefix='checkpoints/test',
    train_splits=[[0,1]],
    test_splits=[[2]],
    batch_size=64,
    pl_module=GreenClassifierLM,
    num_workers=0, 
)


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: checkpoints\test\fold0\lightning_logs
c:\Users\paillarj\AppData\Local\anaconda3\envs\riemann\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
c:\Users\paillarj\AppData\Local\anaconda3\envs\riemann\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Finding best initial lr:   0%|          | 0/20 [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.
LR finder stopped early after 2 steps due to diverging loss.
Failed to compute suggestion for learning rate because there are not enough points. Increase the loop iteration limits or the size of your dataset/dataloader.
Restoring states from the checkpoint path at checkpoints\test\fold0\.lr_find_b66bc6a4-f71d-4bfe-9f8d-b3567bf11141.ckpt
Restored all states from the checkpoint at checkpoints\test\fold0\.lr_find_b66bc6a4-f71d-4bfe-9f8d-b3567bf11141.ckpt


Restored all states from the checkpoint at checkpoints\test\fold0\.lr_find_b66bc6a4-f71d-4bfe-9f8d-b3567bf11141.ckpt


Output()

`Trainer.fit` stopped: `max_epochs=2` reached.


Output()

c:\Users\paillarj\AppData\Local\anaconda3\envs\riemann\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Output()

c:\Users\paillarj\AppData\Local\anaconda3\envs\riemann\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
