In [None]:
!nvidia-smi

Thu Mar 16 12:44:18 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   61C    P0    27W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
!pip install einops

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.6.0-py3-none-any.whl (41 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.6/41.6 KB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.0


In [None]:
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import torchvision
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from tqdm import tqdm
from tqdm.notebook import trange, tqdm

In [None]:
DATA_DIR='./data'

In [None]:
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("device:", DEVICE)

device: cuda


In [None]:
# helpers
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

In [None]:
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

In [None]:
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        out = self.net(x)
        print(f'    FeedForward out.shape - {out.shape}')
        return out

In [None]:
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        print(f'    Attention qkv.shape: {qkv[0].shape} , {qkv[1].shape} , {qkv[2].shape} ')
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
        print(f'    Attention q, k, v after rearranging qkv: {q.shape} , {k.shape} , {v.shape}')

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        print(f'    Attention dots.shape after matmul q and k.transpose: {dots.shape}')

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        print(f'    Attention out.shape after matmul attn and v: {out.shape}')
        out = rearrange(out, 'b h n d -> b n (h d)')
        print(f'    Attention out.shape after rearrange: {out.shape}')
        return self.to_out(out)

In [None]:
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        cnt = 1
        for attn, ff in self.layers:
            print(f' **Transformer Layer - {cnt}**')
            print(f'    x.shape before attn(x) + x: {x.shape}')
            x = attn(x) + x
            print(f'    x.shape after attn(x) + x : {x.shape}')
            x = ff(x) + x
            print(f'    x.shape after ff(x) + x   : {x.shape}')
        return x

In [None]:
class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        print(f' vit-init image_size, image_height, image_width : {image_size}, {image_height}, {image_width}')
        patch_height, patch_width = pair(patch_size)
        print(f' vit-init patch_size, patch_height, patch_width : {patch_size}, {patch_height}, {patch_width}')

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        print(f' vit-init num_patches : {num_patches}')
        patch_dim = channels * patch_height * patch_width
        print(f' vit-init patch_dim, channels : {patch_dim}, {channels}')
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
        print(f' vit-init num_classes, dim, depth, heads, mlp_dim, dim_head : {num_classes}, {dim}, {depth}, {heads}, {mlp_dim}, {dim_head}')

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        print(f' vit-fwd img.size : {img.shape}')
        x = self.to_patch_embedding(img)
        print(f' vit-fwd self.to_patch_embedding.size : {x.shape}')
        b, n, _ = x.shape
        print(f' vit-fwd b, n, _ : {x.shape}')

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        print(f' vit-fwd self.cls_token.shape : {self.cls_token.shape} , cls_tokens.shape : {cls_tokens.shape}')
        x = torch.cat((cls_tokens, x), dim=1)
        print(f' vit-fwd After concatenating cls_tokens with x ->  x.shape : {x.shape}')
        x += self.pos_embedding[:, :(n + 1)]
        print(f' vit-fwd self.pos_embedding.shape : {self.pos_embedding.shape},n : {n},  self.pos_embedding[:, :(n + 1)] : {self.pos_embedding[:, :(n + 1)].shape}')
        print(f' vit-fwd After adding self.pos_embedding with x ->  x.shape : {x.shape}')
        x = self.dropout(x)

        x = self.transformer(x)
        print(f' vit-fwd x.shape after transformer: {x.shape}')

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
        print(f' vit-fwd x.shape after mean: {x.shape}, self.pool : {self.pool}')

        x = self.to_latent(x)
        print(f' vit-fwd x.shape after self.to_latent(x): {x.shape}')
        out = self.mlp_head(x)
        print(f' vit-fwd out.shape after self.mlp_head(x): {out.shape}')
        return out

In [None]:
model = ViT(image_size=32,patch_size= 4,num_classes=10,
            dim=512,depth=6,heads=8,mlp_dim=512,
            dropout=0.1,emb_dropout=0.1)

 vit-init image_size, image_height, image_width : 32, 32, 32
 vit-init patch_size, patch_height, patch_width : 4, 4, 4
 vit-init num_patches : 64
 vit-init patch_dim, channels : 48, 3
 vit-init num_classes, dim, depth, heads, mlp_dim, dim_head : 10, 512, 6, 8, 512, 64


In [None]:
s1 = torch.randn(1, 1, 512)
s2 = repeat(s1, '() n d -> b n d', b = 32)
# 1, 1, 512 -> 32, 1, 512
s2.shape

torch.Size([32, 1, 512])

In [None]:
model.to(DEVICE)

ViT(
  (to_patch_embedding): Sequential(
    (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=4, p2=4)
    (1): Linear(in_features=48, out_features=512, bias=True)
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (transformer): Transformer(
    (layers): ModuleList(
      (0): ModuleList(
        (0): PreNorm(
          (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fn): Attention(
            (attend): Softmax(dim=-1)
            (to_qkv): Linear(in_features=512, out_features=1536, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=512, out_features=512, bias=True)
              (1): Dropout(p=0.1, inplace=False)
            )
          )
        )
        (1): PreNorm(
          (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fn): FeedForward(
            (net): Sequential(
              (0): Linear(in_features=512, out_features=512, bias=True)
              (1): GELU(approximate='none')
  

In [None]:
print("Number of parameters: {:,}".format(sum(p.numel() for p in model.parameters())))

Number of parameters: 9,523,722


In [None]:
!pip install torchsummary
from torchsummary import summary
summary(model, input_size=(3, 32, 32))

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
 vit-fwd img.size : torch.Size([2, 3, 32, 32])
 vit-fwd self.to_patch_embedding.size : torch.Size([2, 64, 512])
 vit-fwd b, n, _ : torch.Size([2, 64, 512])
 vit-fwd self.cls_token.shape : torch.Size([1, 1, 512]) , cls_tokens.shape : torch.Size([2, 1, 512])
 vit-fwd After concatenating cls_tokens with x ->  x.shape : torch.Size([2, 65, 512])
 vit-fwd self.pos_embedding.shape : torch.Size([1, 65, 512]),n : 64,  self.pos_embedding[:, :(n + 1)] : torch.Size([1, 65, 512])
 vit-fwd After adding self.pos_embedding with x ->  x.shape : torch.Size([2, 65, 512])
 **Transformer Layer - 1**
    x.shape before attn(x) + x: torch.Size([2, 65, 512])
    Attention qkv.shape: torch.Size([2, 65, 512]) , torch.Size([2, 65, 512]) , torch.Size([2, 65, 512]) 
    Attention q, k, v after rearranging qkv: torch.Size([2, 8, 65, 64]) , torch.Size([2, 8, 65, 64]) , torch.Size([2, 8, 65, 64])
    Attention dots.shap

In [None]:
IMAGE_SIZE = 32

NUM_CLASSES = 10
NUM_WORKERS = 8
BATCH_SIZE = 128
EPOCHS = 10

cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(32, scale=(0.75, 1.0), ratio=(1.0, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandAugment(num_ops=1, magnitude=8),
    transforms.ColorJitter(0.1, 0.1, 0.1),
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std),
    transforms.RandomErasing(p=0.25)
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std)
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=4)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False, num_workers=4)


Files already downloaded and verified
Files already downloaded and verified


In [None]:
import time

clip_norm = True

model = nn.DataParallel(model, device_ids=[0]).cuda()
opt = optim.Adam(model.parameters(), lr=1e-2)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, EPOCHS)
criterion = nn.CrossEntropyLoss()
scaler = torch.cuda.amp.GradScaler()

for epoch in range(EPOCHS):
    start = time.time()
    train_loss, train_acc, n = 0, 0, 0
    pbar = tqdm(trainloader)
    for i, (X, y) in enumerate(trainloader):
        model.train()
        X, y = X.cuda(), y.cuda()

        # lr = lr_schedule(epoch + (i + 1)/len(trainloader))
        # opt.param_groups[0].update(lr=lr)

        with torch.cuda.amp.autocast():
            output = model(X)
            loss = criterion(output, y)

        scaler.scale(loss).backward()
        if clip_norm:
            scaler.unscale_(opt)
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(opt)
        scaler.update()
        opt.zero_grad()        
        
        train_loss += loss.item() * y.size(0)
        train_acc += (output.max(1)[1] == y).sum().item()
        n += y.size(0)
        pbar.set_description(desc=f'Loss={train_loss :0.4f} Batch={i} Train Acc={train_acc/n :0.4f}')          
        
    model.eval()
    test_acc, m = 0, 0
    with torch.no_grad():
        for i, (X, y) in enumerate(testloader):
            X, y = X.cuda(), y.cuda()
            with torch.cuda.amp.autocast():
                output = model(X)
            test_acc += (output.max(1)[1] == y).sum().item()
            m += y.size(0)
    
    scheduler.step(epoch-1)

    print(f' Epoch: {epoch} | Train Acc: {train_acc/n:.4f}, Test Acc: {test_acc/m:.4f}, Time: {time.time() - start:.1f}, lr: {lr:.6f}')
