# PyTorch Lightning abstraction basics

*Putting it all together with PL abstraction mechanics.*

[PyTorch Lightning](https://www.pytorchlightning.ai/) is a framework built on top of PyTorch, which takes care of the boilerplates, as well as simplifies the training parallelism. It is often compared to Keras for TensorFlow. We will integrated directly our models and data modules with the PL mechanics, so that distributed training becomes easier.

Lightning has grown into a massive framework with functionalities missing from vanilla PyTorch, but for basic-level understanding of the abstraction logic, there are only a few usefull components, most notably:
* `pl.LightningModule`, a wrapper for a PyTorch model, with implementable train, test, and validation loops
* `pl.LightningDataModule`, a wrapper for a PyTorch Dataset, with implementable data splitting logic
* `pl.Trainer` to orchestrate training + testing phases, as well as inference. Gradient clipping, 
* `pl.callbacks.base.Callback` to organize runtime workflow. Comes standard with EarlyStopping, ModelCheckpoint, LearningRateMonitor, and ModelPruning, among others
* A Profiler to debug resource utilization

Let's first load all the necessary params.

In [1]:
from kosmoss import CONFIG, LOGS_PATH, METADATA
from kosmoss.parallel.data import FlattenedDataModule
from kosmoss.parallel.models import LitMLP

In [4]:
import numpy as np
import os
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger

# Ensures this Notebook's reproducibility
pl.seed_everything(42, workers=True)

step = CONFIG['timestep']
params = METADATA[str(step)]['flattened']

Global seed set to 42


## Model and training logic

In [5]:
!cat models.py

cat: models.py: No such file or directory


In [6]:
x_feats = params['x_shape'][-1]
y_feats = params['y_shape'][-1]

In [7]:
print(f'x number of features: {x_feats}')
print(f'y number of features: {y_feats}')

x number of features: 4128
y number of features: 552


In [8]:
mlp = LitMLP(
    in_channels=x_feats,
    hidden_channels=100,
    out_channels=y_feats
)
mlp

LitMLP(
  (normalization_layer): Normalize()
  (net): Sequential(
    (0): Normalize()
    (1): Linear(in_features=4128, out_features=100, bias=True)
    (2): SiLU()
    (3): Linear(in_features=100, out_features=100, bias=True)
    (4): SiLU()
    (5): Linear(in_features=100, out_features=100, bias=True)
    (6): SiLU()
    (7): Linear(in_features=100, out_features=552, bias=True)
  )
)

## Dataset creation and data loading mechanics

In [9]:
!cat data.py

import numpy as np
import os.path as osp
from pytorch_lightning import LightningDataModule
import torch
from typing import Tuple, Union

from kosmoss import CONFIG, METADATA, PROCESSED_DATA_PATH

class FlattenedDataset(torch.utils.data.Dataset):
    
    def __init__(self, 
                 step: int, 
                 mode: Union['efficient', 'controlled'] = 'controlled') -> None:
        super().__init__()
        self.step = step
        self.mode = mode
        self.params = METADATA[str(self.step)]['flattened']
    
    def __len__(self) -> int:
        
        return self.params['dataset_len']
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor]:
        
        shard_size = len(self) // self.params['num_shards']
        fileidx = idx // shard_size
        rowidx = idx % shard_size
        
        def _load(name: Union['x', 'y']) -> Tuple[torch.Tensor]:
            main_path = osp.join(PROCESSED_DATA_PATH, f"flattened-{self.step}")
            
            if self.m

* `batch_size` sets the number of element in a batch of data.
* `num_workers` sets the number of workers the DataLoader can spawn to handle data loading and Dataset batching.

In [10]:
import psutil
cores = psutil.cpu_count(logical=False)

In [11]:
datamodule = FlattenedDataModule(
    batch_size=1024,
    
    # In CPU-only setup, make sure you still have enough cores to handle the training, 
    # Not just data loading, otherwise, it will bottleneck
    num_workers=cores
)

## Orchestrating the training

In [12]:
logger = TensorBoardLogger(
    save_dir=LOGS_PATH,
    name='flattened_mlp_logs',
    log_graph=True
)

All the training instrumentation is done by an object call the Trainer. You can fix parameters such as:
* `max_epochs` unless an early stopping happens
* `accelerator` type and `device` logical number

Notably interesting: 
* `callbacks` to handle in-betweens
* `gradient_clip_val` and `gradient_clip_algorithm` to setup the gradient clipping
* `logger` to interface with loss and metrics logging
* `resume_from_checkpoint` helps resuming a previously initiated training
* `amp_backend` to switch to Nvidia Apex framework for Automatic Mixed Precision support

In [13]:
cpu_trainer = Trainer(
    max_epochs=1,
    logger=logger,
    deterministic=True,
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


Training CPU is a one-line

In [14]:
cpu_trainer.fit(model=mlp, datamodule=datamodule)

Missing logger folder: /home/jupyter/.kosmoss/logs/flattened_mlp_logs

  | Name                | Type       | Params
---------------------------------------------------
0 | normalization_layer | Normalize  | 0     
1 | net                 | Sequential | 488 K 
---------------------------------------------------
488 K     Trainable params
0         Non-trainable params
488 K     Total params
1.955     Total estimated model params size (MB)




Global seed set to 42                                                 
Epoch 0:  89%|████████▉ | 848/954 [01:55<00:14,  7.32it/s, loss=0.592, v_num=0, train_loss=0.567]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/106 [00:00<?, ?it/s][A
Validating:   1%|          | 1/106 [00:01<03:07,  1.78s/it][A
Epoch 0:  90%|████████▉ | 856/954 [01:57<00:13,  7.26it/s, loss=0.592, v_num=0, train_loss=0.567]
Epoch 0:  91%|█████████ | 866/954 [01:57<00:11,  7.34it/s, loss=0.592, v_num=0, train_loss=0.567]
Epoch 0:  92%|█████████▏| 876/954 [01:58<00:10,  7.42it/s, loss=0.592, v_num=0, train_loss=0.567]
Epoch 0:  93%|█████████▎| 886/954 [01:58<00:09,  7.50it/s, loss=0.592, v_num=0, train_loss=0.567]
Epoch 0:  94%|█████████▍| 896/954 [01:58<00:07,  7.58it/s, loss=0.592, v_num=0, train_loss=0.567]
Validating:  45%|████▌     | 48/106 [00:02<00:01, 49.03it/s][A
Epoch 0:  95%|█████████▍| 906/954 [01:58<00:06,  7.65it/s, loss=0.592, v_num=0, train_loss=0.567]
Epoch 0:  96%|█████████▋| 9

Never forget to test. The handy thing with the `Trainer` is, if a `.test()` is called somewhere at runtime, once a `SIGTERM` is thrown by the runtime such as a `KeyboardInterruptError`, it gets caught by Lightning, which tries to gracefully release resources, terminate training, and run the test anyway.

In [15]:
cpu_trainer.test(model=mlp, datamodule=datamodule)



Testing:  99%|█████████▉| 105/106 [00:02<00:00, 83.03it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': 0.5979500412940979, 'test_loss_epoch': 0.5979500412940979}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 106/106 [00:03<00:00, 34.88it/s]


[{'test_loss': 0.5979500412940979, 'test_loss_epoch': 0.5979500412940979}]