# Lecture 5: Fine-Tuning a Vision Tranformer using Lightning
Image classification: given an image, which of the following classes is it an image of.

## Lightning
The DL framework with batteries included. It is a layer on top of Pytorch to organize code to remove boilerplate: it abstracts away all the engineering complexity needed for scale.

The HuggingFace Trainer API can be seen as a framework similar to PyTorch Lightning in the sense that it also abstracts the training away using a Trainer object. However, contrary to PyTorch Lightning, it is not meant to be a general framework. Rather, it is made especially for fine-tuning Transformer-based models available in the HuggingFace Transformers library

In [1]:
!pip install --quiet "setuptools==59.5.0" "pytorch-lightning>=1.4" "matplotlib" "torch>=1.8" "ipython[notebook]" "torchmetrics>=0.7" "torchvision" "seaborn"


You should consider upgrading via the '/Users/maxcasas/.pyenv/versions/tests_aladdin_max/bin/python -m pip install --upgrade pip' command.[0m[33m
[0m

## Data Loading

In [2]:
from torchvision import transforms
from torchvision.datasets import CIFAR10
import pytorch_lightning as pl
import os

DATASET_PATH = os.environ.get("PATH_DATASETS", "data/")


In [3]:
train_transform = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.RandomResizedCrop(
            (32, 32), scale=(0.8, 1.0), ratio=(0.9, 1.1)),
        transforms.ToTensor(),
        transforms.Normalize([0.49139968, 0.48215841, 0.44653091],
            [0.24703223, 0.24348513, 0.26158784]),
    ]
)

test_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize([0.49139968, 0.48215841, 0.44653091],
        [0.24703223, 0.24348513, 0.26158784]),
    ]
)

Composing transformations. The compositions are performed in sequence. We horizontally flip the image randomly with a given probability (default 0.5). A crop of the original image is made: the crop has a random area (H*W) and a random aspect ratio. It is resized to the given size. scale (tuple or float) specifies the lower and upper bounds for the random area of the crop before resizing; while ratio (tuple or float) specifies the lower and upper bounds for the random aspect ration of the crop, before resizing. The ToTensor converts a PIL image or numpy ndarray HxWxC in the range [0, 255] to a torch.FloatTensor of shape CxHxw in the range [0.0, 1.0]. The last one normalizes a tensor image with mean and standard deviation. It will normalize each channel of the input using the precomputed means and standard deviation for the CIGAR dataset that we will use. The constants correspond to the values that scale and shift the data to a zero mean and standard deviation of one. 

The transformations of train are different from test, because the train transforms help augment the data to give the dataset more examples, but in test time, we don't want to corrupt the examples by performing augmentations like cropping them. Tip: Test time augmentations where multiple augmented images are passed through the network and their outputs averaged to get a more performant model. 

In [None]:
train_dataset = CIFAR10(
    root=DATASET_PATH, train=True, download=True, transform=train_transform
)
val_dataset = CIFAR10(
    root=DATASET_PATH, train=True, download=True, transform=test_transform
)
test_dataset = CIFAR10(
    root=DATASET_PATH, train=False, download=True, transform=test_transform
)

We're loading the same dataset as in the train dataset, but with different transformations. We're applying the same transform to the test set as we do to the val set because we want the validation set to help us pick a model that will perform well on the test set. 

In [None]:
import torch
import torch.utils.data as data
pl.seed_everything(42)
train_set, _ = torch.utils.data.random_split(train_dataset, [45000, 5000])
_, val_set = torch.utils.data.random_split(val_dataset, [45000, 5000])

Sets pseudo-random number generators and sets a couple of environment variables. The train_dataset and val_dataset loaded the same data and transformed it in two different ways. Here it looks like we’re able to make the train_set and val_set use different sets of images, which is what we’d like to evaluate generalization.



In [None]:
import matplotlib.pyplot as plt
import torchvision

# Visualize some examples
NUM_IMAGES = 4
CIFAR_IMAGES = torch.stack(
    [val_set[idx][0] for idx in range(NUM_IMAGES)], dim=0
)

img_grid = torchvision.utils.make_grid(CIFAR_IMAGES, nrow=4, normalize = True, pad_value = 0.9)
img_grid = img_grid.permute(1, 2, 0)

plt.figure(figsize=(10, 10))
plt.title("Image examples of the CIFAR10 dataset")
plt.imshow(img_grid)
plt.axis("off")
plt.show()
plt.close()

In [None]:
train_loader = data.DataLoader(
    train_set, batch_size=128,
    shuffle = True, drop_last=True, pin_memory=True, num_workers=4)
val_loader = data.DataLoader(
    val_set, batch_size=128,
    shuffle = False, drop_last=False, pin_memory=True, num_workers=4)
test_loader = data.DataLoader(
    test_dataset, batch_size=128,
    shuffle = False, drop_last=False, pin_memory=True, num_workers=4)

The dataloader combines a dataset and a sampler, and provides an iterable over the given dataset. It allows us to iterate over a dataset in batches given by the batch size. Shuffles at every epoch if true, which improves performance. This is beause gradient descent relies on randomization to get out of local minimas. drop_last drops the last incomplete batch if the dataset size is not divisible by the batch size if True. num_workers specifies how many subprocess to use for data loading. 

num_workers?
* =0 means ONLY the main process will load batches (that can be a bottleneck)
* =1 means only one workers (just not the main process) will load data, but it will still be slow
* The performance of high num_workers depends on the batch size and your machine
* A general place to start is to set it eqal to the **number of CPU cores in that machine**. Use os.cpu_count(), but depending on your batch size, you may overflow RAM memory
* Increasing the number will increase your CPU memory consumption
* The best thing is to increase slowly and stop once there is no more im provement in your training speed. For debugging purposes or for dataloaders that load very small datasets, i is desirable to set it equal to 0

# Tokenization


In [5]:
def img_to_patch(x, patch_size, flatten_channels=True):
    """
    Inputs:
        x - Tensor representing the image of shape [B, C, H, W]
        patch_size - Number of pixels per dimension of the patches (integer)
        flatten_channels - If True, the patches will be returned in a flattened 
        format as a feature vector instead of an image grid
    """

    B, C, H, W = x.shape
    x = x.reshape(
        B, 
        C,
        torch.div(H, patch_size, rounding_mode="trunc"),
        patch_size,
        torch.div(W, patch_size, rounding_mode="floor"),
        patch_size
    )

    x = x.permute(0, 2, 4, 1, 3, 5) # [B, H', W', C, p_H, p_W]
    x = x.flatten(1, 2) # [B, H' * W', C, p_H, p_W]

    if flatten_channels:
        x = x.flatten(2, 4) # [B, H' * W', C * p_H * p_W]
    
    return x

img_patches = img_to_patch(
    CIFAR_IMAGES, patch_size=4, flatten_channels=Falsec
)

8

The Visio Transformer is a model for image classification that views images as sequences of smaller patches. So as a preprocessing step, we split an image of, for example, 32x32 pixels into a grid of 8x8 of size 4x4 each. The Batch and Channels dimensions are untouched, and we're working to transform the Height and Wifdth into 4 pieces: H', p_H, 2', p_W.

The permute operations are getting us to the point at which we will have H'*W' patches for every image, and we can visualize them by looking at (C, p_H, p_W). 

Then, we combine (flatten) the height and width dimension so that we have one vector of (C*p_H*p_W) elements for each of the H'*W' patches. Each of those patches is considered to be a "word"/"token".

In [None]:
fig, ax = plt.subplots(CIFAR_IMAGES.shape[0], 1, figsize=(14, 3))
fig.suptitle("Images as input sequences of patches")
for i in range(CIFAR_IMAGES.shape[0]):
    img_grid = torchvision.utils.make_grid(
        img_patches[i], nrow=64, normalize = True, pad_value = 0.9
    )
    img_grid = img_grid.permute(1, 2, 0)
    ax[i].imshow(img_grid)
    ax[i].axis("off")
plt.show()
plt.close()

The make_grid function takes in a 4D mini-batch Tensor of shape (BxCxHxW) or a list of images all of the same size; nrow sets the number of images displayed in each row of the grid, normalize shifts the image to the range (0,1) and pad_value sets the value for the padded pixels. 

# Neural Net Module


In [None]:
import torch.nn as nn

class AttentionBlock(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.0):
        super().__init__()

        self.layer_norm_1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.layer_norm_2 = nn.LayerNorm(embed_dim)
        self.linear = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(dropout),
        )
    
    def forward(self, x):
        inp_x = self.layer_norm_1(x)
        x = x + self.attn(inp_x, inp_x, inp_x)[0]
        x = x + self.linear(self.layer_norm_2(x))

        return x


In the forward method we compute output Tensors from input Tensors. 

In [None]:
class VisionTransformer(nn.Module):
    def __init__(
        self,
        embed_dim,
        hidden_dim,
        num_channels, 
        num_heads, 
        num_layers,
        num_classes,
        patch_size, 
        num_patches,
        dropout=0.0
    ):

        super().__init__()

        self.patch_size = patch_size

        # Layers/Networks
        self.input_layer = nn.Linear(
            num_channels * (patch_size**2), embed_dim
        )

        self.transformer = nn.Sequential(
            *[AttentionBlock(embed_dim, hidden_dim, num_heads, dropout) for _ in range(num_layers)]
        )

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim), nn.Linear(embed_dim, num_classes)
        )

        self.dropout = nn.Dropout(dropout)

        # Parameters/Embeddings
        self.cls_token = nn.Parameter(
            torch.randn(1, 1, embed_dim)
        )
        self.pos_embedding = nn.Parameter(
            torch.randn(1, num_patches + 1, embed_dim)
        )

        


* A linear projection layer that maps the input patches (each of num_channels * patch_size**2) to a feature vector of larger size (embed dim)
* A multi-layer perceptron MLP that takes an output feature vector and maps it to a classification prediction

In [None]:
def forward(self, x):
    x = img_to_patch(x, self.patch_size)
    B, T, _ = x.shape
    x = self.input_layer(x)

    cls_token = self.cls_token.repeat(B, 1, 1)
    x = torch.cat([cls_token, x], dim=1)
    x = x + self.pos_embedding[:, :T + 1]
    
    x = self.dropout(x)
    x = x.transpose(0, 1)
    x = self.transformer(x)

    cls = x[0]
    out = self.mlp_head(cls)
    return out

We are doing a sum of the positional embeddings with our x. Notice how pos_embeddings is of shape [1, 65, 256] and x is of shape [B, 65, 256] and yet we're able to sum them up, applying the pos_embeddings to every sample in the batch. This is broadcasting. 

# Lightning Module

In [None]:
import torch.nn.functional as F
import torch.optim as optim

class ViT(pl.LightningModule):
    def __init__(self, model_kwargs, lr):
        super().__init__()
        self.save_hyperparameters()
        self.model = VisionTransformer(**model_kwargs)

    def forward(self, x):
        return self.model(x)
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.model.parameters(), lr=self.hparams.lr)
        lr_scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[100, 150], gamma=0.1
        )
        
        return [optimizer], [lr_scheduler]
    
    def _calculate_loss(self, batch, mode="train"):
        imgs, labels = batch
        preds = self.model(imgs)
        loss = F.cross_entropy(preds, labels)
        acc = (preds.argmax(dim=1) == labels).float().mean()

        self.log("%s_loss" % mode, loss, prog_bar=True)
        self.log("%s_acc" % mode, acc, prog_bar=True)

        return loss
    
    def training_step(self, batch, batch_idx):
        loss = self._calculate_loss(batch, mode="train")
        return loss

    def validation_step(self, batch, batch_idx):
        self._calculate_loss(batch, mode="val")

    def test_step(self, batch,  batch_idx):
        self._calculate_loss(batch, mode="test")

A LightningModule organizes your PyTorch code into sections:
* Computations
* Forward: Used for inference only (separate from training_step)
* Optimizer and scheduler (through configure_optimizers). The optimizer takes in the parameters and determines how the parameters are updated. The scheduler contains the optimizer as a member and alters its parameters learning rates. We don't need to worry about these for now
* Training Loop (training_step)
* Validation Loop (validation_step)
* Test Loop (test_step)

All of the loops use _calculate_loss, which computes the cross_entropy loss for the batch comparing the predictions (pred) of the model with the labels, logging the accuracy in the process. 

In [7]:
CHECKPOINT_PATH = os.environ.get(
    "PATH_CHECKPOINT",
    "saved_models/VisionTransformers/")

def train_model(**kwargs):
    trainer = pl.Trainer(
        default_root_dir=os.path.join(CHECKPOINT_PATH, "ViT"),
        fast_dev_run=5,
    )

    pl.seed_everything(42)  # To be reproducible
    model = ViT(**kwargs)
    trainer.fit(model, train_loader, val_loader)
    test_result = trainer.test(
        model, dataloaders=test_loader, verbose=False)
    return model, test_result

The basic use of the trainer is to initialize and then fit the model using the train_loader and the val_loader. We use the test method on the Trainer using the test loader. 

In [None]:
model, results = train_model(
    model_kwargs={
        "embed_dim": 256,
        "hidden_dim": 512,
        "num_heads": 8,
        "num_layers": 6,
        "patch_size": 4,
        "num_channels": 3,
        "num_patches": 64,
        "num_classes": 10,
        "dropout": 0.2,
    },
    lr=3e-4,
)
print("Results", results)