## Fine-tune TrOCR on a Custom Dataset

This notebook demonstrates how to fine-tune the HuggingFace version of TrOCR using a custom dataset of handwritten text images and transcriptions. It includes optional support for freezing encoder or decoder layers during training.

## Set-up environment
First, let's install all required libraries:

In [None]:
!pip install -q transformers
!pip install -q sentencepiece
!pip install -q jiwer
!pip install -q datasets
!pip install -q evaluate
!pip install -q -U accelerate

!pip install -q matplotlib
!pip install -q protobuf==3.20.1
!pip install -q tensorboard

##  Upload and extract your dataset (image folder) in Colab

In [None]:
!unzip "/content/cropped.zip"

## Imports

In [None]:
import os
import torch
import evaluate
import numpy as np
import pandas as pd
import glob as glob
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision.transforms as transforms

from PIL import Image
from zipfile import ZipFile
from tqdm.notebook import tqdm
from dataclasses import dataclass
from torch.utils.data import Dataset
from urllib.request import urlretrieve
from transformers import (
    VisionEncoderDecoderModel,
    TrOCRProcessor,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    default_data_collator
)

block_plot = False
plt.rcParams['figure.figsize'] = (12, 9)

In [None]:
bold = f"\033[1m"
reset = f"\033[0m"

In [None]:
def seed_everything(seed_value):
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
df=pd.read_csv(r"/content/final_data.csv")

In [None]:
df=df[(df["total"]>=30) & (df["total"]<=50)]

### Shuffle and splitting the data in train,test and validation set

In [None]:
shuffled_df = df.sample(frac=1).reset_index(drop=True)
shuffled_df=shuffled_df.iloc[:,:5]
shuffled_df.sample(4)

In [None]:
shuffled_df.shape

In [None]:
train_df=shuffled_df[:part1]
test_df=shuffled_df[part1:part2]
valid_df=shuffled_df[part2:]
train_df.shape,test_df.shape,valid_df.shape

## Training and Dataset Configurations

In [None]:
@dataclass(frozen=True)
class TrainingConfig:
    BATCH_SIZE:    int = 10
    EPOCHS:        int = 10
    LEARNING_RATE: float = 0.00005

@dataclass(frozen=True)
class DatasetConfig:
    DATA_ROOT:     str = 'conten/croppe/cropped//'

@dataclass(frozen=True)
class ModelConfig:
    MODEL_NAME: str = 'microsoft/trocr-small-handwritten'

In [None]:
train_df.head()

In [None]:
test_df.head()

In [None]:
# Augmentations.
train_transforms = transforms.Compose([
    transforms.ColorJitter(brightness=.5, hue=.3),
    transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
])

In [None]:
class CustomOCRDataset(Dataset):
    def __init__(self, root_dir, df, processor, max_target_length=128):
        self.root_dir = root_dir
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # The image file name.
        file_name = self.df['imscanno'].iloc[idx]
        text = self.df['words'].iloc[idx]
       # Read the image, apply augmentations, and get the transformed pixels.
        image = Image.open(self.root_dir+ str(file_name) + "_cropped_0.jpg").convert('RGB')
        image = train_transforms(image)
        pixel_values = self.processor(image, return_tensors='pt').pixel_values
        # Pass the text through the tokenizer and get the labels,
        labels = self.processor.tokenizer(
            text,
            padding='max_length',
            max_length=self.max_target_length
        ).input_ids
        # We are using -100 as the padding token.
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        # print(encoding)
        return encoding

In [None]:
processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME)
train_dataset = CustomOCRDataset(
    root_dir=os.path.join(DatasetConfig.DATA_ROOT),
    df=train_df,
    processor=processor
)
valid_dataset = CustomOCRDataset(
    root_dir=os.path.join(DatasetConfig.DATA_ROOT),
    df=test_df,
    processor=processor
)

In [None]:
print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(valid_dataset))

In [None]:
encoding = train_dataset[0]
for k,v in encoding.items():
    print(k, v.shape)

 ## Dataset Preview

In [None]:
image = Image.open(train_dataset.root_dir + str(train_df['imscanno'][0]) + "_cropped_0.jpg").convert("RGB")
image = train_transforms(image)
plt.imshow(image)
plt.axis('off')

In [None]:
labels = encoding['labels']
labels[labels == -100] = processor.tokenizer.pad_token_id
label_str = processor.decode(labels, skip_special_tokens=True)
print(label_str)

## TrOCR Architecture Summary
1. Encoder: A Vision Transformer (ViT) processes the image and converts it into visual embeddings.

2. Decoder: A language model (like GPT-2 or BART-style Transformer) generates text token by token from the encoder output.

## Initialize the Model

## Strategy 1: Train All Parameters (Encoder + Decoder)
#### When to Use:
1. Your dataset is very different from the original (e.g., new language, handwriting, new domain).

2. You want maximum model flexibility.

3. You have enough compute resources.

In [None]:
model = VisionEncoderDecoderModel.from_pretrained(ModelConfig.MODEL_NAME)
model.to(device)
print(model)
# Total parameters and trainable parameters.
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")

## Strategy 2: Freeze Encoder, Train Only Decoder
#### When to Use:
1. You’re dealing with similar image types (e.g., printed English text) but different output formats.Example: Converting scanned English forms to structured JSON.

2. You want the model to learn new language tasks.

In [None]:
# # Load the pretrained model
# model = VisionEncoderDecoderModel.from_pretrained(ModelConfig.MODEL_NAME)
# # model.to(device)

# # Freeze the encoder parameters to train only the decoder
# for param in model.encoder.parameters():
#     param.requires_grad = False

# # Print the model to verify which parameters are frozen
# print(model)

# # Calculate and display total parameters and trainable parameters
# total_params = sum(p.numel() for p in model.parameters())
# print(f"{total_params:,} total parameters.")

# total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# print(f"{total_trainable_params:,} training parameters (decoder only).")

## Strategy 3: Freeze Decoder, Train Only Encoder
#### When to Use:
1. You’re using new types of images (e.g., handwritten text, new fonts), but the language remains similar.Example: Printed-to-handwritten transfer within the same language.

2. You want to preserve language modeling quality.

In [None]:
# # Load the pretrained model
# model = VisionEncoderDecoderModel.from_pretrained(ModelConfig.MODEL_NAME)
# # model.to(device)  # Uncomment and move the model to GPU or CPU as needed

# # Freeze decoder parameters (i.e., do NOT train decoder)
# for param in model.decoder.parameters():
#     param.requires_grad = False

# # (Optional) Freeze the language model head (if you don't want to fine-tune it either)
# if hasattr(model, 'lm_head'):
#     for param in model.lm_head.parameters():
#         param.requires_grad = False

# # (Optional) Freeze cross-attention layers in decoder if necessary
# if hasattr(model.decoder, "encoder_attn"):
#     for param in model.decoder.encoder_attn.parameters():
#         param.requires_grad = False

# # Print model structure (to verify frozen parts)
# print(model)

# # Count total parameters
# total_params = sum(p.numel() for p in model.parameters())
# print(f"{total_params:,} total parameters.")

# # Count only trainable parameters (encoder part)
# total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# print(f"{total_trainable_params:,} training parameters (encoder only).")

## Strategy 4: Train Only Last N Layers of Both
When to Use:
1. You want to fine-tune efficiently with fewer trainable parameters.

2. You’re adapting the model slightly (e.g., slight font/layout change or domain shift).

3. You want to reduce overfitting or train on limited hardware.

In [None]:
# # Load pretrained TrOCR model
# model = VisionEncoderDecoderModel.from_pretrained(ModelConfig.MODEL_NAME)

# # Freeze all parameters
# for param in model.parameters():
#     param.requires_grad = False

# # Unfreeze last N layers of encoder
# N = 4  # Number of last layers to train

# # Check and unfreeze encoder layers
# if hasattr(model.encoder, "encoder") and hasattr(model.encoder.encoder, "layer"):
#     encoder_layers = model.encoder.encoder.layer
#     for layer in encoder_layers[-N:]:
#         for param in layer.parameters():
#             param.requires_grad = True
# else:
#     print("Warning: Could not find encoder layers. Check encoder structure.")

# # Unfreeze last N layers of decoder
# if hasattr(model.decoder, "model") and hasattr(model.decoder.model, "decoder"):
#     decoder_layers = model.decoder.model.decoder.layers
#     for layer in decoder_layers[-N:]:
#         for param in layer.parameters():
#             param.requires_grad = True
# else:
#     print("Warning: Could not find decoder layers. Check decoder structure.")

# # Print parameter summary
# total_params = sum(p.numel() for p in model.parameters())
# trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# print(f"Total Parameters: {total_params:,}")
# print(f"Trainable Parameters (last {N} layers only): {trainable_params:,}")


## Model Configurations

In [None]:
# Set special tokens used for creating the decoder_input_ids from the labels.
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# Set Correct vocab size.
model.config.vocab_size = model.config.decoder.vocab_size
model.config.eos_token_id = processor.tokenizer.sep_token_id

# model.config.max_length = 20
# model.config.early_stopping = True
# model.config.no_repeat_ngram_size = 3
# model.config.length_penalty = 2.0
# model.config.num_beams = 4

We use the AdamW optimizer here with a weight decay of 0.0005.

In [None]:
optimizer = optim.AdamW(
    model.parameters(), lr=TrainingConfig.LEARNING_RATE, weight_decay=0.0005
)

## Evaluation Metric

In [None]:
cer_metric = evaluate.load('cer')

In [None]:
def compute_cer(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)

    return {"cer": cer}

## Training and Validation Loops

In [None]:
training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy='epoch',
    per_device_train_batch_size=TrainingConfig.BATCH_SIZE,
    per_device_eval_batch_size=TrainingConfig.BATCH_SIZE,
    fp16=True,
    output_dir='seq2seq_model_printed/',
    logging_strategy='epoch',
    save_strategy='epoch',
    save_total_limit=5,
    report_to='tensorboard',
    num_train_epochs=TrainingConfig.EPOCHS
)

In [None]:
# Initialize trainer.
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=processor.feature_extractor,
    args=training_args,
    compute_metrics=compute_cer,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    data_collator=default_data_collator
)

## Train

In [None]:
res = trainer.train()

In [None]:
res

In [None]:
res.global_step

## Inference

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available else 'cpu')

In [None]:
processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME)
# trained_model = VisionEncoderDecoderModel.from_pretrained('seq2seq_model_printed/checkpoint-'+str(res.global_step)).to(device)
trained_model = VisionEncoderDecoderModel.from_pretrained('seq2seq_model_printed/checkpoint-'+str(res.global_step))

In [None]:
trained_model

## Saving Model checkpoint

In [None]:
# prompt: give code to zip folder

import zipfile

def zip_folder(folder_path, zip_path):
  """Zips a folder.

  Args:
      folder_path: Path to the folder to zip.
      zip_path: Path to the output zip file.
  """
  with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
    for root, _, files in os.walk(folder_path):
      for file in files:
        zipf.write(os.path.join(root, file),
                   os.path.relpath(os.path.join(root, file),
                                   os.path.join(folder_path, '..')))

# Example usage:
zip_folder('/content/seq2seq_model_printed/checkpoint-15625', '/content/checkpoint-15625.zip')
