# ViT implementation practice

#### 1. **Tokenizer**, which takes an image and splits into several non-overlapping patches.

In [None]:
try:
    import einops
except:
    !pip install einops

In [None]:
import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange


class Image2Tokens(nn.Module):
    def __init__(self, image_size, dim, in_dim=3, patch_size=16, emb_dropout=0.):
        super().__init__()
        image_height, image_width = image_size
        num_patches = (image_height // patch_size) * (image_width // patch_size)
        self.to_patch_embedding = nn.Sequential(
            Rearrange('#! >>> fill the correct einops statement here for prepare patches', p1=patch_size, p2=patch_size),
            #! >>> fill the embedding layer declaration here (hint: transform from `patch_size * patch_size * in_dim` to `dim`
        )
        self.pos_embedding = #! >>> fill the modules declaration here
        self.cls_token = #! >>> fill the modules declaration here
        self.dropout = nn.Dropout(emb_dropout)

    def forward(self, img):
        x = self.to_patch_embedding(img)

        #! >>> fill the neccessary codes here
        # Steps:
        #      1. Split the images into (patch_size x patch_size) non-overlapping patches
        #      2. Apply a linear transformation to make the 2D patches into 1D vectors in given hidden size: `dim`
        #      3. Append a class-token at the very beginning (index 0)
        #      4. Add a learnable embeddings to every tokens (including the added class-token)
        
        return self.dropout(x)

In [None]:
tokenizer = Image2Tokens(image_size=(224,224), dim=768)
tokenizer(torch.randn(1,3,224,224)).shape

#### 2. **Multi-Head Self-Attention**. Implement the following equation: $Softmax(\frac{QK^T}{\sqrt{d}})V$, where Q, K, V are embedded representations of input, $Q=w_Qx, K=w_Kx, V=w_Vx$, in multi-head manner.

In [None]:
class Attention(nn.Module):
    def __init__(self, dim, heads=8, dropout=0.):
        super().__init__()
        self.heads = heads
        self.scale = #! >>> fill the correct scale here
        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(dim, dim*3)
        self.to_out = nn.Sequential(
            nn.Linear(dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
        
        #! >>> fill the implementation to complete the self-attention implementations
        # Hint: the q,k,v are already in multi-headed, with the shape: (batch, heads, num_tokens, dim).
        
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

In [None]:
attn_op = Attention(dim=768)
attn_op(torch.randn(1,197,768)).shape

#### 3. **FeedForwardNetwork (FFN)**. Implement $FFN=w_2 GELU(w_1x + b_1) + b_2$.

In [None]:
class FeedForwardNetwork(nn.Module):
    def __init__(self, dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            #! >>> fill the neccessary modules to complete the FFN module, note the w1 transforms the dim to 4*dim, and w2 transforms the intermidiate output from 4*dim back to dim.
        )
        
    def forward(self, x):
        return self.net(x)

In [None]:
ffn = FeedForwardNetwork(768)
ffn(torch.randn(1,197,768)).shape

#### 4. Implement the **transformer encoder**, using prenorm shortcut style.

In [None]:
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)
    

class Transformer(nn.Module):
    def __init__(self, layers, dim, heads=8, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(layers):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads=heads, dropout=dropout)),
                PreNorm(dim, FeedForwardNetwork(dim, dropout=dropout))
            ]))
    def forward(self, x):
        #! >>> Implement the forward flow here, all the neccessary modules are already declared.
        # Remember the return the result :)

In [None]:
model = Transformer(layers=12, dim=768)
model(torch.randn(1,197,768)).shape

#### 5. Package the **ViT model** with hyper-parameters configurable: (i) number of layers; (ii) hidden size; (iii) number of multiheads; (iv) image size (for tokenizer); (v) classifier 

In [None]:
class ViT(nn.Module):
    def __init__(self, layers, dim, heads, image_size, num_classes, patch_size=16, in_dim=3, dropout=0., emb_dropout=0.):
        super().__init__()
        self.tokenizer = Image2Tokens(#! >>> using the arguments to call the implemented components in previous cells
        self.transformer = Transformer(#! >>> using the arguments to call the implemented components in previous cells
        self.classifier = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        #! >>> Implement the forward flow here
        # Note the classifier should accept only class-token, not all the tokens.              
        return self.classifier(out)

In [None]:
model = ViT(layers=12, dim=768, heads=12, image_size=(224,224), num_classes=1000)
model(torch.randn(1,3,224,224)).shape

#### Sanity check the parameter counts are matches the numbers in the ViT paper: (i) 87M for ViT-Base; (ii) 304M for ViT-Large; (iii) 632M for ViT-Huge

In [None]:
def num_of_parameters(model):
    params = 0
    for i in model.parameters():
        params += i.numel()
    return params

In [None]:
vit_base = ViT(layers=12, dim=768, heads=12, image_size=(224,224), num_classes=1000)
vit_large = ViT(layers=24, dim=1024, heads=16, image_size=(224,224), num_classes=1000)
vit_huge = ViT(layers=32, dim=1280, heads=16, image_size=(224,224), num_classes=1000)

print(num_of_parameters(vit_base))
print(num_of_parameters(vit_large))
print(num_of_parameters(vit_huge))

#### The outputs should be:
```
86567656
304326632
632199400
```

---

# Train the implemented ViT on "Cat vs Dog Classification"

In [None]:
import os
import zipfile
import glob
from sklearn.model_selection import train_test_split

def UnzipTrainSet():
    with zipfile.ZipFile('train.zip') as train_zip:
        train_zip.extractall('data_dogcat')
        
def GetList():
    train_list = glob.glob(os.path.join('data_dogcat', 'train', '*.jpg'))
    return train_list

#### Train/Validation Split

In [None]:
import os
if not os.path.exists('train.zip'):
    !wget -O train.zip https://www.dropbox.com/s/wd28l8279yttbez/train.zip?dl=1

try:
    train_list = GetList()
    assert len(train_list) == 25000
except:
    os.makedirs('data_dogcat', exist_ok=True)
    UnzipTrainSet()
    train_list = GetList()
    assert len(train_list) == 25000

labels = [path.split('/')[-1].split('.')[0] for path in train_list]
split_train_list, split_val_list = train_test_split(train_list, 
                                                    test_size=0.2,
                                                    stratify=labels)
print(f"Train Data: {len(split_train_list)}")
print(f"Valid Data: {len(split_val_list)}")

#### Dataloader and Augmentations (resize, horizontal flip)

In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image

train_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
)

val_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]
)

class CatsDogsDataset(Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform

    def __len__(self):
        self.filelength = len(self.file_list)
        return self.filelength

    def __getitem__(self, idx):
        img_path = self.file_list[idx]
        img = Image.open(img_path)
        img_transformed = self.transform(img)

        label = img_path.split("/")[-1].split(".")[0]
        label = 1 if label == "dog" else 0

        return img_transformed, label

## Training with implemented ViT 

#### To run the training we need least 10GB of GPU memory

In [None]:
from tqdm.auto import tqdm
import torch
from torch import nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR

batch_size = 32
epochs = 5
lr = 1e-3
weight_decay = 1e-6
device = 'cuda'

model = ViT(layers=9, dim=192, heads=12, image_size=(224, 224), num_classes=2)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = CosineAnnealingLR(optimizer, epochs)

train_data = CatsDogsDataset(split_train_list, transform=train_transforms)
val_data = CatsDogsDataset(split_val_list, transform=val_transforms)

train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, persistent_workers=True)
val_loader = DataLoader(dataset=val_data, batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True, persistent_workers=True)


for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    model.train()
    for data, label in tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

        
    model.eval()    
    with torch.inference_mode():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in tqdm(val_loader):
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(val_loader)
            epoch_val_loss += val_loss / len(val_loader)


    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )
    
    scheduler.step()

## (Optional) Using ImageNet pretrained weights, using `timm` library

#### we use half batch size to fit the smallest pretrained model, ViT Base, however it still needs at least 12G GPU memory to run this cell

In [None]:
try:
    import timm
except:
    !pip install timm

In [None]:
import timm

class PretrainedViT(nn.Module):
    def __init__(self, num_class):
        super().__init__()
        self.pretrained_model = timm.create_model('vit_base_patch16_224_in21k', True)
        self.classifier = nn.Linear(768, num_class)
    
    def forward(self, x):
        x = self.pretrained_model.forward_features(x)
        return self.classifier(x)

In [None]:
from tqdm.auto import tqdm
import torch
from torch import nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR

batch_size = 16
epochs = 5
lr = 1e-4
weight_decay = 1e-6
device = 'cuda'

pretrained_model = PretrainedViT(2)
pretrained_model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(pretrained_model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = CosineAnnealingLR(optimizer, epochs)


for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    pretrained_model.train()
    for data, label in tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)

        output = pretrained_model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

            
    pretrained_model.eval()
    with torch.inference_mode():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in tqdm(val_loader):
            data = data.to(device)
            label = label.to(device)

            val_output = pretrained_model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(val_loader)
            epoch_val_loss += val_loss / len(val_loader)


    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )
    
    scheduler.step()

#### Since ViT is a data-hungry model, the significant accuracy improvement can be observed using large dataset pretrained weights. This suggests always consider using the pretrained model if possible.

# (Extra) Visual inspect the learned attention weights

#### The following codes only applicable on our implemented ViT. Some modifications are needed for `timm` ViTs.

In [None]:
from functools import wraps
import torch
from torch import nn


def find_modules(nn_module, type):
    return [module for module in nn_module.modules() if isinstance(module, type)]

class Recorder(nn.Module):
    def __init__(self, vit, device = None):
        super().__init__()
        self.vit = vit

        self.data = None
        self.recordings = []
        self.hooks = []
        self.hook_registered = False
        self.ejected = False
        self.device = device

    def _hook(self, _, input, output):
        self.recordings.append(output.clone().detach())

    def _register_hook(self):
        modules = find_modules(self.vit.transformer, Attention)
        for module in modules:
            handle = module.attend.register_forward_hook(self._hook)
            self.hooks.append(handle)
        self.hook_registered = True

    def eject(self):
        self.ejected = True
        for hook in self.hooks:
            hook.remove()
        self.hooks.clear()
        return self.vit

    def clear(self):
        self.recordings.clear()

    def record(self, attn):
        recording = attn.clone().detach()
        self.recordings.append(recording)

    def forward(self, img):
        assert not self.ejected, 'recorder has been ejected, cannot be used anymore'
        self.clear()
        if not self.hook_registered:
            self._register_hook()

        pred = self.vit(img)

        # move all recordings to one device before stacking
        target_device = self.device if self.device is not None else img.device
        recordings = tuple(map(lambda t: t.to(target_device), self.recordings))

        attns = torch.stack(recordings, dim = 1) if len(recordings) > 0 else None
        return pred, attns

In [None]:
import cv2
import matplotlib.pyplot as plt

patch_size = 16
image_size = 224
select_index = 0 # change the index to visualize different images

model.cpu()
model.eval()
v = Recorder(model) # load our implemented ViT

data, label = next(iter(val_loader))
_, attns = v(data)

m = attns[select_index][:, :, 0, 1:].sum(1)
plt.imshow(data[select_index].permute(1,2,0))
plt.imshow(cv2.resize(m.sum(0).reshape(image_size//patch_size, image_size//patch_size).numpy(), (image_size, image_size)), cmap='hot', alpha=0.5)
plt.axis('off')
plt.title('Averaged Attention Weights')
plt.show()

_, arr = plt.subplots(3, 3, figsize=(10, 10))
for idx, i in enumerate(m):
    arr[idx//3, idx%3].imshow(cv2.resize(i.reshape(image_size//patch_size, image_size//patch_size).numpy(), (image_size, image_size)), cmap='hot')
    arr[idx//3, idx%3].axis('off')
    arr[idx//3, idx%3].set_title(f'Attention Layer {idx}')
plt.show()

#### Note we trained our implemented ViT for only 5 epochs. The visualization result may not be converged.

### Here we provide the pretrained model (trained for 100 epochs) for visualization. (accuracy 0.7613)

In [None]:
if not os.path.exists('model.pt'):
    !wget https://www.dropbox.com/s/ffj5j1kdt30racl/model.pt?dl=1 # trained 100 epochs, acc: 0.7613


patch_size = 16
image_size = 224
select_index = 0 # change the index to visualize different images

model = ViT(layers=9, dim=192, heads=12, image_size=(224, 224), num_classes=2)
model.load_state_dict(torch.load('model.pt')['weight'])
model.cpu()
model.eval()
v = Recorder(model) # load our implemented ViT

data, label = next(iter(val_loader))
pred, attns = v(data)

m = attns[select_index][:, :, 0, 1:].sum(1)
plt.imshow(data[select_index].permute(1,2,0))
plt.imshow(cv2.resize(m[-1].reshape(image_size//patch_size, image_size//patch_size).numpy(), (image_size, image_size)), cmap='hot', alpha=0.5)
plt.axis('off')
plt.title('Averaged Attention Weights')
plt.show()

_, arr = plt.subplots(3, 3, figsize=(10, 10))
for idx, i in enumerate(m):
    arr[idx//3, idx%3].imshow(cv2.resize(i.reshape(image_size//patch_size, image_size//patch_size).numpy(), (image_size, image_size)), cmap='hot')
    arr[idx//3, idx%3].axis('off')
    arr[idx//3, idx%3].set_title(f'Attention Layer {idx}')
plt.show()