In [1]:
# 2️⃣ Import librari2ess
from transformers import DonutProcessor, VisionEncoderDecoderModel
from PIL import Image
from pdf2image import convert_from_path
import torch
from datasets import load_dataset

import os

from datasets import Dataset
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
import json
from torch.utils.data import Dataset

In [2]:
os.getcwd()

'/root/donut_ocr'

In [2]:
from dataclasses import dataclass
from typing import Any, Dict, List
import torch

@dataclass
class DonutDataCollator:
    """
    Custom DataCollator for Donut.
    It just stacks pixel_values and labels. No token padding!
    """
    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        pixel_values = torch.stack([f["pixel_values"] for f in features])
        labels = torch.stack([f["labels"] for f in features])

        return {
            "pixel_values": pixel_values,
            "labels": labels
        }


In [3]:
class DonutFormDataset(Dataset):
    """
    A custom PyTorch Dataset for fine-tuning Donut on key-value form extraction tasks.
    Expects:
      - JSONL file with {"file_name": ..., "ground_truth": {"gt_parse": {...}}}
      - images/ folder with matching image files
      - Donut processor (DonutProcessor)
    """

    def __init__(self, jsonl_path, images_dir, processor):
        self.samples = []
        self.images_dir = images_dir
        self.processor = processor

        # Load JSON lines
        with open(jsonl_path, 'r') as f:
            for line in f:
                self.samples.append(json.loads(line))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        item = self.samples[idx]

        # Load image
        image_path = f"{self.images_dir}/{item['file_name']}"
        image = Image.open(image_path).convert("RGB")

        # Prepare target text prompt
        prompt = "<s_docvqa><s_answer>"
        target_text = json.dumps(item["ground_truth"]["gt_parse"], ensure_ascii=False)

        # Tokenize target
        target = prompt + target_text

        # Encode image
        pixel_values = self.processor.image_processor(image, return_tensors="pt").pixel_values.squeeze()

        labels = self.processor.tokenizer(
            target,
            add_special_tokens=False,
            max_length=512,  # Use a safe max length for JSON text
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        ).input_ids.squeeze()


        # Replace pad tokens with -100 for ignored loss
        labels[labels == self.processor.tokenizer.pad_token_id] = -100

        return {
            "pixel_values": pixel_values,
            "labels": labels
        }


In [4]:
# 3️⃣ Load pre-trained Donut model (DocVQA version)
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")

model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids("<s>")
model.config.pad_token_id = processor.tokenizer.pad_token_id


# Put model on GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
model.to(device)

#device = "cpu" 
#model.to(device)



Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


cuda


VisionEncoderDecoderModel(
  (encoder): DonutSwinModel(
    (embeddings): DonutSwinEmbeddings(
      (patch_embeddings): DonutSwinPatchEmbeddings(
        (projection): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
      )
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): DonutSwinEncoder(
      (layers): ModuleList(
        (0): DonutSwinStage(
          (blocks): ModuleList(
            (0): DonutSwinLayer(
              (layernorm_before): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
              (attention): DonutSwinAttention(
                (self): DonutSwinSelfAttention(
                  (query): Linear(in_features=128, out_features=128, bias=True)
                  (key): Linear(in_features=128, out_features=128, bias=True)
                  (value): Linear(in_features=128, out_features=128, bias=True)
                  (dropout): Dropout(p=0.0, inplace=False)
                )

In [14]:
# 4️⃣ Convert PDF to image example
pages = convert_from_path("example_forms/Example Clinical Notes 2.pdf")
image = pages[0]  # First page only
image.save("page1.png")

In [15]:
# 5️⃣ Zero-shot inference on your form
question = "<s_docvqa><s_question>What are the patient's problems?</s_question><s_answer>"
inputs = processor(image, question, return_tensors="pt").to(device)
outputs = model.generate(**inputs)
decoded = processor.batch_decode(outputs, skip_special_tokens=True)
print("Prediction:", decoded)

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Prediction: ["What are the patient's problems? instability of left patellofemoral joint - onset: 09/19"]


In [5]:
train_dataset = DonutFormDataset(
    jsonl_path="Dataset/metadata.jsonl",
    images_dir="Dataset/images",
    processor=processor
)

# Check sample
print(train_dataset[0].keys())  # Should output: dict_keys(['pixel_values', 'labels'])


dict_keys(['pixel_values', 'labels'])


In [6]:
# 6️⃣ Fine-tuning: prepare dataset (pseudo-code)
# Create a list of dicts: [{"image": ..., "gt_parse": {...}}, ...]
# For each, encode prompt + answer, use Trainer API

# Example training loop setup (pseudo-code)


# Create your custom dataset
# dataset = Dataset.from_dict({"image": [...], "labels": [...]})

# Define training arguments 
training_args = Seq2SeqTrainingArguments(
    output_dir="./donut_finetune",
    per_device_train_batch_size=1,
    num_train_epochs=3,
    fp16=True,
    remove_unused_columns=False  # ✅ Needed for custom dataset
)


trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    processing_class=processor.tokenizer,
    data_collator=DonutDataCollator()  # ✅ This disables input_ids padding
)


trainer.train()



Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1553: indexSelectLargeIndex: block: [464,0,0], thread: [96,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1553: indexSelectLargeIndex: block: [464,0,0], thread: [97,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1553: indexSelectLargeIndex: block: [464,0,0], thread: [98,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1553: indexSelectLargeIndex: block: [464,0,0], thread: [99,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1553: indexSelectLargeIndex: block: [464,0,0], thread: [100,0,0] Assertion `srcI

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
# 7️⃣ Inference loop for new forms
def predict_form(image_path, question):
    image = Image.open(image_path)
    inputs = processor(image, question, return_tensors="pt").to(device)
    outputs = model.generate(**inputs)
    result = processor.batch_decode(outputs, skip_special_tokens=True)
    return result

# Example
# print(predict_form("page1.png", "<s_docvqa><s_question>Patient Name</s_question><s_answer>"))

# ✅ This is your starter flow!
# Replace "your_form.pdf" with your real form, create your JSON labels, and fine-tune if needed.