# Handwritten Image Classification with Vision Transformers

In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, random_split, DataLoader

from torchvision import datasets, transforms, models
import torchvision.transforms as transforms
from torch.utils.data import Dataset, random_split, DataLoader
from torchvision.utils import save_image

from torchsummary import summary

import spacy

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import os
import time
import math
from PIL import Image
import glob
from IPython.display import display

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
BATCH_SIZE = 64
LR = 1e-4
NUM_EPOCHES = 100

## Preprocessing

In [4]:
mean, std = (0.5,), (0.5,)

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(mean, std)
                              ])

In [5]:
trainset = datasets.MNIST('data/MNIST/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)

testset = datasets.MNIST('data/MNIST/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)

In [8]:
class ViT(nn.Module):
    def __init__(self, image_size, channel_size, patch_size, embed_size, classes, dropout=0.1):
        super().__init__()
        
        self.p = patch_size
        self.image_size = image_size
        self.embed_size = embed_size
        self.num_patches = (image_size//patch_size)**2
        self.patch_size = channel_size*patch_size**2
        self.classes = classes
        self.dropout = nn.Dropout(dropout)
        
        self.embeddings = nn.Linear(self.patch_size, self.embed_size)
        self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_size))
        
        self.ff = nn.Sequential(
            nn.Linear(self.embed_size, self.classes)
        )
        
    def forward(self, x):
        b, c, h, w = x.size()        
        
        x = x.reshape(b, int((h/self.p)*(w/self.p)), c*self.p*self.p)
        
        x = self.embeddings(x)

        b, n, e = x.size()
        
        cls_tokens = self.cls_token.expand(b, 1, e)
        
        print(cls_tokens.size())
        return x

In [9]:
image_size = 28
channel_size = 1
patch_size = 7
embed_size = 512
classes = 10

model = ViT(image_size, channel_size, patch_size, embed_size, classes).to(device)
model

ViT(
  (dropout): Dropout(p=0.1, inplace=False)
  (embeddings): Linear(in_features=49, out_features=512, bias=True)
  (ff): Sequential(
    (0): Linear(in_features=512, out_features=10, bias=True)
  )
)

In [10]:
for img, label in trainloader:
    img = img.to(device)
    label = label.to(device)
    
    print(img.size())
    print(label.size())
    print("-"*100)
    
    out = model(img)
    
    print(out.size())
    break

torch.Size([64, 1, 28, 28])
torch.Size([64])
----------------------------------------------------------------------------------------------------
torch.Size([64, 1, 512])
torch.Size([64, 16, 512])
