# TrOCR â€” Quick Test Notebook 

This notebook lets you:
- Load a **local** TrOCR checkpoint (e.g., `./runs/trocr/epoch_1`)
- Run a **single-image** prediction
- Run a **mini evaluation** (CER/WER) on a CSV split

> **Run this notebook from *inside* your `TrOCR/` folder.**

In [12]:
# ==== 0) Paths you can edit ====
CKPT_PATH = "./runs/trocr/epoch_1"   # checkpoint folder saved by your trainer
IMG_PATH  = "../data/image_splits/validation_set_splits/V06/page_1/line_20.png"  # any existing image
VAL_CSV   = "./val_tiny_exist.csv"   # CSV with columns: image_path, transcription

# Decoding settings
DECODE_MAX_NEW_TOKENS = 64
DECODE_NUM_BEAMS = 5
DECODE_DO_SAMPLE = False

# Helpful env flags (optional)
import os
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

'false'

In [13]:
# ==== 1) Imports & device ====
import os, torch
from PIL import Image
from transformers import TrOCRProcessor, VisionEncoderDecoderModel

def pick_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    has_mps = getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
    if has_mps:
        return torch.device("mps")
    return torch.device("cpu")

device = pick_device()
print("device:", device)

device: mps


In [14]:
# ==== 2) Load processor & model from LOCAL checkpoint ====
assert os.path.isdir(CKPT_PATH), f"Checkpoint dir not found: {CKPT_PATH}"

# Load locally to avoid Hub lookups
processor = TrOCRProcessor.from_pretrained(CKPT_PATH, use_fast=False, local_files_only=True)
model = VisionEncoderDecoderModel.from_pretrained(CKPT_PATH, local_files_only=True).to(device).eval()

# Guard config for stage1-like checkpoints
if getattr(model.config, "decoder_start_token_id", None) is None:
    model.config.decoder_start_token_id = processor.tokenizer.bos_token_id
if getattr(model.config, "pad_token_id", None) is None:
    model.config.pad_token_id = processor.tokenizer.pad_token_id

# Keep label length modest (matches memory-friendly training advice)
try:
    processor.tokenizer.model_max_length = max(processor.tokenizer.model_max_length, 128)
except Exception:
    pass

print("Loaded processor & model from:", CKPT_PATH)

Loaded processor & model from: ./runs/trocr/epoch_1


## Single-image prediction

In [15]:
# ==== 3) Predict on one image ====
assert os.path.isfile(IMG_PATH), f"Image not found: {IMG_PATH}"
img = Image.open(IMG_PATH).convert("RGB")

inputs = processor(images=img, return_tensors="pt").to(device)
with torch.no_grad():
    ids = model.generate(
        **inputs,
        max_new_tokens=DECODE_MAX_NEW_TOKENS,
        num_beams=DECODE_NUM_BEAMS,
        do_sample=DECODE_DO_SAMPLE,
    )
pred = processor.batch_decode(ids, skip_special_tokens=True)[0]
print("PRED:", pred)

PRED: # 30


## Mini evaluation (CER/WER)

In [7]:
# ==== 4) Evaluate on a CSV (tiny val) ====

import pandas as pd
from jiwer import cer, wer

assert os.path.isfile(VAL_CSV), f"CSV not found: {VAL_CSV}"
df = pd.read_csv(VAL_CSV)
assert "image_path" in df.columns and "transcription" in df.columns, "CSV must have columns: image_path, transcription"

gts, preds = [], []
for _, r in df.iterrows():
    p = r["image_path"]
 
    full = p if os.path.isabs(p) else os.path.join("..", p)
    if not os.path.isfile(full):
        # skip missing rows
        continue
    img = Image.open(full).convert("RGB")
    inputs = processor(images=img, return_tensors="pt").to(device)
    with torch.no_grad():
        ids = model.generate(
            **inputs,
            max_new_tokens=DECODE_MAX_NEW_TOKENS,
            num_beams=DECODE_NUM_BEAMS,
            do_sample=DECODE_DO_SAMPLE,
        )
    preds.append(processor.batch_decode(ids, skip_special_tokens=True)[0])
    gts.append(str(r["transcription"]))

print("samples:", len(preds))
if len(preds):
    print("CER:", cer(gts, preds))
    print("WER:", wer(gts, preds))
else:
    print("No valid samples found (check VAL_CSV paths and that files exist)")

samples: 37
CER: 0.8782816229116945
WER: 1.0158730158730158


In [19]:
import os, pandas as pd
from PIL import Image
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from jiwer import cer, wer

CKPT = "./runs/overfit16_strict/epoch_15"        # change to your best epoch
CSV  = "./overfit16.csv"

device = ("cuda" if torch.cuda.is_available()
          else "mps" if getattr(torch.backends,"mps",None) and torch.backends.mps.is_available()
          else "cpu")
proc  = TrOCRProcessor.from_pretrained(CKPT, use_fast=False, local_files_only=True)
model = VisionEncoderDecoderModel.from_pretrained(CKPT, local_files_only=True).to(device).eval()

df = pd.read_csv(CSV)
preds, gts = [], []
for _, r in df.iterrows():
    full = os.path.join("..", r["image_path"])  # CSV is relative to repo root
    if not os.path.isfile(full): continue
    img = Image.open(full).convert("RGB")
    inp = proc(images=img, return_tensors="pt").to(device)
    with torch.no_grad():
        ids = model.generate(
            **inp,
            num_beams=8, do_sample=False, max_new_tokens=64,
            no_repeat_ngram_size=3, repetition_penalty=1.2,
            length_penalty=0.8, early_stopping=True,
        )
    preds.append(proc.batch_decode(ids, skip_special_tokens=True)[0])
    gts.append(str(r["transcription"]))

print("samples:", len(preds))
print("CER:", cer(gts, preds))
print("WER:", wer(gts, preds))
list(zip(gts, preds))[:5]  # peek first 5


samples: 16
CER: 0.8065693430656934
WER: 0.9615384615384616


[('Multivitamin', ': tablet #'),
 ('mefenamic acid 500mg/capsule', ': tablet x'),
 ('S:', '#'),
 ('Label: 1 tablet after breakfast', ': tablet x'),
 ('Telmisartan #200', 'Bioene')]

In [20]:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import torch, os, re

ckpt = "./runs/overfit1_strict/epoch_60"  # your checkpoint
img  = "../data/image_splits/validation_set_splits/V06/page_1/line_7.png"  # your test image

device = ("cuda" if torch.cuda.is_available()
          else "mps" if getattr(torch.backends,"mps",None) and torch.backends.mps.is_available()
          else "cpu")

proc  = TrOCRProcessor.from_pretrained(ckpt, use_fast=False, local_files_only=True)
model = VisionEncoderDecoderModel.from_pretrained(ckpt, local_files_only=True).to(device).eval()

# Make sure config has proper ids (safety)
if getattr(model.config, "decoder_start_token_id", None) is None:
    model.config.decoder_start_token_id = proc.tokenizer.bos_token_id
if getattr(model.config, "pad_token_id", None) is None:
    model.config.pad_token_id = proc.tokenizer.pad_token_id
eos_id = proc.tokenizer.eos_token_id

im  = Image.open(img).convert("RGB")
inp = proc(images=im, return_tensors="pt").to(device)

# Use beam search + anti-repetition
ids = model.generate(
    **inp,
    num_beams=8,                 # enables length_penalty
    do_sample=False,
    max_new_tokens=16,           # short since GT is "Daily"
    no_repeat_ngram_size=3,      # block short repeats
    repetition_penalty=1.3,      # discourage token reuse
    length_penalty=0.8,          # used only with beams
    eos_token_id=eos_id,         # encourage stopping
    pad_token_id=model.config.pad_token_id,
)

pred = proc.batch_decode(ids, skip_special_tokens=True)[0].strip()

# Optional tiny post-process: collapse immediate word duplicates
pred = re.sub(r"\b(\w+)(\s+\1\b)+", r"\1", pred)

print("PRED:", pred)


PRED: Daily


In [None]:
python - <<'PY'
import os, torch, pandas as pd
from PIL import Image
from jiwer import cer, wer
from transformers import TrOCRProcessor, VisionEncoderDecoderModel

CKPT="TrOCR/runs/hw_refine_lr3e6/epoch_3"
CSV ="TrOCR/val.csv"; BASE="."

proc  = TrOCRProcessor.from_pretrained(CKPT, use_fast=False, local_files_only=True)
model = VisionEncoderDecoderModel.from_pretrained(CKPT, local_files_only=True).eval()

preds,gts=[],[]
df = pd.read_csv(CSV)
for _, r in df.iterrows():
    p = r["image_path"]; p = p if os.path.isabs(p) else os.path.join(BASE,p)
    if not os.path.isfile(p): continue
    im = Image.open(p).convert("RGB")
    inp = proc(images=im, return_tensors="pt")
    ids = model.generate(**inp, num_beams=10, do_sample=False, max_new_tokens=64,
                         no_repeat_ngram_size=4, repetition_penalty=1.5, length_penalty=0.8)
    preds.append(proc.batch_decode(ids, skip_special_tokens=True)[0])
    gts.append(str(r["transcription"]))
print("samples:",len(preds))
print("CER/WER:", cer(gts,preds), wer(gts,preds))
for i,(gt,pr) in enumerate(list(zip(gts,preds))[:20]):
    print(f"\n[{i}] GT: {gt}\n     PR: {pr}")
PY