In [4]:
%pip install -q datasets jiwer

Note: you may need to restart the kernel to use updated packages.


In [7]:
import os
import torch
import pandas as pd
import requests
from PIL import Image
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from torch.utils.data import Dataset, DataLoader
from datasets import load_metric
from tqdm.notebook import tqdm

# Load pre-trained model checkpoint `VisionEncoderDecoderModel`
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-printed')
model = VisionEncoderDecoderModel.from_pretrained(
    'microsoft/trocr-base-printed')

# Test specific example
url = 'https://fki.tic.heia-fr.ch/static/img/a01-122-02-00.jpg'
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")

pixel_values = processor(images=image, return_tensors="pt").pixel_values
generated_ids = model.generate(pixel_values)
generated_text = processor.batch_decode(
    generated_ids, skip_special_tokens=True)[0]
print("Generated text:", generated_text)

# Path to the IAM image data
root_dir = 'IAM/image/'

# Path to the gt_test.txt test information file
gt_file_path = os.path.join(root_dir, '../gt_test.txt')

# Character error rate for accuracy measurement
cer = load_metric("cer")

# Read gt_test.txt and create file_names and texts lists
file_names = []
texts = []

with open(gt_file_path, 'r') as f:
    for line in f:
        parts = line.strip().split('\t')
        if len(parts) == 2:
            file_names.append(parts[0])
            texts.append(parts[1])

# Check if any files are missing texts
assert len(file_names) == len(
    texts), "File names and texts lists must be of the same length"


Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-base-printed and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Generated text: INDLUS THE


Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-base-handwritten and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Running evaluation...


  0%|          | 0/37 [00:00<?, ?it/s]

KeyError: 0

In [None]:
# Define the dataset class
class IAMDataset(Dataset):
    def __init__(self, root_dir, df, processor, max_target_length=128):
        self.root_dir = root_dir
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # get file name + text
        file_name = self.df['file_name'][idx]
        text = self.df['text'][idx]
        # prepare image (i.e. resize + normalize)
        image = Image.open(os.path.join(
            self.root_dir, file_name)).convert("RGB")
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        # add labels (input_ids) by encoding the text
        labels = self.processor.tokenizer(
            text, padding="max_length", max_length=self.max_target_length).input_ids
        # important: make sure that PAD tokens are ignored by the loss function
        labels = [label if label !=
                  self.processor.tokenizer.pad_token_id else -100 for label in labels]
        encoding = {"pixel_values": pixel_values.squeeze(),
                    "labels": torch.tensor(labels)}
        return encoding
    


In [None]:
# Create DataFrame
df = pd.DataFrame({'file_name': file_names, 'text': texts})

# Split DataFrame into train, test, and validation sets
train_df = df.sample(frac=0.8, random_state=42)
test_df = df.drop(train_df.index).sample(frac=0.5, random_state=42)
valid_df = df.drop(train_df.index).drop(test_df.index)

# Create datasets
train_dataset = IAMDataset(root_dir=root_dir, df=train_df, processor=processor)
test_dataset = IAMDataset(root_dir=root_dir, df=test_df, processor=processor)
eval_dataset = IAMDataset(root_dir=root_dir, df=valid_df, processor=processor)

# DataLoaders
test_dataloader = DataLoader(test_dataset, batch_size=8)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

print("Running evaluation...")

for batch in tqdm(test_dataloader):
    # predict using generate
    pixel_values = batch["pixel_values"].to(device)
    outputs = model.generate(pixel_values)

    # decode
    pred_str = processor.batch_decode(outputs, skip_special_tokens=True)
    labels = batch["labels"]
    labels[labels == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(labels, skip_special_tokens=True)

    # add batch to metric
    cer.add_batch(predictions=pred_str, references=label_str)

final_score = cer.compute()

print("Character error rate on test set:", final_score)