# Handwritten Image Classification with Vision Transformers

In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, random_split, DataLoader

from torchvision import datasets, transforms, models
import torchvision.transforms as transforms
from torch.utils.data import Dataset, random_split, DataLoader
from torchvision.utils import save_image

from torchsummary import summary

import spacy

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import os
import time
import math
from PIL import Image
import glob
from IPython.display import display

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
BATCH_SIZE = 64
LR = 1e-4
NUM_EPOCHES = 100

## Preprocessing

In [4]:
mean, std = (0.5,), (0.5,)

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(mean, std)
                              ])

In [5]:
trainset = datasets.MNIST('data/MNIST/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)

testset = datasets.MNIST('data/MNIST/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)

In [18]:
#class VisionAttention(nn.Module):
#    def __init__(self):
#        super(VisionAttention, self).__init__()
#        
#    def forward(self, x):
#        return x
    
class VisionEncoder(nn.Module):
    def __init__(self, embed_size, num_heads, dropout=0.1):
        super(VisionEncoder, self).__init__()
    
        self.embed_size = embed_size
        self.num_heads = num_heads
        
        self.norm1 = None
        self.norm2 = None
        
        #self.attention = VisionAttention()
    
    def forward(self, i):
        print("hello", i.size())
        i = i + self.attention(self.norm1(i))
        i = i + self.ff(self.norm2(i))
        return i
    
class ViT(nn.Module):
    def __init__(self, image_size, channel_size, patch_size, embed_size, num_heads, classes, num_layers, dropout=0.1):
        super(ViT, self).__init__()
        
        self.p = patch_size
        self.image_size = image_size
        self.embed_size = embed_size
        self.num_patches = (image_size//patch_size)**2
        self.patch_size = channel_size*patch_size**2
        self.num_heads = num_heads
        self.classes = classes
        self.num_layers = num_layers
        self.dropout = nn.Dropout(dropout)
        
        self.embeddings = nn.Linear(self.patch_size, self.embed_size)
        self.class_token = nn.Parameter(torch.randn(1, 1, self.embed_size))
        
        self.temp = VisionEncoder(self.embed_size, self.num_heads, dropout)
        
        self.encoders = nn.ModuleList([])
        for layer in range(self.num_layers):
            temp = nn.ModuleList([
                VisionEncoder(self.embed_size, self.num_heads, dropout)
            ])
            self.encoders.append(temp)
        
        self.ff = nn.Sequential(
            nn.Linear(self.embed_size, self.classes)
        )
        
    def forward(self, x):
        b, c, h, w = x.size()        
        
        x = x.reshape(b, int((h/self.p)*(w/self.p)), c*self.p*self.p)
        x = self.embeddings(x)

        b, n, e = x.size()
        class_tokens = self.class_token.expand(b, 1, e)
        
        x = torch.cat((x, class_tokens), dim=1)
        
        #for encoder in self.encoders:
        #    print(len(encoder))
        #    x = encoder(x)
        
        x = self.temp(x)
        
        return x

In [19]:
image_size = 28
channel_size = 1
patch_size = 7
embed_size = 512
num_heads = 8
classes = 10
num_layers = 2

model = ViT(image_size, channel_size, patch_size, embed_size, num_heads, classes, num_layers).to(device)
model

ViT(
  (dropout): Dropout(p=0.1, inplace=False)
  (embeddings): Linear(in_features=49, out_features=512, bias=True)
  (temp): VisionEncoder()
  (encoders): ModuleList(
    (0): ModuleList(
      (0): VisionEncoder()
    )
    (1): ModuleList(
      (0): VisionEncoder()
    )
  )
  (ff): Sequential(
    (0): Linear(in_features=512, out_features=10, bias=True)
  )
)

In [20]:
for img, label in trainloader:
    img = img.to(device)
    label = label.to(device)
    
    print(img.size())
    print(label.size())
    print("-"*100)
    
    out = model(img)
    
    print(out.size())
    break

torch.Size([64, 1, 28, 28])
torch.Size([64])
----------------------------------------------------------------------------------------------------
hello torch.Size([64, 17, 512])


ModuleAttributeError: 'VisionEncoder' object has no attribute 'attention'