<a href="https://colab.research.google.com/github/SOHAM-3T/Medical-Prescription-Analyzer-/blob/main/Fine_Tuning_Model_Gamma.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# IMPORTANT: SOME KAGGLE DATA SOURCES ARE PRIVATE
# RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES.
import kagglehub
kagglehub.login()


In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.

soham3ripathy_prescription_data_set_path = kagglehub.dataset_download('soham3ripathy/prescription-data-set')
soham3ripathy_ground_truth_seed_path = kagglehub.dataset_download('soham3ripathy/ground-truth-seed')

print('Data source import complete.')


In [None]:
# ==============================================================================
# Step 1: Install a Known Stable Environment (Definitive Method)
# ==============================================================================
print("Installing a specific, known-stable set of libraries for a clean environment...")

# --- FIX: A single, unified install command with pinned versions ---
# This forces the installer to find a compatible solution for all libraries at once,
# which is the most robust way to prevent the dependency conflicts we saw earlier.
!pip install \
    torch==2.1.0 \
    torchvision==0.16.0 \
    torchaudio==2.1.0 \
    transformers==4.35.2 \
    datasets==2.15.0 \
    accelerate==0.25.0 \
    bitsandbytes==0.41.2 \
    peft==0.7.1 \
    evaluate \
    jiwer \
    sentencepiece \
    pillow

import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel, VisionEncoderDecoderConfig
from transformers import Trainer, TrainingArguments
from datasets import Dataset
from PIL import Image
import os
import json
import numpy as np
import re
import evaluate
import gc

print("Installation and imports complete. Environment is now stable.")


In [None]:
# ==============================================================================
# Step 2: Configuration and Paths
# ==============================================================================

# --- IMPORTANT: UPDATE THESE PATHS ---
# This should point to the folder containing your uploaded images (p1.jpg, etc.)
IMAGE_DIR = "/kaggle/input/prescription-data-set"
# This should point to the ground_truth_seed.txt file we created
GROUND_TRUTH_FILE = "/kaggle/input/ground-truth-seed/ground_truth.txt"

# This is where your new, custom model will be saved
OUTPUT_MODEL_DIR = "/kaggle/working/donut-finetuned-prescription-ocr"

In [None]:
# ==============================================================================
# Step 3: Load and Prepare the Dataset
# ==============================================================================

def load_dataset_from_file(image_dir, gt_file_path):
    print(f"Loading dataset from {gt_file_path}...")
    dataset_dict = {"image_path": [], "ground_truth": []}

    with open(gt_file_path, 'r') as f:
        for line in f:
            parts = line.strip().split(',', 1)
            if len(parts) == 2:
                filename, text = parts
                image_path = os.path.join(image_dir, filename)
                if os.path.exists(image_path):
                    dataset_dict["image_path"].append(image_path)
                    dataset_dict["ground_truth"].append(text)
                else:
                    print(f"Warning: Image file not found and will be skipped: {image_path}")

    print(f"Loaded {len(dataset_dict['image_path'])} images.")
    return Dataset.from_dict(dataset_dict)

full_dataset = load_dataset_from_file(IMAGE_DIR, GROUND_TRUTH_FILE)

if len(full_dataset) > 1:
    dataset_split = full_dataset.train_test_split(test_size=0.1, seed=42)
    train_dataset = dataset_split["train"]
    eval_dataset = dataset_split["test"]
else:
    train_dataset = full_dataset
    eval_dataset = full_dataset

print(f"Training set size: {len(train_dataset)}")
print(f"Evaluation set size: {len(eval_dataset)}")


In [None]:
# ==============================================================================
# Step 4: Load Model and Processor
# ==============================================================================

print("Loading pre-trained Donut model and processor...")
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base")

task_start_token = "<s_transcription>"
task_end_token = "</s_transcription>"
processor.tokenizer.add_special_tokens({"additional_special_tokens": [task_start_token, task_end_token]})
model.decoder.resize_token_embeddings(len(processor.tokenizer))

In [None]:
# ==============================================================================
# Step 5: Preprocess Data for the Model
# ==============================================================================

class DonutDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, processor, max_length=384, split="train"):
        self.dataset = dataset
        self.processor = processor
        self.max_length = max_length
        self.split = split

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = Image.open(item['image_path']).convert("RGB")

        # Aggressive resize to avoid OOM
        image = image.resize((800, 600))

        # 👇 FIX: make sure pixel_values is a torch.Tensor, not numpy
        pixel_values = self.processor(image, return_tensors="pt").pixel_values.squeeze(0)

        target_sequence = f"{task_start_token}{item['ground_truth']}{task_end_token}"

        tokenized_output = self.processor.tokenizer(
            target_sequence,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )

        labels = tokenized_output.input_ids.squeeze()
        labels[labels == processor.tokenizer.pad_token_id] = -100

        return {"pixel_values": pixel_values, "labels": labels}

processed_train_dataset = DonutDataset(train_dataset, processor)
processed_eval_dataset = DonutDataset(eval_dataset, processor, split="validation")

In [None]:
# ==============================================================================
# Step 6: Define Evaluation Metrics + Custom Collator
# ==============================================================================

print("Setting up evaluation metrics (Word Error Rate)...")
wer = evaluate.load("wer")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    if isinstance(logits, tuple):
        logits = logits[0]
    pred_ids = np.argmax(logits, axis=-1)
    labels = np.where(labels != -100, labels, processor.tokenizer.pad_token_id)
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(labels, skip_special_tokens=True)
    return {"wer": wer.compute(predictions=pred_str, references=label_str)}

# 👇 Custom collator to handle torch tensors safely
class DonutCollator:
    def __init__(self, processor):
        self.processor = processor

    def __call__(self, features):
        pixel_values = torch.stack([f["pixel_values"] for f in features])
        labels = torch.stack([f["labels"] for f in features])
        return {"pixel_values": pixel_values, "labels": labels}



In [None]:
# Fix decoder_start_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(task_start_token)

training_args = TrainingArguments(
    output_dir=OUTPUT_MODEL_DIR,
    num_train_epochs=10,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,
    gradient_checkpointing=True,
    evaluation_strategy="epoch",
    save_strategy="no",          # <-- do not save checkpoints at all
    logging_steps=50,
    report_to="none",
    fp16=torch.cuda.is_available(),
    remove_unused_columns=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=processed_train_dataset,
    eval_dataset=processed_eval_dataset,
    tokenizer=processor.tokenizer,
    data_collator=DonutCollator(processor),   # 👈 fix here
    compute_metrics=compute_metrics,
)

gc.collect()
torch.cuda.empty_cache()

print("\nStarting model fine-tuning...")
trainer.train()
print("Fine-tuning complete!")

print(f"Saving fine-tuned model to {OUTPUT_MODEL_DIR}")
model.save_pretrained(OUTPUT_MODEL_DIR)
processor.save_pretrained(OUTPUT_MODEL_DIR)
print("Process complete.")