In [None]:
# ✅ Suppress tokenizer parallelism warning
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# ✅ Import libraries
import torch
from datasets import load_dataset, Image as DatasetsImage
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
from jiwer import cer, wer, compute_measures
from tqdm import tqdm

# ✅ Load model and processor
model_path = "./trocr-iam-finetuned"  # change if needed
processor = TrOCRProcessor.from_pretrained(model_path)
model = VisionEncoderDecoderModel.from_pretrained(model_path)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device).eval()

# ✅ Load IAM dataset
dataset = load_dataset("gagan3012/IAM")
dataset = dataset.cast_column("image", DatasetsImage())

# ✅ Use test split or fallback to part of train, then take only first 50
if "test" in dataset:
    test_data = dataset["test"].select(range(50))  # Limit to 50 samples
else:
    print("⚠️ No test split found. Using part of train split.")
    test_data = dataset["train"].shuffle(seed=123).select(range(50))

# ✅ Run inference on limited samples
preds, refs = [], []

for sample in tqdm(test_data, desc="Evaluating"):
    image = sample["image"].convert("RGB")
    pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(device)

    with torch.no_grad():
        generated_ids = model.generate(pixel_values, max_length=128)
    predicted_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

    preds.append(predicted_text.strip())
    refs.append(sample["text"].strip())

# ✅ Compute evaluation metrics
cer_score = cer(refs, preds)
wer_score = wer(refs, preds)
measures = compute_measures(refs, preds)
exact_match = sum([1 for r, p in zip(refs, preds) if r == p]) / len(refs)

# ✅ Display metrics
print("\n📊 OCR Evaluation Metrics (on 50 samples)")
print(f"🔹 Character Error Rate (CER): {cer_score:.4f}")
print(f"🔹 Word Error Rate (WER):      {wer_score:.4f}")
print(f"🔹 Exact Match Accuracy:       {exact_match*100:.2f}%")
print(f"🔹 Insertions:                 {measures['insertions']}")
print(f"🔹 Deletions:                  {measures['deletions']}")
print(f"🔹 Substitutions:              {measures['substitutions']}")


Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "pooler_act": "tanh",
  "pooler_output_size": 768,
  "qkv_bias": false,
  "torch_dtype": "float32",
  "transformers_version": "4.51.3"
}

Config of the decoder: <class 'transformers.models.trocr.modeling_trocr.TrOCRForCausalLM'> is overwritten by shared decoder config: TrOCRConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  ...
  "use_learned_position_embeddings": true,
  "vocab_size": 50265
}

Output is truncated. View as a scrollable element or open in a text editor. Adjus