<a href="https://colab.research.google.com/github/Elman295/CvT-Introducing-Convolutions-to-Vision-Transformers/blob/main/CvT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import nn

#`Data`

In [2]:
train_ds = datasets.MNIST(
    root = "data",
    train = True,
    download = True,
    transform = transforms.ToTensor()
)

test_ds = datasets.MNIST(
    root = "data",
    train = False,
    download = True,
    transform = transforms.ToTensor()
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 11837613.57it/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 349372.11it/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 3188507.27it/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 4230630.42it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw






In [3]:
train_dl = DataLoader(dataset = train_ds, batch_size = 32, shuffle = True)

test_dl = DataLoader(dataset = test_ds, batch_size = 32, shuffle = False)

#`Model`

In [4]:
class Convolutional_Transformer_Block(nn.Module):

  def __init__(self,c,w,h):
    super(Convolutional_Transformer_Block, self).__init__()

    self.conv_q = nn.Conv2d(c,c,1,1)
    self.conv_k = nn.Conv2d(c,c,1,1)
    self.conv_v = nn.Conv2d(c,c,1,1)
    self.att = nn.MultiheadAttention(w*h,w // 2)
    self.norm = nn.LayerNorm(w*h)

    self.mlp = nn.Sequential(
        nn.Linear(w*h,256),
        nn.GELU(),
        nn.Linear(256,w*h)
    )


  def forward(self, x):

    b,c,w,h = x.shape

    q = self.conv_q(x)
    k = self.conv_k(x)
    v = self.conv_v(x)

    q = q.view(-1,w*h)
    k = k.view(-1,w*h)
    v = v.view(-1,w*h)
    att,_ = self.att(q,k,v)
    r1 = att.view(-1,c,w,h) + x
    r1_norm = self.norm(r1.view(-1,w*h))
    r2 = self.mlp(r1_norm).view(-1,c,w,h) + r1


    return r2

In [5]:
block = Convolutional_Transformer_Block(100,14,14)
x = torch.rand(size = (16,100,14,14))
y = block(x)
print(y.shape)

torch.Size([16, 100, 14, 14])


In [6]:
class CvT(nn.Module):

  def __init__(self):
    super(CvT, self).__init__()

    self.CTE_1 = nn.Conv2d(1,16,5,1)
    self.pool_1 = nn.MaxPool2d(2,2)
    self.block_1 = Convolutional_Transformer_Block(16,12,12)


    self.CTE_2 = nn.Conv2d(16,32,3,1)
    # self.pool_2 = nn.MaxPool2d(2,2)
    self.block_2 = Convolutional_Transformer_Block(32,10,10)

    self.conv_cls = nn.Sequential(
        nn.Conv2d(32,64,3,1),
        nn.ReLU(),
        nn.MaxPool2d(2,2)
    )

    self.mlp_head = nn.Sequential(
        nn.Linear(64*4*4,512),
        nn.ReLU(),
        nn.Linear(512,10)
    )


  def forward(self, x):

    f1 = self.block_1(self.pool_1(self.CTE_1(x)))
    f2 = self.block_2(self.CTE_2(f1))
    f3 = self.conv_cls(f2)
    f3 = f3.view(-1,64*4*4)
    f4 = self.mlp_head(f3)

    return f4



In [7]:
model = CvT()
x = torch.rand(size = (16,1,28,28))
y = model(x)
print(y.shape)

torch.Size([16, 10])


In [8]:
loss_fn = nn.CrossEntropyLoss()
opt = torch.optim.Adam(params = model.parameters(), lr = 1e-3)

In [9]:
def train(data, model, loss_fn, opt):

  model.train()
  size = len(data.dataset)

  for b, (x,y) in enumerate(data):
    y_pred = model(x)
    loss = loss_fn(y_pred, y)
    opt.zero_grad()
    loss.backward()
    opt.step()


    if b % 100 == 0:
      print(f"loss:{loss.item()} [{b*len(x)} | {size}]")




In [10]:
def test(data, model, loss_fn):

  model.eval()
  size = len(data.dataset)
  num_batch = len(data)
  test_loss, test_acc = 0,0

  with torch.no_grad():
    for x,y in data:
      y_pred = model(x)
      test_loss += loss_fn(y_pred, y).item()
      test_acc += (y_pred.argmax(1) == y).type(torch.float).sum().item()


    test_loss /= num_batch
    test_acc /= size

    print(f"test loss :{test_loss} accuracy:{test_acc * 100}")


In [11]:
for e in range(10):
  print(f"epoch:{e+1}=-=-=-=-=")
  train(train_dl, model, loss_fn, opt)
  test(test_dl, model,loss_fn)

epoch:1=-=-=-=-=
loss:2.3080899715423584 [0 | 60000]
loss:0.14006751775741577 [3200 | 60000]
loss:0.016916001215577126 [6400 | 60000]
loss:0.1797446459531784 [9600 | 60000]
loss:0.1605663150548935 [12800 | 60000]
loss:0.2651771605014801 [16000 | 60000]
loss:0.038033030927181244 [19200 | 60000]
loss:0.057509973645210266 [22400 | 60000]
loss:0.12970760464668274 [25600 | 60000]
loss:0.1586177945137024 [28800 | 60000]
loss:0.19121892750263214 [32000 | 60000]
loss:0.08086230605840683 [35200 | 60000]
loss:0.06593361496925354 [38400 | 60000]
loss:0.1956416815519333 [41600 | 60000]
loss:0.09584520757198334 [44800 | 60000]
loss:0.10751969367265701 [48000 | 60000]
loss:0.052194900810718536 [51200 | 60000]
loss:0.025682460516691208 [54400 | 60000]
loss:0.1555529236793518 [57600 | 60000]
test loss :0.07118231057897567 accuracy:97.76
epoch:2=-=-=-=-=
loss:0.006538978777825832 [0 | 60000]
loss:0.005823739804327488 [3200 | 60000]
loss:0.03055047243833542 [6400 | 60000]
loss:0.008304054848849773 [9600