# Fine-Tuning TrOCR with TensorFlow (Handwriting Recognition)

This notebook walks through the complete pipeline for fine-tuning a Transformer-based OCR model (`TrOCR`) using TensorFlow on IAM/Imgur5K handwriting datasets.

In [None]:
# Install necessary packages
!pip install transformers datasets tensorflow opencv-python jiwer -q
!pip install -U sentencepiece

In [None]:
# Check TensorFlow & GPU
import tensorflow as tf
print("TensorFlow version:", tf.__version__)
print("GPU available:", tf.config.list_physical_devices('GPU'))

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Image preprocessing
IMG_SIZE = (384, 384)
def preprocess_image_tf(image_path):
    img = tf.io.read_file(image_path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, IMG_SIZE)
    img = tf.image.rgb_to_grayscale(img)
    img = img / 255.0
    return img

In [None]:
# Tokenize text labels using TrOCR processor
from transformers import TrOCRProcessor
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
def tokenize_label(text):
    return processor.tokenizer(text, padding="max_length", truncation=True, max_length=128, return_tensors="tf").input_ids

In [None]:
# Step 6: Load IAM and Imgur5K Datasets
# -------------------------------------

# IAM Handwriting Dataset from Hugging Face
from datasets import load_dataset

# This loads line-level annotations for IAM handwriting
iam_dataset = load_dataset("iam_dataset", split="train[:90%]")
iam_val_dataset = load_dataset("iam_dataset", split="train[90%:]")

# Download Imgur5K manually and unzip to a folder in Drive or local path
# Assuming /content/Imgur5K/ folder with 'images' and 'labels.csv'

import pandas as pd
import os

IMGUR_PATH = "/content/Imgur5K"
labels_df = pd.read_csv(os.path.join(IMGUR_PATH, "labels.csv"))  # columns: ['image_path', 'text']

imgur_image_paths = [os.path.join(IMGUR_PATH, p) for p in labels_df["image_path"]]
imgur_labels = labels_df["text"].tolist()

# Convert IAM to filepaths and labels (line-level)
iam_image_paths = [example['image']['path'] for example in iam_dataset]
iam_labels = [example['text'] for example in iam_dataset]

val_image_paths = [example['image']['path'] for example in iam_val_dataset]
val_labels = [example['text'] for example in iam_val_dataset]

# Combine datasets for training
train_image_paths = iam_image_paths + imgur_image_paths
train_texts = iam_labels + imgur_labels

In [None]:
# Load TrOCR model (TensorFlow version)
from transformers import TFVisionEncoderDecoderModel
model = TFVisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten", from_pt=True)

In [None]:
# Compile model
def compute_loss(y_true, y_pred):
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    return loss_fn(y_true, y_pred)

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=5e-5), loss=compute_loss)

In [None]:
# Train model (fill in your train_ds, val_ds first)
# train_ds = load_dataset(train_image_paths, train_texts).batch(4)
# val_ds = load_dataset(val_image_paths, val_texts).batch(4)
# model.fit(train_ds, validation_data=val_ds, epochs=10)

In [None]:
# Evaluation (CER and WER)
from jiwer import cer, wer
def evaluate_model(model, processor, test_images, test_labels):
    predictions, ground_truth = [], []
    for img_path, true_text in zip(test_images, test_labels):
        img = preprocess_image_tf(img_path)
        img = tf.expand_dims(img, axis=0)
        pixel_values = processor(images=img.numpy(), return_tensors="tf").pixel_values
        generated = model.generate(pixel_values)
        pred_text = processor.batch_decode(generated, skip_special_tokens=True)[0]
        predictions.append(pred_text)
        ground_truth.append(true_text)
    return cer(ground_truth, predictions), wer(ground_truth, predictions)

In [None]:
# Save model
model.save_pretrained("/content/trocr-tf-finetuned")
processor.save_pretrained("/content/trocr-tf-finetuned")