In [1]:
%pip install transformers pillow torch torchvision datasets scikit-learn matplotlib tqdm

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting scikit-learn
  Downloading scikit_learn-1.6.0-cp311-cp311-win_amd64.whl.metadata (15 kB)
Collecting matplotlib
  Downloading matplotlib-3.10.0-cp311-cp311-win_amd64.whl.metadata (11 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-18.1.0-cp311-cp311-win_amd64.whl.metadata (3.4 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Using cached dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting pandas (from datasets)
  Using cached pandas-2.2.3-cp311-cp311-win_amd64.whl.metadata (19 kB)
Collecting xxhash (from datasets)
  Using cached xxhash-3.5.0-cp311-cp311-win_amd64.whl.metadata (13 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Using cached multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec (from torch)
  Using cached fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Collecting aiohttp (from datasets)
  Downloading aiohttp-3.11.11-cp311-

In [None]:
import os
import pandas as pd
from PIL import Image
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments
from datasets import Dataset
from torch.utils.data import Dataset as TorchDataset
from torchvision.transforms import Compose, ToTensor, Normalize

In [None]:
class OCRDataset(TorchDataset):
    def __init__(self, dataframe, processor, image_dir):
        self.dataframe = dataframe
        self.processor = processor
        self.image_dir = image_dir
        self.transform = Compose([ToTensor(), Normalize(mean=[0.5], std=[0.5])])

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.dataframe.iloc[idx]['image_path'])
        text = self.dataframe.iloc[idx]['text']

        image = Image.open(image_path).convert("RGB")
        pixel_values = self.processor.image_processor(image, return_tensors="pt").pixel_values[0]

        labels = self.processor.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).input_ids[0]
        return {"pixel_values": pixel_values, "labels": labels}

In [None]:
def load_data(processor, train_csv, val_csv, image_dir):
    train_df = pd.read_csv(train_csv)
    val_df = pd.read_csv(val_csv)

    train_dataset = OCRDataset(train_df, processor, image_dir)
    val_dataset = OCRDataset(val_df, processor, image_dir)
    return train_dataset, val_dataset

In [None]:
def fine_tune_model(processor, model, train_dataset, val_dataset, output_dir, training_args):
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=processor.tokenizer,
        data_collator=None,
    )
    trainer.train()
    trainer.save_model(output_dir)

In [None]:
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")

In [None]:
image_dir = "./input/dates/"
train_csv = "./training_data/train.csv"
val_csv = "./training_data/validation.csv"
fine_tuned_model_path = "./trained_model/"
output_text_file = "./output/trained_model_results.txt"
train_dataset, val_dataset = load_data(processor, train_csv, val_csv, image_dir)

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir=fine_tuned_model_path,
    evaluation_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=5,
    save_strategy="epoch",
    logging_dir="./logs",
    logging_strategy="steps",
    logging_steps=10,
    save_total_limit=3,
    predict_with_generate=True,
)

In [None]:
fine_tune_model(processor, model, train_dataset, val_dataset, fine_tuned_model_path, training_args)
print("Fine-tuning complete. Model saved!")

In [None]:
fine_tuned_model = VisionEncoderDecoderModel.from_pretrained(fine_tuned_model_path)
processor = TrOCRProcessor.from_pretrained(fine_tuned_model_path)

In [None]:
def extract_text_with_fine_tuned_model(image_path, processor, model):
    try:
        image = Image.open(image_path).convert("RGB")
        pixel_values = processor.image_processor(image, return_tensors="pt").pixel_values
        generated_ids = model.generate(pixel_values)
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        return generated_text
    except Exception as e:
        return f"Error processing {image_path}: {str(e)}"

In [None]:
def process_images_in_folder(folder_path, output_file, processor, model):
    results = []
    for filename in os.listdir(folder_path):
        if filename.lower().endswith((".png", ".jpg", ".jpeg")):
            image_path = os.path.join(folder_path, filename)
            recognized_text = extract_text_with_fine_tuned_model(image_path, processor, model)
            results.append(f"{filename}: {recognized_text}")
            print(f"Processed {filename}")

    with open(output_file, "w", encoding="utf-8") as f:
        f.write("\n".join(results))

    print(f"Results saved to {output_file}")

In [None]:
process_images_in_folder(image_dir, output_text_file, processor, fine_tuned_model)