# TrOCR Fine-Tuning for Prescription Words (English + Arabic)

In [1]:
!pip install transformers torchvision pandas openpyxl pillow accelerate jiwer



In [20]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image, ImageEnhance, ImageOps
from transformers import (
    TrOCRProcessor,
    VisionEncoderDecoderModel,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)
import os
from jiwer import wer, cer

## 1. Load and Prepare Dataset

In [3]:
# Load your Excel file
df = pd.read_excel("/kaggle/input/data-fine-tune/DataPreparation.xlsx")  # Change to your file path
print(f"Loaded {len(df)} samples")
df.head()

Loaded 1300 samples


Unnamed: 0,image,word
0,Dental_prescription_605 (1)_crop_0.jpg,اللزوم
1,Dental_prescription_605 (1)_crop_1.jpg,باراسيتامول
2,Dental_prescription_605 (1)_crop_2.jpg,ميترونيدازول
3,Dental_prescription_605 (1)_crop_3.jpg,الغذاء
4,Dental_prescription_605 (1)_crop_4.jpg,عند


In [5]:
# Configuration
IMAGE_FOLDER = "/kaggle/input/cropped-images/PrescriptionImagesData"  # Change this
TEST_SIZE = 0.2
BATCH_SIZE = 8
EPOCHS = 10
MODEL_NAME = "microsoft/trocr-base-handwritten"  # or "microsoft/trocr-base-stage1"

In [6]:
def preprocess_image(image):
    # Convert to grayscale
    image = image.convert('L')  # 'L' mode for grayscale
    # Enhance contrast
    enhancer = ImageEnhance.Contrast(image)
    image = enhancer.enhance(2.0)  # Increase contrast by factor of 2
    # Resize to a fixed height (e.g., 64px) while maintaining aspect ratio
    fixed_height = 64
    width_percent = (fixed_height / float(image.size[1]))
    new_width = int((float(image.size[0]) * float(width_percent)))
    image = image.resize((new_width, fixed_height), Image.Resampling.LANCZOS)
    # Convert back to RGB (required for TrOCR processor)
    image = image.convert('RGB')
    return image
    #########################################################################
    # gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    # img_blur = cv2.GaussianBlur(gray, (5, 5), 0)
    # img_thresh = cv2.adaptiveThreshold(img_blur, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 101, 20)
    # kernel = np.ones((2, 2), np.uint8)
    # img_morph = cv2.morphologyEx(img_thresh, cv2.MORPH_CLOSE, kernel, iterations=1)
    # clahe = cv2.createCLAHE(clipLimit=5.0, tileGridSize=(8, 8))
    # img_clahe = clahe.apply(img_morph)
    # return img_clahe

In [14]:
# Create PyTorch Dataset
class PrescriptionDataset(Dataset):
    def __init__(self, df, processor, image_folder):
        self.df = df
        self.processor = processor
        self.image_folder = image_folder

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # Get image name and label
        img_name = self.df.iloc[idx]['image']
        text_label = str(self.df.iloc[idx]['word'])
        
        # Load image
        image_path = os.path.join(self.image_folder, img_name)
        image = Image.open(image_path).convert('RGB')
        image = preprocess_image(image)  # Apply preprocessing

        # Process image and text
        pixel_values = self.processor(image, return_tensors="pt").pixel_values.squeeze()
        labels = self.processor.tokenizer(
            text_label,
            return_tensors="pt",
            padding="max_length",
            max_length=64,
            truncation=True
        ).input_ids.squeeze()

        return {
            "pixel_values": pixel_values,
            "labels": labels
        }

# Initialize model

In [33]:
# Initialize processor
processor = TrOCRProcessor.from_pretrained(MODEL_NAME)

# Split data
train_df = df.sample(frac=1-TEST_SIZE, random_state=42)
test_df = df.drop(train_df.index)

# Create datasets
train_dataset = PrescriptionDataset(train_df, processor, IMAGE_FOLDER)
eval_dataset = PrescriptionDataset(test_df, processor, IMAGE_FOLDER)

print(f"Train samples: {len(train_dataset)}, Eval samples: {len(eval_dataset)}")

Train samples: 1040, Eval samples: 260


In [16]:
model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id

Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "qkv_bias": false,
  "transformers_version": "4.47.0"
}

Config of the decoder: <class 'transformers.models.trocr.modeling_trocr.TrOCRForCausalLM'> is overwritten by shared decoder config: TrOCRConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_cross_attention": true,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classifier_dropout": 0.0,
  "cross_attention_hidden_size": 768,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder

# Training

In [38]:
import torch, gc
def cleanup():
    gc.collect()
    torch.cuda.empty_cache()
    !rm -rf /kaggle/working/*
cleanup()

In [39]:
def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions
    
    # Safety checks
    vocab_size = len(processor.tokenizer)
    pred_ids = np.clip(pred_ids, 0, vocab_size - 1)
    
    # Decoding
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)

    cleanup()
    # Metric calculation
    return {
        "cer": cer(label_str, pred_str),
        "wer": wer(label_str, pred_str),
    }

In [40]:
# training_args = Seq2SeqTrainingArguments(
#     output_dir="./",
#     per_device_train_batch_size=BATCH_SIZE,
#     per_device_eval_batch_size=BATCH_SIZE,
#     evaluation_strategy="epoch",
#     logging_strategy="epoch",
#     save_strategy="no",
#     learning_rate=4e-5,
#     num_train_epochs=EPOCHS,
#     warmup_ratio=0.1,
#     weight_decay=0.01,
#     fp16=True if torch.cuda.is_available() else False,
#     report_to="none",
#     save_total_limit=0,
#     load_best_model_at_end=False,
#     gradient_accumulation_steps=2,
#     dataloader_pin_memory=False,
#     dataloader_num_workers=2,
#     generation_max_length=64,
#     generation_num_beams=1,
#     predict_with_generate=True,
# )
training_args = Seq2SeqTrainingArguments(
    output_dir="./",  # Use root directory to minimize writes
    per_device_train_batch_size=4,  # Reduced from 8 to save memory
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=2,  # Compensate for smaller batch size
    
    # Evaluation and logging
    evaluation_strategy="epoch",
    logging_strategy="steps",  # More frequent logging than epoch
    logging_steps=50,          # Log every 50 steps
    
    # Disable all saving to conserve space
    save_strategy="no",        # No checkpoint saving
    save_total_limit=0,        # Keep zero checkpoints
    load_best_model_at_end=False,
    
    # Training hyperparameters
    learning_rate=4e-5,
    num_train_epochs=10,
    warmup_ratio=0.1,
    weight_decay=0.01,
    
    # Memory optimizations
    fp16=True,  # Enable mixed precision if GPU supports it
    fp16_full_eval=True,
    dataloader_pin_memory=False,  # Reduce memory pressure
    dataloader_num_workers=2,     # Optimal for Kaggle
    
    # Generation config
    predict_with_generate=True,
    generation_max_length=64,
    generation_num_beams=1,       # Faster than multi-beam
    
    # Disable unnecessary features
    report_to="none",             # No external logging
    no_cuda=False,                # Ensure GPU is used
)

# Initialize trainer with fixed compute_metrics
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
)

# Run training
train_result = trainer.train()



Epoch,Training Loss,Validation Loss,Cer,Wer
1,0.0212,0.062586,0.28637,0.354478
2,0.0659,0.0619,0.250983,0.33209
3,0.0369,0.059539,0.20118,0.279851
4,0.0263,0.047922,0.163172,0.246269
5,0.0165,0.047791,0.171691,0.268657
6,0.011,0.046907,0.12844,0.190299
7,0.008,0.041189,0.125164,0.182836
8,0.0017,0.037429,0.114024,0.186567
9,0.0014,0.036898,0.104849,0.175373
10,0.0004,0.036984,0.110747,0.175373


Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Tr

In [41]:
trainer.save_model("final_model")
!tar -czvf model.tar.gz final_model  # Compress
!rm -rf final_model  # Remove uncompressed version

final_model/
final_model/generation_config.json
final_model/training_args.bin
final_model/model.safetensors
final_model/config.json


# Test

In [47]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Initialize counters
total_samples = len(test_df)
correct_exact = 0  # Exact match (whole word)
correct_chars = 0   # Character-level accuracy
total_chars = 0

# Test loop
for i in range(total_samples):
    # Load and preprocess image (same as during training)
    test_image_path = os.path.join(IMAGE_FOLDER, test_df.iloc[i]['image'])
    test_image = Image.open(test_image_path).convert('RGB')
    test_image = preprocess_image(test_image)  # Apply preprocessing
    
    # Generate prediction
    with torch.no_grad():
        pixel_values = processor(test_image, return_tensors="pt").pixel_values.to(device)
        generated_ids = model.generate(pixel_values)
        predicted_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    # Get ground truth
    true_text = str(test_df.iloc[i]['word'])
    
    # Calculate exact match
    if predicted_text.strip() == true_text.strip():
        correct_exact += 1
    
    # Calculate character-level accuracy
    min_len = min(len(predicted_text), len(true_text))
    matched_chars = sum(1 for a, b in zip(predicted_text, true_text) if a == b)
    correct_chars += matched_chars
    total_chars += max(len(predicted_text), len(true_text))

# Compute metrics
exact_accuracy = (correct_exact / total_samples) * 100
char_accuracy = (correct_chars / total_chars) * 100

# Print summary
print("\n=== Test Results ===")
print(f"Exact Match Accuracy: {exact_accuracy:.2f}% ({correct_exact}/{total_samples})")
print(f"Character-Level Accuracy: {char_accuracy:.2f}%")
print("\nSample Predictions (First 5):")
for i in range(min(5, total_samples)):
    test_image_path = os.path.join(IMAGE_FOLDER, test_df.iloc[i]['image'])
    true_text = str(test_df.iloc[i]['word'])
    with torch.no_grad():
        pixel_values = processor(Image.open(test_image_path).convert('RGB'), return_tensors="pt").pixel_values.to(device)
        predicted_text = processor.batch_decode(model.generate(pixel_values), skip_special_tokens=True)[0]
    print(f"Image: {test_df.iloc[i]['image']}")
    print(f"Predicted: '{predicted_text}' | Actual: '{true_text}'")
    print("-----")


=== Test Results ===
Exact Match Accuracy: 83.46% (217/260)
Character-Level Accuracy: 83.91%

Sample Predictions (First 5):
Image: Dental_prescription_605 (1)_crop_1.jpg
Predicted: 'باراسيتامول' | Actual: 'باراسيتامول'
-----
Image: Dental_prescription_605 (1)_crop_8.jpg
Predicted: 'بعد' | Actual: 'بعد'
-----
Image: Dental_prescription_619 (1)_crop_4.jpg
Predicted: 'ساعات' | Actual: 'ساعات'
-----
Image: Dental_prescription_619 (1)_crop_5.jpg
Predicted: 'قبل' | Actual: 'قبل'
-----
Image: Dental_prescription_619 (1)_crop_7.jpg
Predicted: 'كل' | Actual: 'كل'
-----
