# Fine-tuning DTrOCR on IAM dataset
This is an example of fine-tuning DTrOCR on IAM dataset handwritten words from [Kaggle](https://www.kaggle.com/datasets/teykaicong/iamondb-handwriting-dataset). IAM Aachen splits can be downloaded [here](https://www.openslr.org/56/).

# Dataset folder structure
```
iam_words/
│
├── words/                              # Folder containing word images as PNGs
│   ├── a01/                            # First folder
│   │   ├── a01-000u/
│   │   │   ├── a01-000u-00-00.png
│   │   │   └── a01-000u-00-01.png
│   │   .
│   │   .
│   │   .
│   └── r06/                            # Last folder
│       ├── r06-000/
│       │   ├── r06-000-00-00.png
│       │   └── r06-000-00-01.png
│
├── xml/                                # XML files
│	├── a01-000u.xml
│	.
│	.
│	.
│	└── r06-143.xml
│
└── splits/                             # IAM Aachen splits
    ├── train.uttlist
    ├── validation.uttlist
    └── test.uttlist
```

# Build lists of images and texts

In [1]:
#!pip install triton==2.0.0 # Install Triton explicitly with a version compatible with PyTorch Inductor.

import torch
torch.set_float32_matmul_precision('high')

In [2]:
import glob
from pathlib import Path

dataset_path = Path('iam_words')

xml_files = sorted(glob.glob(str(dataset_path / 'xml' / '*.xml')))
word_image_files = sorted(glob.glob(str(dataset_path / 'words' / '**' / '*.png'), recursive=True))

print(f"{len(xml_files)} XML files and {len(word_image_files)} word image files")

1539 XML files and 115320 word image files


In [3]:
import re
from pathlib import Path
from PIL import Image

# Define the Word class
class Word:
    def __init__(self, word_id, file_path, writer_id, transcription):
        self.id:str = word_id
        self.file_path:Path = file_path
        self.writer_id:str = writer_id
        self.transcription:str = transcription

    def __repr__(self):
        return (f"Word(id='{self.id}', file_path=PosixPath('{self.file_path}'), "
                f"writer_id='{self.writer_id}', transcription='{self.transcription}')")
pattern = r"Word\(id='([^']+)',\s*file_path=PosixPath\('([^']+)'\),\s*writer_id='([^']+)',\s*transcription='([^']+)'\)"

def extract_word_obj(line:str):
  match = re.search(pattern, line)

# Check if the match was successful
  if match:
    word_id = match.group(1)
    file_path = Path(match.group(2))  # Convert the string to a Path object
    writer_id = match.group(3)
    transcription = match.group(4)

    # Create the Word object
    word_object = Word(word_id, file_path, writer_id, transcription)
    return word_object
words = []
# Example: Reading the file and loading each line as a Word object
file_path = 'words_local.txt'  # Replace with your file path

with open(file_path, 'r') as file:
    # Read the lines of the file and convert each to a Word instance
    for line in file.readlines():
      word = extract_word_obj(line)
      words.append(word)
# Print the loaded Word instan

# Train test split

In [4]:
with open('iam_words/splits/train.uttlist') as fp:
    train_ids = [line.replace('\n', '') for line in fp.readlines()]

with open('iam_words/splits/test.uttlist') as fp:
    test_ids = [line.replace('\n', '') for line in fp.readlines()]

with open('iam_words/splits/validation.uttlist') as fp:
    validation_ids = [line.replace('\n', '') for line in fp.readlines()]

print(f"Train size: {len(train_ids)}; Validation size: {len(validation_ids)}; Test size: {len(test_ids)}")

Train size: 747; Validation size: 116; Test size: 336


In [6]:
i = 0
for word in words:
  if words[i] == None:
    words.pop(i)
  i += 1
print(len(words))

113332


In [7]:
train_word_records = [word for word in words if word.id in train_ids]
validation_word_records = [word for word in words if word.id in validation_ids]
test_word_records = [word for word in words if word.id in test_ids]

print(f'Train size: {len(train_word_records)}; Validation size: {len(validation_word_records)}; Test size: {len(test_word_records)}')

Train size: 54315; Validation size: 8727; Test size: 25491


# Build dataset and dataloader

In [8]:
from dtrocr.processor import DTrOCRProcessor
from dtrocr.config import DTrOCRConfig
from torch.utils.data import Dataset

class IAMDataset(Dataset):
    def __init__(self, words: list[Word], config: DTrOCRConfig):
        super(IAMDataset, self).__init__()
        self.words = words
        self.processor = DTrOCRProcessor(config, add_eos_token=True, add_bos_token=True)

    def __len__(self):
        return len(self.words)

    def __getitem__(self, item):
        inputs = self.processor(
            images=Image.open(self.words[item].file_path).convert('RGB'),
            texts=self.words[item].transcription,
            padding='max_length',
            return_tensors="pt",
            return_labels=True,
        )
        return {
            'pixel_values': inputs.pixel_values[0],
            'input_ids': inputs.input_ids[0],
            'attention_mask': inputs.attention_mask[0],
            'labels': inputs.labels[0]
        }

config = DTrOCRConfig(
    #use_rnnt_loss=True
)

train_data = IAMDataset(words=train_word_records, config=config)
validation_data = IAMDataset(words=validation_word_records, config=config)
test_data = IAMDataset(words=test_word_records, config=config)

preprocessor_config.json:   0%|          | 0.00/806 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


processing_florence2.py:   0%|          | 0.00/48.7k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/microsoft/Florence-2-large:
- processing_florence2.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


tokenizer_config.json:   0%|          | 0.00/34.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.10M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/2.44k [00:00<?, ?B/s]

configuration_florence2.py:   0%|          | 0.00/15.1k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/microsoft/Florence-2-large:
- configuration_florence2.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


In [16]:
import tqdm
import multiprocessing as mp
import xml.etree.ElementTree as ET

from PIL import Image
from dataclasses import dataclass
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=mp.cpu_count())
validation_dataloader = DataLoader(validation_data, batch_size=32, shuffle=False, num_workers=mp.cpu_count())
test_dataloader = DataLoader(test_data, batch_size=32, shuffle=False, num_workers=mp.cpu_count())

# Model

In [18]:
import torch
torch.set_float32_matmul_precision('high')

from dtrocr.model import DTrOCRLMHeadModel

model = DTrOCRLMHeadModel(config)
model = torch.compile(model)
model.to(device=0)

OptimizedModule(
  (_orig_mod): DTrOCRLMHeadModel(
    (transformer): DTrOCRModel(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (token_embedding): Embedding(50257, 768)
      (positional_embedding): Embedding(256, 768)
      (hidden_layers): ModuleList(
        (0-23): 24 x GPT2Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2Attention(
            (c_attn): Conv1D(nf=2304, nx=768)
            (c_proj): Conv1D(nf=768, nx=768)
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D(nf=3072, nx=768)
            (c_proj): Conv1D(nf=768, nx=3072)
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False)
          )
      

# Training

In [19]:
import torch._dynamo
torch._dynamo.config.suppress_errors = True

In [20]:
import torch
print(torch.cuda.is_available())  # Should print True
print(torch.version.cuda)         # Should match your CUDA toolkit
print(torch.cuda.current_device()) # Should print 0 (or your GPU index)
print(torch.cuda.get_device_name(0)) # Should print "NVIDIA T1200 Laptop GPU"

True
12.4
0
NVIDIA T1200 Laptop GPU


In [None]:
from typing import Tuple

def evaluate_model(model: torch.nn.Module, dataloader: DataLoader) -> Tuple[float, float]:
    # set model to evaluation mode
    model.eval()

    losses, accuracies = [], []
    with torch.no_grad():
        for inputs in tqdm.tqdm(dataloader, total=len(dataloader), desc=f'Evaluating test set'):
            inputs = send_inputs_to_device(inputs, device=0)
            outputs = model(**inputs)

            losses.append(outputs.loss.item())
            accuracies.append(outputs.accuracy.item())

    loss = sum(losses) / len(losses)
    accuracy = sum(accuracies) / len(accuracies)

    # set model back to training mode
    model.train()

    return loss, accuracy

def send_inputs_to_device(dictionary, device):
    return {key: value.to(device=device) if isinstance(value, torch.Tensor) else value for key, value in dictionary.items()}

use_amp = True
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
optimiser = torch.optim.Adam(params=model.parameters(), lr=1e-4)

EPOCHS = 8
train_losses, train_accuracies = [], []
validation_losses, validation_accuracies = [], []
for epoch in range(EPOCHS):
    epoch_losses, epoch_accuracies = [], []
    for inputs in tqdm.tqdm(train_dataloader, total=len(train_dataloader), desc=f'Epoch {epoch + 1}'):

        # set gradients to zero
        optimiser.zero_grad()

        # send inputs to same device as model
        inputs = send_inputs_to_device(inputs, device=0)

        # forward pass
        with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=use_amp):
            outputs = model(**inputs)

        # calculate gradients
        scaler.scale(outputs.loss).backward()

        # update weights
        scaler.step(optimiser)
        scaler.update()

        epoch_losses.append(outputs.loss.item())
        epoch_accuracies.append(outputs.accuracy.item())

    # store loss and metrics
    train_losses.append(sum(epoch_losses) / len(epoch_losses))
    train_accuracies.append(sum(epoch_accuracies) / len(epoch_accuracies))

    # tests loss and accuracy
    validation_loss, validation_accuracy = evaluate_model(model, validation_dataloader)
    validation_losses.append(validation_loss)
    validation_accuracies.append(validation_accuracy)

    print(f"Epoch: {epoch + 1} - Train loss: {train_losses[-1]}, Train accuracy: {train_accuracies[-1]}, Validation loss: {validation_losses[-1]}, Validation accuracy: {validation_accuracies[-1]}")

  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
Epoch 1:   0%|          | 0/1698 [00:00<?, ?it/s]

# Test

In [None]:
torch.save(model, 'full_model.pth')

In [None]:
from DTrOCR.dtrocr.model import DTrOCRLMHeadModel
from DTrOCR.dtrocr.config import DTrOCRConfig
from DTrOCR.dtrocr.processor import DTrOCRProcessor

# model = DTrOCRLMHeadModel(DTrOCRConfig())
model.eval()
model.to('cpu')
test_processor = DTrOCRProcessor(DTrOCRConfig())

In [None]:
from PIL import Image

import numpy as np
import matplotlib.pyplot as plt

from PIL import Image

for test_word_record in test_word_records[:50]:
    image_file = test_word_record.file_path
    image = Image.open(image_file).convert('RGB')

    inputs = test_processor(
        images=image,
        texts=test_processor.tokeniser.bos_token,
        return_tensors='pt'
    )

    model_output = model.generate(
        inputs,
        test_processor,
        num_beams=3
    )

    predicted_text = test_processor.tokeniser.decode(model_output[0], skip_special_tokens=True)

    plt.figure(figsize=(10, 5))
    plt.title(predicted_text, fontsize=24)
    plt.imshow(np.array(image, dtype=np.uint8))
    plt.xticks([]), plt.yticks([])
    plt.show()