# TrOCR Fine-Tuning for Prescription Words (English + Arabic)

**Dataset Structure:**
- Excel file with columns: `image_name` | `word_label`
- Folder containing all cropped word images

In [1]:
!pip install transformers torchvision pandas openpyxl pillow accelerate



In [2]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image
from transformers import (
    TrOCRProcessor,
    VisionEncoderDecoderModel,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)
import os

## 1. Load and Prepare Dataset

In [3]:
# Load your Excel file
df = pd.read_excel("/kaggle/input/excel-file/mohammed file.xlsx")  # Change to your file path
print(f"Loaded {len(df)} samples")
df.head()

Loaded 292 samples


Unnamed: 0,Cropped Image Name,word
0,Dental_prescription_605 (1)_crop_0.jpg,اللزوم
1,Dental_prescription_605 (1)_crop_1.jpg,باراسيتامول
2,Dental_prescription_605 (1)_crop_2.jpg,ميترونيدازول
3,Dental_prescription_605 (1)_crop_3.jpg,الغذاء
4,Dental_prescription_605 (1)_crop_4.jpg,عند


In [4]:
# Configuration
IMAGE_FOLDER = "/kaggle/input/cropped-images"  # Change this
TEST_SIZE = 0.1  # 10% for validation
BATCH_SIZE = 8
EPOCHS = 10
MODEL_NAME = "microsoft/trocr-base-handwritten"  # or "microsoft/trocr-base-stage1"

In [5]:
# Create PyTorch Dataset
class PrescriptionDataset(Dataset):
    def __init__(self, df, processor, image_folder):
        self.df = df
        self.processor = processor
        self.image_folder = image_folder

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # Get image name and label
        img_name = self.df.iloc[idx]['Cropped Image Name']
        text_label = str(self.df.iloc[idx]['word'])
        
        # Load image
        image_path = os.path.join(self.image_folder, img_name)
        image = Image.open(image_path).convert('RGB')

        # Process image and text
        pixel_values = self.processor(image, return_tensors="pt").pixel_values.squeeze()
        labels = self.processor.tokenizer(
            text_label,
            return_tensors="pt",
            padding="max_length",
            max_length=64,
            truncation=True
        ).input_ids.squeeze()

        return {
            "pixel_values": pixel_values,
            "labels": labels
        }

In [6]:
# Initialize processor
processor = TrOCRProcessor.from_pretrained(MODEL_NAME)

# Split data
train_df = df.sample(frac=1-TEST_SIZE, random_state=42)
test_df = df.drop(train_df.index)

# Create datasets
train_dataset = PrescriptionDataset(train_df, processor, IMAGE_FOLDER)
eval_dataset = PrescriptionDataset(test_df, processor, IMAGE_FOLDER)

print(f"Train samples: {len(train_dataset)}, Eval samples: {len(eval_dataset)}")

preprocessor_config.json:   0%|          | 0.00/224 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.12k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/772 [00:00<?, ?B/s]

Train samples: 263, Eval samples: 29


# Initialize model

In [7]:
model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id

config.json:   0%|          | 0.00/4.17k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "qkv_bias": false,
  "transformers_version": "4.47.0"
}

Config of the decoder: <class 'transformers.models.trocr.modeling_trocr.TrOCRForCausalLM'> is overwritten by shared decoder config: TrOCRConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_cross_attention": true,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classifier_dropout": 0.0,
  "cross_attention_hidden_size": 768,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder

generation_config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

# Training

In [15]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./",  # Use root directory to avoid nested folders
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    evaluation_strategy="steps",  # More frequent, smaller evaluations
    eval_steps=200,  # Evaluate every 200 steps
    logging_strategy="steps",
    logging_steps=50,
    learning_rate=4e-5,
    num_train_epochs=EPOCHS,
    warmup_ratio=0.1,
    weight_decay=0.01,
    fp16=True if torch.cuda.is_available() else False,
    report_to="none",
    # Disable all saving to conserve space
    save_strategy="no",
    save_total_limit=0,
    load_best_model_at_end=False,
    # Memory/performance optimizations
    gradient_accumulation_steps=2,
    fp16_full_eval=True,
    generation_max_length=64,
    generation_num_beams=1,
    # Kaggle-specific optimizations
    dataloader_pin_memory=False,  # Reduces memory usage
    dataloader_num_workers=2,  # Optimal for Kaggle
)



In [16]:
def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    # Clip predictions to valid token ID range
    pred_ids = np.clip(pred_ids, 0, len(processor.tokenizer) - 1)
    
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)

    # Calculate character error rate (CER)
    cer = 0
    for pred, label in zip(pred_str, label_str):
        # Simple CER calculation
        cer += sum(1 for a, b in zip(pred, label) if a != b) / max(len(pred), len(label))
    cer /= len(pred_str)

    return {"cer": cer}

In [17]:
# Create trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
)

# Run training
train_result = trainer.train()

Step,Training Loss,Validation Loss


# Test

In [37]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
n = 28
# 2. Load test image
for i in range(n):
    test_image_path = os.path.join(IMAGE_FOLDER, test_df.iloc[i]['Cropped Image Name'])
    test_image = Image.open(test_image_path).convert('RGB')
    # 3. Preprocess with device awareness
    with torch.no_grad():
        # Move inputs to same device as model
        pixel_values = processor(test_image, return_tensors="pt").pixel_values.to(device)
        # Generate predictions
        generated_ids = model.generate(pixel_values)
        # Decode results
        predicted_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    # 4. Display results
    print("\n=== Prediction Test ===")
    print(f"Image: {test_df.iloc[i]['Cropped Image Name']}")
    print(f"Predicted: {predicted_text}")
    print(f"Actual: {test_df.iloc[i]['word']}")
    # print(f"Match: {predicted_text == test_df.iloc[0]['word']}")


=== Prediction Test ===
Image: Dental_prescription_619 (1)_crop_11.jpg
Predicted: 8
Actual: 8

=== Prediction Test ===
Image: Dental_prescription_607 (1)_crop_0.jpg
Predicted: كل
Actual: كل

=== Prediction Test ===
Image: Dental_prescription_482_crop_4.jpg
Predicted: 6
Actual: 6

=== Prediction Test ===
Image: Dental_prescription_543_crop_1.jpg
Predicted: القطار
Actual: باراسيتامول

=== Prediction Test ===
Image: Dental_prescription_543_crop_5.jpg
Predicted: ساعات
Actual: الفطار

=== Prediction Test ===
Image: Dental_prescription_543_crop_9.jpg
Predicted: الفطار
Actual: الفطار

=== Prediction Test ===
Image: Dental_prescription_566_crop_5.jpg
Predicted: قبل
Actual: قبل

=== Prediction Test ===
Image: Dental_prescription_557_crop_10.jpg
Predicted: سيبروفلوكساسين
Actual: سيبروفلوكساسين

=== Prediction Test ===
Image: Dental_prescription_619_crop_9.jpg
Predicted: 6
Actual: 6

=== Prediction Test ===
Image: Dental_prescription_621_crop_0.jpg
Predicted: الفطار
Actual: الفطار

=== Predictio