This Kernel is based on this amazing [⚡Plant2021 PyTorch Lightning Starter [ Training ]⚡](https://www.kaggle.com/pegasos/plant2021-pytorch-lightning-starter-training) by [Sh1r0](https://www.kaggle.com/pegasos). This kernel is intended to showcase [Weights and Biases](https://wandb.ai/site) integration with PyTorch Lightning. 

# ⚡ PyTorch Lightning

PyTorch is an extremely powerful framework for your deep learning research. But once the research gets complicated and things like 16-bit precision, multi-GPU training, and TPU training get mixed in, users are likely to introduce bugs. **PyTorch Lightning lets you decouple research from engineering.**

**PyTorch Lightning ⚡ is not another framework but a style guide for PyTorch.**

To learn more about PyTorch Lightning check out my blog posts at Weights and Biases [Fully Connected](https://wandb.ai/fully-connected):

* [Image Classification using PyTorch Lightning](https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning--VmlldzoyODk1NzY)
* [Transfer Learning Using PyTorch Lightning](https://wandb.ai/wandb/wandb-lightning/reports/Transfer-Learning-Using-PyTorch-Lightning--VmlldzoyODk2MjA)
* [Multi-GPU Training Using PyTorch Lightning](https://wandb.ai/wandb/wandb-lightning/reports/Multi-GPU-Training-Using-PyTorch-Lightning--VmlldzozMTk3NTk)

# <img src="https://i.imgur.com/gb6B4ig.png" width="400" alt="Weights & Biases" />

Weights & Biases helps you build better models faster with a central dashboard for your machine learning projects. It not only logs your training metrics but can log hyperparameters and output metrics, then visualize and compare results and quickly share findings with your team mates. Track everything you need to make your models reproducible with Weights & Biases— from hyperparameters and code to model weights and dataset versions. 

### [Check this Kaggle kernel to learn more about Weights and Biases$\rightarrow$](https://www.kaggle.com/ayuraj/experiment-tracking-with-weights-and-biases)
![img](https://i.imgur.com/BGgfZj3.png)

# PyTorch Lightning + Weights and Biases 

PyTorch Lightning provides a lightweight wrapper for organizing your PyTorch code and easily adding advanced features such as distributed training and 16-bit precision. W&B provides a lightweight wrapper for logging your ML experiments. It is incorporated directly into the PyTorch Lightning library, so you can check out [their documentation](https://pytorch-lightning.readthedocs.io/en/stable/extensions/generated/pytorch_lightning.loggers.WandbLogger.html#pytorch_lightning.loggers.WandbLogger) for the API and reference info.

### Use the intergration in few lines of code.

```
from pytorch_lightning.loggers import WandbLogger  # newline 1
from pytorch_lightning import Trainer

wandb_logger = WandbLogger()  # newline 2
trainer = Trainer(logger=wandb_logger)
```

[![thumbnail](https://i.imgur.com/M7xZ04g.png)](https://www.youtube.com/watch?v=hUXQm46TAKc)


In [2]:
import wandb
from pytorch_lightning.loggers import WandbLogger

wandb.login()

[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize


[34m[1mwandb[0m: Paste an API key from your profile and hit enter:  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [3]:
import cv2
import timm
import torch
import numpy as np
import pandas as pd

import torch.nn as nn
import albumentations as A
import pytorch_lightning as pl
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader
from albumentations.core.composition import Compose, OneOf
from albumentations.augmentations.transforms import CLAHE, GaussNoise, ISONoise
from albumentations.pytorch import ToTensorV2

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning import Callback
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

from sklearn.model_selection import train_test_split

# 📀 Hyperparameters

In [4]:
# Config dictionary that will be logged to W&B.
CONFIG = dict (
    seed = 42,
    train_val_split = 0.2,
    model_name = 'resnet50',
    pretrained = True,
    img_size = 256,
    num_classes = 12,
    lr = 5e-4,
    min_lr = 1e-6,
    t_max = 20,
    num_epochs = 10,
    batch_size = 32,
    accum = 1,
    precision = 16,
    n_fold = 5,
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
)

# Directories
PATH = "../input/plant-pathology-2021-fgvc8/"

image_size = CONFIG['img_size']
TRAIN_DIR = f'../input/resized-plant2021/img_sz_{image_size}/'
TEST_DIR = PATH + 'test_images/'

# Seed everything
seed_everything(CONFIG['seed'])

Global seed set to 42


42

# 🔧 DataModule

In [5]:
# Read CSV file
df = pd.read_csv(PATH + "train.csv")

# Label encode 
labels = list(df['labels'].value_counts().keys())
labels_dict = dict(zip(labels, range(12)))
df = df.replace({"labels": labels_dict})
df.head()

Unnamed: 0,image,labels
0,800113bb65efe69e.jpg,1
1,8002cb321f8bfcdf.jpg,7
2,80070f7fb5e2ccaa.jpg,0
3,80077517781fb94f.jpg,0
4,800cbf0ff87721f8.jpg,4


In [6]:
class PlantDataset(Dataset):
    def __init__(self, df, transform=None):
        self.image_id = df['image'].values
        self.labels = df['labels'].values
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image_id = self.image_id[idx]
        label = self.labels[idx]
        
        image_path = TRAIN_DIR + image_id
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        augmented = self.transform(image=image)
        image = augmented['image']
        return {'image':image, 'target': label}

In [7]:
class PlantDataModule(pl.LightningDataModule):
    def __init__(self, batch_size, data_dir: str = './'):
        super().__init__()
        self.batch_size = batch_size
        
        # Train augmentation policy
        self.train_transform = Compose([
            A.RandomResizedCrop(height=CONFIG['img_size'], width=CONFIG['img_size']),
            A.HorizontalFlip(p=0.5),
            A.ShiftScaleRotate(p=0.5),
            A.RandomBrightnessContrast(p=0.5),
            A.Normalize(),
            ToTensorV2(),
        ])

        # Validation/Test augmentation policy
        self.test_transform = Compose([
            A.Resize(height=CONFIG['img_size'], width=CONFIG['img_size']),
            A.Normalize(),
            ToTensorV2(),
        ])
        

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            # Random train-validation split
            train_df, valid_df = train_test_split(df, test_size=CONFIG['train_val_split'])
            
            # Train dataset
            self.train_dataset = PlantDataset(train_df, self.train_transform)
            # Validation dataset
            self.valid_dataset = PlantDataset(valid_df, self.test_transform)
                        
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4, drop_last=True)

    def val_dataloader(self):
        return DataLoader(self.valid_dataset, batch_size=self.batch_size, num_workers=4)

# 🎺 LightningModule - Define the System

In [8]:
class CustomResNet(nn.Module):
    def __init__(self, model_name='resnet18', pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained)
        in_features = self.model.get_classifier().in_features
        self.model.fc = nn.Linear(in_features, CONFIG['num_classes'])

    def forward(self, x):
        x = self.model(x)
        return x

In [9]:
class LitCassava(pl.LightningModule):
    def __init__(self, model):
        super(LitCassava, self).__init__()
        self.model = model
        self.metric = pl.metrics.F1(num_classes=CONFIG['num_classes'])
        self.criterion = nn.CrossEntropyLoss()
        self.lr = CONFIG['lr']

    def forward(self, x, *args, **kwargs):
        return self.model(x)

    def configure_optimizers(self):
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=CONFIG['t_max'], eta_min=CONFIG['min_lr'])

        return {'optimizer': self.optimizer, 'lr_scheduler': self.scheduler}

    def training_step(self, batch, batch_idx):
        image = batch['image']
        target = batch['target']
        output = self.model(image)
        loss = self.criterion(output, target)
        score = self.metric(output.argmax(1), target)
        logs = {'train_loss': loss, 'train_f1': score, 'lr': self.optimizer.param_groups[0]['lr']}
        self.log_dict(
            logs,
            on_step=False, on_epoch=True, prog_bar=True, logger=True
        )
        return loss

    def validation_step(self, batch, batch_idx):
        image = batch['image']
        target = batch['target']
        output = self.model(image)
        loss = self.criterion(output, target)
        score = self.metric(output.argmax(1), target)
        logs = {'valid_loss': loss, 'valid_f1': score}
        self.log_dict(
            logs,
            on_step=False, on_epoch=True, prog_bar=True, logger=True
        )
        return loss

# 📲 Callbacks


In [10]:
# Checkpoint
checkpoint_callback = ModelCheckpoint(monitor='valid_loss',
                                      save_top_k=1,
                                      save_last=True,
                                      save_weights_only=True,
                                      filename='checkpoint/{epoch:02d}-{valid_loss:.4f}-{valid_f1:.4f}',
                                      verbose=False,
                                      mode='min')

# Earlystopping
earlystopping = EarlyStopping(monitor='valid_loss', patience=3, mode='min')

In [11]:
# Custom Callback
class ImagePredictionLogger(Callback):
    def __init__(self, val_samples, num_samples=32):
        super().__init__()
        self.num_samples = num_samples
        self.val_imgs, self.val_labels = val_samples['image'], val_samples['target']
        
    def on_validation_epoch_end(self, trainer, pl_module):
        # Bring the tensors to CPU
        val_imgs = self.val_imgs.to(device=pl_module.device)
        val_labels = self.val_labels.to(device=pl_module.device)
        # Get model prediction
        logits = pl_module(val_imgs)
        preds = torch.argmax(logits, -1)
        # Log the images as wandb Image
        trainer.logger.experiment.log({
            "examples":[wandb.Image(x, caption=f"Pred:{pred}, Label:{y}") 
                           for x, pred, y in zip(val_imgs[:self.num_samples], 
                                                 preds[:self.num_samples], 
                                                 val_labels[:self.num_samples])]
            }, commit=False)

> 📌 Tip: When logging manually through `wandb.log` or `trainer.logger.experiment.log`, make sure to use `commit=False` so the logging step does not increase.

## ⚡ Train and Evaluate the Model with W&B


In [12]:
# Init our data pipeline
datamodule = PlantDataModule(batch_size=CONFIG['batch_size'])
datamodule.setup()

# Samples required by the custom ImagePredictionLogger callback to log image predictions.
val_samples = next(iter(datamodule.val_dataloader()))
val_imgs, val_labels = val_samples['image'], val_samples['target']
val_imgs.shape, val_labels.shape

(torch.Size([32, 3, 256, 256]), torch.Size([32]))

In [13]:
# Init our model
model = CustomResNet(model_name=CONFIG['model_name'], pretrained=CONFIG['pretrained'])
lit_model = LitCassava(model)

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50_ram-a26f946b.pth" to /root/.cache/torch/hub/checkpoints/resnet50_ram-a26f946b.pth


Check out the documentation for WandbLogger [here](https://pytorch-lightning.readthedocs.io/en/stable/extensions/generated/pytorch_lightning.loggers.WandbLogger.html#pytorch_lightning.loggers.WandbLogger).

> 📌 Tip: dditional arguments like entity, group, tags, etc. used by `wandb.init()` can be passed as keyword arguments in this logger.

In [14]:
## Initialize wandb logger
wandb_logger = WandbLogger(project='plant-pathology-lightning', 
                           config=CONFIG,
                           group='ResNet', 
                           job_type='train')

# Initialize a trainer
trainer = Trainer(
            max_epochs=CONFIG['num_epochs'],
            gpus=1,
            accumulate_grad_batches=CONFIG['accum'],
            precision=CONFIG['precision'],
            callbacks=[earlystopping,
                       ImagePredictionLogger(val_samples)],
            checkpoint_callback=checkpoint_callback,
            logger=wandb_logger,
            weights_summary='top',
)

# Train the model ⚡🚅⚡
trainer.fit(lit_model, datamodule)

# Close wandb run
wandb.finish() 

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using native 16bit precision.
[34m[1mwandb[0m: Currently logged in as: [33mayush-thakur[0m (use `wandb login --relogin` to force relogin)



  | Name      | Type             | Params
-----------------------------------------------
0 | model     | CustomResNet     | 23.5 M
1 | metric    | F1               | 0     
2 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
23.5 M    Trainable params
0         Non-trainable params
23.5 M    Total params
94.130    Total estimated model params size (MB)


Validation sanity check: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Saving latest checkpoint...


VBox(children=(Label(value=' 31.51MB of 31.68MB uploaded (0.00MB deduped)\r'), FloatProgress(value=0.994544584…

0,1
_runtime,1353.0
_timestamp,1619808406.0
_step,4649.0
train_loss,0.64505
train_f1,0.79147
lr,0.00029
epoch,9.0
valid_loss,0.45967
valid_f1,0.86021


0,1
_runtime,▁▂▃▃▄▅▆▆▇██
_timestamp,▁▂▃▃▄▅▆▆▇██
_step,▁▂▂▃▄▄▅▆▇▇█
train_loss,█▅▄▃▃▂▂▂▁▁
train_f1,▁▄▅▆▆▇▇▇██
lr,███▇▆▆▅▃▂▁
epoch,▁▂▃▃▄▅▆▆▇█
valid_loss,██▆▅▅▅▄▃▁▁
valid_f1,▁▁▃▄▃▃▄▆██


## Visualize Metrics

![img](https://i.imgur.com/n6P7K4M.gif)

## Visualize Model Predictions

![img](https://i.imgur.com/lgkLnrt.gif)

## Visualize CPU and GPU Metrics

![img](https://i.imgur.com/ZLjrbhj.gif)

# ❄️ Resources

I hope you find this kernel useful and will encouage you to try out Weights and Biases. Here are some relevant links that you might want to check out:

* Check out the [official documentation](https://docs.wandb.ai/) to learn more about the best practices and advanced features. 

* Check out the [examples GitHub repository](https://github.com/wandb/examples) for curated and minimal examples. This can be a good starting point. 

* [Weights and Biases Fully Connected](https://wandb.ai/fully-connected) is a home for curated tutorials, free-form dicussions, paper summaries, industry expert advices and more. 