# Transformer-based Optical Character Recognition with Pre-Trained Models  
(TrOCR: https://arxiv.org/abs/2109.10282)

note: no CNN backbone 

-----------

## Architecture  
Build up with a double Transformer architecture:  
- Image transformer as encoder: Extracts the visual features
- Text transformer as decoder: Language modeling
- Encoder: represent the image patches
- Decoder: Generates word-sections sequence using visual features and previous predictions

### Encoder:  
Input image: R^{3 x H x W}   
-> Resize to fixed size: R^{H x W}.  
-> The transformer encoder needs sequence of input tokens: Resize image to batch of N = (H * W) / (P^2) square patches with a fixed size of (P x P), where H and W are guaranteed to be devisable by P.  
-> The patches are flattened to vectors and linearly projected to D-dimensional vectors, where D is the hidden size of the transformer through all layers.  
-> Special token "[CLS]": brings together all information from the patch embeddings and represents the whole image.  
-> Additionally, distillation token: when using DeiT pre-trained models to initialize the teacher model.  
  

Patches used instead of features so tranformer can pay attention to parts of the image without feature biases.

### Decoder: 
Standard transformer decoder.

--------------


## Start pipeline

In [9]:
import os

import pandas as pd
import numpy as np

import torch
from torch.utils.data import Dataset
from torchvision.io import read_image

from PIL import Image

from sklearn.model_selection import train_test_split

### Constants

In [10]:
DATA_PATH = f"/home/hkolstee/uniprojects/DATA/HWR/IAM-data/IAM-data/"
TRAIN_TEST_SPLIT = 0.2
BATCH_SIZE = 64

### Prepare Data

In [11]:
raw_data = pd.read_fwf(DATA_PATH + "iam_lines_gt.txt", header = None)
raw_data = raw_data.values.tolist()

data = {'img_names': np.squeeze(raw_data[::2]),
        'labels': np.squeeze(raw_data[1::2])}

data = pd.DataFrame(data)
data

Unnamed: 0,img_names,labels
0,a03-017-07.png,into the pro-communist north and the
1,a03-017-05.png,"to 1958 kept the kingdom in peace, though"
2,a03-017-08.png,pro-western centre and south.
3,a03-017-02.png,in Phnom Penh indicate that he still regards
4,a03-017-06.png,at the cost of virtual partition of the country
...,...,...
7453,d06-000-08.png,fears are based upon completely
7454,d06-000-05.png,"is worrying them, to find the original"
7455,d06-000-09.png,irrational pre-conceived notions - or to
7456,d06-000-02.png,"already suggested, not to be silly or"


### TODO: Data augmentation

### Train test split

In [12]:
train, test = train_test_split(data, test_size = TRAIN_TEST_SPLIT)

# reset indices from current random state
train.reset_index(inplace = True)
test.reset_index(inplace = True)

### Create custom pytorch dataset

In [13]:
class HandWritingDataset(Dataset):
    def __init__(self, data: pd.DataFrame, batch_size):
        self.data = data
        self.batch_size = batch_size

    # function to get patches of tokens from original input image
    def __getTokens__(self, image):
        pass

    def __getitem__(self, index):
        # input image
        image_path = os.path.join(DATA_PATH, "img", self.data['img_names'][index])
        # torchvision read_image call
        image = read_image(image_path)
        
        # string label
        label = self.data['labels'][index]

        return image, label

    def __len__(self):
        # return length of column
        return len(self.data)

In [14]:
train_set = HandWritingDataset(train, BATCH_SIZE)
test_set = HandWritingDataset(test, BATCH_SIZE)