# Training a GREEN submodel without learning wavelets

This notebook provides a minimal example of training a Green submodel without learning wavelets. The difference with the full model is that inputs are covariance matrices instead of raw EEG data. 

In [1]:
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
import torch

from research_code.pl_utils import get_green_g2, GreenClassifierLM
from research_code.crossval_utils import pl_crossval

In [2]:
# Create a dummy dataset
n = 50 # subjects
f = 4 # filterbank size
c = 3 # channels 

U = np.exp(np.random.randn(n, f, c))
diag = U[..., np.newaxis] * np.eye(c)
V = np.random.randn(n, f, c, c)
spd = V @ diag @ np.transpose(V, (0, 1, 3, 2))

X = torch.Tensor(spd).to(torch.float32)
y = torch.Tensor(np.random.randint(2, size=(n, 2))).to(torch.float32)

dataset = TensorDataset(X, y)

In [3]:
model = get_green_g2(
    n_ch=c,
    n_freqs=f,
    orth_weights=True,
    dropout=.5,
    hidden_dim=[8],
    logref='logeuclid',
    bi_out=[2],
    dtype=torch.float32,
    out_dim=2
)
model_pl = GreenClassifierLM(model=model,)
model_pl

GreenClassifierLM(
  (model): GreenG2(
    (spd_layers): Sequential(
      (0): LedoitWold(n_freqs=4, init_shrinkage=-3.0, learnable=True)
      (1): BiMap(d_in=3, d_out=2, n_freqs=4
    )
    (proj): LogEig(ref=logeuclid, reg=0.0001, n_freqs=4, size=2
    (head): Sequential(
      (0): BatchNorm1d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): Dropout(p=0.5, inplace=False)
      (2): Linear(in_features=12, out_features=8, bias=True)
      (3): GELU(approximate='none')
      (4): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): Dropout(p=0.5, inplace=False)
      (6): Linear(in_features=8, out_features=2, bias=True)
    )
  )
)

In [4]:
pl_crossval_output, _ = pl_crossval(
    model, 
    dataset=dataset,
    n_epochs=2,
    save_preds=True,
    ckpt_prefix='checkpoints/test',
    train_splits=[[0,1]],
    test_splits=[[2]],
    batch_size=4,
    pl_module=GreenClassifierLM,
    num_workers=0, 
)


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
c:\Users\paillarj\AppData\Local\anaconda3\envs\riemann\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
c:\Users\paillarj\AppData\Local\anaconda3\envs\riemann\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Finding best initial lr:   0%|          | 0/20 [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.
LR finder stopped early after 2 steps due to diverging loss.
Failed to compute suggestion for learning rate because there are not enough points. Increase the loop iteration limits or the size of your dataset/dataloader.
Restoring states from the checkpoint path at checkpoints\test\fold0\.lr_find_53d79eea-094a-4aad-8615-13ab86532a9f.ckpt
Restored all states from the checkpoint at checkpoints\test\fold0\.lr_find_53d79eea-094a-4aad-8615-13ab86532a9f.ckpt

  | Name  | Type    | Params
----------------------------------
0 | model | GreenG2 | 244   
----------------------------------
190       Trainable params
54        Non-trainable params
244       Total params
0.001     Total estimated model params size (MB)
Restored all states from the checkpoint at checkpoints\test\fold0\.lr_find_53d79eea-094a-4aad-8615-13ab86532a9f.ckpt


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\paillarj\AppData\Local\anaconda3\envs\riemann\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.
c:\Users\paillarj\AppData\Local\anaconda3\envs\riemann\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Predicting: |          | 0/? [00:00<?, ?it/s]



pred_acc =  0.0


c:\Users\paillarj\AppData\Local\anaconda3\envs\riemann\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

