<a href="https://colab.research.google.com/github/adam-mehdi/MuarAugment/blob/master/RandAugmentTutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Tutorial for `AlbumentationsRandAugment`

`AlbumentationsRandAugment` is a straightforward implementation of [RandAugment](https://arxiv.org/abs/1909.13719?utm_source=feedburner&utm_medium=feed&utm_campaign=Feed%253A+arxiv%252FQSXk+%2528ExcitingAds%2521+cs+updates+on+arXiv.org%2529). Use it just like a list of transforms. I provide here an end-to-end pipeline for image classification in PyTorch Lightning, but if you are interested only in `AlbumentationRandAugment`, skip to section 2: 'Creating the RandAugment Dataset'. Let's begin! 🎇

## 1. Install and import

In [1]:
%%capture 
!pip install albumentations --upgrade
!pip install timm
!pip install pytorch-lightning
!pip install git+https://github.com/adam-mehdi/MuarAugment.git

In [2]:
import os
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedKFold

import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2

import torch
import torchvision
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import pytorch_lightning as pl

from muar.augmentations import AlbumentationsRandAugment



## 2. Creating the RandAugment Dataset

We add RandAugment to the Dataset. Here's a simplified version of the algorithm:

```python
import numpy as np
import albumentations as 

def rand_augment(N_TFMS, MAGN):
    # initialize the transform list
    transforms = [A.HorizontalFlip(p=1), 
                  A.Rotate(MAGN*9, p=1),  
                  A.RandomBrightness(MAGN/20, p=1)]
    # randomly choose `N_TFMS` transforms from the list
    composition = np.random.choice(transforms, N_TFMS, replacement=False)   
    return A.Compose(composition)
```

We initialize `AlbumentationsRandAugment` in the Dataset's `__init__` and call that object within `get_transform`. Each time we call the `rand_augment`, the object returns a different, randomly chosen list of transforms. We want to apply a different composition on each image, so we need to do that every time we get an item. We can implement that using `AlbumentationsRandAugment` as follows.

In [3]:
class RandAugmentDataset(Dataset):
    def __init__(self, data, stage='train', image_size=(28,28), N_TFMS=0, MAGN=0):
        super().__init__()
        self.images,self.labels = list(zip(*data))
        self.stage, self.size = stage, image_size
        self.N_TFMS, self.MAGN = N_TFMS, MAGN
        if stage == 'train':
            self.rand_augment = AlbumentationsRandAugment(N_TFMS, MAGN)
        else: 
            self.rand_augment = None
        
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image,label = self.images[idx],self.labels[idx]
        image = np.array(image)[:,:,None]
        image = np.repeat(image, 3, axis=2) # image must be 3 channels
        
        transform = get_transform(self.rand_augment, self.stage, self.size)
        augmented = transform(image=image)['image']
        return augmented, torch.LongTensor([label])

In [5]:
def get_transform(rand_augment, stage='train', size=(28,28)):
    if stage == 'train':
        resize_tfm = [A.Resize(*size)]
        
        rand_tfms = rand_augment() # returns a list of transforms

        tensor_tfms = [A.Normalize(), ToTensorV2()]
        return A.Compose(resize_tfm + rand_tfms + tensor_tfms)

    elif stage=='valid':
        resize_tfm = [A.Resize(*size)]
        tensor_tfms = [A.Normalize(), ToTensorV2()]
        return A.Compose(resize_tfm + tensor_tfms)

## 3. Defining the Model

A standard `LightningModule` with pretraining capacities.

In [6]:
# Backbone for transfer learning.
class Backbone(pl.LightningModule):
    def __init__(self, model_name='resnet18', pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained)
        self.in_features = self.model.get_classifier().in_features
        self.model.fc = nn.Identity()

    def forward(self, x):
        x = self.model(x)
        return x

class LitModule(pl.LightningModule):
    def __init__(self, 
                 model_name: str,
                 pretrained: bool,
                 num_classes: int,
                 lr: float
                 ):
        
        super().__init__()
        self.save_hyperparameters()

        self.backbone = Backbone(model_name=model_name, pretrained=pretrained)
        self.backbone.freeze()

        self.fc = nn.Linear(self.backbone.in_features, num_classes)
        self.metric = pl.metrics.F1(num_classes=num_classes)
        self.criterion = nn.CrossEntropyLoss()
        self.lr = lr

    def forward(self, x):
        x = self.backbone(x)
        return self.fc(x)

    def training_step(self, batch, batch_idx):
        input,target = batch[0],batch[1].squeeze(1)

        if self.current_epoch == 2: self.backbone.unfreeze()
        output = self(input)
        loss = self.criterion(output, target)
        score = self.metric(output.argmax(1), target)

        return loss

    def validation_step(self, batch, batch_idx):
        input,target = batch[0],batch[1].squeeze(1)

        output = self(input)
        loss = self.criterion(output, target)
        score = self.metric(output.argmax(1), target)

        return loss
    
    def configure_optimizers(self):
        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
        return self.optimizer

## 4. Training

We're going to train on the FashionMNIST dataset, classifying what kind of apparel an image portrays.

In [7]:
class cfg:
    batch_size = 64
    N_TFMS = 3
    MAGN = 4
    model_name = 'resnet18'
    pretrained = False
    num_classes = 10
    lr = 3e-3
    max_epochs = 20
    precision = 16

In [8]:
%%capture
train_data = torchvision.datasets.FashionMNIST('/content/', train=True, download=True)
valid_data = torchvision.datasets.FashionMNIST('/content/', train=False, download=True)

train_dataset = RandAugmentDataset(train_data, stage='train', N_TFMS=cfg.N_TFMS, MAGN=cfg.MAGN)
valid_dataset = RandAugmentDataset(valid_data, stage='valid')

train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=2, drop_last=True, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=2, drop_last=True, pin_memory=True)

In [9]:
model = LitModule(cfg.model_name, cfg.pretrained, cfg.num_classes, cfg.lr)



In [14]:
trainer = pl.Trainer(gpus=1, precision=cfg.precision, max_epochs=cfg.max_epochs)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.
Running in fast_dev_run mode: will run a full train, val and test loop using 1 batch(es).


In [15]:
trainer.fit(model, train_dataloader=train_loader, val_dataloaders=valid_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | backbone  | Backbone         | 11.2 M
1 | fc        | Linear           | 5.1 K 
2 | metric    | F1               | 0     
3 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
5.1 K     Trainable params
11.2 M    Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




In [12]:
# save model
trainer.save_checkpoint("/content/image_classification_model.pt")

I'll leave inspecting the performance for you! Thank you for your attention. (｡･∀･)ﾉﾞ