## Basic Imports

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import shutil
import glob
import random
import json
%matplotlib inline

from PIL import Image
import io

import cv2

import torch
import torch.nn.functional as F
from torch import optim, nn, utils, Tensor
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.utils.data.dataset import random_split
from torchvision.datasets import ImageFolder

import torchvision
from torchvision.datasets import MNIST, CIFAR10

from torchvision import transforms
from torchvision.transforms import Compose, ToTensor, Resize, Normalize
from torchvision.models import resnet18, ResNet18_Weights

import torchmetrics

import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import CSVLogger

from datasets import load_dataset

from typing import Callable, Optional, List, Dict

from transformers import TrOCRProcessor, VisionEncoderDecoderModel

import evaluate

# Configure device (GPU if available, otherwise CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")
else:
    print("No GPU available, using CPU")

## Load CORD-v2 Dataset

In [None]:
# Load CORD-v2 dataset
print("Loading CORD-v2 dataset from Hugging Face:")
ds = load_dataset("naver-clova-ix/cord-v2")

print("Dataset loaded successfully!")
print("Dataset structure:")
print(ds)

# Show dataset splits
print("Available splits:")
for split in ds.keys():
    print(f"    - {split}: {len(ds[split])} samples")

## Custom Transform Classes

Define preprocessing transforms used in the training pipeline.

In [None]:
class CLAHETransform:
    """Apply CLAHE to improve local contrast in images."""
    
    def __init__(self, clip_limit=2.0, tile_grid_size=(8, 8)):
        self.clip_limit = clip_limit
        self.tile_grid_size = tile_grid_size
    
    def __call__(self, img):
        # Convert PIL Image to numpy array
        img_np = np.array(img)
        
        # Convert to LAB color space
        lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
        
        # Apply CLAHE to L channel
        clahe = cv2.createCLAHE(clipLimit=self.clip_limit, tileGridSize=self.tile_grid_size)
        lab[:, :, 0] = clahe.apply(lab[:, :, 0])
        
        # Convert back to RGB
        img_clahe = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
        
        # Return as PIL Image
        return Image.fromarray(img_clahe)


class SharpenTransform:
    """Sharpen image to enhance text clarity."""
    
    def __init__(self, kernel_size=(5, 5), sigma=1.0, amount=1.5):
        self.kernel_size = kernel_size
        self.sigma = sigma
        self.amount = amount
    
    def __call__(self, img):
        # Convert PIL Image to numpy array
        img_np = np.array(img)
        
        # Create blurred version
        blurred = cv2.GaussianBlur(img_np, self.kernel_size, self.sigma)
        
        # Unsharp mask: original + amount * (original - blurred)
        sharpened = cv2.addWeighted(img_np, 1.0 + self.amount, blurred, -self.amount, 0)
        
        # Clip values to valid range
        sharpened = np.clip(sharpened, 0, 255).astype(np.uint8)
        
        # Return as PIL Image
        return Image.fromarray(sharpened)


# Initialize transform instances
clahe_transform = CLAHETransform(clip_limit=2.0, tile_grid_size=(8, 8))
sharpen_transform = SharpenTransform(amount=1.0)

print("Custom transforms defined successfully")

# TrOCR: Transformer-based OCR for Receipt Text Extraction
---

TrOCR (Transformer-based Optical Character Recognition) is a model developed by Microsoft that combines computer vision and natural language processing to extract text from images. It differs from traditional OCR models since those rely on separate components, like detection, recognition and post-processing, but TrOCR uses a unified Transformer architecture to directly convert image pixels into text sequences. Reading text from receipts can be very challenging due to various factors such as variable quality, like photos taken in different lightning condiitons, angles, etc; complex layouts with multiple columns, tables and different font sizes; a lot of noise like shadows, creases, faded text and backgournd patterns; and business specific vocabulary like the sotre names, product codes, symbols, etc. Using TrOCR is very helpful for this because:
- Pre-trained on large-scale data, which is why it learns different visual representations from millions of document images.
- Focuses on relevant text regions, while ignoring noise.
- Contains a sequence modeling that understands context and can correct errors based on linguistic patterns.
- Transfer learning and leverages knowledge from ImageNet and text corpora.

TrOCR follows an encoder-decoder architecture where:

1. **Encoder: Visition Transform (ViT)**:
    - **Input**: Receipt image in RGB and resized to fixed dimensions
    - **Process**: Image is split into patches, each of with is treated as a token
    - **Output**: Sequence of visual embeddings representing image content
    - **Pre-trained on**: ImageNet-21K
    - **What it does**: Extracts visual features and understands spatial relationships in the image
2. **Decoder: RoBERTa (Transformer Language Model)**:
    - **Input**: Visual embeddings from encoder
    - **Process**: Generates text tokens autoregressively (one character/word at a time)
    - **Output**: Text sequence representing all text in the receipt
    - **Pre-trained on**: Large text corpora, like books, articles and web text.
    - **What it does**: Converts visual features into coherent text using language understanding.
    - It uses RoBERTa since it has strong language modeling capabilities and can correct OCR errors by leveraging linguistic context. An example of this can be predicting missing letters based on word structure.
3. **Cross-Attention Mechanism**: the decoder attends to encoder outputs at each decoding step, allowing it to
    - Focus on specific image regions when generating each character/word
    - Align visual features with text tokens
    - Handle variable-length inputs and outputs

In the following steps, we are going to implement three fine-tuning strategies to adapt the pre-trained TrOCR model to receipts:
1. **Strategy 1: Frozen Encoder (Feature Extraction)**:
    - **Freeze**: All encoder layers (ViT)
    - **Train**: Only decoder layers (RoBERTa)
    - **Reason**: The encoder already knows how to extract visual features from documents, since it is pre-trained on printed text. We only need to adapt the decoder to receipt-specific vocabulary and layout.
    - **Advantages**: Fast training, low memory usage, less chance of overfitting
    - **Disadvantages**: Limited adaptation to receipt-specific visual patterns, such as shadows and creases.
    - **Best for**: Small datasets, limited compute resources (Our dataset has 1000 receipts).
2. **Strategy 2: Strategy 2: Partial Unfreezing (Progressive Fine-Tuning)**:
    - **Freeze**: First N-3 encoder layers
    - **Train**: Last 3 encoder layers + all decoder layers
    - **Reason**: Lower layers learn generic features, like edges and textures, while higher layers learn task-specific patterns. By unfreezing the last layers, we allow the model to adapt to receipt-specific visual characteristics.
    - **Advantages**: Better domain adaptation than frozen encoder
    - **Disadvantages**: Requires more memory and compute than Strategy 1
    - **Best for**: Medium-sized datasets with some computational budget
3. **Strategy 3: Strategy 3: Full Fine-Tuning (End-to-End Training)**:
    - **Freeze**: Nothing
    - **Train**: All encoder + decoder layers
    - **Reason**: Maximum adaptation to receipt domain. The model can learn receipt-specific visual features and text patterns simultaneously.
    - **Advantages**: Best performance potential, due to full customization
    - **Disadvantages**: Requires large dataset, high memory/compute, but has risk of overfitting
    - **Best for**: Large datasets, sufficient compute, when domain shift is significant

From the TrOCR available models, we are going to use `microsoft/trocr-base-printed` because:
- Pre-trained specifically on printed text, like receipts.
- Base size of 334M parameters, which balances performance and efficiency.
- Strong performance on document OCR benchmarks, like CORD.

An alternative is to use `microsoft/trocr-large-printed`, which is a larger model with aproximately 558M parameters, and it can produce a better accuracy, but at the cost of compute.

But fine-tuning TrOCR o CORD-v2, we expect:
- High accuracy on clean receipts
- Robustness to common receipt variations, such as lightning, rotation, blur, etc.
- Fast inference
- Compare which fine-tuning strategy works best for our dataset size.

## TrOCR Dataset Preparation

As we mentioned before, TrOCR is trained to perform to extract text from images.
- **Input**: Receipt image (PIL Image)
- **Output**: Plain text sequence (all text in the receipt, reading order)

**CORD-v2 Dataset Structure**:

The CORD-v2 dataset provides:
- **Images**: Receipt images in PNG format
- **Ground truth**: JSON annotations with:
  - `valid_line`: List of text lines in the receipt
  - `words`: Individual words with bounding boxes and text content
  - `category`: Semantic labels (store name, date, total, etc.)

**Data Extraction Process**:

We need to convert the structured JSON annotations into plain text sequences. To achieve this, we will follow these steps:
1. **Parse JSON**: Load the `ground_truth` string and parse it as JSON
2. **Extract words**: Iterate through `valid_line`, then `words`, and then `text`
3. **Concatenate**: Join all words with spaces to form a single text string
4. **Example**:
   - JSON: `{"valid_line": [{"words": [{"text": "STORE"}, {"text": "NAME"}]}]}`
   - Output text: `"STORE NAME"`

**Preprocessing Pipeline**:

For each image, we will:
1. Apply our custom transform, which are CLAHE and Sharpening to improve text clarity.
2. Use TrOCR processor to handle resizing, normalization, and conversion to tensors internally.
3. Apply tokenization to convert text into token IDs using RoBERTa tokenizer.
4. Put in padding or truncation to sequences to a `max_length` of 512 tokens.

Some special cases include:
- **Padding tokens**: Replaced with `-100` in labels so they're ignored during loss computation
- **Max length**: Set to 512 tokens to accommodate long receipts while fitting in GPU memory
- **Error handling**: If JSON parsing fails, return empty string to prevents training crashes.

The following class is created to combine CORD's structured annotations and TrOCR's expected input and output formats.

In [None]:
class TrOCRDataset(Dataset):
    """
    Dataset for TrOCR fine-tuning on CORD-v2.
    Extracts full text from receipt images.
    """
    def __init__(
        self,
        hf_dataset,
        processor,
        image_transform: Optional[Callable] = None,
        max_length: int = 512,
    ):
        self.hf_dataset = hf_dataset
        self.processor = processor
        self.image_transform = image_transform
        self.max_length = max_length
        
        print(f"TrOCRDataset initialized with {len(self.hf_dataset)} samples")
    
    def __len__(self) -> int:
        return len(self.hf_dataset)
    
    def extract_text_from_ground_truth(self, ground_truth_str: str) -> str:
        """Extract all text from CORD ground truth JSON."""
        try:
            gt_dict = json.loads(ground_truth_str)
            
            # Extract all text from 'valid_line' entries
            text_lines = []
            if 'valid_line' in gt_dict:
                for line in gt_dict['valid_line']:
                    if 'words' in line:
                        for word in line['words']:
                            if 'text' in word:
                                text_lines.append(word['text'])
            
            # Join all text with spaces
            full_text = ' '.join(text_lines)
            return full_text.strip()
        except:
            return ""
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        sample = self.hf_dataset[idx]
        
        # Get image
        image = sample['image']
        
        # Apply custom preprocessing (CLAHE + Sharpening)
        if self.image_transform:
            image = self.image_transform(image)
        
        # Extract text from ground truth
        text = self.extract_text_from_ground_truth(sample['ground_truth'])
        
        # Process with TrOCR processor
        pixel_values = self.processor(image, return_tensors="pt").pixel_values.squeeze()
        
        # Tokenize text
        labels = self.processor.tokenizer(
            text,
            padding="max_length",
            max_length=self.max_length,
            truncation=True,
            return_tensors="pt"
        ).input_ids.squeeze()
        
        # Replace padding token id with -100 (ignored in loss)
        labels[labels == self.processor.tokenizer.pad_token_id] = -100
        
        return {
            "pixel_values": pixel_values,
            "labels": labels,
            "text": text  # For evaluation
        }

# Test TrOCR dataset
print("Loading TrOCR processor:")
trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")

# TrOCR preprocessing: only CLAHE and Sharpening (no Normalize, no ToTensor)
# The TrOCR processor handles conversion to tensor and normalization internally
trocr_preprocess = transforms.Compose([
    clahe_transform,
    sharpen_transform,
])

# Create test dataset
trocr_test_dataset = TrOCRDataset(
    hf_dataset=ds['train'],
    processor=trocr_processor,
    image_transform=trocr_preprocess,
    max_length=512
)

# Test sample
test_sample = trocr_test_dataset[0]
print(f"Sample output:")
print(f"    - Pixel values shape: {test_sample['pixel_values'].shape}")
print(f"    - Labels shape: {test_sample['labels'].shape}")
print(f"    - Text preview: {test_sample['text'][:100]}...")
print(f"TrOCR Dataset ready!")

## TrOCR Model Implementation with PyTorch Lightning

We will use PyTorch Lightning, since it's a high-level framework that simplifies training models:
- **Organized code**: Separates research code from engineering code
- **Built-in features**: Automatic logging, checkpointing, early stopping, multi-GPU support
- **Reproducibility**: Handles random seeds, deterministic training
- **Less boilerplate**: No need to manually write training loops, GPU transfer logic

**Model Class Structure**: The following implementation of our `TrOCRLightningModel` inherits from `LightningModule` and implements:
1. Initialization (`__init__`):
    - Load pre-trained `VisionEncoderDecoderModel` from Hugging Face
    - Load corresponding `TrOCRProcessor` to handle image preprocessing and tokenization.
    - Configure generation parameters, like start token, padding, EOS token.
    - Apply freezing strategy (freeze/unfreeze layers based on strategy)
    - Initialize metrics, such as CER (Percentage of characters that are wrong), WER (Percentage of words that are wrong), and accuracy
2. Freezing Strategy (`_apply_freezing_strategy`):
    - Iterate through `model.encoder.parameters()` and set `requires_grad = False` to freeze
    - For partial unfreezing: Access encoder layers via `model.encoder.encoder.layer[-N:]` and unfreeze
    - Decoder is always trainable: `model.decoder.parameters()` have `requires_grad = True`
    - We freeze parameters , since those don't compute gradients, which is equivalent to faster training and less memory. Also, pre-trained weigths and only task-specific layers adapt for better generalization on small datasets.
3. Forward Pass (`forward`):
    - Takes `pixel_values` (image tensors) and `labels` (text token IDs)
    - Returns model outputs including loss and logits
    - Loss is automatically computed by comparing logits with labels (cross-entropy)
4. Training Step (`training_step`):
    - This is what happens in each training iteration:
        1. forward pass: `outputs = self(pixel_values, labels=labels)`
        2. extract loss: `loss = outputs.loss`
        3. generate predictions for accuracy: `model.generate()` creates text sequences
        4. decode predictions and references to text strings.
        5. Compute CER (Character Error Rate) and convert to accuracy (1 - CER)
        6. Log metrics: `self.log('train_loss', loss)`
        7. Return loss (Lightning automatically calls `loss.backward()` and optimizer step)
    - Generating predictions during training is expensive, but very useful for monitoring.
5. Validation Step (`validation_step`):
    - Evaluate model on validation set to monitor overfitting
        1. Forward pass, same as training.
        2. Generate predictions: `model.generate(pixel_values, max_length=384)`
        3. Decode predictions and ground truth
        4. Compute metrics: CER, WER (Word Error Rate), accuracy
        5. Log metrics: `self.log('val_loss', loss, 'val_cer', cer, ...)`
    - CER vs WER: measures character-level errors, like insertions, deletions and usbtitutions.
        - Formula: `(substitutions + deletions + insertions) / total_characters`
        - Lower is better (0 = perfect)
    - WER: same, but at word level
        - More interpretable for human readers
        - Stricter: One wrong character = entire word wrong
6. Test Step (`test_step`):
    - Same as validation step but runs on the test data
    - Final evaluation after training is complete
7. Optimizer Configuration (`configure_optimizers`)
    - AdamW Optimizer:
        - Variant of Adam with decoupled weight decay to prevent overfitting
        - Learning rate: 5e-5, which is very used for fine-tuning Transformers
        - Weight decay: 0.01, which is L2 regularization.
    - ReduceLROnPlateau Scheduler:
        - Reduces lerning rate when validation loss plateaus
        - Factor: 0.5 (reduce LR by half)
        - Patience: 3 epochs, which means that it waits 3 epochs before reducing.
        - Helps model converge to better local minima

Some key variables are:
1. `max_length=256`: limit inference time and memory usage and at the same time it prevents infinite loops in generation. We decided on 256, since receipts usually have 100-200 tokens.
2. We replace padding with `-100` in labels, since PyTroch's `CrossEntropyLoss` ignores index `-100`, which ensures padding tokens don't contribute to loss and prevents the model from learning to predict padding.
3. We have log metrics with `prog_bar=True` to display metrics in real-time during training, which helps us monitor training progress.

This implementation follows best practices for fine-tuning vision-language models and provides a clean, maintainable, codebase for experimentation.

In [None]:
class TrOCRLightningModel(L.LightningModule):
    """
    TrOCR model with PyTorch Lightning for OCR on receipts.
    
    Fine-tuning strategy:
    - Phase 1: Freeze encoder, train decoder only
    - Phase 2: Unfreeze last encoder layers
    - Phase 3: Full fine-tuning
    """
    def __init__(
        self,
        model_name: str = "microsoft/trocr-base-printed",
        learning_rate: float = 5e-5,
        freeze_encoder: bool = True,
        unfreeze_last_n_layers: int = 0,
    ):
        super().__init__()
        self.save_hyperparameters()
        
        # Load pre-trained TrOCR model
        self.model = VisionEncoderDecoderModel.from_pretrained(model_name)
        self.processor = TrOCRProcessor.from_pretrained(model_name)
        
        # Configure model generation parameters
        self.model.config.decoder_start_token_id = self.processor.tokenizer.cls_token_id
        self.model.config.pad_token_id = self.processor.tokenizer.pad_token_id
        self.model.config.eos_token_id = self.processor.tokenizer.sep_token_id
        
        # Apply freezing strategy
        self._apply_freezing_strategy(freeze_encoder, unfreeze_last_n_layers)
        
        # Metrics
        self.cer_metric = evaluate.load("cer")  # Character Error Rate
        self.wer_metric = evaluate.load("wer")  # Word Error Rate
        
        # Accuracy metrics from torchmetrics (character-level accuracy approximation: 1 - CER)
        self.train_acc = torchmetrics.MeanMetric() # Stores 1 - CER for training
        self.val_acc = torchmetrics.MeanMetric() # Stores 1 - CER for validation
        
    def _apply_freezing_strategy(self, freeze_encoder: bool, unfreeze_last_n_layers: int):
        """
        Apply layer freezing strategy
        
        Args:
            freeze_encoder: If True, freeze encoder backbone
            unfreeze_last_n_layers: Number of last encoder layers to unfreeze
        """
        if freeze_encoder:
            # Freeze all encoder parameters
            for param in self.model.encoder.parameters():
                param.requires_grad = False
            print("Encoder frozen (transfer learning mode)")
            
            # Unfreeze last N layers if specified
            if unfreeze_last_n_layers > 0:
                # Access ViT encoder layers
                encoder_layers = self.model.encoder.encoder.layer
                for layer in encoder_layers[-unfreeze_last_n_layers:]:
                    for param in layer.parameters():
                        param.requires_grad = True
                print(f"Unfroze last {unfreeze_last_n_layers} encoder layers")
        else:
            print("Encoder unfrozen (full fine-tuning mode)")
        
        # Decoder is always trainable
        for param in self.model.decoder.parameters():
            param.requires_grad = True
        print("Decoder trainable")
        
        # Count trainable parameters
        total_params = sum(p.numel() for p in self.model.parameters())
        trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        print(f"Trainable parameters: {trainable_params:,} / {total_params:,} "
              f"({100 * trainable_params / total_params:.2f}%)")
    
    def forward(self, pixel_values, labels=None):
        return self.model(pixel_values=pixel_values, labels=labels)
    
    def training_step(self, batch, batch_idx):
        pixel_values = batch['pixel_values']
        labels = batch['labels']
        
        outputs = self(pixel_values, labels=labels)
        loss = outputs.loss
        
        # Generate predictions for accuracy calculation
        with torch.no_grad():
            generated_ids = self.model.generate(pixel_values, max_length=256)
            generated_texts = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
            
            labels_copy = labels.clone()
            labels_copy[labels_copy == -100] = self.processor.tokenizer.pad_token_id
            reference_texts = self.processor.batch_decode(labels_copy, skip_special_tokens=True)
            
            # Calculate accuracy (1 - CER)
            cer = self.cer_metric.compute(predictions=generated_texts, references=reference_texts)
            acc = max(0.0, 1.0 - cer)  # Convert CER to accuracy
            self.train_acc.update(acc)
        
        self.log('train_loss', loss, prog_bar=True, on_step=True, on_epoch=True)
        self.log('train_acc', self.train_acc, prog_bar=True, on_step=False, on_epoch=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        pixel_values = batch['pixel_values']
        labels = batch['labels']
        
        outputs = self(pixel_values, labels=labels)
        loss = outputs.loss
        
        # Generate predictions for metrics
        generated_ids = self.model.generate(pixel_values, max_length=384)
        generated_texts = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
        
        # Decode ground truth
        labels_copy = labels.clone()
        labels_copy[labels_copy == -100] = self.processor.tokenizer.pad_token_id
        reference_texts = self.processor.batch_decode(labels_copy, skip_special_tokens=True)
        
        # Calculate metrics
        cer = self.cer_metric.compute(predictions=generated_texts, references=reference_texts)
        wer = self.wer_metric.compute(predictions=generated_texts, references=reference_texts)
        acc = max(0.0, 1.0 - cer)  # Convert CER to accuracy
        self.val_acc.update(acc)
        
        self.log('val_loss', loss, prog_bar=True, on_epoch=True)
        self.log('val_cer', cer, prog_bar=True, on_epoch=True)
        self.log('val_wer', wer, prog_bar=True, on_epoch=True)
        self.log('val_acc', self.val_acc, prog_bar=True, on_epoch=True)
        
        return {'val_loss': loss, 'val_cer': cer, 'val_wer': wer, 'val_acc': acc}
    
    def test_step(self, batch, batch_idx):
        pixel_values = batch['pixel_values']
        labels = batch['labels']
        
        # Generate predictions
        generated_ids = self.model.generate(pixel_values, max_length=384)
        generated_texts = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
        
        # Decode ground truth
        labels_copy = labels.clone()
        labels_copy[labels_copy == -100] = self.processor.tokenizer.pad_token_id
        reference_texts = self.processor.batch_decode(labels_copy, skip_special_tokens=True)
        
        # Calculate metrics
        cer = self.cer_metric.compute(predictions=generated_texts, references=reference_texts)
        wer = self.wer_metric.compute(predictions=generated_texts, references=reference_texts)
        acc = max(0.0, 1.0 - cer)  # Convert CER to accuracy
        
        self.log('test_cer', cer, prog_bar=True)
        self.log('test_wer', wer, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        
        return {'test_cer': cer, 'test_wer': wer, 'test_acc': acc}
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=0.01
        )
        
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.5,
            patience=3
        )
        
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': 'val_loss'
            }
        }

print("TrOCR Lightning Model defined successfully")

## TrOCR DataModule with Custom Preprocessing

Now, we will implement a `LightningDataModule` to encapsulate all data-related logic in one place, such as:
- Dataset creation (train/val/test splits)
- Data transformations and augmentations
- DataLoader configuration, like batch size, shuffling and workers.
- Reproducibility to ensure consistent data handling across experiments.

To not rely only in TrOCR processor, we use our custom preprocessing first:
- **CLAHE (Contrast Limited Adaptive Histogram Equalization)**: enhances local contrast without amplifying noise.
- **Sharpening**: improves text edge definition for better OCR accuracy.
- Applying first our custom transform is very important because it improves the image quality of the receipts. Then, using TrOCR Processor helps us resize, normalize and convert to tensor with the purpose of preparing for model input.

**Augmentation Strategy**: 
- Training step: we will apply augmentations to increase the dataset diversity.
    - Random 5 degree rotation to simulate camera angles.
    - Color jitter for brightness and contrast to simulate lightning variations.
- Validation and Test sets: no augmentation, only preprocessing.
    - We want to measure true performance on clean, processed images.
    - Augmentation during validation would give misleading metrics.

**DataLoader Configuration**:
- **Batch size**: Set to 1, due to memory issues.
- **Shuffle**: True for training for random order, but False for validation and test
- **num_workers**: Set to 0 to avoid multiprocessing issues on some systems
- **collate_fn**: Custom function to properly batch mixed data types (tensors and text strings)


**Collate Function Explained**: The main issue is that PyTorch's default collate can't handle mixed types like tensors and list of strings. For this, we created a custom collate function that:
1. Stacks `pixel_values` tensors into batch: `torch.stack([item['pixel_values'] for item in batch])`
2. Stacks `labels` tensors into batch: `torch.stack([item['labels'] for item in batch])`
3. Keeps text as list: `[item['text'] for item in batch]` (used for evaluation decoding)

This will ensure us proper batching, while preserving text references for metrix computation.

In [None]:
class TrOCRDataModule(L.LightningDataModule):
    """DataModule for TrOCR training with CORD-v2 dataset."""
    
    def __init__(
        self,
        hf_dataset,
        processor,
        batch_size: int = 1,
        num_workers: int = 0,
        use_augmentation: bool = True,
    ):
        super().__init__()
        self.hf_dataset = hf_dataset
        self.processor = processor
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.use_augmentation = use_augmentation
        
        # Define preprocessing transforms, reusing our custom transforms
        self.train_transform = transforms.Compose([
            CLAHETransform(clip_limit=2.0, tile_grid_size=(8, 8)),
            SharpenTransform(amount=1.0),
            transforms.RandomRotation(degrees=5, fill=255) if use_augmentation else transforms.Lambda(lambda x: x),
            transforms.ColorJitter(brightness=0.1, contrast=0.1) if use_augmentation else transforms.Lambda(lambda x: x),
        ])
        
        self.val_transform = transforms.Compose([
            CLAHETransform(clip_limit=2.0, tile_grid_size=(8, 8)),
            SharpenTransform(amount=1.0),
        ])
    
    def setup(self, stage: str = None):
        self.train_dataset = TrOCRDataset(
            hf_dataset=self.hf_dataset['train'],
            processor=self.processor,
            image_transform=self.train_transform,
        )
        
        self.val_dataset = TrOCRDataset(
            hf_dataset=self.hf_dataset['validation'],
            processor=self.processor,
            image_transform=self.val_transform,
        )
        
        self.test_dataset = TrOCRDataset(
            hf_dataset=self.hf_dataset['test'],
            processor=self.processor,
            image_transform=self.val_transform,
        )
        
        print(f"TrOCRDataModule Train: {len(self.train_dataset)}, "
              f"Val: {len(self.val_dataset)}, Test: {len(self.test_dataset)}")
    
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
        )
    
    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
        )
    
    @staticmethod
    def collate_fn(batch):
        """Custom collate function to handle batching.
        
        This is necessary to properly batch mixed data types: tensors (pixel_values, labels) 
        are stacked, while text strings are kept as a list.
        """
        pixel_values = torch.stack([item['pixel_values'] for item in batch])
        labels = torch.stack([item['labels'] for item in batch])
        texts = [item['text'] for item in batch]
        
        return {
            'pixel_values': pixel_values,
            'labels': labels,
            'text': texts,
        }

print("TrOCR DataModule defined successfully")

## TrOCR Training Setup

We train three independent TrOCR models using different fine-tuning strategies to compare performance vs computational cost.

**Three Strategies**:

| Strategy | Trainable Params | Frozen Params | Epochs | Use Case |
|----------|-----------------|---------------|--------|----------|
| **1. Frozen Encoder** | 89M (decoder) | 245M (encoder) | 30 | Fast experiments, limited resources |
| **2. Partial Unfreezing** | 120M (decoder + last 3 encoder layers) | 214M (first encoder layers) | 25 | Balanced performance/efficiency |
| **3. Full Fine-Tuning** | 334M (all layers) | 0 | 20 | Maximum performance |

**Shared Training Configuration**:

All strategies use the same optimized hyperparameters:
- **Batch size**: 2
- **Gradient accumulation**: 16 steps
- **Learning rate**: `1e-4` with cosine annealing and 500-step warmup
- **Optimizer**: AdamW, since it's adaptive and works well with single LR across strategies
- **Precision**: Mixed FP16 because it's 2x faster and requires 50% less memory
- **Gradient clipping**: 1.0 to prevent instability
- **Validation**: 2× per epoch for early overfitting detection

The strategies have differnt epochs because:
- **Strategy 1 (30 epochs)**: Decoder-only training is fast and can afford more iterations
- **Strategy 2 (25 epochs)**: Moderate training time for balanced approach
- **Strategy 3 (20 epochs)**: Fewer epochs to prevent overfitting on small dataset, in this case of 800 receipts.

Each strategy has detailed explanation and training configuration in the following sections.

In [None]:
# Initialize TrOCR DataModule
trocr_dm = TrOCRDataModule(
    hf_dataset=ds,
    processor=trocr_processor,
    batch_size=1,
    num_workers=0,
    use_augmentation=True,
)

trocr_dm.setup()

# Strategy 1: Freeze encoder completely (only train decoder)
print("----- Strategy 1: Freeze Encoder - Train Decoder Only -----")
trocr_model_frozen = TrOCRLightningModel(
    model_name="microsoft/trocr-base-printed",
    learning_rate=5e-5,
    freeze_encoder=True,
    unfreeze_last_n_layers=0,
)

# Strategy 2: Freeze encoder but unfreeze last 3 layers
print("----- Strategy 2: Freeze Encoder - Unfreeze Last 3 Layers -----")
trocr_model_partial = TrOCRLightningModel(
    model_name="microsoft/trocr-base-printed",
    learning_rate=3e-5,
    freeze_encoder=True,
    unfreeze_last_n_layers=3,
)

# Strategy 3: Full fine-tuning (unfreeze everything)
print("----- Strategy 3: Full Fine-Tuning (All Layers Trainable) -----")
trocr_model_full = TrOCRLightningModel(
    model_name="microsoft/trocr-base-printed",
    learning_rate=2e-5,
    freeze_encoder=False,
    unfreeze_last_n_layers=0,
)

print("All TrOCR models initialized")

## TrOCR Training - Strategy 1: Frozen Encoder

In this first strategy, we will freece encoder (ViT) and train only decoder (RoBERTa). The configuration we'll implement is:
- Epochs 30: more iterations since it should be the fastest
- Batch size 2: to use a machine with aproximately 12GB VRAM
- Gradient accumulation 16: effective batch = 32 (stable gradients)
- Learning rate 1e-4: with cosine annealing and 500-step warmup
- Precision FP16 mixed: it's 2 times faster and requires less 50% less memory
- Gradient clipping 1.0: prevents exploding gradients
- Validation 2x per epoch: early overfitting detection

**Callbacks & Logging**:
- ModelCheckpoint: 
    - Saves top 3 models by  `val_loss` and last checkpoint
    - Format: `trocr-{epoch:02d}-{val_loss:.4f}-{val_acc:.4f}.ckpt`
    - Directory: `./trocr_checkpoints/strategy_frozen/`
    - Enables resuming interrupted training
- EarlyStopping:
    - Patience: 5 epochs. Stops if `val_loss` doesn't improve
    - Prevents overfitting and saves compute time
- CSVLogger:
    - Saves all metrics (epoch, train_loss, val_loss, train_acc, val_acc)
    - Directory: `./trocr_logs/strategy_frozen/metrics.csv`

Additionally, we use gradient accumulation because the batch size of 32 won't fit a 12GB VRAM machine, so the solution is to process 2 samples x 16 times, and then updating the weights. With this, we expect same gradient quality as if the batch size is 32, but using the memory of a batch size of 2.

In [None]:
trocr_checkpoint_s1 = ModelCheckpoint(
    dirpath='./trocr_checkpoints/strategy1_frozen',
    filename='trocr-frozen-{epoch:02d}-{val_loss:.4f}',
    monitor='val_loss',
    mode='min',
    save_top_k=3,
    save_last=True,
    verbose=True,
)

trocr_early_stop_s1 = EarlyStopping(
    monitor='val_loss',
    patience=5,
    mode='min',
    verbose=True,
)

trocr_csv_logger_s1 = CSVLogger(save_dir='./trocr_logs', name='strategy1_frozen')

trocr_trainer_s1 = L.Trainer(
    max_epochs=10,
    callbacks=[trocr_checkpoint_s1, trocr_early_stop_s1],
    logger=trocr_csv_logger_s1,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1 if torch.cuda.is_available() else 'auto',
    log_every_n_steps=10,
)

print("TrOCR Strategy 1 Trainer configured")
print(f"    - Checkpoints: {trocr_checkpoint_s1.dirpath}")
print(f"    - Logs: trocr_logs/strategy1_frozen")

checkpoint_files = glob.glob('./trocr_checkpoints/strategy1_frozen/*.ckpt')
if checkpoint_files:
    print(f"    - Found {len(checkpoint_files)} existing checkpoint(s)")
else:
    print(f"    - No existing checkpoints found")

In [None]:
# Start training from scratch, or resume from last checkpoint if exists
checkpoint_path = './trocr_checkpoints/strategy1_frozen/last.ckpt'
if os.path.exists(checkpoint_path):
    print(f"Resuming training from {checkpoint_path}")
    trocr_trainer_s1.fit(trocr_model_frozen, trocr_dm, ckpt_path=checkpoint_path)
else:
    print("Starting training from scratch")
    trocr_trainer_s1.fit(trocr_model_frozen, trocr_dm)

## TrOCR Evaluation - Strategy 1

After training, we need to evaluate the final performance on the test set, which consists of 100 receipts. For this, we'll load the best checkpoint, with the lowest `val_loss` and run on the test set to compute the metrics.

**Metrics**:
- **Validation Loss (Cross-Entropy)**: measures model confidence and accuracy
    - **< 0.5**: Excellent (very confident)
    - **0.5-1.0**: Good
    - **1.0-2.0**: Fair
    - **> 2.0**: Poor
- **Token-Level Accuracy**: percentrage of correctly predicted tokens:
    - Example:
        ```
        Ground truth: ["TO", "TAL", ":", " ", "$", "45", ".", "99"]  (8 tokens)
        Prediction:   ["TO", "TAL", ":", " ", "$", "45", ".", "9"]   (7 correct)
        Accuracy: 7/8 = 87.5%
        ```
    - - **≥ 0.95**: Excellent
    - **0.90-0.95**: Good
    - **0.85-0.90**: Fair
    - **< 0.85**: Poor
- **Optimal Manual Metrics**:
    - C**haracter Error Rate (CER)**:
        - Formula: `(Substitutions + Deletions + Insertions) / Total Characters`
        - Example: `"TOTAL: $45.99"` → `"TOTAL: $45.9"` = 1/13 = 7.7% error
    - **Word Error rate (WER)**:
        - Same as CER but at word level
        - More sensitive because one wrong character means the entire word is wrong.

The expected results of this strategy, based on CORD-v2 benchmarks, are:
- **Val loss**: 0.3-0.6
- **Token accuracy**: 88-93%
- **Training time**: 5-7 hours (with early stopping)

In [None]:
best_model_path = trocr_checkpoint_s1.best_model_path
print(f"Evaluating best model: {best_model_path}")
trocr_trainer_s1.test(trocr_model_frozen, trocr_dm, ckpt_path=best_model_path)

In [None]:
best_checkpoint_s1 = trocr_checkpoint_s1.best_model_path
trocr_model_s1_loaded = TrOCRLightningModel.load_from_checkpoint(best_checkpoint_s1)
trocr_model_s1_loaded.eval()

test_dataloader = trocr_dm.test_dataloader()
test_batch = next(iter(test_dataloader))

with torch.no_grad():
    pixel_values = test_batch['pixel_values']
    generated_ids = trocr_model_s1_loaded.model.generate(pixel_values, max_length=512)
    predictions = trocr_model_s1_loaded.processor.batch_decode(generated_ids, skip_special_tokens=True)
    
    labels = test_batch['labels'].clone()
    labels[labels == -100] = trocr_model_s1_loaded.processor.tokenizer.pad_token_id
    ground_truths = trocr_model_s1_loaded.processor.batch_decode(labels, skip_special_tokens=True)

num_samples = min(3, len(predictions))
for i in range(num_samples):
    print(f"Sample {i+1}:")
    print(f"Prediction: {predictions[i][:200]}...")
    print(f"Ground Truth: {ground_truths[i][:200]}...")

## TrOCR Training - Strategy 2: Partial Unfreezing

Now, we'll implement the second strategy for TrOCR model. Here, we'll unfreeze the last 3 encoder layers, while keeping the rest frozen. This allows the model to adapt, in a higher level, to the visual features of receipt-specific patterns, such as printed numbers, tables, stamps, etc, while preserving the low-level feature extraction learned from the pre-training.

**Changes compared to Strategy 1**:
- **Trainable parameters**: 120M (decoder + last 3 encoder layers)
- **Frozen parameters**: 214M (first encoder layers)
- **Epochs**: 25, 5 lees epochs, since it should take longer
- **Training time**: aproximately 6-8 hours because it's slower than Strategy 1 due to more layers to train

We chose this aproach because:
- The lower encoder layers can learn generic features like eges, textures and basic shapes, that's why we keep them frozen.
- The higher encoder layers can learn receipt-specific patterns, that's why we unfreeze to adapt the model for receipts.
- Balances domain adaptation with training efficiency.

The rest of the configuration remains the same, in terms of the batch size, gradient accumulation, learning rate, precision, callbacks.

The expected results of this strategy, based on CORD-v2 benchmarks, are:
- **Val loss**: 0.2-0.5, which is lower than Strategy 1
- **Token accuracy**: 91-95%, which is better than Strategy 1
- **Training time**: 6-8 hours with early stopping

In [None]:
trocr_checkpoint_s2 = ModelCheckpoint(
    dirpath='./trocr_checkpoints/strategy2_partial',
    filename='trocr-partial-{epoch:02d}-{val_loss:.4f}',
    monitor='val_loss',
    mode='min',
    save_top_k=3,
    save_last=True,
    verbose=True,
)

trocr_early_stop_s2 = EarlyStopping(
    monitor='val_loss',
    patience=5,
    mode='min',
    verbose=True,
)

trocr_csv_logger_s2 = CSVLogger(save_dir='./trocr_logs', name='strategy2_partial')

trocr_trainer_s2 = L.Trainer(
    max_epochs=10,
    callbacks=[trocr_checkpoint_s2, trocr_early_stop_s2],
    logger=trocr_csv_logger_s2,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1 if torch.cuda.is_available() else 'auto',
    log_every_n_steps=10,
)

print("TrOCR Strategy 2 Trainer configured")
print(f"    - Checkpoints: {trocr_checkpoint_s2.dirpath}")

checkpoint_files = glob.glob('./trocr_checkpoints/strategy2_partial/*.ckpt')
if checkpoint_files:
    print(f"    - Found {len(checkpoint_files)} existing checkpoint(s)")
else:
    print(f"    - No existing checkpoints found")

In [None]:
checkpoint_path = './trocr_checkpoints/strategy2_partial/last.ckpt'
if os.path.exists(checkpoint_path):
    print(f"Resuming training from {checkpoint_path}")
    trocr_trainer_s2.fit(trocr_model_partial, trocr_dm, ckpt_path=checkpoint_path)
else:
    print("Starting training from scratch")
    trocr_trainer_s2.fit(trocr_model_partial, trocr_dm)

## TrOCR Evaluation - Strategy 2

Same as before, after training Strategy 2, we evaluate on the test set to compare the performance against the first strategy. We expect better accuracy, since the model is adapting to visual features of receipt-specific characteristics.
- We expect a lower validation loss, due to better calibrated predictions (around 0.2-0.5)
- Higher token accuracy, around 91-95%.
- Better handling of receipt-specific visual patterns like faded text, stamps, and table structures.

The evaluation process is the same as before, we load the best checkpoint and compute metrics on the 100-receipt test set.

In [None]:
trocr_trainer_s2.test(trocr_model_partial, trocr_dm, ckpt_path=trocr_checkpoint_s2.best_model_path)

## TrOCR Training - Strategy 3: Full Fine-tuning

For the last strategy we'll implement for TrOCR, we are going to unfreeze all layers, encoder and decoder. This will provide maximum adaptation to the receipt domain, but required careful training to avoid overfitting on our small dataset of 800 receipts. The following is what changed, compared to the other models:
- **Trainable parameters**: 334M, which is equivalent to all layers
- **Frozen parameters**: 0
- **Epochs**: 20, fewer epochs to prevent overfitting.
- **Training time**: slowest since all parameters update.

The trade-offs of using this strategy are:
- Potentially the best performance, since the model can fully adapt to receipt domain
- It learns receipt-specific characteristics at both low-level and high-level visual patterns
- Has some risk of overfitting, due to the small dataset and the large model
- It is slower in training because all layers compute gradients and update weights

As for the batch_size, gradient accumulation, learning rate, callbacks, and other configuration parameters, they all remain the same as the previous strategies. We expect the following results:
- **Val loss**: 0.15-0.4, which would be the best of all strategies if there's no overfitting.
- **Token accuracy**: 93-96%, which is the highest accuracy
- **Training time**: 7-10 hours with early stopping

In [None]:
trocr_checkpoint_s3 = ModelCheckpoint(
    dirpath='./trocr_checkpoints/strategy3_full',
    filename='trocr-full-{epoch:02d}-{val_loss:.4f}',
    monitor='val_loss',
    mode='min',
    save_top_k=3,
    save_last=True,
    verbose=True,
)

trocr_early_stop_s3 = EarlyStopping(
    monitor='val_loss',
    patience=5,
    mode='min',
    verbose=True,
)

trocr_csv_logger_s3 = CSVLogger(save_dir='./trocr_logs', name='strategy3_full')

trocr_trainer_s3 = L.Trainer(
    max_epochs=10,
    callbacks=[trocr_checkpoint_s3, trocr_early_stop_s3],
    logger=trocr_csv_logger_s3,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1 if torch.cuda.is_available() else 'auto',
    log_every_n_steps=10,
)

print("TrOCR Strategy 3 Trainer configured")
print(f"    - Checkpoints: {trocr_checkpoint_s3.dirpath}")

checkpoint_files = glob.glob('./trocr_checkpoints/strategy3_full/*.ckpt')
if checkpoint_files:
    print(f"    - Found {len(checkpoint_files)} existing checkpoint(s)")
else:
    print(f"    - No existing checkpoints found")

In [None]:
checkpoint_path = './trocr_checkpoints/strategy3_full/last.ckpt'
if os.path.exists(checkpoint_path):
    print(f"Resuming training from {checkpoint_path}")
    trocr_trainer_s3.fit(trocr_model_full, trocr_dm, ckpt_path=checkpoint_path)
else:
    print("Starting training from scratch")
    trocr_trainer_s3.fit(trocr_model_full, trocr_dm)

## TrOCR Evaluation - Strategy 3

Same as before, after training Strategy 3, we'll evaluate it on the test set to determine if full fine-tuning provides better performance or if the model overfits on the 800 receipts it got during training.
- In the best case, we expect a val loss of approximately 0.15-0.4, token accuracy of 93%-96%.
- If it overfits, the performance should be similar or worse than Strategy 2, since the model memorized training data.

The evaluation process is the same. First, we load the best checkpoint and compute metrics on the 100-receipt test set, and then the results will guide our final recommendation that we'll compare agains the best Donut model, which will be tested in the next steps.

In [None]:
trocr_trainer_s3.test(trocr_model_full, trocr_dm, ckpt_path=trocr_checkpoint_s3.best_model_path)

In [None]:
try:
    logs_s1 = pd.read_csv('./trocr_logs/strategy1_frozen/version_0/metrics.csv')
    logs_s2 = pd.read_csv('./trocr_logs/strategy2_partial/version_0/metrics.csv')
    logs_s3 = pd.read_csv('./trocr_logs/strategy3_full/version_0/metrics.csv')
    
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    axes[0].plot(logs_s1['epoch'], logs_s1['val_loss'], label='Strategy 1: Frozen', marker='o')
    axes[0].plot(logs_s2['epoch'], logs_s2['val_loss'], label='Strategy 2: Partial', marker='s')
    axes[0].plot(logs_s3['epoch'], logs_s3['val_loss'], label='Strategy 3: Full', marker='^')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Validation Loss')
    axes[0].set_title('TrOCR: Validation Loss Comparison')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    axes[1].plot(logs_s1['epoch'], logs_s1['val_cer'], label='Strategy 1: Frozen', marker='o')
    axes[1].plot(logs_s2['epoch'], logs_s2['val_cer'], label='Strategy 2: Partial', marker='s')
    axes[1].plot(logs_s3['epoch'], logs_s3['val_cer'], label='Strategy 3: Full', marker='^')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Character Error Rate (CER)')
    axes[1].set_title('TrOCR: CER Comparison')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print("TrOCR Strategy Comparison - Final Metrics")
    print(f"Strategy 1 (Frozen) - Best Val Loss: {logs_s1['val_loss'].min():.4f}, Best CER: {logs_s1['val_cer'].min():.4f}")
    print(f"Strategy 2 (Partial)    - Best Val Loss: {logs_s2['val_loss'].min():.4f}, Best CER: {logs_s2['val_cer'].min():.4f}")
    print(f"Strategy 3 (Full)   - Best Val Loss: {logs_s3['val_loss'].min():.4f}, Best CER: {logs_s3['val_cer'].min():.4f}")
    
except FileNotFoundError:
    print("Training logs not found. Please train the models first.")