# TrOCR OCR Notebook – V3

This is the **Final TrOCR notebook** with:
- Google Drive mounting (prevents data loss)
- Correct handling of the provided CSV schema
- 5‑Fold Cross‑Validation
- Augmentation, newline token handling, and logging



## 1. Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [None]:
import os
os.listdir('/content/drive/MyDrive/ocr_data')

['training_data.csv', 'images']

In [None]:
import pandas as pd, os

df = pd.read_csv('/content/drive/MyDrive/ocr_data/training_data.csv')
missing = [f for f in df['file_name']
           if not os.path.exists(f'/content/drive/MyDrive/ocr_data/images/{f}')]

print("Missing:", missing)

Missing: []


## 2. Install Dependencies

In [None]:
!pip install -q transformers accelerate evaluate albumentations opencv-python pandas scikit-learn

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/84.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
!pip install jiwer

Collecting jiwer
  Downloading jiwer-4.0.0-py3-none-any.whl.metadata (3.3 kB)
Collecting rapidfuzz>=3.9.7 (from jiwer)
  Downloading rapidfuzz-3.14.3-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (12 kB)
Downloading jiwer-4.0.0-py3-none-any.whl (23 kB)
Downloading rapidfuzz-3.14.3-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (3.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m23.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rapidfuzz, jiwer
Successfully installed jiwer-4.0.0 rapidfuzz-3.14.3


## 3. Imports & Reproducibility

In [None]:
import torch, os, random
import numpy as np
import pandas as pd
import cv2
from PIL import Image
import albumentations as A
import evaluate
from sklearn.model_selection import KFold
from transformers import (
    TrOCRProcessor,
    VisionEncoderDecoderModel,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    EarlyStoppingCallback
)
from torch.utils.data import Dataset

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)


## 4. Dataset Configuration

In [None]:
# UPDATE THESE PATHS ONLY
BASE_DIR = '/content/drive/MyDrive/ocr_data'
CSV_PATH = f'{BASE_DIR}/training_data.csv'
IMAGE_DIR = f'{BASE_DIR}/images'

df = pd.read_csv(CSV_PATH)

# Map CSV → model inputs
image_paths = [os.path.join(IMAGE_DIR, f) for f in df['file_name']]
labels = df['transcription_text'].tolist()

print('Total samples:', len(image_paths))

Total samples: 691


## 5. Preprocessing & Augmentation

In [None]:
USE_AUGMENTATION = False  # Toggle for experiments

augment = A.Compose([
    A.GaussNoise(var_limit=(10, 50), p=0.5),
    A.ElasticTransform(alpha=1, sigma=50, p=0.3),
    A.Rotate(limit=5, border_mode=cv2.BORDER_CONSTANT, p=0.5)
])

def preprocess_image(path):
    img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    img = cv2.resize(img, (384, 384))
    if USE_AUGMENTATION:
        img = augment(image=img)['image']
    img = cv2.adaptiveThreshold(
        img, 255,
        cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
        cv2.THRESH_BINARY, 61, 11
    )
    return Image.fromarray(img)


  A.GaussNoise(var_limit=(10, 50), p=0.5),


## 6. Dataset Class

In [None]:
class OCRDataset(Dataset):
    def __init__(self, paths, texts, processor):
        self.paths = paths
        self.texts = texts
        self.processor = processor

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        image = preprocess_image(self.paths[idx])
        text = self.texts[idx]
        enc = self.processor(image, text, return_tensors='pt', padding='max_length')
        return {
            'pixel_values': enc.pixel_values.squeeze(),
            'labels': enc.labels.squeeze()
        }


## 7. Model & Tokenizer (Newline Support)

In [None]:
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
tokenizer = processor.tokenizer

if '\n' not in tokenizer.get_vocab():
    tokenizer.add_tokens(['\n'])

model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten')
model.decoder.resize_token_embeddings(len(tokenizer))

model.config.pad_token_id = tokenizer.pad_token_id
model.config.eos_token_id = tokenizer.sep_token_id
model.config.decoder_start_token_id = tokenizer.cls_token_id


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/224 [00:00<?, ?B/s]

The image processor of type `ViTImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. 


config.json: 0.00B [00:00, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/772 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Loading weights:   0%|          | 0/478 [00:00<?, ?it/s]

VisionEncoderDecoderModel LOAD REPORT from: microsoft/trocr-base-handwritten
Key                         | Status  | 
----------------------------+---------+-
encoder.pooler.dense.weight | MISSING | 
encoder.pooler.dense.bias   | MISSING | 

Notes:
- MISSING	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.


generation_config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


## 8. Metrics (CER / WER)

In [None]:
import evaluate
import numpy as np

# 1. Load the evaluation tools (prevents NameError)
cer_metric = evaluate.load("cer")
wer_metric = evaluate.load("wer")

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    # 2. Fix the OverflowError: Replace -100 with the pad_token_id
    # This allows the tokenizer to decode the labels without crashing
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id

    # Some versions of the trainer return pred_ids as a tuple or with -100s
    if isinstance(pred_ids, tuple):
        pred_ids = pred_ids[0]
    pred_ids[pred_ids == -100] = tokenizer.pad_token_id

    # 3. Decode the IDs into strings
    # We use tokenizer here because it handles the added tokens (like \n) correctly
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    # 4. Calculate the scientific metrics for your thesis
    cer = cer_metric.compute(predictions=pred_str, references=label_str)
    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"cer": cer, "wer": wer}

Downloading builder script: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

## 9. 5‑Fold Cross‑Validation (Colab‑Safe)

In [None]:
import os
import shutil
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, EarlyStoppingCallback
from sklearn.model_selection import KFold

# MANUAL CONTROLLER
# Change this to 1, 2, 3, 4, or 5 to run a specific fold
TARGET_FOLD = 5


# 1. Setup Logging and KFold
kf = KFold(n_splits=5, shuffle=True, random_state=SEED)
LOG_DIR = f'{BASE_DIR}/logs'
os.makedirs(LOG_DIR, exist_ok=True)

for fold, (train_idx, val_idx) in enumerate(kf.split(image_paths), 1):
    # Only run the fold we are currently targeting
    if fold != TARGET_FOLD:
        continue

    print(f'===== STARTING MANUAL RUN: Fold {fold} =====')

    # 2. Prepare Fold Datasets
    train_ds = OCRDataset([image_paths[i] for i in train_idx],
                           [labels[i] for i in train_idx], processor)
    val_ds = OCRDataset([image_paths[i] for i in val_idx],
                         [labels[i] for i in val_idx], processor)

    # 3. Training Arguments (Optimized for T4 and Thesis reporting)
    training_args = Seq2SeqTrainingArguments(
        output_dir=f'{BASE_DIR}/logs/fold_{fold}',
        predict_with_generate=True,
        eval_strategy="epoch",
        save_strategy="epoch",

        # Memory optimizations
        per_device_train_batch_size=2,
        per_device_eval_batch_size=2,
        gradient_accumulation_steps=4,
        fp16=True,
        optim="adamw_torch_fused",

        learning_rate=5e-5,
        num_train_epochs=8,
        weight_decay=0.01,
        load_best_model_at_end=True,
        metric_for_best_model="cer",
        greater_is_better=False,
        save_total_limit=1,
        report_to="none"
    )

    # 4. Initialize Trainer with a FRESH model
    trainer = Seq2SeqTrainer(
        model=model, # Ensure you re-loaded this in a previous cell
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=val_ds,
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
    )

    # 5. Execute Training
    trainer.train()

    # 6. Final Evaluation and EXPLICIT SAVE
    metrics = trainer.evaluate()

    # This ensures the best model weights are saved cleanly outside of a checkpoint folder
    trainer.save_model(f'{BASE_DIR}/logs/fold_{fold}/final_model')

    # Save Results
    with open(f'{LOG_DIR}/ocr_fold_{fold}_metrics.txt', 'w') as f:
        for k, v in metrics.items():
            f.write(f'{k}: {v}\n')

    # AUTOMATIC CHECKPOINT CLEANUP
    # This loop finds all "checkpoint-XXX" folders and deletes them to save Drive space
    print(f" Starting cleanup for Fold {fold}...")
    fold_dir = f'{BASE_DIR}/logs/fold_{fold}'
    for item in os.listdir(fold_dir):
        item_path = os.path.join(fold_dir, item)
        if os.path.isdir(item_path) and item.startswith("checkpoint"):
            print(f"Removing bulky checkpoint: {item}")
            shutil.rmtree(item_path)

    print(f"Fold {fold} Completed. Final Model and Metrics saved. Checkpoints cleared.")
    break

===== STARTING MANUAL RUN: Fold 5 =====


Epoch,Training Loss,Validation Loss,Cer,Wer
1,No log,0.166823,0.73767,0.893939
2,No log,0.139321,0.831152,0.97619
3,No log,0.126901,0.735783,0.907648
4,No log,0.122557,0.734524,0.91342
5,No log,0.120448,0.728234,0.894661
6,No log,0.118947,0.712632,0.853535
7,No log,0.120623,0.699925,0.849928
8,1.143754,0.120086,0.694011,0.83189


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

There were missing keys in the checkpoint model loaded: ['decoder.output_projection.weight'].


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

 Starting cleanup for Fold 5...
Removing bulky checkpoint: checkpoint-560
Fold 5 Completed. Final Model and Metrics saved. Checkpoints cleared.


## How to Run
1. First run: `USE_AUGMENTATION = False`, `RUN_ALL_FOLDS = False`
2. Second run: `USE_AUGMENTATION = True`, `RUN_ALL_FOLDS = False`
3. Optional: Run remaining folds one per session

All logs and checkpoints are saved to Google Drive.