<a href="https://colab.research.google.com/github/aspiringastro/vit-step-by-step/blob/main/vit_version_3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline


In [None]:
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.datasets import CIFAR10
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np

class CIFAR10DataSet():
    def __init__(self, data_dir="data/cifar10", train_val_split=0.8):
        self.data_dir = data_dir
        self.dataset = CIFAR10(root=self.data_dir, download=True)
        self.mean = (0.485, 0.456, 0.406)
        self.std = (0.229, 0.224, 0.225)
        self.train_val_split = train_val_split

    def  train_dataloader(self, batch_size=32, resize=32, p=0.5, mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225), num_workers=4):
        tf = T.Compose([
                T.RandomResizedCrop(size=resize),
                T.RandomHorizontalFlip(p=p),
                T.RandomVerticalFlip(p=p),
                T.ToTensor(),
                T.Normalize(mean, std),
            ]
        )
        ds = CIFAR10(root=self.data_dir, train=True, transform=tf)
        num_train = len(ds)
        indices = list(range(num_train))
        split = int(np.floor(self.train_val_split * num_train))
        train_sampler = SubsetRandomSampler(indices[split:])
        dl = DataLoader(
            ds,
            batch_size=batch_size, 
            num_workers=num_workers, 
            sampler=train_sampler, 
            drop_last=True
            )
        return dl
    
    def  val_dataloader(self, batch_size=32, resize=32, p=0.5, mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225), num_workers=4):
        tf = T.Compose([
                T.RandomResizedCrop(size=resize),
                T.RandomHorizontalFlip(p=p),
                T.RandomVerticalFlip(p=p),
                T.ToTensor(),
                T.Normalize(mean, std),
            ]
        )
        ds = CIFAR10(root=self.data_dir, train=True, transform=tf)
        num_train = len(ds)
        indices = list(range(num_train))
        split = int(np.floor(self.train_val_split * num_train))
        val_sampler = SubsetRandomSampler(indices[:split])
        dl = DataLoader(
            ds,
            batch_size=batch_size, 
            num_workers=num_workers, 
            sampler=val_sampler, 
            drop_last=True
            )
        return dl

    
    def test_dataloader(self, batch_size=32, mean=(0.485, 0.456, 0.406) ,std=(0.229, 0.224, 0.225), num_workers=2):
        tf = T.Compose([
            T.ToTensor(),
            T.Normalize(mean, std),
            ]
        )
        ds = CIFAR10(root=self.data_dir, train=False, transform=tf)
        dl = DataLoader(ds,batch_size=batch_size, num_workers=num_workers, drop_last=True)
        return dl
    
    def get_next(self, dl):
        return next(iter(dl))
    
    def get_classes(self):
        classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
        return classes



In [None]:
import torchvision
import torchvision.utils
import torchvision.transforms.functional as TF
from torchvision.utils import make_grid

import matplotlib.pyplot as plt
import numpy as np

plt.rcParams["savefig.bbox"] = 'tight'

def make_image(img, mean=(0., 0., 0.), std=(1., 1., 1.)):
    #denormalize
    for i in range(3):
        img[i] = img[i] * std[i] + mean[i]
    npimg = img.numpy()
    return np.transpose(npimg, (1, 2, 0))

def show_image(imgs, mean=(0., 0., 0.), std=(1., 1., 1.)):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        for j in range(3):
            img[j] = img[j] * std[j] + mean[j]
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

def show_images(imgs, mean=(0., 0., 0.), std=(1., 1., 1.)):
    grid_imgs = make_grid(imgs)
    grid_imgs = make_image(grid_imgs, mean, std)
    plt.imshow(grid_imgs)
    plt.axis('off')

    


In [None]:
# Hyper-parameters
eval_iters = 50
dropout = 0.2
device = 'cuda' if torch.cuda.is_available() else 'cpu'
learning_rate = 1e-2
max_iters = 3000
eval_interval = 500
n_embd = 384
img_size=128

In [None]:
dataset = CIFAR10DataSet()
downloader = dataset.train_dataloader()

In [None]:
# data loading
def get_batch(dataset, split, batch_size=32, resize=img_size):
    if split == 'train':
        downloader = dataset.train_dataloader(batch_size=batch_size, resize=resize)
    elif split == 'val':
        downloader = dataset.val_dataloader(batch_size=batch_size, resize=resize)
    elif split == 'test':
        downloader = dataset.test_dataloader(batch_size=batch_size, resize=resize)
    else:
        raise AttributeError(f'Invalid Split parameter ({split}) provided.')
    x, y_label = dataset.get_next(downloader)
    y = F.one_hot(y_label, num_classes=len(dataset.get_classes()))
    x, y = x.to(device), y.float().to(device)
    return x,y

dataset = CIFAR10DataSet()

x, y = get_batch(dataset, "train")
print(x.shape, y.shape)

x_val, y_val = get_batch(dataset, "val")
print(x_val.shape, y_val.shape)

x_test, y_test= get_batch(dataset, "test")
print(x_test.shape, y_test.shape)


In [None]:
class PatchEmbedding(nn.Module):
    """
    Patch embed layer that takes a 2D image to create embed patches of size P
    """
    def __init__(self, img_size, patch_size, in_chans=3, embed_dim=n_embd):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        self.patch_embd = nn.Conv2d(
            in_chans,
            embed_dim,
            kernel_size=patch_size, 
            stride=patch_size,
            device=device,
            )
        
    def forward(self, x):
        # print(f"PatchEmbedding: x.shape: {x.shape}")
        x = self.patch_embd(x)
        # print(f"PatchEmbedding: patch_embd(x).shape: {x.shape}")
        x = x.flatten(2)
        # print(f"PatchEmbedding: flatten(patch_embed(x)).shape: {x.shape}")
        x = x.transpose(1,2)
        # print(f"PatchEmbedding: transpose(flatten(patch_embed(x)),(1,2)).shape: {x.shape}")
        return x

In [None]:

class Head(nn.Module):
    """ one head of single attention """
    def __init__(self, head_size, n_embd):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)  # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2, -1) * C**-0.5 # sqrt of head size, (B,T,C) @ (B,T,C)^T => (B,T,C) @ (B,C,T) => (B,T,T)
        wei = F.softmax(wei, dim=1) # (B,T,T)
        wei = self.dropout(wei)
        v = self.value(x) #(B,T,C)
        out = wei @ v # (B,T,T) @ (B,T,C) = (B,T,C)
        return out

In [None]:
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size, n_embd):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size, n_embd) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1) # concat over channel dimension
        out = self.proj(out) # projection is a linear transformation of the outcome of the previous multi-head layer
        out = self.dropout(out) # dropout
        return out

In [None]:
class FeedForward(nn.Module):
    """ a simple linear layer of feedforward followed by non-linearity"""

    def __init__(self, n_embd):
        super().__init__()
        self.nn = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd), # projection layer in FFwd
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.nn(x)

In [None]:
class Block(nn.Module):
    """ Transformer Block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd : embedding dimension
        # n_head : number of heads needed for multi-head self-attention
        super().__init__()
        assert n_embd % n_head == 0, f'n_embd {n_embd}, n_head: {n_head} must be a divisor'
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size, n_embd) # communication
        self.ffwd = FeedForward(n_embd) # computation
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        # No residual connections
        # x = self.sa(x)
        # x = self.ffwd(x)
        # with residual connection
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))

        return x


In [None]:
class VisionTransformerModel(nn.Module):

    def __init__(self, img_size=img_size, patch_size=4, in_chans=3, embed_dim=n_embd, n_classes=10, n_layers=4, n_heads=6):
        super().__init__()
        # Every patch sequence begins with a CLS token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        self.patch_embedding_table = PatchEmbedding(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        self.position_embedding_table = nn.Embedding(self.patch_embedding_table.n_patches, embedding_dim=embed_dim)
        self.blocks = nn.Sequential(
            *[ Block(embed_dim, n_heads) for _ in range(n_layers)],
            nn.LayerNorm(embed_dim),
        )
        self.vm_head = nn.Linear(embed_dim, n_classes)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, idx, targets=None):
        n_samples, n_chans, n_patch, _ = idx.shape
        # print(f"VisionTransformerModel: n_samples={n_samples}, n_chans={n_chans}, n_patch={n_patch}")
        
        patch_emb = self.patch_embedding_table(idx)
        # print(f"VisionTransformerModel: patch_emb shape={patch_emb.shape}")
        n_patches = self.patch_embedding_table.n_patches
        pos_emb = self.position_embedding_table(torch.arange(n_patches, device=device))
        # print(f"VisionTransformerModel: pos_emb shape={pos_emb.shape}")
        x = patch_emb + pos_emb
        # print(f"VisionTransformerModel: x (after patch+pos) shape={x.shape}")

        # Prepend the cls_token
        cls_token = self.cls_token.expand(n_samples, -1, -1)
        # print(f"VisionTransformerModel: cls_token shape={cls_token.shape}")
        x =  torch.cat((cls_token, x), dim=1)
        # print(f"VisionTransformerModel: x (after cat cls_token) shape={x.shape}")

        x = self.blocks(x)
        # print(f"VisionTransformerModel: x.blocks shape={x.shape}")

        cls_token_final = x[:, 0]
        # print(f"VisionTransformerModel: cls_token_final shape={cls_token_final.shape}")

        logits = self.softmax(self.vm_head(cls_token_final))
        
        # print(f"VisionTransformerModel: logits shape={logits.shape}")
        if targets is None:
            loss = None
        else:
            # B, T, C = logits.shape
            # print(f"VisionTransformerModel: B T C shape={B} {T} {C}")
            # logits = logits.view(B*T, C)
            # targets = targets.view(B*T)
            # print(f"VisionTransformerModel: Logits = {logits.shape}\n{logits}\n")
            # print(f"VisionTransformerModel: Targets = {targets.shape}\n{targets}\n")
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    


In [None]:
model = VisionTransformerModel()
m = model.to(device)

In [None]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(dataset, split)
            logits, loss = m(X.to(device), Y.to(device))
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [None]:
# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')


In [None]:


# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
from tqdm import tqdm

for it in tqdm(range(max_iters)):
    
    # every once in a while evaluate the loss on train and val sets
    if it % eval_interval == 0 or it == max_iters - 1:
        losses = estimate_loss()
        print(f"\tstep {it}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch(dataset, 'train')

    # evaluate the loss
    logits, loss = m(xb.to(device), yb.to(device))
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

In [None]:
torch.save(m, "vit_base_attempt_1.pth")

In [None]:
n_samples = 100
xt, yt = get_batch(dataset, 'test', n_samples)
xy, yt = xt.to(device), yt.to(device)
k = 3
labels = dataset.get_classes()
logits = model(xt.to(device))
print(logits[0].shape)
softmax = nn.Softmax(dim=-1)

for j in range(n_samples):
  target = labels[torch.argmax(yt[j])]
  probs = softmax(logits[0][j])
  top_probs, top_ics = probs.topk(k)


  for i, (ix_, prob_) in enumerate(zip(top_ics, top_probs)):
    ix = ix_.item()
    prob = prob_.item()
    cls = labels[ix].strip()
    print(f"{i}: {cls:<45} -- {prob*100.0:2.1f}% -- {target}")