#Setup

In [None]:
# Install required packages
!pip install transformers==4.36.0
!pip install datasets
!pip install pillow
!pip install opencv-python
!pip install torch torchvision
!pip install evaluate
!pip install jiwer
!pip install accelerate -U




#Import Libraries

In [None]:
import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import (
    TrOCRProcessor,
    VisionEncoderDecoderModel,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    default_data_collator
)
from datasets import load_metric
from PIL import Image
import pandas as pd
from sklearn.model_selection import KFold
import json
from tqdm import tqdm

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")


#Image Processing Functions

In [None]:
class PrescriptionImagePreprocessor:
    """
    Preprocessing pipeline for prescription images following the methodology
    from Ponnuru et al. (2024) as described in the thesis.
    """

    def __init__(self, target_size=(384, 384)):
        self.target_size = target_size

    def preprocess_image(self, image_path):
        """
        Apply preprocessing steps:
        1. Resize with INTER_LINEAR interpolation
        2. Convert to grayscale
        3. Adaptive thresholding (ADAPTIVE_THRESH_GAUSSIAN_C, block_size=61, C=11)
        """
        # Read image
        img = cv2.imread(image_path)
        if img is None:
            raise ValueError(f"Cannot read image: {image_path}")

        # Step 1: Resize with INTER_LINEAR interpolation
        img_resized = cv2.resize(img, self.target_size, interpolation=cv2.INTER_LINEAR)

        # Step 2: Convert to grayscale
        img_gray = cv2.cvtColor(img_resized, cv2.COLOR_BGR2GRAY)

        # Step 3: Adaptive thresholding
        # Parameters: block_size=61, C=11 (from thesis methodology)
        img_thresh = cv2.adaptiveThreshold(
            img_gray,
            255,
            cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
            cv2.THRESH_BINARY,
            blockSize=61,
            C=11
        )

        # Convert to PIL Image for TrOCR processor
        return Image.fromarray(img_thresh)

    def augment_image(self, image):
        """
        Data augmentation techniques from Ali et al. (2024):
        - Brightness adjustment
        - Contrast normalization
        - Translation
        - Minor shearing
        - Elastic transformation
        - Gaussian noise
        - Cropping with padding
        """
        img_array = np.array(image)

        # Random brightness adjustment
        if np.random.random() > 0.5:
            brightness_factor = np.random.uniform(0.8, 1.2)
            img_array = np.clip(img_array * brightness_factor, 0, 255).astype(np.uint8)

        # Contrast normalization
        if np.random.random() > 0.5:
            alpha = np.random.uniform(0.9, 1.1)
            img_array = cv2.convertScaleAbs(img_array, alpha=alpha, beta=0)

        # Translation (slight shift)
        if np.random.random() > 0.5:
            tx, ty = np.random.randint(-20, 20, 2)
            M = np.float32([[1, 0, tx], [0, 1, ty]])
            img_array = cv2.warpAffine(img_array, M, (img_array.shape[1], img_array.shape[0]))

        # Gaussian noise
        if np.random.random() > 0.5:
            noise = np.random.normal(0, 5, img_array.shape)
            img_array = np.clip(img_array + noise, 0, 255).astype(np.uint8)

        return Image.fromarray(img_array)

# Test the preprocessor
preprocessor = PrescriptionImagePreprocessor()



#Dataset


In [None]:
class PrescriptionDataset(Dataset):
    """
    Custom PyTorch Dataset for prescription images and transcriptions.

    Expected data format:
    - images/: folder containing prescription images
    - annotations.csv: CSV with columns ['image_id', 'transcription']
    """

    def __init__(self, data_df, image_dir, processor, preprocessor, augment=False):
        self.data_df = data_df
        self.image_dir = image_dir
        self.processor = processor
        self.preprocessor = preprocessor
        self.augment = augment

    def __len__(self):
        return len(self.data_df)

    def __getitem__(self, idx):
        # Get image path and transcription
        row = self.data_df.iloc[idx]
        image_path = os.path.join(self.image_dir, row['image_id'])
        transcription = row['transcription']

        # Preprocess image
        image = self.preprocessor.preprocess_image(image_path)

        # Apply augmentation during training
        if self.augment:
            image = self.preprocessor.augment_image(image)

        # Process with TrOCR processor
        pixel_values = self.processor(image, return_tensors="pt").pixel_values.squeeze()

        # Tokenize transcription
        labels = self.processor.tokenizer(
            transcription,
            padding="max_length",
            max_length=128,
            truncation=True
        ).input_ids

        # Replace padding token id with -100 for loss calculation
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]

        return {
            "pixel_values": pixel_values,
            "labels": torch.tensor(labels)
        }

print("Dataset class created")


#Data Preparation and 5-Fold Cross-Validation Setup

In [None]:
def prepare_data_splits(annotations_csv, n_folds=5, test_size=0.2, random_seed=42):
    """
    Implements the two-phase data splitting strategy:
    1. Hold out 20% for final test set
    2. Use remaining 80% for 5-fold cross-validation

    This follows the methodology in Section 3.5 of the thesis.
    """
    # Load annotations
    df = pd.read_csv(annotations_csv)

    # First split: 80% train/val, 20% test
    from sklearn.model_selection import train_test_split
    train_val_df, test_df = train_test_split(
        df,
        test_size=test_size,
        random_state=random_seed
    )

    # Setup 5-fold cross-validation on the train_val set
    kfold = KFold(n_splits=n_folds, shuffle=True, random_state=random_seed)

    # Save test set
    test_df.to_csv('test_set.csv', index=False)

    print(f"ðŸ“Š Data Split Summary:")
    print(f"   Total samples: {len(df)}")
    print(f"   Train/Val samples (80%): {len(train_val_df)}")
    print(f"   Test samples (20%): {len(test_df)}")
    print(f"   Cross-validation folds: {n_folds}")

    return train_val_df, test_df, kfold

# Example usage (you'll need to provide your actual CSV file)
# train_val_df, test_df, kfold = prepare_data_splits('annotations.csv')
print("Data splitting function ready")


#Model Initialization

In [None]:
def initialize_trocr_model():
    """
    Initialize TrOCR-Base-Handwritten model and processor.

    Model: microsoft/trocr-base-handwritten (pre-trained)
    This follows Table 3.1 in the thesis methodology.
    """
    model_name = "microsoft/trocr-base-handwritten"

    # Load processor (handles image preprocessing and tokenization)
    processor = TrOCRProcessor.from_pretrained(model_name)

    # Load model
    model = VisionEncoderDecoderModel.from_pretrained(model_name)

    # Set special tokens
    model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
    model.config.pad_token_id = processor.tokenizer.pad_token_id
    model.config.vocab_size = model.config.decoder.vocab_size

    # Enable gradient checkpointing to save memory
    model.config.use_cache = False
    model.gradient_checkpointing_enable()

    print(f"âœ… Model loaded: {model_name}")
    print(f"   Encoder: {model.config.encoder.model_type}")
    print(f"   Decoder: {model.config.decoder.model_type}")
    print(f"   Vocab size: {model.config.vocab_size}")

    return model, processor

# Initialize model and processor
model, processor = initialize_trocr_model()


#Evaluation Metrics (CER & WER)

In [None]:
from jiwer import wer, cer

def compute_metrics(pred):
    """
    Compute Character Error Rate (CER) and Word Error Rate (WER).

    CER: Measures character-level transcription accuracy using Levenshtein distance
    WER: Measures word-level transcription accuracy

    As described in Section 3.6.1 of the thesis.
    """
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    # Replace -100 in labels (used for padding)
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id

    # Decode predictions and labels
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)

    # Calculate CER and WER
    cer_score = cer(label_str, pred_str)
    wer_score = wer(label_str, pred_str)

    return {
        "cer": cer_score,
        "wer": wer_score
    }

print("Evaluation metrics (CER & WER) configured")


#Training Configuration

In [None]:
def get_training_args(output_dir, num_train_epochs=15):
    """
    Training hyperparameters from Table 3.1 of the thesis:
    - Optimizer: AdamW
    - Learning rate: 2e-5
    - Batch size: 16
    - Weight decay: 0.01
    - Epochs: 10-15 with early stopping
    - Scheduler: Linear with warm-up
    - Gradient clipping: 1.0
    """
    training_args = Seq2SeqTrainingArguments(
        output_dir=output_dir,

        # Training hyperparameters from Table 3.1
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        learning_rate=2e-5,
        num_train_epochs=num_train_epochs,
        weight_decay=0.01,
        max_grad_norm=1.0,  # Gradient clipping

        # AdamW optimizer (default in Trainer)
        optim="adamw_torch",

        # Learning rate scheduler: Linear with warm-up
        lr_scheduler_type="linear",
        warmup_steps=500,

        # Early stopping configuration
        load_best_model_at_end=True,
        metric_for_best_model="cer",
        greater_is_better=False,  # Lower CER is better

        # Evaluation and saving
        evaluation_strategy="steps",
        eval_steps=500,
        save_strategy="steps",
        save_steps=500,
        save_total_limit=3,  # Keep only 3 best checkpoints

        # Logging
        logging_steps=100,
        logging_dir=f"{output_dir}/logs",

        # Other settings
        predict_with_generate=True,
        fp16=torch.cuda.is_available(),  # Use mixed precision if GPU available
        dataloader_num_workers=2,
        remove_unused_columns=False,

        # For reproducibility
        seed=42,
    )

    return training_args

print("Training arguments configured with thesis hyperparameters")


#Training Loop with 5-Fold Cross-Validation

In [None]:
def train_with_cross_validation(
    train_val_df,
    kfold,
    image_dir,
    processor,
    preprocessor,
    base_output_dir="./trocr_cv_results"
):
    """
    Train TrOCR using 5-fold cross-validation as described in Section 3.5.

    For each fold:
    - 4 folds for training
    - 1 fold for validation
    - Track CER and WER metrics
    """
    fold_results = []

    for fold_idx, (train_indices, val_indices) in enumerate(kfold.split(train_val_df)):
        print(f"\n{'='*60}")
        print(f"ðŸ“‚ FOLD {fold_idx + 1}/{kfold.n_splits}")
        print(f"{'='*60}\n")

        # Split data for this fold
        train_fold_df = train_val_df.iloc[train_indices].reset_index(drop=True)
        val_fold_df = train_val_df.iloc[val_indices].reset_index(drop=True)

        print(f"Training samples: {len(train_fold_df)}")
        print(f"Validation samples: {len(val_fold_df)}")

        # Create datasets
        train_dataset = PrescriptionDataset(
            train_fold_df, image_dir, processor, preprocessor, augment=True
        )
        val_dataset = PrescriptionDataset(
            val_fold_df, image_dir, processor, preprocessor, augment=False
        )

        # Initialize fresh model for this fold
        model, _ = initialize_trocr_model()

        # Training arguments
        output_dir = f"{base_output_dir}/fold_{fold_idx + 1}"
        training_args = get_training_args(output_dir)

        # Early stopping callback (stop if no improvement for 3 evaluations)
        from transformers import EarlyStoppingCallback
        early_stopping = EarlyStoppingCallback(
            early_stopping_patience=3,
            early_stopping_threshold=0.0
        )

        # Initialize trainer
        trainer = Seq2SeqTrainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            data_collator=default_data_collator,
            compute_metrics=compute_metrics,
            callbacks=[early_stopping]
        )

        # Train
        print(f"\nðŸš€ Starting training for Fold {fold_idx + 1}...\n")
        train_result = trainer.train()

        # Evaluate
        print(f"\nðŸ“Š Evaluating Fold {fold_idx + 1}...\n")
        eval_result = trainer.evaluate()

        # Store results
        fold_results.append({
            'fold': fold_idx + 1,
            'train_loss': train_result.training_loss,
            'eval_cer': eval_result['eval_cer'],
            'eval_wer': eval_result['eval_wer']
        })

        # Save model
        trainer.save_model(f"{output_dir}/best_model")

        print(f"\nâœ… Fold {fold_idx + 1} completed!")
        print(f"   CER: {eval_result['eval_cer']:.4f}")
        print(f"   WER: {eval_result['eval_wer']:.4f}")

    # Calculate average metrics across folds
    avg_cer = np.mean([r['eval_cer'] for r in fold_results])
    avg_wer = np.mean([r['eval_wer'] for r in fold_results])

    print(f"\n{'='*60}")
    print("ðŸ“ˆ CROSS-VALIDATION SUMMARY")
    print(f"{'='*60}")
    print(f"Average CER: {avg_cer:.4f}")
    print(f"Average WER: {avg_wer:.4f}")

    # Save results
    results_df = pd.DataFrame(fold_results)
    results_df.to_csv(f"{base_output_dir}/cv_results.csv", index=False)

    return fold_results

print("Cross-validation training function set")


#Final Model Training and Testing

In [None]:
def train_final_model(train_val_df, image_dir, processor, preprocessor, output_dir="./trocr_final"):
    """
    After cross-validation, train final model on entire train/val set.
    Then evaluate on held-out test set.
    """
    print("\nðŸŽ¯ Training final model on complete train/val set...\n")

    # Create dataset
    train_dataset = PrescriptionDataset(
        train_val_df, image_dir, processor, preprocessor, augment=True
    )

    # Initialize model
    model, _ = initialize_trocr_model()

    # Training arguments
    training_args = get_training_args(output_dir, num_train_epochs=15)

    # Trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        data_collator=default_data_collator,
        compute_metrics=compute_metrics
    )

    # Train
    trainer.train()

    # Save final model
    trainer.save_model(f"{output_dir}/final_model")
    processor.save_pretrained(f"{output_dir}/final_model")

    print("âœ… Final model trained and saved!")

    return trainer

def evaluate_on_test_set(trainer, test_df, image_dir, processor, preprocessor):
    """
    Evaluate the final model on the held-out 20% test set.
    """
    print("\nðŸ§ª Evaluating on held-out test set...\n")

    # Create test dataset (no augmentation)
    test_dataset = PrescriptionDataset(
        test_df, image_dir, processor, preprocessor, augment=False
    )

    # Evaluate
    test_results = trainer.evaluate(test_dataset)

    print("\nðŸ“Š TEST SET RESULTS:")
    print(f"   CER: {test_results['eval_cer']:.4f}")
    print(f"   WER: {test_results['eval_wer']:.4f}")

    return test_results

print("Final training and testing functions set")


#Inference Function

In [None]:
def predict_prescription(image_path, model, processor, preprocessor):
    """
    Perform inference on a single prescription image.

    Args:
        image_path: Path to prescription image
        model: Trained TrOCR model
        processor: TrOCR processor
        preprocessor: Image preprocessor

    Returns:
        Transcribed text string
    """
    # Preprocess image
    image = preprocessor.preprocess_image(image_path)

    # Convert to pixel values
    pixel_values = processor(image, return_tensors="pt").pixel_values

    # Move to same device as model
    device = next(model.parameters()).device
    pixel_values = pixel_values.to(device)

    # Generate transcription
    model.eval()
    with torch.no_grad():
        generated_ids = model.generate(pixel_values, max_length=128)

    # Decode
    transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

    return transcription

# Example usage:
# transcription = predict_prescription("path/to/prescription.jpg", model, processor, preprocessor)
# print(f"Transcription: {transcription}")

print("Inference function set")
