# Lightweight Transformer-based Optical Character Recognition 
(https://hal.science/hal-03685976/file/A_Light_Transformer_Based_Architecture_for_Handwritten_Text_Recognition.pdf)

## ***note: Has a CNN backbone***

-----------

## Architecture  
Build up with a double Transformer architecture:  
- Image transformer as encoder: Extracts the visual features
- Text transformer as decoder: Language modeling
- Encoder: 
- Decoder: Generates word-sections sequence using visual features and previous predictions

### Encoder:  

### Decoder: 

--------------


## Start pipeline

In [1]:
import os
from collections import OrderedDict

import pandas as pd
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision.io import read_image

from PIL import Image

from sklearn.model_selection import train_test_split

### Constants

In [2]:
DATA_PATH = f"/home/hkolstee/uniprojects/DATA/HWR/IAM-data/IAM-data/"
TRAIN_TEST_SPLIT = 0.2
BATCH_SIZE = 64

### Prepare Data

In [3]:
raw_data = pd.read_fwf(DATA_PATH + "iam_lines_gt.txt", header = None)
raw_data = raw_data.values.tolist()

data = {'img_names': np.squeeze(raw_data[::2]),
        'labels': np.squeeze(raw_data[1::2])}

data = pd.DataFrame(data)
data

Unnamed: 0,img_names,labels
0,a03-017-07.png,into the pro-communist north and the
1,a03-017-05.png,"to 1958 kept the kingdom in peace, though"
2,a03-017-08.png,pro-western centre and south.
3,a03-017-02.png,in Phnom Penh indicate that he still regards
4,a03-017-06.png,at the cost of virtual partition of the country
...,...,...
7453,d06-000-08.png,fears are based upon completely
7454,d06-000-05.png,"is worrying them, to find the original"
7455,d06-000-09.png,irrational pre-conceived notions - or to
7456,d06-000-02.png,"already suggested, not to be silly or"


### TODO: Data augmentation

### Train test split

In [4]:
train, test = train_test_split(data, test_size = TRAIN_TEST_SPLIT)

# reset indices from current random state
train.reset_index(inplace = True)
test.reset_index(inplace = True)

### Create custom pytorch dataset

In [5]:
class HandWritingDataset(Dataset):
    def __init__(self, data: pd.DataFrame, batch_size):
        self.data = data
        self.batch_size = batch_size

    # function to get patches of tokens from original input image
    # def __getTokens__(self, image):
        # pass

    def __getitem__(self, index):
        # input image
        image_path = os.path.join(DATA_PATH, "img", self.data['img_names'][index])
        # torchvision read_image call
        image = read_image(image_path)
        
        # string label
        label = self.data['labels'][index]

        return image, label

    def __len__(self):
        # return length of column
        return len(self.data)

In [6]:
train_set = HandWritingDataset(train, BATCH_SIZE)
test_set = HandWritingDataset(test, BATCH_SIZE)

### Model

In [8]:
class OCRTransformer(nn.module):
    def __init__(self, input_width, input_height):
        super(OCRTransformer, self).__init()

        # convolutional block (5 convolutions)
        # first convolution
        self.conv1 = nn.Conv2d(in_channels = 1, out_channels = 8, kernel_size = (3,3))
        width = input_height - 2
        height = input_height - 2
        self.leakyRelu = nn.LeakyReLU()     # reuse in later layers
        self.layerNorm1 = nn.LayerNorm(normalized_shape = [8, width, height])
        self.maxPool = nn.MaxPool2d((2,2))  # reuse in later layers
        width = int(np.floor(width/2))
        height = int(np.floor(height/2))
        self.dropout = nn.Dropout(0.2)      # reuse in later layers

        # second convolutional layer
        self.conv2 = nn.Conv2d(in_channels = 8, out_channels = 16, kernel_size = (3, 3))
        width -= 2
        height -= 2
        self.layerNorm2 = nn.LayerNorm(normalized_shape = [16, width, height])
        # after maxpool
        width = int(np.floor(width/2))
        height = int(np.floor(height/2))

        # third convolutional layer
        self.conv3 = nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size = (3, 3))
        width -= 2
        height -= 2
        self.layerNorm3 = nn.LayerNorm(normalized_shape = [32, width, height])
        # after maxpool
        width = int(np.floor(width/2))
        height = int(np.floor(height/2))

        # forth convolutional layer
        self.conv4 = nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = (3, 3))
        width -= 2
        height -= 2
        self.layerNorm4 = nn.LayerNorm(normalized_shape = [64, width, height])
        # no maxpool

        # fifth convolutional layer (kernel size to better match shape of character)
        self.conv5 = nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = (4, 2))
        width -= 1
        height -= 3
        self.layerNorm5 = nn.LayerNorm(normalized_shape = [128, width, height])
        # no maxpool

        # following is convolution with width 1 which is used to flatten the current output
        self.flattenConv = nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = (height, 1))
        self.layerNorm5 = nn.LayerNorm(normalized_shape = [128, width, 1])

        # dense layer to upscale from 128 to 256
        self.dense1 = nn.Linear(in_features = 128, out_features = 256)
        # sinusoidal positional encoding is added to the output of the dense layer

        # Here starts: transformer layers

    def forward(self, input):
        # through 5 convolutional layers
        # first conv
        out = self.layerNorm1(self.leakyRelu(self.conv1(input)))
        out = self.dropout(out)
        # second conv
        out = self.layerNorm2(self.leakyRelu(self.conv2(out)))
        out = self.dropout(out)
        # third conv
        out = self.layerNorm3(self.leakyRelu(self.conv3(out)))
        out = self.dropout(out)
        # forth conv
        out = self.layerNorm4(self.leakyRelu(self.conv4(out)))
        # fifth conv
        out = self.layerNorm5(self.leakyRelu(self.conv5(out)))

        # flatten layer
        out = self.layerNorm5(self.leakyRelu(self.flattenConv(out)))
        # dense layer
        out = self.dense1(out)






SyntaxError: unexpected EOF while parsing (3833818504.py, line 57)