<div class="alert alert-info">
    <h2 align='center'>🐾 PyTorch Lightning Training Baseline for GPU & TPU + W&B Tracking 🚄</h1>
</div>

<p style='text-align: center'>
    To run the model on TPU, un-comment and run the below cell and change the <code>gpus=1</code> argument to <code>tpu_cores=1</code> or <code>tpu_cores=8</code> in the <code>Trainer</code> class.
</p>

<h1 style='color: #fc0362; font-family: Segoe UI; font-size: 1.5em; font-weight: 300; font-size: 24px'>If you liked this notebook, kindly leave an upvote ⬆️</h1>

In [None]:
# ! curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
# ! python pytorch-xla-env-setup.py --version 1.7 --apt-packages libomp5 libopenblas-dev

<h1 align='center' style='color: #8532a8; font-family: Segoe UI; font-size: 1.5em; font-weight: 300; font-size: 32px'>1. Installation & Imports</h1>

In [None]:
%%sh
pip install -q pytorch-lightning==1.1.8
pip install -q timm
pip install -q albumentations
pip install -q --upgrade wandb

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import timm
import torch
import transformers
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

import os
import cv2
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import r2_score, mean_squared_error

import wandb
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint

from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout, ShiftScaleRotate, CenterCrop, Resize
)
from albumentations.pytorch import ToTensorV2

import warnings
warnings.simplefilter('ignore')

<center><img src="https://i.imgur.com/gb6B4ig.png" width="400" alt="Weights & Biases"/></center><br>
<p style="text-align:center">WandB is a developer tool for companies turn deep learning research projects into deployed software by helping teams track their models, visualize model performance and easily automate training and improving models.
We will use their tools to log hyperparameters and output metrics from your runs, then visualize and compare results and quickly share findings with your colleagues.<br><br></p>

![img](https://i.imgur.com/BGgfZj3.png)

To login to W&B, you can use below snippet.

```python
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
wb_key = user_secrets.get_secret("WANDB_API_KEY")

wandb.login(key=wb_key)
```
Make sure you have your W&B key stored as `WANDB_API_KEY` under Add-ons -> Secrets

You can view [this](https://www.kaggle.com/ayuraj/experiment-tracking-with-weights-and-biases) notebook to learn more about W&B tracking.

If you don't want to login to W&B, the kernel will still work and log everything to W&B in anonymous mode.

In [None]:
Config = dict(
    NFOLDS = 5,
    EPOCHS = 5,
    LR = 2e-4,
    IMG_SIZE = (224, 224),
    MODEL_NAME = 'tf_efficientnet_b6_ns',
    DR_RATE = 0.35,
    NUM_LABELS = 1,
    TRAIN_BS = 32,
    VALID_BS = 16,
    min_lr = 1e-6,
    T_max = 20,
    T_0 = 25,
    NUM_WORKERS = 4,
    infra = "Kaggle",
    competition = 'petfinder',
    _wandb_kernel = 'tanaym',
    wandb = False
)

In [None]:
# Make a W&B Logger
wandb_logger = WandbLogger(project='pytorchlightning', group='vision', job_type='train', anonymous='allow', config=Config)

<h1 align='center' style='color: #8532a8; font-family: Segoe UI; font-size: 1.5em; font-weight: 300; font-size: 32px'>2. Dataset Class</h1>

In [None]:
class PetfinderData(Dataset):
    def __init__(self, df, is_test=False, augments=None):
        self.df = df
        self.is_test = is_test
        self.augments = augments
        
        self.images, self.meta_features, self.targets = self._process_df(self.df)
    
    def __getitem__(self, index):
        img = self.images[index]
        meta_feats = self.meta_features[index]
        meta_feats = torch.tensor(meta_feats, dtype=torch.float32)
        
        img = cv2.imread(img)
        img = img[:, :, ::-1]
        img = cv2.resize(img, Config['IMG_SIZE'])
        
        if self.augments:
            img = self.augments(image=img)['image']
        
        if not self.is_test:
            target = torch.tensor(self.targets[index], dtype=torch.float32)
            return img, meta_feats, target
        else:
            return img, meta_feats
    
    def __len__(self):
        return len(self.df)
    
    def _process_df(self, df):
        TRAIN = "../input/petfinder-pawpularity-score/train"
        TEST = "../input/petfinder-pawpularity-score/test"
        
        if not self.is_test:
            df['Id'] = df['Id'].apply(lambda x: os.path.join(TRAIN, x+".jpg"))
        else:
            df['Id'] = df['Id'].apply(lambda x: os.path.join(TEST, x+".jpg"))
            
        meta_features = df.drop(['Id', 'Pawpularity'], axis=1).values
        
        return df['Id'].tolist(), meta_features, df['Pawpularity'].tolist()

<h1 align='center' style='color: #8532a8; font-family: Segoe UI; font-size: 1.5em; font-weight: 300; font-size: 32px'>3. Augmentations</h1>

In [None]:
class Augments:
    """
    Contains Train, Validation Augments
    """
    train_augments = Compose([
        Resize(*Config['IMG_SIZE'], p=1.0),
        HorizontalFlip(p=0.5),
        VerticalFlip(p=0.5),
        Normalize(
            mean=[0.485, 0.456, 0.406], 
            std=[0.229, 0.224, 0.225], 
            max_pixel_value=255.0, 
            p=1.0
        ),
        ToTensorV2(p=1.0),
    ],p=1.)
    
    valid_augments = Compose([
        Resize(*Config['IMG_SIZE'], p=1.0),
        Normalize(
            mean=[0.485, 0.456, 0.406], 
            std=[0.229, 0.224, 0.225], 
            max_pixel_value=255.0, 
            p=1.0
        ),
        ToTensorV2(p=1.0),
    ], p=1.)

<h1 align='center' style='color: #8532a8; font-family: Segoe UI; font-size: 1.5em; font-weight: 300; font-size: 32px'>4. Pytorch Lightning Model Class</h1>

In [None]:
class PetFinderModel(pl.LightningModule):
    def __init__(self, pretrained=True):
        super(PetFinderModel, self).__init__()
        self.model = timm.create_model(Config['MODEL_NAME'], pretrained=pretrained)
        
        self.n_features = self.model.classifier.in_features
        self.model.reset_classifier(0)
        self.fc = nn.Linear(self.n_features + 12, Config['NUM_LABELS'])
        
        self.train_loss = nn.MSELoss()
        self.valid_loss = nn.MSELoss()

    def forward(self, images, meta):
        features = self.model(images)
        features = torch.cat([features, meta], dim=1)
        output = self.fc(features)
        return output
    
    def training_step(self, batch, batch_idx):
        imgs = batch[0]
        meta = batch[1]
        target = batch[2]
        
        out = self(imgs, meta)
        train_loss = torch.sqrt(self.train_loss(out, target))
        
        logs = {'train_loss': train_loss}
        
        return {'loss': train_loss, 'log': logs}
    
    def validation_step(self, batch, batch_idx):
        imgs = batch[0]
        meta = batch[1]
        target = batch[2]
        
        out = self(imgs, meta)
        valid_loss = torch.sqrt(self.valid_loss(out, target))
        
        return {'val_loss': valid_loss}
    
    def validation_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        logs = {'val_loss': avg_loss}
        
        print(f"val_loss: {avg_loss}")
        return {'avg_val_loss': avg_loss, 'log': logs}
    
    def configure_optimizers(self):
        opt = torch.optim.Adam(self.parameters(), lr=Config['LR'])
        sch = torch.optim.lr_scheduler.CosineAnnealingLR(
            opt, 
            T_max=Config['T_max'],
            eta_min=Config['min_lr']
        )
        
        return [opt], [sch]

<h1 align='center' style='color: #8532a8; font-family: Segoe UI; font-size: 1.5em; font-weight: 300; font-size: 32px'>5. KFolds Model Training</h1>

In [None]:
# Run the Kfolds training loop
kf = StratifiedKFold(n_splits=Config['NFOLDS'])
train_file = pd.read_csv("../input/petfinder-pawpularity-score/train.csv")

for fold_, (train_idx, valid_idx) in enumerate(kf.split(X=train_file, y=train_file['Pawpularity'])):
    print(f"{'='*20} Fold: {fold_} {'='*20}")
    
    train_df = train_file.loc[train_idx]
    valid_df = train_file.loc[valid_idx]
    
    train_set = PetfinderData(
        train_df,
        augments = Augments.train_augments
    )

    valid_set = PetfinderData(
        valid_df,
        augments = Augments.valid_augments
    )
    
    train = DataLoader(
        train_set,
        batch_size=Config['TRAIN_BS'],
        shuffle=True,
        num_workers=Config['NUM_WORKERS'],
        pin_memory=True
    )
    valid = DataLoader(
        valid_set,
        batch_size=Config['VALID_BS'],
        shuffle=False,
        num_workers=Config['NUM_WORKERS']
    )
    
    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        dirpath="./",
        filename=f"fold_{fold_}_{Config['MODEL_NAME']}",
        save_top_k=1,
        mode="min",
    )
    
    model = PetFinderModel()
    trainer = pl.Trainer(
        max_epochs=Config['EPOCHS'], 
        gpus=1, 
        callbacks=[checkpoint_callback], 
        logger= wandb_logger
    )
    trainer.fit(model, train, valid)

## [View the Complete Dashboard Here 🎯](https://wandb.ai/anony-mouse-138818/pytorchlightning/runs/1mgcybd2?apiKey=9c1b4ff53762c75e39d283f5434cc5552455b179)

![](https://imgur.com/XyvhYFZ.gif)