****************************************************************

In [39]:
# Importing the root of this bootcamp
import os.path as osp
import sys

sys.path.append(osp.abspath('..'))

# PyTorch Lightning abstraction

Putting it all together with PL abstraction mechanics

Let's first load all the necessary params

In [74]:
import numpy as np
import os
import os.path as osp
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch_geometric as pyg
import torch_optimizer as optim
import torchmetrics.functional as F

import config

step = config.config['timestep']
params = config.params[str(step)]['flattened']

## Model and training logic

In [75]:
class ThreeDCorrectionModule(pl.LightningModule):

    def forward(self, x: torch.Tensor):
        # print(x.shape)
        # print(self.net(x))
        return self.net(x)
    
    def _common_step(self, batch, batch_idx, stage):
        
        x, y = batch
        y_hat = self(x)
        
        loss = F.mean_squared_error(y_hat, y)
        self.log(f"{stage}_loss", loss, prog_bar=True, on_step=True, batch_size=len(batch))
        
        return y_hat, loss
    
    def training_step(self, batch, batch_idx):
        _, loss = self._common_step(batch, batch_idx, "train")
        return loss

    def validation_step(self, batch, batch_idx):
        y_hat, _ = self._common_step(batch, batch_idx, "val")

    def test_step(self, batch, batch_idx):
        y_hat, _ = self._common_step(batch, batch_idx, "test")
    
    def configure_optimizers(self):
        return optim.AdamP(self.parameters(), lr=self.lr)

In [76]:
class LitMLP(ThreeDCorrectionModule):
    
    class Normalize(nn.Module):
        
        def __init__(self):
            super().__init__()
            self.epsilon = torch.tensor(1.e-8)
            
            step = config.config['timestep']
            stats = torch.load(os.path.join(config.data_path, f"stats-flattened-{step}.pt"))
            
            self.x_mean = stats["x_mean"]
            self.x_std = stats["x_std"]
            
        def forward(self, x: torch.Tensor):

            return (x - self.x_mean) / (self.x_std + self.epsilon)
    
    def __init__(self, 
                 in_channels: int, 
                 hidden_channels: int, 
                 out_channels: int, 
                 lr: float = 1e-4):
        super().__init__()
        
        self.lr = lr
        self.net = nn.Sequential(
            self.Normalize(),
            nn.Linear(in_channels, hidden_channels),
            nn.SiLU(),
            nn.Linear(hidden_channels, hidden_channels),
            nn.SiLU(),
            nn.Linear(hidden_channels, hidden_channels),
            nn.SiLU(),
            nn.Linear(hidden_channels, out_channels),
        )

In [77]:
x_feats = params['x_shape'][-1]
y_feats = params['y_shape'][-1]

In [78]:
print(f'x number of features: {x_feats}')
print(f'y number of features: {y_feats}')

x number of features: 4128
y number of features: 552


In [79]:
mlp = LitMLP(
    in_channels=x_feats,
    hidden_channels=100,
    out_channels=y_feats
)
mlp

LitMLP(
  (net): Sequential(
    (0): Normalize()
    (1): Linear(in_features=4128, out_features=100, bias=True)
    (2): SiLU()
    (3): Linear(in_features=100, out_features=100, bias=True)
    (4): SiLU()
    (5): Linear(in_features=100, out_features=100, bias=True)
    (6): SiLU()
    (7): Linear(in_features=100, out_features=552, bias=True)
  )
)

## Dataset creation and data loading mechanics

In [80]:
class FlattenedDataset(torch.utils.data.Dataset):
    
    def __init__(self, step):
        super().__init__()
        self.step = step
        self.params = config.params[str(self.step)]['flattened']
    
    def __len__(self):
        
        return self.params['dataset_len']
    
    def __getitem__(self, idx):
        
        shard_size = len(self) // self.params['num_shards']
        fileidx = idx // shard_size
        rowidx = idx % shard_size
        
        def _load(name):
            main_path = osp.join(config.processed_data_path, f"flattened-{self.step}")
            # data = np.memmap(
            #     osp.join(main_path, name, f'{fileidx}.npy'), 
            #     dtype = self.params['dtype'],
            #     mode='r',
            #     shape=self.params[f'{name}_shape']
            # )
            data = np.load(osp.join(main_path, name, f'{fileidx}.npy'))
            tensor = torch.squeeze(torch.tensor(data[rowidx, ...]))
            return tensor
        
        x = _load('x')
        y = _load('y')
        
        return x, y

In [81]:
class FlattenedDataModule(pl.LightningDataModule):
    
    def __init__(self, 
                 batch_size: int,
                 num_workers: int):
        
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.timestep = config.config['timestep']
        super().__init__()
        
    def prepare_data(self):
        pass
    
    def setup(self, stage: str):
        dataset = FlattenedDataset(self.timestep)
        length = len(dataset)
        
        self.train, self.val, self.test = torch.utils.data.random_split(
            dataset,
            [
                int(length * .8), 
                int(length * .1), 
                int(length * .1)
            ])
    
    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train, 
            batch_size=self.batch_size, 
            shuffle=True, 
            num_workers=self.num_workers)
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val, 
            batch_size=self.batch_size, 
            num_workers=self.num_workers)
    
    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            self.test, 
            batch_size=self.batch_size, 
            num_workers=self.num_workers)

In [82]:
datamodule = FlattenedDataModule(
    batch_size=256,
    num_workers=0
)

## Orchestrating the training

All the training instrumentation is done by an object call the Trainer. You can fix parameters such as `max_epochs`, the `accelerator` type and `device` logical number.

Also interesting: 
* `callbacks` to handle in-betweens
* `gradient_clip_val` and `gradient_clip_algorithm` to setup the gradient clipping
* `logger` to interface with loss and metrics logging

In [83]:
trainer = pl.Trainer()
trainer.fit(
    model=mlp,
    datamodule=datamodule
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name | Type       | Params
------------------------------------
0 | net  | Sequential | 488 K 
------------------------------------
488 K     Trainable params
0         Non-trainable params
488 K     Total params
1.955     Total estimated model params size (MB)


                                                                      

  f"The dataloader, {name}, does not have many workers which may be a bottleneck."


Epoch 0:   0%|          | 4/3816 [00:12<3:16:07,  3.09s/it, loss=1.93, v_num=17, train_loss=2.080]