In [None]:
import numpy as np

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader

from torchvision.datasets.mnist import MNIST
from torchvision.transforms import ToTensor


# ViT from scratch

### Import MNIST dataset

In [None]:
transform = ToTensor()

train_set = MNIST(root = './../datasets', train=True, download = True, transform=transform)
test_set = MNIST(root = './../datasets', train=False, download = True, transform=transform)

train_loader = DataLoader(train_set, shuffle = True, batch_size = 16)
test_loader = DataLoader(test_set, shuffle = True, batch_size = 16)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./../datasets/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./../datasets/MNIST/raw/train-images-idx3-ubyte.gz to ./../datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./../datasets/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./../datasets/MNIST/raw/train-labels-idx1-ubyte.gz to ./../datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./../datasets/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./../datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to ./../datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./../datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./../datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./../datasets/MNIST/raw



In [None]:
# from torchvision.datasets.cifar import CIFAR10

# transform = ToTensor()

# train_set_cifar = CIFAR10(root = './../datasets', train=True, download = True, transform=transform)
# test_set_cifar = CIFAR10(root = './../datasets', train=False, download = True, transform=transform)

# train_loader_cifar = DataLoader(train_set_cifar, shuffle = True, batch_size = 16)
# test_loader_cifar = DataLoader(test_set_cifar, shuffle = True, batch_size = 16)

### Train and Test function

In [None]:
def train_ViT_classify(model, optimizer, N_EPOCHS, train_loader, device = "cpu"):
  criterion = CrossEntropyLoss()
  for epoch in range(N_EPOCHS):
    train_loss = 0.0
    for batch in train_loader:
      x,y = batch
      x = x.to(device)
      y = y.to(device)
      y_hat = model(x)
      loss = criterion(y_hat, y) / len(x)

      train_loss += loss.item()

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

    print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.2f}")


def test_ViT_classify(model, optimizer, test_loader):
    criterion = CrossEntropyLoss()
    correct, total = 0, 0
    test_loss = 0.0
    for batch in test_loader:
        x, y = batch
        x = x.to(device)
        y = y.to(device)
        
        y_hat = model(x)
        loss = criterion(y_hat, y) / len(x)
        test_loss += loss

        correct += torch.sum(torch.argmax(y_hat, dim=1) == y).item()
        total += len(x)
    print(f"Test loss: {test_loss:.2f}")
    print(f"Test accuracy: {correct / total * 100:.2f}%")

### Multi-head Self Attention(MSA) model

In [None]:
class MSA(nn.Module):
    def __init__(self, d, n_heads=2):
        super(MSA, self).__init__()
        self.d = d
        self.n_heads = n_heads

        assert d % n_heads == 0, f"Can't divide dimension {d} into {n_heads} heads"

        d_head = int(d / n_heads)
        self.q_mappings = [nn.Linear(d_head, d_head) for _ in range(self.n_heads)]
        self.k_mappings = [nn.Linear(d_head, d_head) for _ in range(self.n_heads)]
        self.v_mappings = [nn.Linear(d_head, d_head) for _ in range(self.n_heads)]
        self.d_head = d_head
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, sequences):
        # Sequences has shape (N, seq_length, token_dim)
        # We go into shape    (N, seq_length, n_heads, token_dim / n_heads)
        # And come back to    (N, seq_length, item_dim)  (through concatenation)
        result = []
        for sequence in sequences:
            seq_result = []
            for head in range(self.n_heads):
                q_mapping = self.q_mappings[head]
                k_mapping = self.k_mappings[head]
                v_mapping = self.v_mappings[head]

                seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]
                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)

                attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
                seq_result.append(attention @ v)
            result.append(torch.hstack(seq_result))
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])

### Position encoding


In [None]:
def get_positional_embeddings(sequence_length, d, device="cpu"):
    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
    return result.to(device)

### ViT Model

In [None]:
class ViT(nn.Module):
    def __init__(self, input_shape, n_patches=7, hidden_d=8, n_heads=2, out_d=10):
        # Super constructor
        super(ViT, self).__init__()

        # Input and patches sizes
        self.input_shape = input_shape
        self.n_patches = n_patches
        self.n_heads = n_heads
        assert input_shape[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        assert input_shape[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        self.patch_size = (input_shape[1] / n_patches, input_shape[2] / n_patches)
        self.hidden_d = hidden_d

        # 1) Linear mapper
        self.input_d = int(input_shape[0] * self.patch_size[0] * self.patch_size[1])
        self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)

        # 2) Classification token
        self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))

        # 3) Positional embedding
        # (In forward method)

        # 4a) Layer normalization 1
        self.ln1 = nn.LayerNorm((self.n_patches ** 2 + 1, self.hidden_d))

        # 4b) Multi-head Self Attention (MSA) and classification token
        self.msa = MSA(self.hidden_d, n_heads)

        # 5a) Layer normalization 2
        self.ln2 = nn.LayerNorm((self.n_patches ** 2 + 1, self.hidden_d))

        # 5b) Encoder MLP
        self.enc_mlp = nn.Sequential(
            nn.Linear(self.hidden_d, self.hidden_d),
            nn.ReLU()
        )

        # 6) Classification MLP
        self.mlp = nn.Sequential(
            nn.Linear(self.hidden_d, out_d),
            nn.Softmax(dim=-1)
        )

    def forward(self, images):
        # Dividing images into patches
        n, c, w, h = images.shape
        patches = images.reshape(n, self.n_patches ** 2, self.input_d)

        # Running linear layer for tokenization
        tokens = self.linear_mapper(patches)

        # Adding classification token to the tokens
        tokens = torch.stack([torch.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])

        # Adding positional embedding
        tokens += get_positional_embeddings(self.n_patches ** 2 + 1, self.hidden_d, device).repeat(n, 1, 1)

        # TRANSFORMER ENCODER BEGINS ###################################
        # NOTICE: MULTIPLE ENCODER BLOCKS CAN BE STACKED TOGETHER ######
        # Running Layer Normalization, MSA and residual connection
        self.msa(self.ln1(tokens.to("cpu")).to(device))
        out = tokens + self.msa(self.ln1(tokens))

        # Running Layer Normalization, MLP and residual connection
        out = out + self.enc_mlp(self.ln2(out))
        # TRANSFORMER ENCODER ENDS   ###################################

        # Getting the classification token only
        out = out[:, 0]

        return self.mlp(out)

### Train ViT model

In [None]:
device = "cpu"
model = ViT((1, 28, 28), n_patches=7, hidden_d=20, n_heads=2, out_d=10)
model = model.to(device)

N_EPOCHS = 5
LR = 0.01
optimizer = Adam(model.parameters(), lr=LR)

In [None]:
train_ViT_classify(model, optimizer, N_EPOCHS, train_loader, device)

### Test ViT model

In [None]:
test_ViT_classify(model, optimizer, test_loader)

### Load model and testing 

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
ls

[0m[01;34mdrive[0m/  [01;34msample_data[0m/


In [None]:
cd drive/MyDrive/RTML/Final/

/content/drive/MyDrive/RTML/Final


In [None]:
ls

10_LSTM.ipynb              13_PPO_DDPG.ipynb  trained-pytorch-vit-imagenet.pt
11_seq2seq.ipynb           9_RNNS.ipynb       trained-vit_scratch_MNIST.pt
11_ViT.ipynb               chatDataset.txt
12_Reinforce_PG_A2C.ipynb  [0m[01;34mdata[0m/


In [None]:
def test_ViT_classify(model, optimizer, test_loader):
    criterion = CrossEntropyLoss()
    correct, total = 0, 0
    test_loss = 0.0
    all_losses = []
    for batch in test_loader:
        x, y = batch
        x = x.to(device)
        y = y.to(device)
        
        y_hat = model(x)
        loss = criterion(y_hat, y) / len(x)
        test_loss += loss

        all_losses.append(test_loss)
        # plt.figure()
        # plt.plot(all_losses)

        correct += torch.sum(torch.argmax(y_hat, dim=1) == y).item()
        total += len(x)
    print(f"Test loss: {test_loss:.2f}")
    print(f"Test accuracy: {correct / total * 100:.2f}%")

In [None]:
model.load_state_dict(torch.load('trained-vit_scratch_MNIST.pt'))
test_ViT_classify(model, optimizer,test_loader)

Test loss: 92.12
Test accuracy: 10.28%


In [None]:
#import matplotlib.pyplot as plt
# plt.figure()
# plt.plot(test_ViT_classify.all_losses)

In [None]:
# model.load_state_dict(torch.load('trained-vit_scratch_MNIST.pt'))

# criterion = CrossEntropyLoss()
# correct, total = 0, 0
# test_loss = 0.0
# all_losses = []
# for batch in test_loader:
#     x, y = batch
#     x = x.to(device)
#     y = y.to(device)
    
#     y_hat = model(x)
#     loss = criterion(y_hat, y) / len(x)
#     test_loss += loss

#     # all_losses.append(loss)
#     # plt.figure()
#     # plt.plot(all_losses)

#     correct += torch.sum(torch.argmax(y_hat, dim=1) == y).item()
#     total += len(x)
# print(f"Test loss: {test_loss:.2f}")
# print(f"Test accuracy: {correct / total * 100:.2f}%")

# Pretrained ViT

In [None]:
!pip install vit-pytorch



In [None]:
import torch 
from vit_pytorch import ViT

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)
# print(preds)

In [None]:
import torch
from torchvision.models import resnet50

from vit_pytorch.distill import DistillableViT, DistillWrapper

teacher = resnet50(pretrained = True)

v = DistillableViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

distiller = DistillWrapper(
    student = v,
    teacher = teacher,
    temperature = 3,           # temperature of distillation
    alpha = 0.5,               # trade between main loss and distillation loss
    hard = False               # whether to use soft or hard distillation
)

img = torch.randn(2, 3, 256, 256)
labels = torch.randint(0, 1000, (2,))

loss = distiller(img, labels)
loss.backward()

# after lots of training above ...

pred = v(img) # (2, 1000)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

### Pretrain on ImageNet Finetune on CIFAR10

In [None]:
from torchvision.datasets.cifar import CIFAR10

transform = ToTensor()

train_set_cifar = CIFAR10(root = './../datasets', train=True, download = True, transform=transform)
test_set_cifar = CIFAR10(root = './../datasets', train=False, download = True, transform=transform)

train_loader_cifar = DataLoader(train_set_cifar, shuffle = True, batch_size = 16)
test_loader_cifar = DataLoader(test_set_cifar, shuffle = True, batch_size = 16)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
from vit_pytorch import ViT


device = "cpu"
model = ViT( image_size = 224, patch_size = 16, num_classes = 1000, dim = 1024,
depth = 6,
heads = 16, mlp_dim = 2048, dropout = 0.1, emb_dropout = 0.1
).to(device)


model.load_state_dict(torch.load('trained-pytorch-vit-imagenet.pt', map_location=torch.device('cpu')))
test_ViT_classify(model, optimizer,test_loader_cifar)

In [None]:
# device = "cpu"
# model = ViT((1, 28, 28), n_patches=7, hidden_d=20, n_heads=2, out_d=10)
# model = model.to(device)

# N_EPOCHS = 5
# LR = 0.01
# optimizer = Adam(model.parameters(), lr=LR)

# model.load_state_dict(torch.load('trained-vit_scratch_MNIST.pt'))
# test_ViT_classify(model, optimizer,test_loader)

Test loss: 90.09
Test accuracy: 15.22%
