## Introduction
Motivation - We go through many kernels and repos and try to understand the what the model is doing. And as you know, apart from the fundemental differences in the architecture and training procedure we see lots of variation in the coding style like the way the training loop is, what happens at the end of each training or validation epoch. So you have to spend some time to understand things that does not really contribute to the fundemental approach. And even for the person developing the model, he spends a good time to write boilerplate code.  

This kernel is a quick introduction to [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning#how-do-i-do-use-it) a PyTorch wrapper for ML researchers. This automates non essential procedures and enforces a good coding style to make machine learning solutions much more consistent and reproducible. "More of a style guide than a framework".  

![lightning_logo.svg](attachment:lightning_logo.svg)

tl;dr just jump into [The Lightning Module](#The-Lightning-Module) section which has the stuff I wanted to share!

In [None]:
!pip install pytorch-lightning

Importing libraries

In [None]:
import numpy as np 
import pandas as pd 

import os
import gc 
import sys

from IPython.core.display import display
from ipywidgets import IntSlider, interact
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import transforms, models
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim import Adam

from pytorch_lightning import Trainer
from pytorch_lightning.logging import WandbLogger
from pytorch_lightning.core import LightningModule


# notebook params
_ = plt.rcParams['figure.figsize'] = [15, 2]
np.random.seed(400)

In [None]:
package_path = '../input/efficientnet/efficientnet-pytorch/EfficientNet-PyTorch/'
sys.path.append(package_path)
from efficientnet_pytorch import EfficientNet

### Data Overview

In [None]:
import os
for dirname, _, filenames in os.walk('/kaggle/input/bengaliai-cv19'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

### Reading The CSVs

In [None]:
class_map_df = pd.read_csv('/kaggle/input/bengaliai-cv19/class_map.csv')
print(f'class map shape: ', class_map_df.shape)
class_map_df.sample(50).drop_duplicates(['component_type'])

In [None]:
train_df = pd.read_csv('/kaggle/input/bengaliai-cv19/train.csv')
print(f'train data shape: ', train_df.shape)
print(f'unique graphemes: ', train_df['grapheme'].nunique())
print(f'unique grapheme_root: ', train_df['grapheme_root'].nunique())
print(f'unique vowel_diacritic: ', train_df['vowel_diacritic'].nunique())
print(f'unique consonant_diacritic: ', train_df['consonant_diacritic'].nunique())

train_df.head()

In [None]:
test_df = pd.read_csv('/kaggle/input/bengaliai-cv19/test.csv')
print(f'test data shape: ', test_df.shape)
test_df.head(6)

In [None]:
sample_sub_df = pd.read_csv('/kaggle/input/bengaliai-cv19/sample_submission.csv')
sample_sub_df.head(6)

In [None]:
# clean up
del train_df
del test_df
del class_map_df
_ = gc.collect()

### Reading The Image files

Since loading the parquet files is a bit slow, we use another [public dataset](https://www.kaggle.com/corochann/bengaliaicv19feather) in [feather format](https://github.com/wesm/feather) which is around 30 times faster!

Some helper functions to handle and visualize the data

In [None]:
def get_image_data(mode='val', debug=False):
    '''
    helper function for PyTorch Dataset class
    
    Arguments:
        mode (str) -- reads the feather files with train in the filename for train and val,
                      and reads files with test in their names.
                           
    Returns:
        img_df (dataframe) -- training images if train = true, else test images
    '''
    
    img_list = []
    file_type = mode # to fetch files
    if mode == 'val':
        file_type = 'train'
    for dirname, _, filenames in os.walk('/kaggle/input/bengaliaicv19feather'):
        for filename in filenames:
            if file_type in filename:
                img_list.append(pd.read_feather(os.path.join(dirname, filename)))
                           
    if mode == 'val':
        img_df = pd.DataFrame(img_list[-1])
    elif mode == 'train':
        img_df = pd.concat(img_list[0:-1])
    else:
        img_df = pd.concat(img_list)
        
    print(f"[Helper] {mode} image dataset: {img_df.shape}")
    
    img_df = img_df[0:25]
    
    del img_list
    _ = gc.collect()
    
    return img_df

In [None]:
def plot_images(rows, cols, img_df, train=True):
    """
    Grid of images
    
    Arguments:
        rows, cols (int, int) -- dimenstion of the image grid
        img_df (dataframe) -- Dataframe of all the images
        train (boolean) -- fetch meta data from the csv files accordingly
    """
    
    fig = plt.figure(figsize=(15., 12.))
    grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(rows, cols),  # creates 5x5 grid of axes
                 axes_pad=0.3,  # pad between axes in inch.
                 )
    train_df = pd.read_csv('/kaggle/input/bengaliai-cv19/train.csv')
    test_df = pd.read_csv('/kaggle/input/bengaliai-cv19/test.csv')

    meta_df = train_df if train else test_df
    
    for ax, df_row in zip(grid, img_df.sample(rows*cols).values.tolist()):
        # Iterating over the grid returns the Axes.
        _ = ax.imshow(np.asarray(df_row[1:]).astype(int).reshape(137,236))
        # fetch the sample's labels from the csv file
        meta = meta_df[meta_df['image_id']==df_row[0]].values[0]
        
        if train:
            title =  f'{df_row[0]}_{meta[1]}_{meta[2]}_{meta[3]}'
        else: 
            title =  f'{df_row[0]}'
            
        _ = ax.set_title(title)
        _ = ax.axis('off')
   
    _ = plt.show()
    
    del train_df
    del test_df
    del meta_df
    _ = gc.collect()

In [None]:
%%time

# get the data
img_df = get_image_data(mode='train', debug=True)
# visualise few images 
plot_images(5, 5, img_df, train=True)

In [None]:
# clear the data after visualization
# Since we will be loading the data using PyTorch Dataset class again in the later cells
del img_df
_ = gc.collect()

## Preparing PyTorch Dataset

In [None]:
class BengaliAI(Dataset):
    """Bengali AI dataset for training PyTorch models"""

    def __init__(self, mode='val', transform=None, debug=False):
        """
        Arguments:
            mode (str) -- to fetch appr. meta data. Default to val since smaller size
            transform (callable) -- Transform to be applied on each sample 
        """
        self.mode = mode
        if self.mode == 'train' or self.mode == 'val':         
            self.metadata = pd.read_csv(f'/kaggle/input/bengaliai-cv19/train.csv')
        else:
            self.metadata = pd.read_csv(f'/kaggle/input/bengaliai-cv19/test.csv')

        self.data = get_image_data(mode, debug)
        self.transform = transform
        
        if self.mode != 'test':
            _categorical_columns = ['grapheme_root', 'vowel_diacritic', 'consonant_diacritic']
            self.grapheme_root = self.metadata[_categorical_columns[0]]
            self.vowel_diacritic = self.metadata[_categorical_columns[1]]
            self.constant_diacritic = self.metadata[_categorical_columns[2]]
                                                
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        """
        Arguments:
            idx (int) -- Dataset class is a map-style dataset (https://pytorch.org/docs/stable/data.html#map-style-datasets) 
                         where idx is the index of a specific sample in the map

        Returns:
            sample (dic) -- Each sample is a dic with keys as image_id, image, grapheme_root, vowel_diacritic, consonant_diacritic
        """
        # we will disard the image id prefix 'Train_/Test_' so its easy to format it as tensor
        # data = image data
        _data_at_idx = self.data.iloc[idx]
        _image_id = int(_data_at_idx[0].split('_')[1])
        _image = _data_at_idx[1:]
        
        if self.mode == 'test':
            sample = {'image_id': _image_id,
              'image': _image
             }
        
        else:
            _grapheme_root = self.grapheme_root.iloc[idx]
            _vowel_diacritic = self.vowel_diacritic.iloc[idx]
            _constant_diacritic = self.constant_diacritic.iloc[idx]

            sample = {'image_id': _image_id,
                      'image': _image,
                      'grapheme_root': _grapheme_root,
                      'vowel_diacritic': _vowel_diacritic,
                      'consonant_diacritic': _constant_diacritic
                     }

        if self.transform:
            sample = self.transform(sample)
            
        return sample

In [None]:
class ToTensor(object):
    """Convert ndarrays in sample to Tensors. And also makes it 3 channel"""

    def __call__(self, sample):
        for key in sample.keys():
            sample[key] = torch.tensor(sample[key], dtype=torch.float32)
            if key == 'image':
                sample[key] = sample[key].reshape(137, 236).repeat(3, 1, 1)
        
        return sample

#### Sanity Check

In [None]:
%%time 
bengali_dataset= BengaliAI(mode='test', transform=ToTensor(), debug=True)
sample = bengali_dataset[0]
sample

In [None]:
del bengali_dataset
gc.collect()

## The Lightning Module
This is "**The Juice**", where all your research and creativity goes!    
PyTorch Lightning also provides [research seed](https://github.com/williamFalcon/pytorch-lightning-conference-seed), cookie-cutter-like template for your project repository.

<div class="alert alert-block alert-info">
<b>💡</b>
A [LightningModule](https://pytorch-lightning.readthedocs.io/en/latest/lightning-module.html) is a strict superclass of torch.nn.Module but provides an interface to standardize the “ingredients” for a research or production system. 
</div>

In [None]:
class ResNext3(LightningModule):

    def __init__(self):
        super(ResNext3, self).__init__()
        # ResNext
        # backbone_model = torch.hub.load('pytorch/vision:v0.5.0', 'resnext50_32x4d', pretrained=False)
        self.toy_data = True
        backbone_model = EfficientNet.from_name('efficientnet-b7') 
        backbone_model.load_state_dict(torch.load('../input/efficientnet-pytorch/efficientnet-b7-dcc49843.pth'))
        # Take the whole resnext except for the last layer
        backbone_layers = torch.nn.ModuleList(backbone_model.children())[:-2]
        # Unpack all layers to Sequential as list is not a valid parameter 
        self.features = torch.nn.Sequential(*backbone_layers)
        in_features = backbone_model._fc.in_features
        self.fc_grapheme_root = torch.nn.Linear(in_features, 168)
        self.fc_vowel_diacritic = torch.nn.Linear(in_features, 11)
        self.fc_consonant_diacritic = torch.nn.Linear(in_features, 7)

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x,1)
        grapheme = self.fc_grapheme_root(x)
        vowel = self.fc_vowel_diacritic(x)
        consonant = self.fc_consonant_diacritic(x)    
        return grapheme, vowel, consonant
        
    def training_step(self, batch, batch_idx):
        print(batch)
        grapheme, vowel, consonant = self.forward(batch['image'])
        loss_grapheme = F.cross_entropy(grapheme, batch['grapheme_root'].long())
        loss_vowel = F.cross_entropy(vowel, batch['vowel_diacritic'].long())
        loss_consonant = F.cross_entropy(consonant, batch['consonant_diacritic'].long())
        logger_logs = {"tl_grapheme": loss_grapheme, 
                       "tl_vowel": loss_vowel, 
                       "tl_consonant": loss_consonant}
        
        return {'loss': loss_grapheme+loss_vowel+loss_consonant, 'log': logger_logs}

    def validation_step(self, batch, batch_idx):
        print(batch)
        grapheme, vowel, consonant = self.forward(batch['image'])
        loss_grapheme = F.cross_entropy(grapheme, batch['grapheme_root'].long())
        loss_vowel = F.cross_entropy(vowel, batch['vowel_diacritic'].long())
        loss_consonant = F.cross_entropy(consonant, batch['consonant_diacritic'].long())
        logger_logs = {"vl_grapheme": loss_grapheme, 
                       "vl_vowel": loss_vowel, 
                       "vl_consonant": loss_consonant}
        return {'val_loss': loss_grapheme+loss_vowel+loss_consonant, 'log': logger_logs}
                                                                                            
    def validation_end(self, outputs):
        logger_logs = {'avg_val_loss': torch.stack([x['val_loss'] for x in outputs]).mean(),
                       "avl_grapheme": torch.stack([x['log']['vl_grapheme'] for x in outputs]).mean(), 
                       "avl_vowel": torch.stack([x['log']['vl_vowel'] for x in outputs]).mean(), 
                       "avl_consonant": torch.stack([x['log']['vl_consonant'] for x in outputs]).mean()
                       }
        # must return 'val_loss' as key
        return {'val_loss': logger_logs['avg_val_loss'], 'log': logger_logs} 
    
    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=0.001)
        scheduler = ReduceLROnPlateau(optimizer)
        return [optimizer],[scheduler]

    def train_dataloader(self):
        return DataLoader(BengaliAI(mode='train', transform=ToTensor()), batch_size=64, pin_memory=True)
    
    def val_dataloader(self):
        return DataLoader(BengaliAI(mode='val', transform=ToTensor()), batch_size=64, pin_memory=True)        

    def test_dataloader(self):
        return DataLoader(BengaliAI(mode='test', transform=ToTensor()), batch_size=64, pin_memory=True)        

In [None]:
model = ResNext3()
wandb_logger = WandbLogger(name='The-run', project='bengali-ai')
trainer = Trainer(gpus=0, fast_dev_run=True, checkpoint_callback=None)

- Pytorch Lightning supports many experiment tracking platforms like wandb, comet, test tube etc.   
- The logger should be set to wandb_logger since it needs a key which can not be provided when commiting the kernel I set it to None.
- Simply set gpus=1 or more to make it use gpus.  
- Setting the fast_dev_run to True will run all steps of training once to make sure everything is in place.  

  You have tons of these elegant abstractions, Isnt it cool?

In [None]:
trainer.fit(model)

*This is a work in progress. Should I improve this, showing more features of PL or is it too late for example kernels for this competition?*

In [None]:
del trainer