In [21]:
import os 
import torch
import pandas as pd 
from torch.utils.data import Dataset
from PIL import Image, ImageDraw, ImageFont
from transformers import (
    TrOCRProcessor,
    VisionEncoderDecoderModel,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    AutoTokenizer,
    default_data_collator,
)

In [22]:
def create_dummy_dataset(output_dir="data/dummy"):
    os.makedirs(output_dir, exist_ok=True)
    data = [
        ("img_1.jpg", "ក"), 
        ("img_2.jpg", "ខ"), 
        ("img_3.jpg", "កម្ពុជា"), 
        ("img_4.jpg", "សួស្តី"), 
        ("img_5.jpg", "ខ្ញុំរៀន")
    ]
    labels = []
    try:
        font = ImageFont.truetype("../assets/fonts/NotoSansKhmer-Regular.ttf", 60)
    except:
        print("WARNING: 'NotoSansKhmer-Regular.ttf' not found. Text will be squares.")
        font = ImageFont.load_default()
    for file_name, text in data:
        image = Image.new('RGB', (384, 384), color = (255, 255, 255))
        draw = ImageDraw.Draw(image)
        draw.text((50, 150), text, fill=(0, 0, 0), font=font)
        save_path = os.path.join(output_dir, file_name)
        image.save(save_path)
        labels.append({"file_name": file_name, "text": text})   
    df = pd.DataFrame(labels)
    df.to_csv(os.path.join(output_dir, "labels.csv"), index=False)
    print(f"Dummy dataset created at {output_dir}")
    return output_dir, df


In [23]:
class KhmerOCRDataset(Dataset):
    def __init__(self, root_dir,  df, processor, max_target_length=128):
        self.root_dir = root_dir
        self.df = df 
        self.processor = processor
        self.max_target_length = max_target_length
    
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        file_name = self.df.iloc[idx]["file_name"]
        text = self.df.iloc[idx]["text"]
        image_path = os.path.join(self.root_dir, file_name)
        
        image = Image.open(image_path).convert("RGB")
        pixel_values = self.processor(image, return_tensors="pt").pixel_values

        labels = self.processor.tokenizer(
            text,
            padding="max_length",
            max_length=self.max_target_length,
        ).input_ids
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]

        return {
            "pixel_values": pixel_values.squeeze(),
            "labels": torch.tensor(labels),
        }

In [24]:
def main():
    data_dir, df = create_dummy_dataset()

    print("Loading processor...")
    feature_extractor_name = "microsoft/trocr-small-handwritten"
    tokenizer_name = "xlm-roberta-base"

    processor = TrOCRProcessor.from_pretrained(feature_extractor_name)
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    processor.tokenizer = tokenizer

    dataset = KhmerOCRDataset(root_dir = data_dir, df=df, processor=processor)

    print("Loading Model....")
    model = VisionEncoderDecoderModel.from_pretrained(feature_extractor_name)

    model.decoder.resize_token_embeddings(len(processor.tokenizer))
    model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
    model.config.pad_token_id = processor.tokenizer.pad_token_id
    model.config.vocab_size = len(processor.tokenizer)

    model.config.eos_token_id = processor.tokenizer.sep_token_id
    model.config.max_length = 64
    model.config.early_stopping = True
    model.config.no_repeat_ngram_size = 3
    model.config.length_penalty = 2.0
    model.config.num_beams = 4

    training_args = Seq2SeqTrainingArguments(
        output_dir = "./outputs",
        per_device_train_batch_size = 2,
        num_train_epochs = 5,
        predict_with_generate = True,
        logging_steps = 2,
        save_steps = 100,
        eval_strategy = "no",
        fp16= torch.cuda.is_available(),
        remove_unused_columns = False,
    )

    trainer = Seq2SeqTrainer(
        model = model,
        tokenizer = processor.feature_extractor,
        args = training_args,
        train_dataset = dataset,
        data_collator = default_data_collator,
    )
    
    print("Starting Training....")
    trainer.train()

    print("Training Finished! Saving Model...")
    model.save_pretrained("./khmer_trocr_model")
    processor.save_pretrained("./khmer_trocr_model")

    print("\n--- Running Inference Test ---")
    image = Image.open(os.path.join(data_dir, "img_3.jpg")).convert("RGB")
    pixel_values = processor(image, return_tensors = "pt").pixel_values.to(model.device)
    
    generated_ids = model.generate(pixel_values)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

    print(f"Original Text: កម្ពុជា")
    print(f"Predicted Text: {generated_text}")

if __name__ == "__main__":
    main()

Dummy dataset created at data/dummy
Loading processor...
Loading Model....


Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-small-handwritten and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Seq2SeqTrainer(


Starting Training....


`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


Step,Training Loss
2,7.8242
4,7.7192
6,4.8514
8,5.9519
10,5.2501
12,5.2348
14,5.4393




Training Finished! Saving Model...

--- Running Inference Test ---
Original Text: កម្ពុជា
Predicted Text: [c
