# Donut DocVQA Implementation

This notebook implements Donut for Document Visual Question Answering (DocVQA) using the fine-tuned model.

## Features:
- Uses `donut-base-finetuned-docvqa` model
- Direct VQA approach: Image + Question → Answer
- Proper error handling and evaluation metrics

## Setup and Dependencies

In [1]:
import json
import os
import csv
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import string
import gc
import time
import warnings
warnings.filterwarnings("ignore")

from PIL import Image
from transformers import pipeline, DonutProcessor, VisionEncoderDecoderModel
import torch
import matplotlib.pyplot as plt

## Configuration

DOCVQA DATASET (RUN ONLY THE CONFIGURATION YOU WANT TO ASSESS)

In [13]:
# Dataset configuration
DATA_DIR = 'docvqa_samples_300'
IMAGE_DIR = os.path.join(DATA_DIR, "images")
METADATA_FILE = os.path.join(DATA_DIR, "metadata.json")
OUTPUT_CSV = "results\Donut_finetuned_results.csv"

NEW DATASET (RUN ONLY THE CONFIGURATION YOU WANT TO ASSESS)

In [15]:
# Configuration
DATA_DIR = "NewDataset"
IMAGE_DIR = os.path.join(DATA_DIR, "images")
METADATA_FILE = os.path.join(DATA_DIR, "metadata.json")
OUTPUT_CSV = "results_NEWDATA/OCR_DONUT_RESULTS_NEWDATASET.csv"

In [4]:
# Model configuration
MODEL_NAME = "naver-clova-ix/donut-base-finetuned-docvqa"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 1
CLEAR_CACHE_EVERY = 50

In [5]:
def normalize(text):
    """Normalize text for evaluation"""
    return text.lower().translate(str.maketrans('', '', string.punctuation)).strip()

def exact_match(pred, ground_truths):
    """Calculate exact match score"""
    pred_norm = normalize(pred)
    return any(pred_norm == normalize(gt) for gt in ground_truths)

def f1_score(pred, ground_truths):
    """Calculate F1 score between prediction and ground truth answers"""
    def score(pred, gt):
        pred_tokens = normalize(pred).split()
        gt_tokens = normalize(gt).split()
        common = set(pred_tokens) & set(gt_tokens)
        
        if not common:
            return 0.0
            
        precision = len(common) / len(pred_tokens) if pred_tokens else 0.0
        recall = len(common) / len(gt_tokens) if gt_tokens else 0.0
        
        return 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
    
    return max(score(pred, gt) for gt in ground_truths)

## Donut DocVQA Model

In [6]:
class DonutDocVQA:
    """Donut DocVQA model implementation"""
    
    def __init__(self, model_name=MODEL_NAME):
        try:
            self.pipeline = pipeline(
                task="document-question-answering",
                model=model_name,
                device=0 if torch.cuda.is_available() else -1,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
            )
            self.method = "pipeline"
        except Exception as e:
            self.processor = DonutProcessor.from_pretrained(model_name)
            self.model = VisionEncoderDecoderModel.from_pretrained(model_name)
            self.model.to(device)
            self.model.eval()
            self.method = "manual"
    
    def answer_question(self, image_path, question):
        """Answer a question about a document image"""
        try:
            image = Image.open(image_path).convert('RGB')
            
            if self.method == "pipeline":
                result = self.pipeline(image=image, question=question)
                return result[0]['answer'] if result else ""
            else:
                task_prompt = f"<s_docvqa><s_question>{question}</s_question><s_answer>"
                
                pixel_values = self.processor(image, return_tensors="pt").pixel_values
                pixel_values = pixel_values.to(device)
                
                decoder_input_ids = self.processor.tokenizer(
                    task_prompt, 
                    add_special_tokens=False,
                    return_tensors="pt"
                ).input_ids
                decoder_input_ids = decoder_input_ids.to(device)
                
                with torch.no_grad():
                    outputs = self.model.generate(
                        pixel_values,
                        decoder_input_ids=decoder_input_ids,
                        max_length=self.model.decoder.config.max_position_embeddings,
                        pad_token_id=self.processor.tokenizer.pad_token_id,
                        eos_token_id=self.processor.tokenizer.eos_token_id,
                        use_cache=True,
                        bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
                        return_dict_in_generate=True,
                        do_sample=False,
                        num_beams=1,
                    )
                
                sequence = self.processor.batch_decode(outputs.sequences)[0]
                sequence = sequence.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
                
                answer_start = sequence.find("<s_answer>") + len("<s_answer>")
                answer_end = sequence.find("</s_answer>")
                
                if answer_start > len("<s_answer>") - 1 and answer_end > answer_start:
                    answer = sequence[answer_start:answer_end].strip()
                else:
                    answer = sequence.replace(task_prompt, "").strip()
                
                return answer
        except Exception as e:
            return ""

## Document Processing Function

In [7]:
def process_document_with_donut_docvqa(sample, image_dir, donut_vqa):
    """Process a single document using Donut DocVQA"""
    try:
        doc_id = sample['id']
        image_filename = sample['image_filename']
        question = sample['question']
        ground_truth = sample['answers']

        image_path = os.path.join(image_dir, image_filename)
        
        # Get answer directly from Donut DocVQA
        predicted_answer = donut_vqa.answer_question(image_path, question)
        
        if not predicted_answer:
            return None

        # Evaluate the prediction
        em = exact_match(predicted_answer, ground_truth)
        f1_val = f1_score(predicted_answer, ground_truth)

        return {
            "id": doc_id,
            "image_filename": image_filename,
            "question": question,
            "ground_truth": " | ".join(ground_truth),
            "extracted_content":  predicted_answer,
            "predicted_answer": predicted_answer,
            "exact_match": em,
            "f1_score": round(f1_val, 2)
        }
        
    except Exception as e:
        return None

## Data Loading and Processing

In [8]:
# Load metadata
with open(METADATA_FILE, "r", encoding="utf-8") as f:
    docvqa_metadata = json.load(f)

# Create results directory
os.makedirs("results", exist_ok=True)

# Initialize Donut DocVQA model
donut_vqa = DonutDocVQA(MODEL_NAME)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
Device set to use cuda:0


## Main Processing Pipeline

In [9]:
# Process all documents
all_results = []
processed_count = 0

with open(OUTPUT_CSV, "w", newline="", encoding="utf-8") as csvfile:
    fieldnames = ["id", "image_filename", "question", "ground_truth", 
                  "extracted_content", "predicted_answer", "exact_match", "f1_score"]
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
    writer.writeheader()
    
    for i, sample in enumerate(tqdm(docvqa_metadata, desc="Processing documents")):
        try:
            result = process_document_with_donut_docvqa(sample, IMAGE_DIR, donut_vqa)
            
            if result is not None:
                writer.writerow(result)
                all_results.append(result)
                processed_count += 1
            
            # Memory management
            if (i + 1) % CLEAR_CACHE_EVERY == 0:
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                gc.collect()
                
        except KeyboardInterrupt:
            break
        except Exception as e:
            continue

Processing documents:   0%|          | 0/10 [00:00<?, ?it/s]Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
`sdpa` attention does not support `output_attentions=True` or `head_mask`. Please set your attention to `eager` if you want any of these features.
Processing documents: 100%|██████████| 10/10 [00:10<00:00,  1.06s/it]


## Results Analysis

In [16]:
# Load results from OUTPUT_CSV and compute average F1 and EM scores
results_df = pd.read_csv(OUTPUT_CSV)
avg_f1 = results_df['f1_score'].mean()
avg_em = results_df['exact_match'].mean()

print(f"Average F1 Score: {avg_f1:.3f}")
print(f"Average Exact Match (EM) Score: {avg_em:.3f}")

Average F1 Score: 0.100
Average Exact Match (EM) Score: 0.100
