# PyTorch Lightning abstraction basics

*Putting it all together with PL abstraction mechanics.*

[PyTorch Lightning](https://www.pytorchlightning.ai/) is a framework built on top of PyTorch, which takes care of the boilerplates, as well as simplifies the training parallelism. It is often compared to Keras for TensorFlow. We will integrated directly our models and data modules with the PL mechanics, so that distributed training becomes easier.

Lightning has grown into a massive framework with functionalities missing from vanilla PyTorch, but for basic-level understanding of the abstraction logic, there are only a few usefull components, most notably:
* `pl.LightningModule`, a wrapper for a PyTorch model, with implementable train, test, and validation loops
* `pl.LightningDataModule`, a wrapper for a PyTorch `Dataset`, with implementable data splitting logic
* `pl.Trainer` to orchestrate training + testing phases, as well as inference. Gradient clipping, 
* `pl.callbacks.base.Callback` to organize runtime workflow. Comes standard with `EarlyStopping`, `ModelCheckpoint`, `LearningRateMonitor`, and `ModelPruning`, among others
* A Profiler to debug resource utilization

Let's first load all the necessary params.

In [1]:
from kosmoss import CONFIG, LOGS_PATH, METADATA
from kosmoss.parallel.data import FlattenedDataModule
from kosmoss.parallel.mlp import LitMLP

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import numpy as np
import os
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger

# Ensures this Notebook's reproducibility
pl.seed_everything(42, workers=True)

step = CONFIG['timestep']
params = METADATA[str(step)]['flattened']

Global seed set to 42


## Model and training logic

In [3]:
!cat mlp.py

import os.path as osp
from pytorch_lightning import LightningModule
import torch
import torch.nn as nn
import torch_optimizer as optim
import torchmetrics.functional as F
from typing import List, Tuple, Union

from kosmoss import CONFIG, DATA_PATH

class ThreeDCorrectionModule(LightningModule):
    
    def __init__(self):
        super().__init__()
        
        # 'The LightningModule knows what device it is on. You can access the reference via self.device. Sometimes it is necessary to store tensors as module attributes. However, if they are not parameters they will remain on the CPU even if the module gets moved to a new device. To prevent that and remain device agnostic, register the tensor as a buffer in your modules’s __init__ method with register_buffer().'
        self.register_buffer("epsilon", torch.tensor(1.e-8))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)
    
    def _common_step(self, 
                     batch: List[torch.Tensor

In [4]:
x_feats = params['x_shape'][-1]
y_feats = params['y_shape'][-1]

In [5]:
print(f'x number of features: {x_feats}')
print(f'y number of features: {y_feats}')

x number of features: 4128
y number of features: 552


In [6]:
mlp = LitMLP(
    in_channels=x_feats,
    hidden_channels=100,
    out_channels=y_feats
)
mlp

LitMLP(
  (normalization_layer): Normalize()
  (net): Sequential(
    (0): Normalize()
    (1): Linear(in_features=4128, out_features=100, bias=True)
    (2): SiLU()
    (3): Linear(in_features=100, out_features=100, bias=True)
    (4): SiLU()
    (5): Linear(in_features=100, out_features=100, bias=True)
    (6): SiLU()
    (7): Linear(in_features=100, out_features=552, bias=True)
  )
)

## Dataset creation and data loading mechanics

In [7]:
!cat data.py

import numpy as np
import os.path as osp
from pytorch_lightning import LightningDataModule
import torch
from typing import Tuple, Union

from kosmoss import CONFIG, METADATA, PROCESSED_DATA_PATH

class FlattenedDataset(torch.utils.data.Dataset):
    
    def __init__(self, 
                 step: int, 
                 mode: Union['efficient', 'controlled'] = 'controlled') -> None:
        super().__init__()
        self.step = step
        self.mode = mode
        self.params = METADATA[str(self.step)]['flattened']
    
    def __len__(self) -> int:
        
        return self.params['dataset_len']
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor]:
        
        shard_size = len(self) // self.params['num_shards']
        fileidx = idx // shard_size
        rowidx = idx % shard_size
        
        def _load(name: Union['x', 'y']) -> Tuple[torch.Tensor]:
            main_path = osp.join(PROCESSED_DATA_PATH, f"flattened-{self.step}")
            
            if self.m

* `batch_size` sets the number of element in a batch of data.
* `num_workers` sets the number of workers the DataLoader can spawn to handle data loading and Dataset batching.

In [8]:
import psutil
cores = psutil.cpu_count(logical=False)

In [9]:
datamodule = FlattenedDataModule(
    batch_size=1024,
    
    # In CPU-only setup, make sure you still have enough cores to handle the training, 
    # Not just data loading, otherwise, it will bottleneck
    num_workers=cores
)

## Orchestrating the training

In [10]:
logger = TensorBoardLogger(
    save_dir=LOGS_PATH,
    name='flattened_mlp_logs',
    log_graph=True
)

All the training instrumentation is done by an object call the Trainer. You can fix parameters such as:
* `max_epochs` unless an early stopping happens
* `accelerator` type and `device` logical number

Notably interesting: 
* `callbacks` to handle in-betweens
* `gradient_clip_val` and `gradient_clip_algorithm` to setup the gradient clipping
* `logger` to interface with loss and metrics logging
* `resume_from_checkpoint` helps resuming a previously initiated training
* `amp_backend` to switch to Nvidia Apex framework for Automatic Mixed Precision support

In [11]:
cpu_trainer = Trainer(
    max_epochs=1,
    logger=logger,
    deterministic=True,
)

GPU available: True, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


  rank_zero_warn(


Training CPU is a one-line

In [12]:
cpu_trainer.fit(model=mlp, datamodule=datamodule)


  | Name                | Type       | Params
---------------------------------------------------
0 | normalization_layer | Normalize  | 0     
1 | net                 | Sequential | 488 K 
---------------------------------------------------
488 K     Trainable params
0         Non-trainable params
488 K     Total params
1.955     Total estimated model params size (MB)


  rank_zero_warn(


Global seed set to 42                                     
Epoch 0:  89% 848/954 [02:25<00:18,  5.83it/s, loss=0.594, v_num=3, train_loss=0.569]
Validating: 0it [00:00, ?it/s][A
Validating:   0% 0/106 [00:00<?, ?it/s][A
Epoch 0:  89% 850/954 [03:09<00:23,  4.48it/s, loss=0.594, v_num=3, train_loss=0.569]
Epoch 0:  89% 853/954 [03:09<00:22,  4.49it/s, loss=0.594, v_num=3, train_loss=0.569]
Epoch 0:  90% 858/954 [03:09<00:21,  4.52it/s, loss=0.594, v_num=3, train_loss=0.569]
Epoch 0:  90% 863/954 [03:10<00:20,  4.54it/s, loss=0.594, v_num=3, train_loss=0.569]
Epoch 0:  91% 870/954 [03:10<00:18,  4.57it/s, loss=0.594, v_num=3, train_loss=0.569]
Epoch 0:  92% 877/954 [03:10<00:16,  4.61it/s, loss=0.594, v_num=3, train_loss=0.569]
Epoch 0:  93% 884/954 [03:10<00:15,  4.64it/s, loss=0.594, v_num=3, train_loss=0.569]
Validating:  34% 36/106 [00:45<00:15,  4.43it/s][A
Epoch 0:  93% 891/954 [03:10<00:13,  4.68it/s, loss=0.594, v_num=3, train_loss=0.569]
Epoch 0:  94% 901/954 [03:10<00:11,  4

Never forget to test. The handy thing with the `Trainer` is, if a `.test()` is called somewhere at runtime, once a `SIGTERM` is thrown by the runtime such as a `KeyboardInterruptError`, it gets caught by Lightning, which tries to gracefully release resources, terminate training, and run the test anyway.

In [13]:
cpu_trainer.test(model=mlp, datamodule=datamodule)

  rank_zero_warn(


Testing:  95% 101/106 [00:56<00:00, 20.14it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': 0.5971405506134033, 'test_loss_epoch': 0.5971405506134033}
--------------------------------------------------------------------------------
Testing: 100% 106/106 [00:56<00:00,  1.86it/s]


[{'test_loss': 0.5971405506134033, 'test_loss_epoch': 0.5971405506134033}]