In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch.nn.functional as F
from tqdm import tqdm

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

num_classes = 10
num_epochs = 10
batch_size = 16
learning_rate = 0.01

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_dataset = torchvision.datasets.CIFAR10(root = './data', download = True, train = True, transform = transform)
test_dataset = torchvision.datasets.CIFAR10(root = './data', train = False, transform = transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 80764732.21it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


In [3]:
train_loader = torch.utils.data.DataLoader(dataset = train_dataset, batch_size = batch_size, shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset = test_dataset, batch_size = batch_size, shuffle = False)

In [4]:
class self_attention_cnn(nn.Module):

    def __init__(self, old, new, gamma):
        super(self_attention_cnn, self).__init__()
        self.query = nn.Linear(old, new)
        self.key = nn.Linear(old, new)
        self.value = nn.Linear(old, new)
        self.residue = nn.Linear(new, old)
        self.gamma = gamma

    def forward(self, x):

        ret = torch.zeros(x.shape).to(device)

        for i in range(x.shape[0]):
            size = x[i].shape

#             print(x[i].shape)

            f = self.query((x[i].permute(1, 2, 0)).view(-1, *size[0:1]))
            g = self.key((x[i].permute(1, 2, 0)).view(-1, *size[0:1]))
            h = self.value((x[i].permute(1, 2, 0)).view(-1, *size[0:1]))
            qkt = f@g.transpose(0,1)
            sm = F.softmax(qkt, dim = 1)
            y = sm@h
            transformed_y = self.residue(y)
            o = self.gamma*transformed_y + (x[i].permute(1, 2, 0)).view(-1, *size[0:1])
            ret[i] = o.view(tuple(size[i] for i in [1,2,0])).permute(2, 0, 1)

        return ret

class conv_sa(nn.Module):
    def __init__(self):
        super(conv_sa, self).__init__()
        
        self.conv1 = nn.Conv2d(3, 32, 5)  
        self.Relu = nn.ReLU()
        self.SA1 = self_attention_cnn(32, 4, 1)
        self.conv2 = nn.Conv2d(32, 64, 5)
        self.SA2 = self_attention_cnn(64, 8, 1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Conv2d(64, 128, 5)
        self.SA3 = self_attention_cnn(128, 16, 1)
        self.conv4 = nn.Conv2d(128, 256, 5)
        self.SA4 = self_attention_cnn(256, 32, 1)
        
        self.gap = nn.AvgPool2d(4,4)
        
        self.l1 = nn.Linear(256, 120)
        self.l2 = nn.Linear(120, 84)
        self.l3 = nn.Linear(84, 10)
        
    def forward(self, x):
        out = self.Relu(self.conv1(x))
        out = self.SA1(out)
        out = self.Relu(self.conv2(out))
        out = self.SA2(out)
        out = self.pool(out)
        out = self.Relu(self.conv3(out))
        out = self.SA3(out)
        out = self.Relu(self.conv4(out))
        out = self.SA4(out)
        out = self.gap(out)
        out = out.view((batch_size, 256))
        
        out = self.Relu(self.l1(out))
        out = self.Relu(self.l2(out))
        out = self.l3(out)
        return out
    

Model = conv_sa().to(device)

In [5]:
Model

conv_sa(
  (conv1): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1))
  (Relu): ReLU()
  (SA1): self_attention_cnn(
    (query): Linear(in_features=32, out_features=4, bias=True)
    (key): Linear(in_features=32, out_features=4, bias=True)
    (value): Linear(in_features=32, out_features=4, bias=True)
    (residue): Linear(in_features=4, out_features=32, bias=True)
  )
  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (SA2): self_attention_cnn(
    (query): Linear(in_features=64, out_features=8, bias=True)
    (key): Linear(in_features=64, out_features=8, bias=True)
    (value): Linear(in_features=64, out_features=8, bias=True)
    (residue): Linear(in_features=8, out_features=64, bias=True)
  )
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(64, 128, kernel_size=(5, 5), stride=(1, 1))
  (SA3): self_attention_cnn(
    (query): Linear(in_features=128, out_features=16, bias=True)
    (key): Linear(in_features=128, out_fea

In [6]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(Model.parameters(), lr =learning_rate)


total_steps = len(train_loader)

for epoch in range(num_epochs):
    for i , (images, label) in enumerate(train_loader):

        images = images.to(device)
        label = label.to(device)

        outputs = Model(images)
        loss = criterion(outputs, label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
#         print(i)

        if(i+1) % 1000 == 0:
            print(f'epoch {epoch+1} / {num_epochs}, step {i+1}/{total_steps}, loss = {loss.item():5f}')

epoch 1 / 10, step 1000/3125, loss = 2.291237
epoch 1 / 10, step 2000/3125, loss = 2.012730
epoch 1 / 10, step 3000/3125, loss = 1.949450
epoch 2 / 10, step 1000/3125, loss = 1.730939
epoch 2 / 10, step 2000/3125, loss = 1.922977
epoch 2 / 10, step 3000/3125, loss = 1.479888
epoch 3 / 10, step 1000/3125, loss = 1.648847
epoch 3 / 10, step 2000/3125, loss = 1.814000
epoch 3 / 10, step 3000/3125, loss = 1.129269
epoch 4 / 10, step 1000/3125, loss = 1.220394
epoch 4 / 10, step 2000/3125, loss = 1.048416
epoch 4 / 10, step 3000/3125, loss = 1.024386
epoch 5 / 10, step 1000/3125, loss = 1.205450
epoch 5 / 10, step 2000/3125, loss = 0.730433
epoch 5 / 10, step 3000/3125, loss = 1.156644
epoch 6 / 10, step 1000/3125, loss = 1.245951
epoch 6 / 10, step 2000/3125, loss = 0.982236
epoch 6 / 10, step 3000/3125, loss = 1.050127
epoch 7 / 10, step 1000/3125, loss = 0.656318
epoch 7 / 10, step 2000/3125, loss = 0.557695
epoch 7 / 10, step 3000/3125, loss = 0.679904
epoch 8 / 10, step 1000/3125, loss

In [7]:
with torch.no_grad():

    n_c = 0
    n_t = 0
    
    op = 0

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = Model(images)
#         op = outputs
#         break

        _, predictions = torch.max(outputs.data, 1)
        n_t += labels.shape[0]
        n_c += (predictions == labels).sum().item()

    acc = 100*n_c/n_t
    print(f'acc: {acc:4f}')

acc: 81.682000


In [8]:
with torch.no_grad():

    n_c = 0
    n_t = 0
    
    op = 0

    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = Model(images)
#         op = outputs
#         break

        _, predictions = torch.max(outputs.data, 1)
        n_t += labels.shape[0]
        n_c += (predictions == labels).sum().item()

    acc = 100*n_c/n_t
    print(f'acc: {acc:4f}')

acc: 73.750000


In [72]:
class patchify(nn.Module):
    def __init__(self, patch_size, embedding_size):
        super(patchify, self).__init__()
        
        self.p = patch_size
        self.lin = nn.Linear(3*(patch_size**2), embedding_size)
        
    def forward(self, x):
        
        b, c, h, w = x.size()
        out = x.view(b, c, h//self.p, self.p, w//self.p, self.p)
        out = out.permute(0, 2, 4, 1, 3, 5).contiguous()
        out = out.view(b, -1, c*(self.p**2))
        out = self.lin(out)
        
        return out

class single_self_attention(nn.Module):
    
    def __init__(self, D):
        super(single_self_attention, self).__init__()
        
        self.Normalizer = nn.LayerNorm(normalized_shape = D, device = device)
        self.query = nn.Linear(D, D//8, device = device)
        self.key = nn.Linear(D, D//8, device = device)
        self.value = nn.Linear(D, D//8, device = device)
        self.residue = nn.Linear(D//8, D, device = device)
        
    def forward(self, x):
        
        out = self.Normalizer(x)          
        q = self.query(out)
        k = self.key(out)
        
        qkt = q@k.transpose(0,1)
        qktsm = F.softmax(qkt, dim = 1)
        v = self.value(out)
        
        y = qktsm@v
        o = self.residue(y)
        
        return o       

class self_attention_multihead(nn.Module):
    def __init__(self, embedding_size, heads):
        super(self_attention_multihead, self).__init__()
        
        self.attention_heads = []
        self.embedding_size = embedding_size
        self.heads = heads
        
        for i in range(heads):            
            self.attention_heads.append(single_self_attention(embedding_size//heads))
            
        
    def forward(self, x):
        
        ret = torch.zeros(x.shape).to(device)
        
        for i in range(x.shape[0]):
            
            size = x[i].shape

            multihead_ret = torch.zeros(size[0], self.heads, self.embedding_size//self.heads).to(device)

            temp = x[i].view(size[0], self.heads, self.embedding_size//self.heads)        
            for j in range(self.heads):   
                multihead_ret[:,j,:] = self.attention_heads[j](temp[:,j,:])

            ret[i] = multihead_ret.view(size) + temp.view(size)
        
        return ret
    
class encoder_block(nn.Module):
    def __init__(self, embedding_size, heads, MLP_hidden_size):
        super(encoder_block, self).__init__()
        
        self.MHSA = self_attention_multihead(embedding_size, heads)
        self.Normalizer = nn.LayerNorm(normalized_shape = embedding_size)
        self.Lin1 = nn.Linear(embedding_size, MLP_hidden_size)
        self.Lin2 = nn.Linear(MLP_hidden_size, embedding_size)
        self.Relu = nn.ReLU()
        
    def forward(self, x):
        
        out = self.MHSA(x)
        out = self.Normalizer(out)
        out = self.Relu(self.Lin1(out))
        out = self.Relu(self.Lin2(out))
        
        return out

class ViTLike(nn.Module):
    def __init__(self, patch_size, embedding_size, heads):
        super(ViTLike, self).__init__()
        
        self.patchit = patchify(patch_size, embedding_size)
        
        self.clse = nn.Parameter(data=torch.randn(1, 1, embedding_size), requires_grad=True)
        
        self.pos = nn.Parameter(data=torch.randn(1, 65, embedding_size),
                                               requires_grad=True)
        
        self.eb1 = encoder_block(embedding_size, heads, 2000)
        self.eb2 = encoder_block(embedding_size, heads, 2000)
        self.eb3 = encoder_block(embedding_size, heads, 2000)
        
        self.l1 = nn.Linear(embedding_size, 256)
        self.l2 = nn.Linear(256, 84)
        self.l3 = nn.Linear(84, 10)
        self.Relu = nn.ReLU()
        
        
    def forward(self, x):
        
        class_token = self.clse.expand(batch_size, -1, -1)
        
        out = self.patchit(x)
        
        out = torch.cat((class_token, out), dim=1)
        out = self.pos + out                
        
        out = self.eb1(out)
        out = self.eb2(out)
        out = self.eb3(out)
        
        out = out[:,0,:]
        
        out = self.Relu(self.l1(out))
        out = self.Relu(self.l2(out))
        out = self.l3(out)
        
        return out
    
Model2 = ViTLike(4, 128, 4).to(device)

In [73]:
Model2

ViTLike(
  (patchit): patchify(
    (lin): Linear(in_features=48, out_features=128, bias=True)
  )
  (eb1): encoder_block(
    (MHSA): self_attention_multihead()
    (Normalizer): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (Lin1): Linear(in_features=128, out_features=2000, bias=True)
    (Lin2): Linear(in_features=2000, out_features=128, bias=True)
    (Relu): ReLU()
  )
  (eb2): encoder_block(
    (MHSA): self_attention_multihead()
    (Normalizer): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (Lin1): Linear(in_features=128, out_features=2000, bias=True)
    (Lin2): Linear(in_features=2000, out_features=128, bias=True)
    (Relu): ReLU()
  )
  (eb3): encoder_block(
    (MHSA): self_attention_multihead()
    (Normalizer): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (Lin1): Linear(in_features=128, out_features=2000, bias=True)
    (Lin2): Linear(in_features=2000, out_features=128, bias=True)
    (Relu): ReLU()
  )
  (l1): Linear(in_features=128,

In [75]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(Model2.parameters(), lr =learning_rate)
num_epochs = 5


total_steps = len(train_loader)

for epoch in tqdm(range(num_epochs)):
    for i , (images, label) in enumerate(train_loader):

        images = images.to(device)
        label = label.to(device)

        outputs = Model2(images)
        loss = criterion(outputs, label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
#         print(i)

        if(i+1) % 1000 == 0:
            print(f'epoch {epoch+1} / {num_epochs}, step {i+1}/{total_steps}, loss = {loss.item():5f}')

  0%|          | 0/5 [00:00<?, ?it/s]

epoch 1 / 5, step 1000/3125, loss = 1.852911
epoch 1 / 5, step 2000/3125, loss = 1.674202
epoch 1 / 5, step 3000/3125, loss = 1.575201


 20%|██        | 1/5 [12:14<48:59, 734.78s/it]

epoch 2 / 5, step 1000/3125, loss = 1.469540
epoch 2 / 5, step 2000/3125, loss = 1.704446
epoch 2 / 5, step 3000/3125, loss = 1.633661


 40%|████      | 2/5 [24:31<36:48, 736.16s/it]

epoch 3 / 5, step 1000/3125, loss = 1.363106
epoch 3 / 5, step 2000/3125, loss = 1.922007
epoch 3 / 5, step 3000/3125, loss = 1.220276


 60%|██████    | 3/5 [36:47<24:32, 736.08s/it]

epoch 4 / 5, step 1000/3125, loss = 1.577268
epoch 4 / 5, step 2000/3125, loss = 1.934319
epoch 4 / 5, step 3000/3125, loss = 1.212567


 80%|████████  | 4/5 [49:06<12:16, 736.98s/it]

epoch 5 / 5, step 1000/3125, loss = 1.498479
epoch 5 / 5, step 2000/3125, loss = 1.787369
epoch 5 / 5, step 3000/3125, loss = 1.596136


100%|██████████| 5/5 [1:01:17<00:00, 735.59s/it]


In [76]:
with torch.no_grad():

    n_c = 0
    n_t = 0
    
    op = 0

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = Model2(images)

        _, predictions = torch.max(outputs.data, 1)
        n_t += labels.shape[0]
        n_c += (predictions == labels).sum().item()

    acc = 100*n_c/n_t
    print(f'acc: {acc:4f}')

acc: 50.564000


In [77]:
with torch.no_grad():

    n_c = 0
    n_t = 0
    
    op = 0

    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = Model2(images)
#         op = outputs
#         break

        _, predictions = torch.max(outputs.data, 1)
        n_t += labels.shape[0]
        n_c += (predictions == labels).sum().item()

    acc = 100*n_c/n_t
    print(f'acc: {acc:4f}')

acc: 49.050000


We can see that the accuracy of the Vision Transformer like model is lesser than that of the model involving convolutions and self attention.

It is known that encoder only models are data hungry and require a huge amount of data for them to overpower the convolutional models.

Also, only 5 epochs have been done for the vision transformer model owing for its lesser accuracy.