## Fine-tune TrOCR on the IAM Handwriting Database

download char_dataset from here https://drive.google.com/file/d/1SHqYfdrIdHYza1AwBPy3Zqhcds73qZL-/view?usp=share_link

In [None]:
%env CLEARML_WEB_HOST=https://app.clear.ml
%env CLEARML_API_HOST=https://api.clear.ml
%env CLEARML_FILES_HOST=https://files.clear.ml
# colab
%env CLEARML_API_ACCESS_KEY=
%env CLEARML_API_SECRET_KEY=

env: CLEARML_WEB_HOST=https://app.clear.ml
env: CLEARML_API_HOST=https://api.clear.ml
env: CLEARML_FILES_HOST=https://files.clear.ml
env: CLEARML_API_ACCESS_KEY=BV7YJ53M9E5QI0KZBHNZ
env: CLEARML_API_SECRET_KEY=H5p3jxWcykLlq6LVI2J759UwlaYatKko35w4IXZl0P2QCp9JHy


In [None]:
!pip install -q transformers

In [None]:
!pip install -q datasets jiwer clearml

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!cp /content/drive/MyDrive/char_dataset.zip .

In [None]:
!unzip /content/char_dataset.zip

In [None]:
from clearml import Task, Logger

In [None]:
from datetime import datetime

In [None]:
now_str = datetime.today().strftime('%Y%m%d%H%M')

In [None]:
task = Task.init(project_name='ocr-handwritten', task_name='trocr-exp-'+now_str)

ClearML Task: created new task id=f2fb3eb021f8408b9da26fd91ac65c20
ClearML results page: https://app.clear.ml/projects/d60f182ac5104605afd7f1c45ff2a927/experiments/f2fb3eb021f8408b9da26fd91ac65c20/output/log
2023-02-08 00:21:32,217 - clearml.Task - INFO - Storing jupyter notebook directly as code


## Prepare data

We first download the data. Here, I'm just using the IAM test set, as this was released by the TrOCR authors in the unilm repository. It can be downloaded from [this page](https://github.com/microsoft/unilm/tree/master/trocr). 

Let's make a [regular PyTorch dataset](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html). We first create a Pandas dataframe with 2 columns. Each row consists of the file name of an image, and the corresponding text.

In [None]:
import os
import glob
import pandas as pd

DATA_DIR = "archive"

d_english = pd.read_csv(os.path.join(DATA_DIR, 'english.csv'))

d_english.columns

d_english.image = d_english.image.apply(lambda x: os.path.join(DATA_DIR, x).replace('\\', '/'))

augmented_paths = glob.glob(os.path.join(DATA_DIR, 'augmented', '*', '*'))

d_augmented = pd.DataFrame(augmented_paths, columns=['image'])

d_augmented.image = d_augmented.image.apply(lambda x: x.replace('\\', '/'))

d_augmented['label'] = d_augmented.image.apply(lambda x: x.split('/')[-2])

d_dataset = pd.concat((d_english, d_augmented), axis = 0, ignore_index=True)

In [None]:
# d_dataset = d_dataset.groupby('label').head(3).copy()

In [None]:
from sklearn.model_selection import train_test_split

train_df, test_df = train_test_split(d_dataset, test_size=0.2, random_state=42)
# we reset the indices to start from zero
train_df.reset_index(drop=True, inplace=True)
test_df.reset_index(drop=True, inplace=True)

In [None]:
import torch
from torch.utils.data import Dataset
from PIL import Image

In [None]:
class CharDataset(Dataset):
    def __init__(self, df, processor, max_target_length=128):
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        
        # get file name + text 
        file_name = self.df.loc[idx, 'image']
        text = self.df.loc[idx, 'label']
        
        # prepare image (i.e. resize + normalize)
        image = Image.open(file_name).convert("RGB")
        w, h = image.size
        image = image.resize((w//10,h//10))
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        
        # add labels (input_ids) by encoding the text
        labels = self.processor.tokenizer(text, padding="max_length", max_length=self.max_target_length).input_ids
        
        # important: make sure that PAD tokens are ignored by the loss function
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]

        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding

Let's initialize the training and evaluation datasets:

In [None]:
from transformers import TrOCRProcessor

processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
train_dataset = CharDataset(df=train_df, processor=processor)
eval_dataset = CharDataset(df=test_df, processor=processor)

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


In [None]:
print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(eval_dataset))

Number of training examples: 6732
Number of validation examples: 1683


Let's verify an example from the training dataset:

In [None]:
encoding = train_dataset[0]
for k,v in encoding.items():
    print(k, v.shape)

pixel_values torch.Size([3, 384, 384])
labels torch.Size([128])


We can also check the original image and decode the labels:

In [None]:
labels = encoding['labels']
labels[labels == -100] = processor.tokenizer.pad_token_id
label_str = processor.decode(labels, skip_special_tokens=True)
print(label_str)

75


Let's create corresponding dataloaders:

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=3, shuffle=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=3)

## Train a model

Here, we initialize the TrOCR model from its pretrained weights. Note that the weights of the language modeling head are already initialized from pre-training, as the model was already trained to generate text during its pre-training stage. Refer to the paper for details.

In [None]:
from transformers import VisionEncoderDecoderModel
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-stage1")
model.to(device)

2023-02-08 00:26:41,756 - clearml.model - INFO - Selected model id: 5c7bd38ab2f0406fbb9a9e89ce98e66b


Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-base-stage1 and are newly initialized: ['encoder.pooler.dense.weight', 'encoder.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


VisionEncoderDecoderModel(
  (encoder): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0): ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=False)
              (key): Linear(in_features=768, out_features=768, bias=False)
              (value): Linear(in_features=768, out_features=768, bias=False)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=76

Importantly, we need to set a couple of attributes, namely:
* the attributes required for creating the `decoder_input_ids` from the `labels` (the model will automatically create the `decoder_input_ids` by shifting the `labels` one position to the right and prepending the `decoder_start_token_id`, as well as replacing ids which are -100 by the pad_token_id)
* the vocabulary size of the model (for the language modeling head on top of the decoder)
* beam-search related parameters which are used when generating text.

In [None]:
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size

# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

We will evaluate the model on the Character Error Rate (CER), which is available in HuggingFace Datasets (see [here](https://huggingface.co/metrics/cer)).

In [None]:
from datasets import load_metric

cer_metric = load_metric("cer")


load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate



In [None]:
def compute_cer(pred_ids, label_ids):
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)

    return cer

In [None]:
from transformers import AdamW
from tqdm.notebook import tqdm

optimizer = AdamW(model.parameters(), lr=5e-5)

for epoch in range(1):  # loop over the dataset multiple times
    # train
    model.train()
    train_loss = 0.0
    batch_idx = 1
    for batch in tqdm(train_dataloader):
        # get the inputs
        for k,v in batch.items():
            batch[k] = v.to(device)
            
        # forward + backward + optimize
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        train_loss += (loss.item() - train_loss) / batch_idx

        Logger.current_logger().report_scalar("train", "loss", iteration=(epoch+1) * batch_idx, value=train_loss)

        batch_idx += 1

    print(f"Loss after epoch {epoch}:", train_loss)
    
    
   # evaluate
    model.eval()
    valid_cer = 0.0
    batch_idx = 1
    with torch.no_grad():
        for batch in tqdm(eval_dataloader):
            # run batch generation
            outputs = model.generate(batch["pixel_values"].to(device))
        
            # compute metrics
            cer = compute_cer(pred_ids=outputs, label_ids=batch["labels"])
            valid_cer += (cer - valid_cer) / batch_idx

            Logger.current_logger().report_scalar("validation", "valid cer", iteration=(epoch+1) * batch_idx, value=valid_cer)

            batch_idx += 1
        
        print("Validation CER:", valid_cer)

model.save_pretrained("./trocr-finetuned-augmented")
task.upload_artifact(name='trocr-finetuned-augmented', artifact_object='/content/trocr-finetuned-augmented')





  0%|          | 0/2244 [00:00<?, ?it/s]

Loss after epoch 0: 0.9007217773433915


  0%|          | 0/561 [00:00<?, ?it/s]


Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to 64 (`generation_config.max_length`). Controlling `max_length` via the config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we recommend using `max_new_tokens` to control the maximum length of the generation.


Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to 64 (`generation_config.max_length`). Controlling `max_length` via the config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we recommend using `max_new_tokens` to control the maximum length of the generation.


Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to 64 (`generation_config.max_length`). Controlling `max_length` via the config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we recommend using `max_new_tokens` to control the maximum length of the ge

Validation CER: 0.11304218657159841
2023-02-08 01:03:05,383 - clearml.frameworks - INFO - Found existing registered model id=6b7dcf70598842f18854ab6cb807f562 [/content/trocr-finetuned-augmented/pytorch_model.bin] reusing it.


True

In [None]:
!cp -r trocr-finetuned-augmented/ /content/drive/MyDrive

## Inference

Note that after training, you can easily load the model using the .`from_pretrained(output_dir)` method.

For inference on new images, I refer to my inference notebook, that can also be found in my [Transformers Tutorials repository](https://github.com/NielsRogge/Transformers-Tutorials) on Github.