In [1]:
%pip install transformers pillow torch torchvision datasets scikit-learn matplotlib tqdm

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting scikit-learn
  Downloading scikit_learn-1.6.0-cp311-cp311-win_amd64.whl.metadata (15 kB)
Collecting matplotlib
  Downloading matplotlib-3.10.0-cp311-cp311-win_amd64.whl.metadata (11 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-18.1.0-cp311-cp311-win_amd64.whl.metadata (3.4 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Using cached dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting pandas (from datasets)
  Using cached pandas-2.2.3-cp311-cp311-win_amd64.whl.metadata (19 kB)
Collecting xxhash (from datasets)
  Using cached xxhash-3.5.0-cp311-cp311-win_amd64.whl.metadata (13 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Using cached multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec (from torch)
  Using cached fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Collecting aiohttp (from datasets)
  Downloading aiohttp-3.11.11-cp311-

In [10]:
%pip install transformers[torch]

Collecting accelerate>=0.26.0 (from transformers[torch])
  Downloading accelerate-1.2.1-py3-none-any.whl.metadata (19 kB)
Downloading accelerate-1.2.1-py3-none-any.whl (336 kB)
Installing collected packages: accelerate
Successfully installed accelerate-1.2.1
Note: you may need to restart the kernel to use updated packages.


In [1]:
import os
import pandas as pd
from PIL import Image
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments
from torch.utils.data import Dataset as TorchDataset
from torchvision.transforms import Compose, ToTensor, Normalize
import torch
import logging

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class OCRDataset(TorchDataset):
    def __init__(self, dataframe, processor, image_dir):
        self.dataframe = dataframe
        self.processor = processor
        self.image_dir = image_dir
        self.transform = Compose([ToTensor(), Normalize(mean=[0.5], std=[0.5])])

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.dataframe.iloc[idx]['image_path'])
        text = self.dataframe.iloc[idx]['text']

        image = Image.open(image_path).convert("RGB")
        pixel_values = self.processor.image_processor(image, return_tensors="pt").pixel_values[0]

        labels = self.processor.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).input_ids[0]
        return {"pixel_values": pixel_values, "labels": labels}

In [3]:
def load_data(processor, train_csv, val_csv, image_dir):
    train_df = pd.read_csv(train_csv)
    val_df = pd.read_csv(val_csv)

    train_dataset = OCRDataset(train_df, processor, image_dir)
    val_dataset = OCRDataset(val_df, processor, image_dir)
    return train_dataset, val_dataset

In [4]:
def custom_data_collator(features):
    # Stack `pixel_values` and `labels` tensors from the batch
    pixel_values = torch.stack([f["pixel_values"] for f in features])
    labels = torch.stack([f["labels"] for f in features])

    # Return a dictionary compatible with the VisionEncoderDecoderModel
    return {"pixel_values": pixel_values, "labels": labels}

In [5]:
class CustomSeq2SeqTrainer(Seq2SeqTrainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        logging.info(f"Inputs received in compute_loss: {inputs.keys()}")
        logging.info(f"Unexpected kwargs: {kwargs}")
        inputs = {k: v for k, v in inputs.items() if k in ["pixel_values", "labels"]}
        return super().compute_loss(model, inputs, return_outputs=return_outputs)

In [6]:
def fine_tune_model(processor, model, train_dataset, val_dataset, output_dir, training_args):
    trainer = CustomSeq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=processor.tokenizer,
        data_collator=custom_data_collator,
    )
    trainer.train()
    trainer.save_model(output_dir)
    # Save the processor to the same directory
    processor.save_pretrained(output_dir)

In [7]:
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id

Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "qkv_bias": false,
  "transformers_version": "4.47.1"
}

Config of the decoder: <class 'transformers.models.trocr.modeling_trocr.TrOCRForCausalLM'> is overwritten by shared decoder config: TrOCRConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_cross_attention": true,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classifier_dropout": 0.0,
  "cross_attention_hidden_size": 768,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder

In [8]:
image_dir = "./input/TSfinetuning/"
train_csv = "./training_data/train.csv"
val_csv = "./training_data/validation.csv"
fine_tuned_model_path = "./trained_model/"
output_text_file = "./output/trained_model_results.txt"

In [9]:
train_dataset, val_dataset = load_data(processor, train_csv, val_csv, image_dir)

In [10]:
training_args = Seq2SeqTrainingArguments(
    output_dir=fine_tuned_model_path,
    evaluation_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=5,
    save_strategy="no",
    logging_dir="./logs",
    logging_strategy="steps",
    logging_steps=10,
    save_total_limit=3,
    predict_with_generate=True,
)



In [11]:
fine_tune_model(processor, model, train_dataset, val_dataset, fine_tuned_model_path, training_args)
print("Fine-tuning complete. Model saved!")

  trainer = CustomSeq2SeqTrainer(
                                             
 20%|██        | 1/5 [00:09<00:31,  7.98s/it]

{'eval_loss': 14.617009162902832, 'eval_runtime': 1.5671, 'eval_samples_per_second': 0.638, 'eval_steps_per_second': 0.638, 'epoch': 1.0}


                                             
 40%|████      | 2/5 [00:16<00:22,  7.60s/it]

{'eval_loss': 1.7227602005004883, 'eval_runtime': 1.2778, 'eval_samples_per_second': 0.783, 'eval_steps_per_second': 0.783, 'epoch': 2.0}


                                             
 60%|██████    | 3/5 [00:23<00:14,  7.16s/it]

{'eval_loss': 0.531416654586792, 'eval_runtime': 1.2726, 'eval_samples_per_second': 0.786, 'eval_steps_per_second': 0.786, 'epoch': 3.0}


                                             
 80%|████████  | 4/5 [00:29<00:06,  6.98s/it]

{'eval_loss': 0.3350190818309784, 'eval_runtime': 1.2969, 'eval_samples_per_second': 0.771, 'eval_steps_per_second': 0.771, 'epoch': 4.0}


                                             
100%|██████████| 5/5 [00:36<00:00,  7.32s/it]


{'eval_loss': 0.3156398832798004, 'eval_runtime': 1.2651, 'eval_samples_per_second': 0.79, 'eval_steps_per_second': 0.79, 'epoch': 5.0}
{'train_runtime': 36.6325, 'train_samples_per_second': 0.136, 'train_steps_per_second': 0.136, 'train_loss': 6.360072326660156, 'epoch': 5.0}
Fine-tuning complete. Model saved!


In [12]:
fine_tuned_model = VisionEncoderDecoderModel.from_pretrained(fine_tuned_model_path)
processor = TrOCRProcessor.from_pretrained(fine_tuned_model_path)

Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "qkv_bias": false,
  "transformers_version": "4.47.1"
}

Config of the decoder: <class 'transformers.models.trocr.modeling_trocr.TrOCRForCausalLM'> is overwritten by shared decoder config: TrOCRConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_cross_attention": true,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classifier_dropout": 0.0,
  "cross_attention_hidden_size": 768,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder

In [13]:
def extract_text_with_fine_tuned_model(image_path, processor, model):
    try:
        image = Image.open(image_path).convert("RGB")
        pixel_values = processor.image_processor(image, return_tensors="pt").pixel_values
        generated_ids = model.generate(pixel_values)
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        return generated_text
    except Exception as e:
        return f"Error processing {image_path}: {str(e)}"

In [14]:
def process_images_in_folder(folder_path, output_file, processor, model):
    results = []
    for filename in os.listdir(folder_path):
        if filename.lower().endswith((".png", ".jpg", ".jpeg")):
            image_path = os.path.join(folder_path, filename)
            recognized_text = extract_text_with_fine_tuned_model(image_path, processor, model)
            results.append(f"{filename}: {recognized_text}")
            print(f"Processed {filename}")

    with open(output_file, "w", encoding="utf-8") as f:
        f.write("\n".join(results))

    print(f"Results saved to {output_file}")

In [15]:
image_dir = "./input/TSfinetuning/"

In [16]:
process_images_in_folder(image_dir, output_text_file, processor, fine_tuned_model)

Processed 3142640_box_1_1_0.png
Processed 3142640_box_7_2_0.png


KeyboardInterrupt: 