In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
!pip install -q transformers accelerate pillow pandas jiwer


In [13]:
import os
import pandas as pd
from PIL import Image
import torch
from torch.utils.data import Dataset
from transformers import (
    TrOCRProcessor,
    VisionEncoderDecoderModel,
    TrainingArguments,
    Trainer
)

# ===============================
# PATH CONFIG
# ===============================
BASE_DIR = "/content/drive/MyDrive/handwriting_competition"
IMAGES_DIR = os.path.join(BASE_DIR, "images")

TRAIN_CSV = os.path.join(BASE_DIR, "train.csv")
VAL_CSV   = os.path.join(BASE_DIR, "val.csv")

MODEL_NAME = "microsoft/trocr-base-handwritten"
OUTPUT_DIR = "/content/trocr_answer_ocr"

# ===============================
# TRAIN CONFIG
# ===============================
BATCH_SIZE = 8
EPOCHS = 5
MAX_TEXT_LENGTH = 128

device = torch.device("cuda")
print("Using device:", device)


Using device: cuda


In [14]:
class OCRDataset(Dataset):
    def __init__(self, csv_file, processor):
        self.data = pd.read_csv(csv_file)
        self.processor = processor

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        image_path = row["image_path"]

        # Handle relative image paths
        if not os.path.isabs(image_path):
            image_path = os.path.join(IMAGES_DIR, image_path)

        image = Image.open(image_path).convert("RGB")
        text = str(row["text"])

        pixel_values = self.processor(
            image, return_tensors="pt"
        ).pixel_values.squeeze(0)

        labels = self.processor.tokenizer(
            text,
            padding="max_length",
            max_length=MAX_TEXT_LENGTH,
            truncation=True,
            return_tensors="pt"
        ).input_ids.squeeze(0)

        labels[labels == self.processor.tokenizer.pad_token_id] = -100

        return {
            "pixel_values": pixel_values,
            "labels": labels
        }


In [20]:
processor = TrOCRProcessor.from_pretrained(MODEL_NAME)

model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME).to(device)

# ===============================
# FREEZE VISION ENCODER
# ===============================
for param in model.encoder.parameters():
    param.requires_grad = False

print("✅ Vision encoder frozen. Training decoder only.")

model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size


Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-base-handwritten and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


✅ Vision encoder frozen. Training decoder only.


In [16]:
train_dataset = OCRDataset(TRAIN_CSV, processor)
val_dataset   = OCRDataset(VAL_CSV, processor)

print("Train samples:", len(train_dataset))
print("Val samples:", len(val_dataset))


Train samples: 5511
Val samples: 518


In [17]:
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    eval_strategy="epoch",     # IMPORTANT for older transformers
    save_strategy="epoch",
    num_train_epochs=EPOCHS,
    fp16=True,
    logging_steps=100,
    save_total_limit=2,
    report_to="none",
    remove_unused_columns=False
)


In [18]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset
)


In [21]:
trainer.train(resume_from_checkpoint=True)


There were missing keys in the checkpoint model loaded: ['decoder.output_projection.weight'].


Epoch,Training Loss,Validation Loss
3,0.6351,1.02723
4,0.2978,0.829738
5,0.1085,0.773998


TrainOutput(global_step=3445, training_loss=0.23456071734601425, metrics={'train_runtime': 3031.5299, 'train_samples_per_second': 9.089, 'train_steps_per_second': 1.136, 'total_flos': 2.0654916127061705e+19, 'train_loss': 0.23456071734601425, 'epoch': 5.0})

In [24]:
!mkdir -p /content/drive/MyDrive/Custom_ocr
!cp -r /content/trocr_answer_ocr/checkpoint-3445 \
      /content/drive/MyDrive/Custom_ocr/trocr_best_epoch5
