In [None]:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel

# Load pre-trained model checkpoint `VisionEncoderDecoderModel`
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-printed')

# TrOCR is a decoder model and should be used within a VisionEncoderDecoderModel
model = VisionEncoderDecoderModel.from_pretrained(
    'microsoft/trocr-base-printed')

In [None]:
import torch
from torch.utils.data import Dataset
import os
import time
import numpy as np
import pandas as pd

import cv2
import json
from PIL import Image


generated_ids = model.generate(pixel_values)
generated_text = processor.batch_decode(
    generated_ids, skip_special_tokens=True)[0]

root_dir = 'IAM/image'

class HWDataset(Dataset):
    def __init__(self, root_dir, df, processor, max_target_length=128):
        self.root_dir = root_dir
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # get file name + text
        file_name = self.df['file_name'][idx]
        text = self.df['text'][idx]
        # prepare image (i.e. resize + normalize)
        image = Image.open(self.root_dir + file_name).convert("RGB")
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        # add labels (input_ids) by encoding the text
        labels = self.processor.tokenizer(
            text, padding="max_length", max_length=self.max_target_length).input_ids
        # important: make sure that PAD tokens are ignored by the loss function
        labels = [label if label !=
                  self.processor.tokenizer.pad_token_id else -100 for label in labels]

        encoding = {"pixel_values": pixel_values.squeeze(),
                    "labels": torch.tensor(labels)}
        return encoding

    def __iter__(self):
        for i in range(self.__len__()):
            yield self.__getitem__(i)

# Test specific example
url = 'https://fki.tic.heia-fr.ch/static/img/a01-122-02-00.jpg'
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")

# Extract all .jpg files and create a DataFrame
file_names = []
texts = []  # You need to provide the corresponding texts for each image
for file in os.listdir(root_dir):
    if file.endswith(".jpg"):
        file_names.append(file)
        # Assume text file has the same name but with .txt extension
        text_file = os.path.splitext(file)[0] + ".txt"
        with open(os.path.join(root_dir, text_file), 'r') as f:
            texts.append(f.read().strip())

# Create DataFrame
df = pd.DataFrame({'file_name': file_names, 'text': texts})

# Split DataFrame into train, test, and validation sets
train_df = df.sample(frac=0.8, random_state=42)
test_df = df.drop(train_df.index).sample(frac=0.5, random_state=42)
valid_df = df.drop(train_df.index).drop(test_df.index)

# Create datasets
train_dataset = HWDataset(root_dir=root_dir, df=train_df, processor=processor)
test_dataset = HWDataset(root_dir=root_dir, df=test_df, processor=processor)
eval_dataset = HWDataset(root_dir=root_dir, df=valid_df, processor=processor)

# Example of processing all images into pixels for the model
def convert_to_pixels():
    for sample in train_dataset:
        pixel_values = sample['pixel_values']
        labels = sample['labels']


pixel_values = processor(images=image, return_tensors="pt").pixel_values
    

# Optionally, use the datasets for training and evaluation