In [1]:
import os
import json
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, AdamW
from tqdm import tqdm
import editdistance


In [2]:
TRAIN_DIR = r"D:\Projects\Final_project\Dataset\train_data\train"
TEST_DIR = r"D:\Projects\Final_project\Dataset\test_data\test"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cpu


In [3]:
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")

model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.eos_token_id = processor.tokenizer.eos_token_id
model.config.vocab_size = model.config.decoder.vocab_size

model.to(device)


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "qkv_bias": false,
  "torch_dtype": "float32",
  "transformers_version": "4.50.0.dev0"
}

Config of the decoder: <class 'transformers.models.trocr.modeling_trocr.TrOCRForCausalL

VisionEncoderDecoderModel(
  (encoder): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=False)
              (key): Linear(in_features=768, out_features=768, bias=False)
              (value): Linear(in_features=768, out_features=768, bias=False)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linea

In [4]:
def load_image_text_pairs(folder):
    data = []
    for file in os.listdir(folder):
        if file.endswith(".json"):
            json_path = os.path.join(folder, file)
            image_path = json_path.replace(".json", ".jpg")

            if os.path.exists(image_path):
                with open(json_path, 'r', encoding='utf-8') as f:
                    label_data = json.load(f)

                    if isinstance(label_data, list):
                        full_text = " ".join([item.get("text", "") for item in label_data])
                    elif isinstance(label_data, dict):
                        full_text = label_data.get("text", "")
                    else:
                        full_text = ""

                    data.append((image_path, full_text.strip()))
    return data


In [5]:
train_data = load_image_text_pairs(TRAIN_DIR)[:50]  # Use first 50 for speed
test_data = load_image_text_pairs(TEST_DIR)[:10]


In [6]:
class GNHKDataset(Dataset):
    def __init__(self, data, processor):
        self.data = data
        self.processor = processor

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_path, text = self.data[idx]
        image = Image.open(image_path).convert("RGB")
        encoding = self.processor(images=image, text=text, return_tensors="pt", padding="max_length", truncation=True)
        return {
            "pixel_values": encoding["pixel_values"].squeeze(),
            "labels": encoding["labels"].squeeze()
        }


In [7]:
train_dataset = GNHKDataset(train_data, processor)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

optimizer = AdamW(model.parameters(), lr=5e-5)

model.train()
for epoch in range(1):  # Just 1 epoch for testing
    total_loss = 0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(pixel_values=pixel_values, labels=labels)
        loss = outputs.loss
        total_loss += loss.item()

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    print(f"Epoch {epoch+1} avg loss: {total_loss / len(train_loader):.4f}")


Epoch 1: 100%|█████████████████████████████████████████████████████████████████████████| 25/25 [30:49<00:00, 73.98s/it]

Epoch 1 avg loss: 2.7697





In [9]:
# Inference with correct batch shape
model.eval()
predictions, references = [], []

with torch.no_grad():
    for batch in test_loader:
        pixel_values = batch["pixel_values"].to(device)  # No need to unsqueeze here
        labels = batch["labels"].to(device)

        # Generate the predictions from the model
        generated_ids = model.generate(pixel_values)

        # Decode the predictions and labels
        pred_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        label_text = processor.batch_decode(labels, skip_special_tokens=True)[0]

        predictions.append(pred_text)
        references.append(label_text)

# Calculate Character Error Rate (CER)
cer_score = calculate_cer(predictions, references)
print(f"Character Error Rate (CER): {cer_score:.4f}")


Character Error Rate (CER): 1.0000
