# Vision Transformers (ViT)

Based on "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale", Dosovitskiy et. al.

In [1]:
from vision.attention2D import Attention
import torch

import pickle
import numpy as np

from torch.utils.data import Dataset, DataLoader
import torch.nn as nn 

In [2]:
device = torch.device("mps")

In [3]:
# data boiler plate

all_batches_data = []
all_batches_labels = []

for i in range(1, 6):
    with open(f'generative/autoencoders/data/cifar-10-batches-py/data_batch_{i}', 'rb') as f:
        dataset_dict = pickle.load(f, encoding='bytes')
        all_batches_data.append(dataset_dict[b'data'])
        all_batches_labels.append(dataset_dict[b'labels'])

stacked_data = np.vstack(all_batches_data)
stacked_labels = np.hstack(all_batches_labels)
data = torch.tensor(stacked_data, dtype=torch.float32).view(-1, 3, 32, 32).to(device) / 255.
labels = torch.tensor(stacked_labels, dtype=torch.long).to(device)

split_idx = int(0.8 * len(data))

x_train, x_valid = data[:split_idx], data[split_idx:]
y_train, y_valid = labels[:split_idx], labels[split_idx:]

class CIFARCustomDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

train_ds = CIFARCustomDataset(x_train, y_train)
valid_ds = CIFARCustomDataset(x_valid, y_valid)

train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)
valid_dl = DataLoader(valid_ds, batch_size=32, shuffle=True)

In [None]:
def patchify(x: torch.Tensor, patch_size: int):
    print(x.shape)
    batch_size, channels, width, height = x.shape
    x = x.reshape(batch_size, channels, width // patch_size, patch_size, height // patch_size, patch_size)  # reshape into patches: b, c, w, h -> b, c, num_horizontal, patch_w, num_vertical, patch_h
    print(f"reshaped: {x.shape}")
    x = x.permute(0, 2, 4, 1, 3, 5)  # permute, b, c, num_horizontal, patch_w, num_vertical, patch_h -> b, num_horizontal, num_vertical, channels, patch_w, patch_h
    print(f"permuted: {x.shape}")
    return x.reshape(batch_size, -1, channels * patch_size * patch_size)  # flatten the patches

x_train = patchify(x_train, 8)
x_train.shape

torch.Size([40000, 3, 32, 32])
reshaped: torch.Size([40000, 3, 4, 8, 4, 8])
permuted: torch.Size([40000, 4, 4, 3, 8, 8])


torch.Size([40000, 16, 192])

## Transformer related modules

In [4]:
class MLP(nn.Module): 
    """
    Class implementation of a position wise MLP
    """
    def __init__(self, d_model: int, d_ff: int, dropout: float, num_layers: int = 2) -> None:
        super(MLP, self).__init__()
        
        layers = []
        layers.append(nn.Linear(d_model, d_ff, bias=True))
        for i in range(1, num_layers - 1):
            layers.append(nn.Linear(d_ff, d_ff, bias=True))
            layers.append(nn.ReLU())
        
        layers.append(nn.Linear(d_ff, d_model, bias=True))
        self.mlp_layers = nn.Sequential(*layers)
        
        self.layer_norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x
        x = self.mlp_layers(self.layer_norm(x))
        return self.dropout(x) + residual


class EncoderLayer(nn.Module):
    """
    Encoder layer block for ViT
    """
    def __init__(
        self, 
        num_heads: int,
        num_channels: int,
        d_linear: int,
        num_linear_layers: int = 2,
        num_groups: int = 8,
        dropout: float = 0.1,
        is_masked: bool = False
    ):
        super(EncoderLayer, self).__init__()
        self.norm1, self.norm2 = nn.LayerNorm(num_channels),  nn.LayerNorm(num_channels)
        self.mha = Attention(dropout, num_heads, num_channels, num_groups)
        self.mlp = MLP(num_channels, d_linear, dropout, num_linear_layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = self.mha(self.norm1(x)) + x 
        return self.mlp(self.norm2(h)) + h


class Encoder(nn.Module):
    def __init__(
        self, 
        num_heads: int,
        num_channels: int,
        num_layers: int,
        d_linear: int,
        num_linear_layers: int = 2,
        num_groups: int = 8,
        dropout: float = 0.1,
        is_masked: bool = False
    ):
        super(Encoder, self).__init__()
        self.layers = [EncoderLayer(
            num_heads, num_channels, d_linear, num_linear_layers, num_groups, dropout, is_masked
        ) for _ in range(num_layers)]
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for layer in self.layers:
            layer = layer.to(x.device)
            x = layer(x)
        return x 