In [None]:
!nvidia-smi

Fri Mar 17 15:50:39 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   56C    P0    27W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import torchvision
from tqdm import tqdm
from tqdm.notebook import trange, tqdm

In [None]:
DATA_DIR='./data'
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("device:", DEVICE)

device: cuda


In [None]:
# helpers
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

In [None]:
class PreNorm(nn.Module):
    def __init__(self, dim, numb_patch, fn):
        super().__init__()
        self.norm = nn.LayerNorm([dim, numb_patch, numb_patch])
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

In [None]:
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=dim, out_channels=hidden_dim, kernel_size=(1, 1), padding=0, bias=False),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Conv2d(in_channels=hidden_dim, out_channels=dim, kernel_size=(1, 1), padding=0, bias=False),
            nn.Dropout(dropout)
        )        
    def forward(self, x):
        out = self.net(x)
        return out

In [None]:
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5
        self.head_channels = dim
        in_channels = dim
        out_channels = dim
        self.attend = nn.Softmax(dim = -1)
        self.to_keys = nn.Conv2d(in_channels, out_channels, 1)
        self.to_queries = nn.Conv2d(in_channels, out_channels, 1)
        self.to_values = nn.Conv2d(in_channels, out_channels, 1)

        self.to_out = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, 1), padding=0, bias=False),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b = x.shape[0]
        k = self.to_keys(x).view(b, self.heads, self.head_channels, -1)
        q = self.to_queries(x).view(b, self.heads, self.head_channels, -1)
        v = self.to_values(x).view(b, self.heads, self.head_channels, -1)
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = out.permute(0, 2, 1, 3)
        return self.to_out(out)

In [None]:
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, numb_patch, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, numb_patch, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, numb_patch, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

In [None]:
class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, numb_patch, pool = 'cls', 
                 channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        self.num_patch = (image_height // patch_height)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
        
        self.to_patch_embedding = nn.Conv2d(in_channels=channels,
                                            out_channels=patch_dim,
                                            kernel_size=patch_size,
                                            stride=patch_size,
                                            padding=0)

        self.pos_embedding = nn.Parameter(torch.randn(1, patch_dim + 1, self.num_patch, self.num_patch))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, numb_patch, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm([num_patches, 1, 1]),
            nn.Conv2d(in_channels=num_patches, out_channels=num_classes, kernel_size=(1, 1), padding=0, bias=False)
        )
        self.flatten = nn.Flatten(start_dim=2, end_dim=3)


    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, _, _, _ = x.shape

        cls_tokens = nn.Parameter(torch.ones(b, 1, self.num_patch, self.num_patch),requires_grad=True)
        cls_tokens = cls_tokens.to(device='cuda')
        x = torch.cat((cls_tokens, x), dim=1)
        
        x += self.pos_embedding

        x = self.dropout(x)

        x = self.transformer(x)
        x = self.flatten(x)
        x = x[:, 0]
        x = self.to_latent(x)
        x = torch.unsqueeze(x, 2)
        x = torch.unsqueeze(x, 3)
        out = self.mlp_head(x)
        out = out.view(-1, 10)
        return out

In [None]:
model = ViT(image_size=32,patch_size= 4,num_classes=10,
            dim=49,depth=6,heads=8,mlp_dim=147,numb_patch=8, 
            dropout=0.1,emb_dropout=0.1)

In [None]:
model.to(DEVICE)

ViT(
  (to_patch_embedding): Conv2d(3, 48, kernel_size=(4, 4), stride=(4, 4))
  (dropout): Dropout(p=0.1, inplace=False)
  (transformer): Transformer(
    (layers): ModuleList(
      (0): ModuleList(
        (0): PreNorm(
          (norm): LayerNorm((49, 8, 8), eps=1e-05, elementwise_affine=True)
          (fn): Attention(
            (attend): Softmax(dim=-1)
            (to_keys): Conv2d(49, 49, kernel_size=(1, 1), stride=(1, 1))
            (to_queries): Conv2d(49, 49, kernel_size=(1, 1), stride=(1, 1))
            (to_values): Conv2d(49, 49, kernel_size=(1, 1), stride=(1, 1))
            (to_out): Sequential(
              (0): Conv2d(49, 49, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): Dropout(p=0.1, inplace=False)
            )
          )
        )
        (1): PreNorm(
          (norm): LayerNorm((49, 8, 8), eps=1e-05, elementwise_affine=True)
          (fn): FeedForward(
            (net): Sequential(
              (0): Conv2d(49, 147, kernel_size=(1, 1), 

In [None]:
print("Number of parameters: {:,}".format(sum(p.numel() for p in model.parameters())))

Number of parameters: 226,462


In [None]:
!pip install torchsummary
from torchsummary import summary
summary(model, input_size=(3, 32, 32))

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1             [-1, 48, 8, 8]           2,352
           Dropout-2             [-1, 49, 8, 8]               0
         LayerNorm-3             [-1, 49, 8, 8]           6,272
            Conv2d-4             [-1, 49, 8, 8]           2,450
            Conv2d-5             [-1, 49, 8, 8]           2,450
            Conv2d-6             [-1, 49, 8, 8]           2,450
           Softmax-7            [-1, 8, 49, 49]               0
            Conv2d-8             [-1, 49, 8, 8]           2,401
           Dropout-9             [-1, 49, 8, 8]               0
        Attention-10             [-1, 49, 8, 8]               0
          PreNorm-11             [-1, 49, 8, 8]               0
        LayerNorm-12             [-1, 49, 8, 8]           6,272
    

In [None]:
IMAGE_SIZE = 32

NUM_CLASSES = 10
NUM_WORKERS = 8
BATCH_SIZE = 128
cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(32, scale=(0.75, 1.0), ratio=(1.0, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandAugment(num_ops=1, magnitude=8),
    transforms.ColorJitter(0.1, 0.1, 0.1),
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std),
    transforms.RandomErasing(p=0.25)
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std)
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=4)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False, num_workers=4)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data




Files already downloaded and verified


In [None]:
import time

clip_norm = True

model = nn.DataParallel(model, device_ids=[0]).cuda()
opt = optim.Adam(model.parameters(), lr=1e-2)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, EPOCHS)
criterion = nn.CrossEntropyLoss()
scaler = torch.cuda.amp.GradScaler()
EPOCHS = 2

for epoch in range(EPOCHS):
    start = time.time()
    train_loss, train_acc, n = 0, 0, 0
    pbar = tqdm(trainloader)
    for i, (X, y) in enumerate(trainloader):
        model.train()
        X, y = X.cuda(), y.cuda()

        # lr = lr_schedule(epoch + (i + 1)/len(trainloader))
        # opt.param_groups[0].update(lr=lr)

        with torch.cuda.amp.autocast():
            output = model(X)
            loss = criterion(output, y)

        scaler.scale(loss).backward()
        if clip_norm:
            scaler.unscale_(opt)
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(opt)
        scaler.update()
        opt.zero_grad()        
        
        train_loss += loss.item() * y.size(0)
        train_acc += (output.max(1)[1] == y).sum().item()
        n += y.size(0)
        pbar.set_description(desc=f'Loss={train_loss :0.4f} Batch={i} Train Acc={train_acc/n :0.4f}')          
        
    model.eval()
    test_acc, m = 0, 0
    with torch.no_grad():
        for i, (X, y) in enumerate(testloader):
            X, y = X.cuda(), y.cuda()
            with torch.cuda.amp.autocast():
                output = model(X)
            test_acc += (output.max(1)[1] == y).sum().item()
            m += y.size(0)
    
    scheduler.step(epoch-1)

    print(f' Epoch: {epoch} | Train Acc: {train_acc/n:.4f}, Test Acc: {test_acc/m:.4f}, Time: {time.time() - start:.1f}')


  0%|          | 0/391 [00:00<?, ?it/s]

 Epoch: 0 | Train Acc: 0.3169, Test Acc: 0.4276, Time: 64.6




  0%|          | 0/391 [00:00<?, ?it/s]

 Epoch: 1 | Train Acc: 0.3885, Test Acc: 0.4570, Time: 66.5
