<a href="https://colab.research.google.com/github/Tensor-Reloaded/AI-Learning-Hub/blob/main/resources/beginner_pytorch/06_data_augmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 06. Simple Data Augmentation

Start with the official documentation:

1. https://docs.pytorch.org/vision/main/auto_examples/transforms/plot_transforms_getting_started.html
2. https://docs.pytorch.org/vision/stable/auto_examples/transforms/plot_transforms_illustrations.html

After finishing and understanding the examples in the official documentation, proceed further.

In [None]:
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

import torchvision
from torchvision.transforms import v2
from torchvision.transforms.v2.functional import hflip
from torchvision.datasets import OxfordIIITPet

from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="True"

In [None]:
img_size = 224
num_classes = 37

In [None]:
class ClassificationDataset(Dataset):
    def __init__(self, split, image_transforms):
        self.data = OxfordIIITPet(
            root="../data",
            download=True,
            split=split,
            target_types=("category",),
        )
        self.image_transforms = image_transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        image, class_label = self.data[i]
        image = self.image_transforms(image)
        return image, class_label


image_transforms_train = v2.Compose([
    v2.ToImage(),
    v2.Resize([img_size, img_size]),
    v2.RandomCrop([img_size, img_size], padding=12),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToDtype(torch.float32, scale=True),
    v2.AutoAugment(),
])

image_transforms_test = v2.Compose([
    v2.ToImage(),
    v2.Resize([img_size, img_size]),
    v2.ToDtype(torch.float32, scale=True),
])

cutmix_or_mixup = v2.RandomChoice([
    v2.CutMix(num_classes=num_classes),
    v2.MixUp(num_classes=num_classes),
])

train_dataset = ClassificationDataset("trainval", image_transforms_train)
test_dataset = ClassificationDataset("test", image_transforms_test)

train_loader = DataLoader(train_dataset, shuffle=True, batch_size=32, drop_last=True)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=32, drop_last=False)

## Automatic augmentation

1\. [AutoAugment: Learning Augmentation Policies from Data](https://arxiv.org/abs/1805.09501)
```py
from torchvision.transforms import v2

v2.AutoAugment(policy=v2.AutoAugmentPolicy.IMAGENET) (default)
v2.AutoAugment(policy=v2.AutoAugmentPolicy.SVHN)
v2.AutoAugment(policy=v2.AutoAugmentPolicy.CIFAR10)
```
2\. [RandAugment: Practical automated data augmentation with a reduced search space](https://arxiv.org/abs/1909.13719L)
```py
from torchvision.transforms import v2

v2.RandAugment(num_ops=2, magnitude=8) (default)
v2.RandAugment(num_ops=4, magnitude=4)
v2.RandAugment(num_ops=3, magnitude=10)
```

In [None]:
for images, labels in DataLoader(train_dataset, shuffle=False, batch_size=8):
    break
v2.functional.to_pil_image(torchvision.utils.make_grid(images, nrow=4))

In [None]:
for images, labels in DataLoader(train_dataset, shuffle=False, batch_size=8):
    break
v2.functional.to_pil_image(torchvision.utils.make_grid(images, nrow=4))

### CutMix

In [None]:
for images, labels in DataLoader(train_dataset, shuffle=False, batch_size=8):
    break
images, labels  = v2.CutMix(num_classes=num_classes)(images, labels)
v2.functional.to_pil_image(torchvision.utils.make_grid(images, nrow=4))

In [None]:
for images, labels in DataLoader(train_dataset, shuffle=True, batch_size=8):
    break
images, labels  = v2.CutMix(num_classes=num_classes)(images, labels)
v2.functional.to_pil_image(torchvision.utils.make_grid(images, nrow=4))

Let's see what happens with the labels

In [None]:
def draw_all(to_draw, labels):
    size = num_classes // 4
    for i in range(len(to_draw)):
        plt.imshow(v2.functional.to_pil_image(to_draw[i]))
        label = labels[i]
        label = [round(x.item(), 4) for x in torch.atleast_1d(label)]
        label = '\n'.join([str(label[i * size: (i + 1) * size]) for i in range(4)])
        plt.title(label)
        plt.axis("off")
        plt.show()

In [None]:
draw_all(images, labels)

### MixUp

In [None]:
for images, labels in DataLoader(train_dataset, shuffle=False, batch_size=8):
    break
images, labels  = v2.MixUp(num_classes=num_classes)(images, labels)
v2.functional.to_pil_image(torchvision.utils.make_grid(images, nrow=4))

In [None]:
for images, labels in DataLoader(train_dataset, shuffle=True, batch_size=8):
    break
images, labels  = v2.MixUp(num_classes=num_classes)(images, labels)
v2.functional.to_pil_image(torchvision.utils.make_grid(images, nrow=4))

In [None]:
draw_all(images, labels)

## Other Data Augmentation techniques

The [Mozaic](https://arxiv.org/pdf/2004.12432) Data Augmentation.

For Object detection & segmentation tasks, see https://github.com/albumentations-team/albumentations.


# Excercises

1. Use data augmentation to achieve better results for MNIST. Try use flipping. Does it help?
2. Implement a pipeline for CIFAR-10. Try to achieve better results!