In [1]:
!pip install einops



## Testing Code Blocks

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import numpy as np

import torchvision
import torchvision.transforms as transforms

import os
import argparse
import pandas as pd
import csv
import time

In [4]:
size = 32
bs = 512

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.Resize(size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Resize(size),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=bs, shuffle=True, num_workers=8)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


def pair(t):
    return t if isinstance(t, tuple) else (t, t)

Files already downloaded and verified




Files already downloaded and verified


In [5]:
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# Patch Embedding Old VIT

for batch_idx, (inputs, targets) in enumerate(trainloader):
    break

print(f'input shape is {inputs.shape}')

dim = 512

channels = 3

image_height, image_width = pair(32)
patch_height, patch_width = pair(4)

num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width

print(f'number of patches : {num_patches}')
print(f'patch dim = {patch_dim}')

to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )


x_emb_old = to_patch_embedding(inputs)
print(f'patch embedded shape : {x_emb_old.shape}')

input shape is torch.Size([512, 3, 32, 32])
number of patches : 64
patch dim = 48
patch embedded shape : torch.Size([512, 64, 512])


In [6]:
# new patch embedding  part 1 X_ij

channels = 3

image_height, image_width = pair(32)
patch_height, patch_width = pair(4)

num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width

print(f'number of patches : {num_patches} = {np.sqrt(num_patches)}x{np.sqrt(num_patches)}')
print(f'patch dim = {patch_dim}  = {channels} x {patch_height} x {patch_width}')


Rearrange_L1 = Rearrange('b c (h q1) (w q2) -> b q1 q2 h w c', h = patch_height, w = patch_width)
x = Rearrange_L1(inputs)
print(f'First Rearrange X_ij : {x.shape}')

number of patches : 64 = 8.0x8.0
patch dim = 48  = 3 x 4 x 4
First Rearrange X_ij : torch.Size([512, 8, 8, 4, 4, 3])


In [7]:
# https://pytorch.org/docs/stable/generated/torch.einsum.html

w = nn.Parameter(torch.randn(4, 32), requires_grad=True)

temp = torch.einsum("abcdef,dg,eh->abcghf", (x, w, w))
print(f'Shape after n-mode product is : {temp.shape}')

Shape after n-mode product is : torch.Size([512, 8, 8, 32, 32, 3])


In [8]:
# new patch embedding  part 2 X_ij - bar

# from einops import rearrange, repeat
# from einops.layers.torch import Rearrange

# for batch_idx, (inputs, targets) in enumerate(trainloader):
#     break

print(f'input shape is {inputs.shape}')

dim = 32

channels = 3

image_height, image_width = pair(32)
patch_height, patch_width = pair(4)

num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width

print(f'number of patches : {num_patches} = {np.sqrt(num_patches)}x{np.sqrt(num_patches)}')
print(f'patch dim = {patch_dim}  = {channels} x {patch_height} x {patch_width}')


Rearrange_L1 = Rearrange('b c (h q1) (w q2) -> b q1 q2 h w c', h = patch_height, w = patch_width)
x = Rearrange_L1(inputs)
print(f'First Rearrange X_ij : {x.shape}')

W1 =  nn.Parameter(torch.randn(patch_height, dim), requires_grad=True)
W2 =  nn.Parameter(torch.randn(patch_height, dim), requires_grad=True)
W3 =  nn.Parameter(torch.randn(channels, 1), requires_grad=True)

xbar_ = torch.einsum("abcdef,dg,eh, fi->abcghi", (x, W1, W2, W3))
print(f'Shape after n-mode product is : {xbar_.shape}')

xbar = xbar_.squeeze()
print(f'xbar squeezed size is : {xbar.shape}')



input shape is torch.Size([512, 3, 32, 32])
number of patches : 64 = 8.0x8.0
patch dim = 48  = 3 x 4 x 4
First Rearrange X_ij : torch.Size([512, 8, 8, 4, 4, 3])
Shape after n-mode product is : torch.Size([512, 8, 8, 32, 32, 1])
xbar squeezed size is : torch.Size([512, 8, 8, 32, 32])


In [9]:
# Calculate Q K V Old Method

# x_emb_old
# dim = 512
# inner_dim = dim_head *  heads = 512 * 8
dim = 512
inner_dim = 512*8

to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
qkv = to_qkv(x_emb_old)

print(f'qkv shape is  : {qkv.shape} = ({bs}, {num_patches}, 512*8*3)')




qkv shape is  : torch.Size([512, 64, 12288]) = (512, 64, 512*8*3)


In [10]:
# Calculate Q K V new method
# xbar squeezed size is : torch.Size([512, 8, 8, 32, 32])


W1Q = nn.Parameter(torch.randn(xbar.shape[3], 64), requires_grad=True)
W2Q = nn.Parameter(torch.randn(xbar.shape[4], 64), requires_grad=True)

W1K = nn.Parameter(torch.randn(xbar.shape[3], 64), requires_grad=True)
W2K = nn.Parameter(torch.randn(xbar.shape[4], 64), requires_grad=True)

W1V = nn.Parameter(torch.randn(xbar.shape[3], 64), requires_grad=True)
W2V = nn.Parameter(torch.randn(xbar.shape[4], 64), requires_grad=True)


Q  = torch.einsum("abcde,dg,eh->abcgh", (xbar, W1Q, W2Q))
K  = torch.einsum("abcde,dg,eh->abcgh", (xbar, W1K, W2K))
V  = torch.einsum("abcde,dg,eh->abcgh", (xbar, W1V, W2V))

In [11]:
print(f'Q shape is : {Q.shape}')
print(f'K shape is : {K.shape}')
print(f'V shape is : {V.shape}')

Q shape is : torch.Size([512, 8, 8, 64, 64])
K shape is : torch.Size([512, 8, 8, 64, 64])
V shape is : torch.Size([512, 8, 8, 64, 64])


In [12]:
# Softmax dimension

## TO DO .....

attend = nn.Softmax(dim = -1)


temp = torch.rand(3,3,4,6)
temp2 = attend(temp)

torch.sum(temp2[0])

tensor(12.)

In [20]:
# Calculate Attention new method :

# Q shape is : torch.Size([512, 8, 8, 64, 64])
# K shape is : torch.Size([512, 8, 8, 64, 64])
# V shape is : torch.Size([512, 8, 8, 64, 64])

attend = nn.Softmax(dim = -1)


A = torch.einsum("abcde,aghed->abcgh", (Q,K))
print(f'A : QK^T shape is : {A.shape}')
# QK^T shape is : torch.Size([512, 8, 8, 8, 8])

A_ = attend(A)/8  # ??????????????  TO DO


Attention_out =  torch.einsum("abcde,adefg->abcfg", (A_, V))

print(f'Attention Block Output shape is : {Attention_out.shape}')


A : QK^T shape is : torch.Size([512, 8, 8, 8, 8])
Attention Block Output shape is : torch.Size([512, 8, 8, 64, 64])


# The Rest is Old VIT - New VIT is Not Implemented Yet

<b>  Needs Further Analysis </b>



<font size="+200"><font color= "red"> THE REST OF THE CODE IS IRRELEVANT TO THIS PROCESS.
NO NEED TO REVIEW THEM

</font>
</font>

# VIT

In [None]:
# https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py

import torch
from torch import nn


from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

# Train On CIFAR10

## Libraries

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import numpy as np

import torchvision
import torchvision.transforms as transforms

import os
import argparse
import pandas as pd
import csv
import time

# from utils import progress_bar
# from randomaug import RandAugment

## Hyper parameters

In [3]:
# Hyper parameters
lr = 1e-4
bs = 512
size = 32
n_epochs = 20
patch = 4
dimhead = 512


device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)


cuda


## Load Dataset - Train/Test Loaders

In [4]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.Resize(size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Resize(size),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=bs, shuffle=True, num_workers=8)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 43007646.68it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data




Files already downloaded and verified


## Build Model

In [None]:
net = ViT(
    image_size = size,
    patch_size = patch,
    num_classes = 10,
    dim = dimhead,
    depth = 6,
    heads = 8,
    mlp_dim = 512,
    dropout = 0.1,
    emb_dropout = 0.1
)

net

ViT(
  (to_patch_embedding): Sequential(
    (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=4, p2=4)
    (1): Linear(in_features=48, out_features=512, bias=True)
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (transformer): Transformer(
    (layers): ModuleList(
      (0-5): 6 x ModuleList(
        (0): PreNorm(
          (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fn): Attention(
            (attend): Softmax(dim=-1)
            (to_qkv): Linear(in_features=512, out_features=1536, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=512, out_features=512, bias=True)
              (1): Dropout(p=0.1, inplace=False)
            )
          )
        )
        (1): PreNorm(
          (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fn): FeedForward(
            (net): Sequential(
              (0): Linear(in_features=512, out_features=512, bias=True)
              (1): GELU(approximate='non

## Optimizer / Criterion

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=lr)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs)


num_parameters = sum(p.numel() for p in net.parameters() if p.requires_grad)

print(f'number of parameters : {num_parameters}')

scaler = torch.cuda.amp.GradScaler(enabled=True)

number of parameters : 9523722


## Def Train/Test

In [None]:
best_acc = 0

def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        # Train with amp
        with torch.cuda.amp.autocast(enabled=True):
            outputs = net(inputs)
            loss = criterion(outputs, targets)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()


    acc = 100.*correct/total
    # print(f'acc : {acc}')

    return train_loss/(batch_idx+1), acc


def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()


    acc = 100.*correct/total
    # print(f'acc : {acc}')

    return test_loss, acc


## Train Model

In [None]:
train_list_loss = []
train_list_acc = []

test_list_loss = []
test_list_acc = []

net.cuda()
for epoch in range(0, n_epochs):
    start = time.time()
    trainloss,train_acc = train(epoch)
    val_loss, val_acc = test(epoch)

    scheduler.step(epoch-1) # step cosine scheduling

    train_list_loss.append(trainloss)
    train_list_acc.append(train_acc)

    test_list_loss.append(val_loss)
    test_list_acc.append(val_acc)

    print(f'epoch {epoch}, train loss = {trainloss}, train acc = {train_acc}, test loss = {val_loss}, test acc = {val_acc}, epoch time : {time.time()-start}, lr = {optimizer.param_groups[0]["lr"]}')



Epoch: 0
epoch 0, train loss = 7.023257639943337, train acc = 14.928, test loss = 697.6976375579834, test acc = 14.9, epoch time : 42.170133113861084, lr = 9.938441702975689e-05

Epoch: 1
epoch 1, train loss = 6.552092226184144, train acc = 14.7, test loss = 604.46044921875, test acc = 15.67, epoch time : 39.909051179885864, lr = 0.0001

Epoch: 2
epoch 2, train loss = 5.50928918682799, train acc = 15.236, test loss = 502.34032249450684, test acc = 16.88, epoch time : 41.441991567611694, lr = 9.938441702975689e-05

Epoch: 3
epoch 3, train loss = 4.554507211763031, train acc = 17.284, test loss = 437.2686836719513, test acc = 20.75, epoch time : 41.283379793167114, lr = 9.755282581475769e-05

Epoch: 4
epoch 4, train loss = 4.053258569873109, train acc = 21.468, test loss = 404.33879351615906, test acc = 24.56, epoch time : 40.15654253959656, lr = 9.45503262094184e-05

Epoch: 5
epoch 5, train loss = 3.7678627821863913, train acc = 23.69, test loss = 376.4938209056854, test acc = 25.8, ep