In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import cv2
import math
from tqdm import tqdm
from torchsummary import summary

In [2]:
import torch
from torch import nn
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader

## Compiling the dataset

In [50]:
IMG_SIZE = 64

In [51]:
training_dataset_path = '../ocr_dataset/new_data/training_data/'
testing_dataset_path = '../ocr_dataset/new_data/testing_data/'

In [52]:
characters = os.listdir(training_dataset_path)

In [53]:
labels_dict = dict()
for each in range(len(characters)):
    labels_dict[characters[each]] = each

In [54]:
num_labels = len(characters)
num_labels

36

In [55]:
def data_compiler(dataset_path):
    dataset = []

    for character in characters:
        character_path = os.path.join(dataset_path, character)
        # print(character_path)

        for character_image in os.listdir(character_path):
            image_path = os.path.join(character_path, character_image)
            # print(image_path)

            img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
            # resizing the image
            img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
            # plt.imshow(img)
            # converting to tensor
            img = torch.tensor(img)
            
            # reshaping the tensor
            img = img.view(1, IMG_SIZE, IMG_SIZE)
            # scaling the data
            img = (img/255.0)

            # print(img.shape)
            dataset.append([img, np.eye(num_labels)[labels_dict[character]]])
            
    return dataset


In [56]:
training_dataset = data_compiler(training_dataset_path)
len(training_dataset)

20628

In [57]:
testing_dataset = data_compiler(testing_dataset_path)
len(testing_dataset)

1008

## Vision Transformer Implementation

In [11]:
def patchify(images, patch_size):
    '''
    images shape: (batch_size, num_channels=1, img_size, img_size)
    number of patches = img_size*img_size // (patch_size*patch_size)
    output => patches shape: (batch_size, num_patches, patch_size*patch_size)
    '''
    batch_size, num_channels, img_size, img_size = images.shape
    num_patches = img_size*img_size*num_channels // (patch_size*patch_size)
    sqrt_num_patches = img_size // patch_size
    
    output = torch.zeros(batch_size, num_patches, patch_size*patch_size)
    index = 0
    for image in images[:]:
        for c in range(num_channels):
            for h in range(sqrt_num_patches):
                for w in range(sqrt_num_patches):
                    patch = image[c, h*patch_size:(h+1)*patch_size, w*patch_size:(w+1)*patch_size].flatten()
                    output[index][c*num_patches + h*sqrt_num_patches+w] = patch
        index+=1
    
    # print(output.shape)
    return output


In [12]:
def generate_positional_embedding(sequence_length, embed_size):
    '''
    sequence length in our case: no. of patches + 1 (due to cls token)
    '''
    output = torch.zeros((sequence_length, embed_size))
    
    for i in range(sequence_length):
        for j in range(embed_size):
            if j%2 == 0:
                output[i][j] = np.sin( i / (10000**(j/embed_size)) )
            else:
                output[i][j] = np.cos( i / (10000**((j-1)/embed_size)) )
    return output


In [13]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_shape, patch_size, embed_size):
        super().__init__()
        self.patch_size = patch_size
        
        batch_size, num_channels, img_size, img_size = img_shape
        num_patches = img_size*img_size*num_channels // (patch_size*patch_size)
    
        # linear mapping: num_channels * patch_size * patch_size => embed_size
        input_size = num_channels*patch_size*patch_size
        
        self.linear_mapping = nn.Linear(input_size, embed_size)
        
        self.cls_token = nn.Parameter(torch.rand(1, embed_size))
        
        self.positional_embedding = nn.Parameter(
            generate_positional_embedding(num_patches+1, embed_size)
        )
        self.positional_embedding.requires_grad = False
        
        '''
        In Vision Transformers, the "CLS token" is required to provide
        a global representation of the image for classification tasks,
        allowing the model to leverage the power of the Transformer architecture
        for image analysis and recognition.
        '''
        
    def forward(self, images):
        '''
        input: torch tensor of image batch
        '''
        batch_size, num_channels, img_size, img_size = images.shape
        patches = patchify(images, self.patch_size)
        embeddings = self.linear_mapping(patches)    # tokens
        
        # adding cls tokens to tokens
        tokens = torch.stack([torch.vstack([self.cls_token, embeddings[each]]) for each in range(len(embeddings))])
        
        # adding positional embedding
        positional_embedding = self.positional_embedding.repeat(batch_size, 1, 1) 
        out = tokens + positional_embedding
        
        # print(f'patch embedding output shape: {out.shape}')
        return out
        

In [14]:
class Attention(nn.Module):
    def __init__(self, embed_size, num_heads):
        super().__init__()
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.head_dim = embed_size // num_heads
        
        assert self.head_dim*num_heads == embed_size, 'embedding size should be divisible by the number of heads'
        
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(embed_size, embed_size)
        

    def forward(self, values, keys, queries, mask=None):
        batch_size = queries.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]
        
        # multi-head splitting
        values = values.reshape(batch_size, value_len, self.num_heads, self.head_dim)
        keys = keys.reshape(batch_size, key_len, self.num_heads, self.head_dim)
        queries = queries.reshape(batch_size, query_len, self.num_heads, self.head_dim)
        
        '''
        queries and keys multiplied when finding attention
        n -> batch size
        h -> no. of heads
        q,k -> sequence length for queries and keys
        '''
        energy = torch.einsum('nqhd, nkhd -> nhqk', [queries, keys])
        
        if mask is not None:
            # keeping a small value to be replaced by zero
            energy = energy.masked_fill(mask == 0, float('-1e20'))
        
        # normalizing
        attention = torch.softmax(energy/(self.embed_size**0.5), dim=-1)
        
        '''
        attention shape: (batch size, no. of heads, query sequence length, key sequence length)
        values shape: (batch size, value sequence length, no. of heads, head dimension)
        output shape required: (batch size, query sequence length, no. of heads, head dimension)
        '''
        
        out = torch.einsum('nhql,nlhd -> nqhd', [attention, values])
        '''
        n -> batch size
        h -> no. of heads
        q -> sequence length for queries
        l -> sequence length for keys or values
        d -> head dimension
        '''
        
        # flattening the last two dimenstions
        out = out.reshape(batch_size, query_len, self.num_heads*self.head_dim)
        
        out = self.fc_out(out)
        
        # print(f'attention output shape: {out.shape}')
        return out
        

In [15]:
class FeedForwardBlock(nn.Sequential):
    def __init__(self, embed_size, expansion=4, drop_prob=0):
        super().__init__(
            nn.Linear(embed_size, embed_size*expansion),
            nn.GELU(),
            nn.Dropout(drop_prob),
            nn.Linear(embed_size*expansion, embed_size)        
        )


In [16]:
class EncoderBlock(nn.Module):
    def __init__(self, embed_size, num_heads, drop_prob=0, forward_expansion=4, forward_drop_prob=0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_size)
        self.attention = Attention(embed_size, num_heads)
        self.dropout1 = nn.Dropout(drop_prob)

        self.norm2 = nn.LayerNorm(embed_size)
        self.feedforward = FeedForwardBlock(embed_size, expansion=forward_expansion, drop_prob=forward_drop_prob)
        self.dropout2 = nn.Dropout(drop_prob)
        
    def forward(self, x):
        # making copy of the tensor for residual connection
        x_copy = x.detach().clone()
        x_copy = self.norm1(x_copy)
        x_copy = self.attention(x_copy, x_copy, x_copy)
        x_copy = self.dropout1(x_copy)
        x = x + x_copy
        
        # making copy of the tensor for residual connection
        x_copy = x.detach().clone()
        x_copy = self.norm2(x_copy)
        x_copy = self.feedforward(x_copy)
        x_copy = self.dropout2(x_copy)
        x = x + x_copy
        
        return x
        

In [41]:
class VisionTransformer(nn.Module):
    def __init__(self, img_shape, patch_size, num_heads, embed_size, encoder_depth, num_classes):
        '''
        img_shape: shape of the batch of images passed
        '''
        super().__init__()
        self.patch_embedding = PatchEmbedding(
            img_shape=img_shape,
            patch_size=patch_size,
            embed_size=embed_size
        )
        
        self.encoder = nn.ModuleList(
            [
                EncoderBlock(embed_size, num_heads) for _ in range(encoder_depth)
            ]
        )
        
        # classification head
        self.classification_mlp = nn.Sequential(
            nn.Linear(embed_size, num_classes),
            nn.Softmax(dim=-1)
        )
        
    def forward(self, images):
        out = self.patch_embedding(images)
        
        for encoderblocks in self.encoder:
            out = encoderblocks(out)
            
        # getting the classification tokens
        out = out[:, 0]
        out = self.classification_mlp(out)
        
        return out
    
    def save(self, path='best_model.pth'):
        torch.save(self.state_dict(), path)

    def load(self, path='best_model.pth'):
        self.load_state_dict(torch.load(path))
        self.eval()
        
    def fit(self, train_loader, optimizer=Adam, loss_function=CrossEntropyLoss(), learning_rate=0.001, epochs=5, tqdm_show=False):
        min_loss = float('inf')
        self.optimizer = optimizer(self.parameters(), lr=learning_rate)
        self.loss_func = loss_function
        self.epochs = epochs
        
        if tqdm_show:
            for epoch in tqdm(range(self.epochs)):
                for batch in tqdm(train_loader):
                    train_X, train_y = batch
                    pred_train_y = self.forward(train_X)

                    loss = self.loss_func(pred_train_y, train_y)

                    # self.zero_grad()
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()

                print(f'epoch={epoch+1}: loss={loss}')
                if loss < min_loss:
                    self.save()
                    min_loss = loss
        else:
            for epoch in range(self.epochs):
                for batch in train_loader:
                    train_X, train_y = batch
                    pred_train_y = self.forward(train_X)

                    loss = self.loss_func(pred_train_y, train_y)

                    # self.zero_grad()
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()

                print(f'epoch={epoch+1}: loss={loss}')
                if loss < min_loss:
                    self.save()
                    min_loss = loss
            
        self.save()
        
    def test(self, test_loader):
        self.load()
        
        total = len(test_loader.dataset)
        correct = 0
        
        for batch in test_loader:
            test_X, test_y = batch
            pred_test_y = self.forward(test_X)

            for each in range(len(test_y)):
                pred = torch.argmax(pred_test_y, 1)[each]
                true = torch.argmax(test_y, 1)[each]
                if pred == true:
                    correct += 1
                    
        print(f'accuracy results: {correct}/{total} => {correct/total}')
        return correct/total
    

## Training the Model

In [58]:
BATCH_SIZE = 128

In [59]:
train_dataloader = DataLoader(training_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(testing_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [60]:
NUM_CHANNELS = 1

In [61]:
vit = VisionTransformer(
    img_shape=(BATCH_SIZE, NUM_CHANNELS, IMG_SIZE, IMG_SIZE),
    patch_size = 8,
    num_heads = 4,
    embed_size = 128,
    encoder_depth=6,
    num_classes=num_labels,
)

In [62]:
summary(vit)

Layer (type:depth-idx)                   Param #
├─PatchEmbedding: 1-1                    --
|    └─Linear: 2-1                       8,320
├─ModuleList: 1-2                        --
|    └─EncoderBlock: 2-2                 --
|    |    └─LayerNorm: 3-1               256
|    |    └─Attention: 3-2               19,584
|    |    └─Dropout: 3-3                 --
|    |    └─LayerNorm: 3-4               256
|    |    └─FeedForwardBlock: 3-5        131,712
|    |    └─Dropout: 3-6                 --
|    └─EncoderBlock: 2-3                 --
|    |    └─LayerNorm: 3-7               256
|    |    └─Attention: 3-8               19,584
|    |    └─Dropout: 3-9                 --
|    |    └─LayerNorm: 3-10              256
|    |    └─FeedForwardBlock: 3-11       131,712
|    |    └─Dropout: 3-12                --
|    └─EncoderBlock: 2-4                 --
|    |    └─LayerNorm: 3-13              256
|    |    └─Attention: 3-14              19,584
|    |    └─Dropout: 3-15                

Layer (type:depth-idx)                   Param #
├─PatchEmbedding: 1-1                    --
|    └─Linear: 2-1                       8,320
├─ModuleList: 1-2                        --
|    └─EncoderBlock: 2-2                 --
|    |    └─LayerNorm: 3-1               256
|    |    └─Attention: 3-2               19,584
|    |    └─Dropout: 3-3                 --
|    |    └─LayerNorm: 3-4               256
|    |    └─FeedForwardBlock: 3-5        131,712
|    |    └─Dropout: 3-6                 --
|    └─EncoderBlock: 2-3                 --
|    |    └─LayerNorm: 3-7               256
|    |    └─Attention: 3-8               19,584
|    |    └─Dropout: 3-9                 --
|    |    └─LayerNorm: 3-10              256
|    |    └─FeedForwardBlock: 3-11       131,712
|    |    └─Dropout: 3-12                --
|    └─EncoderBlock: 2-4                 --
|    |    └─LayerNorm: 3-13              256
|    |    └─Attention: 3-14              19,584
|    |    └─Dropout: 3-15                

In [None]:
vit.fit(train_dataloader, epochs=20, learning_rate=0.0005, tqdm_show=True)

In [None]:
vit.test(train_dataloader)

In [None]:
vit.test(test_dataloader)