
## ⚙️ 1. Environment Setup

We will install the required packages and load the models from Hugging Face.

Make sure to run this in an environment with GPU support for faster inference and training.


In [None]:
!pip install -q transformers accelerate datasets peft bitsandbytes sentencepiece
!pip install -q opencv-python pytesseract torchvision


## 🧪 2. Inference Demo using Donut and MiniCPM

We will run inference on an invoice image (`invoice_sample.png`). Please replace it with your real data as needed.


In [5]:

from transformers import DonutProcessor, VisionEncoderDecoderModel
from PIL import Image
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load model
processor = DonutProcessor.from_pretrained("to-be/donut-base-finetuned-invoices")
model = VisionEncoderDecoderModel.from_pretrained("to-be/donut-base-finetuned-invoices").to(device)

# Load image
image = Image.open("/Users/xiaotingzhou/Documents/Lectures/AI_OCR/data/converted_images/invoice_page_1.jpg").convert("RGB")

# Prepare input
pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
task_prompt = "<s>Invoice Information:"

# Inference
outputs = model.generate(pixel_values, decoder_input_ids=processor.tokenizer(task_prompt, return_tensors="pt").input_ids.to(device), max_length=512)
result = processor.batch_decode(outputs, skip_special_tokens=True)[0]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


In [2]:
print("Extracted Data:", result)

Extracted Data: Invoice Information:23-01-30</s_DocumentDate><s_GrossAmount> 61388.00</s_GrossAmount><s_InvoiceNumber> 309824263008</s_InvoiceNumber><s_NetAmount1> 6138872.00</s_NetAmount1><s_TaxAmount1> 0.00</s_TaxAmount1>


# Change Prompt

- Consistent Format : Training and inference use the same task prompt format
- Simple Outputs : Instead of complex structured output, use simple "prompt: value" format
- Task-Specific Data : Create separate training samples for each extraction task
- Better Training : Improved batch size, learning rate, and epochs
- Proper Testing : Test each task prompt individually

In [3]:
def preprocess(example):
    try:
        # Load and process image
        image = Image.open(example["image"]).convert("RGB")
        processed = processor(image, return_tensors="pt")
        pixel_values = processed.pixel_values.squeeze(0)  # Remove batch dimension
        
        # FIXED: Use consistent task prompt format
        task_prompt = "<s>InvoiceNo:"  # This will be your consistent prompt
        ground_truth = example["ground_truth"]
        
        # FIXED: Create target that matches your desired output format
        # Instead of complex structured format, use simple key-value format
        target_text = f"{task_prompt} {ground_truth['InvoiceNo']}</s>"
        
        # Tokenize target
        tokenized = processor.tokenizer(target_text, 
                                      return_tensors="pt", 
                                      padding=False,
                                      truncation=True, 
                                      max_length=512)
        labels = tokenized.input_ids.squeeze(0)  # Remove batch dimension
        
        return {
            "pixel_values": pixel_values,
            "labels": labels
        }
    except Exception as e:
        print(f"Error processing example: {e}")
        return None

## donut-finetuned-task-specific

In [1]:
import json
import os

# Create data directory if it doesn't exist
os.makedirs('data/training', exist_ok=True)

# FIXED: Create separate training samples for different tasks
train_data = []
val_data = []

# Base invoice data
base_invoice = {
    "InvoiceNo": "Y 309824263008",
    "InvoiceDate": "2025年6月30日",
    "Currency": "USD",
    "Amount with Tax": "300",
    "Amount without Tax": "300",
    "Tax": "0"
}

# Create training samples for each field you want to extract
fields_to_extract = {
    "InvoiceNo": "<s>InvoiceNo:",
    "InvoiceDate": "<s>InvoiceDate:", 
    "Currency": "<s>Currency:",
    "Amount with Tax": "<s>Amount:",
    "Tax": "<s>Tax:"
}

# Generate training data for each task
for field_name, task_prompt in fields_to_extract.items():
    # Create training sample
    train_sample = {
        "image": "data/converted_images/invoice_page_1.jpg",
        "ground_truth": {field_name: base_invoice[field_name]},
        "task_prompt": task_prompt  # Add task prompt to data
    }
    train_data.append(train_sample)
    
    # Create validation sample
    val_sample = {
        "image": "data/converted_images/invoice_page_1.jpg",
        "ground_truth": {field_name: base_invoice[field_name]},
        "task_prompt": task_prompt
    }
    val_data.append(val_sample)

# Save the datasets
with open('data/training/train.json', 'w', encoding='utf-8') as f:
    json.dump(train_data, f, indent=2, ensure_ascii=False)

with open('data/training/val.json', 'w', encoding='utf-8') as f:
    json.dump(val_data, f, indent=2, ensure_ascii=False)

print(f"✅ Task-specific training data created!")
print(f"📁 Train data: {len(train_data)} samples")
print(f"📁 Validation data: {len(val_data)} samples")

✅ Task-specific training data created!
📁 Train data: 5 samples
📁 Validation data: 5 samples


In [5]:
def preprocess_with_task_prompt(example):
    try:
        # Load and process image
        image = Image.open(example["image"]).convert("RGB")
        processed = processor(image, return_tensors="pt")
        pixel_values = processed.pixel_values.squeeze(0)
        
        # Use the task prompt from the data
        task_prompt = example["task_prompt"]
        ground_truth = example["ground_truth"]
        
        # Get the field name and value
        field_name = list(ground_truth.keys())[0]
        field_value = ground_truth[field_name]
        
        # Create target text: task_prompt + field_value + end_token
        target_text = f"{task_prompt} {field_value}</s>"
        
        # Tokenize target
        tokenized = processor.tokenizer(target_text, 
                                      return_tensors="pt", 
                                      padding=False,
                                      truncation=True, 
                                      max_length=512)
        labels = tokenized.input_ids.squeeze(0)
        
        return {
            "pixel_values": pixel_values,
            "labels": labels,
            "task_prompt": task_prompt  # Keep for reference
        }
    except Exception as e:
        print(f"Error processing example: {e}")
        return None

In [6]:
def preprocess_with_task_prompt(example):
    try:
        # Load and process image
        image = Image.open(example["image"]).convert("RGB")
        processed = processor(image, return_tensors="pt")
        pixel_values = processed.pixel_values.squeeze(0)
        
        # Use the task prompt from the data
        task_prompt = example["task_prompt"]
        ground_truth = example["ground_truth"]
        
        # Get the field name and value
        field_name = list(ground_truth.keys())[0]
        field_value = ground_truth[field_name]
        
        # Create target text: task_prompt + field_value + end_token
        target_text = f"{task_prompt} {field_value}</s>"
        
        # Tokenize target
        tokenized = processor.tokenizer(target_text, 
                                      return_tensors="pt", 
                                      padding=False,
                                      truncation=True, 
                                      max_length=512)
        labels = tokenized.input_ids.squeeze(0)
        
        return {
            "pixel_values": pixel_values,
            "labels": labels,
            "task_prompt": task_prompt  # Keep for reference
        }
    except Exception as e:
        print(f"Error processing example: {e}")
        return None

In [1]:
from datasets import load_dataset
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
# Force CPU usage
# device = torch.device("cpu")

# # Move model to CPU
# model = model.to(device)
# 1. Define DonutDataCollator
@dataclass
class DonutDataCollator:
    """Custom data collator for Donut model that handles pixel_values and labels"""
    
    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        pixel_values = []
        labels = []
        
        for feature in features:
            # Ensure pixel_values is a tensor
            pv = feature["pixel_values"]
            if isinstance(pv, list):
                pv = torch.tensor(pv)
            elif not isinstance(pv, torch.Tensor):
                pv = torch.tensor(pv)
            pixel_values.append(pv)
            
            # Ensure labels is a tensor
            label = feature["labels"]
            if isinstance(label, list):
                label = torch.tensor(label)
            elif not isinstance(label, torch.Tensor):
                label = torch.tensor(label)
            labels.append(label)
        
        # Stack pixel_values
        pixel_values = torch.stack(pixel_values)
        
        # Pad labels to the same length
        max_length = max(len(label) for label in labels)
        padded_labels = []
        
        for label in labels:
            if len(label) < max_length:
                padded_label = torch.cat([
                    label,
                    torch.full((max_length - len(label),), -100, dtype=label.dtype)
                ])
            else:
                padded_label = label
            padded_labels.append(padded_label)
        
        labels = torch.stack(padded_labels)
        
        return {
            "pixel_values": pixel_values,
            "labels": labels
        }

# 2. Initialize data collator
data_collator = DonutDataCollator()

# 3. Load dataset (assuming you've already created the task-specific data)
dataset = load_dataset("json", data_files={
    "train": "data/training/train.json", 
    "validation": "data/training/val.json"
})

# 4. Apply preprocessing
dataset = dataset.map(preprocess_with_task_prompt, remove_columns=dataset["train"].column_names)
dataset = dataset.filter(lambda x: x is not None)

print(f"✅ Dataset processed: {len(dataset['train'])} train, {len(dataset['validation'])} val samples")

# 5. Training arguments
print("🔄 Switched to CPU training to avoid memory issues")

  from .autonotebook import tqdm as notebook_tqdm


NameError: name 'preprocess_with_task_prompt' is not defined

In [None]:
import gc
import torch
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

# Clear memory first
if torch.backends.mps.is_available():
    torch.mps.empty_cache()
gc.collect()

# Memory-optimized training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir="./donut-finetuned-task-specific",
    per_device_train_batch_size=1,  # Minimal batch size
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,  # Maintain effective batch size
    num_train_epochs=5,  # Reduced epochs
    learning_rate=3e-5, (smaller)
    warmup_steps=25,
    logging_steps=5,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    predict_with_generate=True,
    generation_max_length=64,  # Shorter sequences
    fp16=False,
    dataloader_pin_memory=False,
    remove_unused_columns=False,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    dataloader_num_workers=0,
    max_grad_norm=1.0,
)

# Initialize trainer with optimized settings
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    data_collator=data_collator,
    tokenizer=processor.tokenizer,
)

print("🚀 Starting memory-optimized training...")

  trainer = Seq2SeqTrainer(


🚀 Starting memory-optimized training...


In [2]:
# Clear memory before training
import gc
import torch

# Clear GPU memory
if torch.backends.mps.is_available():
    torch.mps.empty_cache()

# Force garbage collection
gc.collect()

print("🧹 Memory cleared before training")

# Then proceed with training
trainer.train()

🧹 Memory cleared before training


NameError: name 'trainer' is not defined

## donut-finetuned-task-specific

In [3]:
# Load your trained model - use checkpoint-4
finetuned_model = VisionEncoderDecoderModel.from_pretrained("./donut-finetuned-task-specific/checkpoint-4")
finetuned_model.to(device)
finetuned_model.eval()

NameError: name 'VisionEncoderDecoderModel' is not defined

In [4]:
# Test different task prompts
test_prompts = [
    "<s>InvoiceNo:",
    "<s>InvoiceDate:",
    "<s>Currency:",
    "<s>Amount:",
    "<s>Tax:"
]

image = Image.open("data/converted_images/invoice_page_1.jpg").convert("RGB")
pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)

for task_prompt in test_prompts:
    print(f"\n🧪 Testing prompt: {task_prompt}")
    
    # Generate with specific task prompt
    decoder_input_ids = processor.tokenizer(task_prompt, 
                                          add_special_tokens=False, 
                                          return_tensors="pt").input_ids.to(device)
    
    with torch.no_grad():
        outputs = finetuned_model.generate(
            pixel_values,
            decoder_input_ids=decoder_input_ids,
            max_length=128,
            num_beams=1,
            early_stopping=True,
        )
    
    result = processor.batch_decode(outputs, skip_special_tokens=True)[0]
    print(f"📄 Result: {result}")

NameError: name 'Image' is not defined

In [5]:
# Load your trained model - use checkpoint-5 (likely the latest)
finetuned_model = VisionEncoderDecoderModel.from_pretrained("./donut-finetuned-task-specific/checkpoint-5")
finetuned_model.to(device)
finetuned_model.eval()

NameError: name 'VisionEncoderDecoderModel' is not defined

In [6]:
# Test different task prompts
test_prompts = [
    "<s>InvoiceNo:",
    "<s>InvoiceDate:",
    "<s>Currency:",
    "<s>Amount:",
    "<s>Tax:"
]

image = Image.open("data/converted_images/invoice_page_1.jpg").convert("RGB")
pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)

for task_prompt in test_prompts:
    print(f"\n🧪 Testing prompt: {task_prompt}")
    
    # Generate with specific task prompt
    decoder_input_ids = processor.tokenizer(task_prompt, 
                                          add_special_tokens=False, 
                                          return_tensors="pt").input_ids.to(device)
    
    with torch.no_grad():
        outputs = finetuned_model.generate(
            pixel_values,
            decoder_input_ids=decoder_input_ids,
            max_length=128,
            num_beams=1,
            early_stopping=True,
        )
    
    result = processor.batch_decode(outputs, skip_special_tokens=True)[0]
    print(f"📄 Result: {result}")

NameError: name 'Image' is not defined

In [7]:
# Fix the MiniCPM model loading by adding proper imports and error handling
from transformers import AutoTokenizer, AutoModel
from typing import List  # Add this import to fix the NameError
import torch

try:
    model_name = "openbmb/MiniCPM-Llama3-V-2_5"
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    # Use AutoModel instead of AutoModelForCausalLM for vision-language models
    model = AutoModel.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.float16)
    model = model.to(device)
    
    # For MiniCPM, we need to use the chat interface with images
    # This is a placeholder since we need actual image processing
    print("✅ MiniCPM model loaded successfully!")
    print("💡 Note: MiniCPM requires image input for proper inference.")
    
except Exception as e:
    print(f"❌ Error loading MiniCPM model: {e}")
    print("💡 Continuing with Donut model only for this demo.")

  from .autonotebook import tqdm as notebook_tqdm


❌ Error loading MiniCPM model: name 'List' is not defined
💡 Continuing with Donut model only for this demo.


In [8]:
# Fix the MiniCPM model loading by adding proper imports and error handling
from transformers import AutoTokenizer, AutoModel
from typing import List  # Add this import to fix the NameError
import torch

try:
    model_name = "openbmb/MiniCPM-Llama3-V-2_5"
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    # Use AutoModel instead of AutoModelForCausalLM for vision-language models
    model = AutoModel.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.float16)
    model = model.to(device)
    
    # For MiniCPM, we need to use the chat interface with images
    # This is a placeholder since we need actual image processing
    print("✅ MiniCPM model loaded successfully!")
    print("💡 Note: MiniCPM requires image input for proper inference.")
    
except Exception as e:
    print(f"❌ Error loading MiniCPM model: {e}")
    print("💡 Continuing with Donut model only for this demo.")

❌ Error loading MiniCPM model: name 'List' is not defined
💡 Continuing with Donut model only for this demo.


In [9]:
# Create sample training and validation datasets
import json
import os

# Create data directory if it doesn't exist
os.makedirs('data/training', exist_ok=True)

# Sample training data
train_data = [
    {
        "image": "data/converted_images/invoice_page_1.jpg",
        "ground_truth": {
            "InvoiceNo": "Y 309824263008",
            "InvoiceDate": "2025年6月30日",
            "Currency": "USD",
            "Amount with Tax": "300",
            "Amount without Tax": "300",
            "Tax": "0"
        }
    },
    {
        "image": "data/converted_images/invoice_page_1.jpg",  # Using same image for demo
        "ground_truth": {
            "InvoiceNo": "Y 309824263008",
            "InvoiceDate": "2025年6月30日",
            "Currency": "USD",
            "Amount with Tax": "300",
            "Amount without Tax": "300",
            "Tax": "0"
        }
    }
]

# Sample validation data
val_data = [
    {
        "image": "data/converted_images/invoice_page_1.jpg",
        "ground_truth": {
            "InvoiceNo": "Y 309824263008",
            "InvoiceDate": "2025年6月30日",
            "Currency": "USD",
            "Amount with Tax": "300",
            "Amount without Tax": "300",
            "Tax": "0"
        }
    }
]

# Save the datasets
with open('data/training/train.json', 'w', encoding='utf-8') as f:
    json.dump(train_data, f, indent=2, ensure_ascii=False)

with open('data/training/val.json', 'w', encoding='utf-8') as f:
    json.dump(val_data, f, indent=2, ensure_ascii=False)

print("✅ Training and validation datasets created!")
print(f"📁 Train data: {len(train_data)} samples")
print(f"📁 Validation data: {len(val_data)} samples")

✅ Training and validation datasets created!
📁 Train data: 2 samples
📁 Validation data: 1 samples


In [10]:
from datasets import load_dataset
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

# Load dataset with correct file paths
dataset = load_dataset("json", data_files={
    "train": "data/training/train.json", 
    "validation": "data/training/val.json"
})
# Preprocessing
def preprocess(example):
    image = Image.open(example["image"]).convert("RGB")
    pixel_values = processor(image, return_tensors="pt").pixel_values[0]
    task_prompt = "<s>Invoice Information:"
    decoder_input_ids = processor.tokenizer(task_prompt + str(example["ground_truth"]), return_tensors="pt").input_ids[0]
    return {"pixel_values": pixel_values, "labels": decoder_input_ids}

dataset = dataset.map(preprocess)

# Training setup
training_args = Seq2SeqTrainingArguments(
    output_dir="./donut-finetuned-invoices",
    per_device_train_batch_size=1,
    num_train_epochs=5,
    logging_dir="./logs",
    save_total_limit=2,
    eval_strategy="epoch",  # Changed from evaluation_strategy to eval_strategy
    save_strategy="epoch",  # Also add save_strategy for consistency
    load_best_model_at_end=True,  # Optional: load best model at end
    metric_for_best_model="eval_loss",  # Optional: metric to determine best model
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
)

trainer.train()


Generating train split: 2 examples [00:00, 105.54 examples/s]
Generating validation split: 1 examples [00:00, 360.49 examples/s]
Map:   0%|          | 0/2 [00:00<?, ? examples/s]


NameError: name 'Image' is not defined

In [None]:
import torch
from datasets import load_dataset
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from PIL import Image
import json
from dataclasses import dataclass
from typing import Any, Dict, List, Union

# 1. Fixed Custom Data Collator for Donut
@dataclass
class DonutDataCollator:
    """Custom data collator for Donut model that handles pixel_values and labels"""
    
    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        # Extract pixel_values and labels
        pixel_values = []
        labels = []
        
        for feature in features:
            # Ensure pixel_values is a tensor
            pv = feature["pixel_values"]
            if isinstance(pv, list):
                pv = torch.tensor(pv)
            elif not isinstance(pv, torch.Tensor):
                pv = torch.tensor(pv)
            pixel_values.append(pv)
            
            # Ensure labels is a tensor
            label = feature["labels"]
            if isinstance(label, list):
                label = torch.tensor(label)
            elif not isinstance(label, torch.Tensor):
                label = torch.tensor(label)
            labels.append(label)
        
        # Stack pixel_values
        pixel_values = torch.stack(pixel_values)
        
        # Pad labels to the same length
        max_length = max(len(label) for label in labels)
        padded_labels = []
        
        for label in labels:
            # Pad with -100 (ignored in loss computation)
            if len(label) < max_length:
                padded_label = torch.cat([
                    label,
                    torch.full((max_length - len(label),), -100, dtype=label.dtype)
                ])
            else:
                padded_label = label
            padded_labels.append(padded_label)
        
        labels = torch.stack(padded_labels)
        
        return {
            "pixel_values": pixel_values,
            "labels": labels
        }

# 2. Fixed preprocessing function
def preprocess(example):
    try:
        # Load and process image
        image = Image.open(example["image"]).convert("RGB")
        # Ensure we get a proper tensor
        processed = processor(image, return_tensors="pt")
        pixel_values = processed.pixel_values.squeeze(0)  # Remove batch dimension
        
        # Create proper target format for Donut
        ground_truth = example["ground_truth"]
        # Convert ground truth to Donut's expected format
        target_text = f"<s_InvoiceNumber>{ground_truth['InvoiceNo']}</s_InvoiceNumber><s_InvoiceDate>{ground_truth['InvoiceDate']}</s_InvoiceDate><s_Currency>{ground_truth['Currency']}</s_Currency><s_AmountWithTax>{ground_truth['Amount with Tax']}</s_AmountWithTax><s_AmountWithoutTax>{ground_truth['Amount without Tax']}</s_AmountWithoutTax><s_Tax>{ground_truth['Tax']}</s_Tax></s>"
        
        # Tokenize target
        tokenized = processor.tokenizer(target_text, 
                                      return_tensors="pt", 
                                      padding=False,
                                      truncation=True, 
                                      max_length=512)
        labels = tokenized.input_ids.squeeze(0)  # Remove batch dimension
        
        return {
            "pixel_values": pixel_values,
            "labels": labels
        }
    except Exception as e:
        print(f"Error processing example: {e}")
        return None

# 3. Load and preprocess dataset
dataset = load_dataset("json", data_files={
    "train": "data/training/train.json", 
    "validation": "data/training/val.json"
})

# Apply preprocessing and filter out None values
dataset = dataset.map(preprocess, remove_columns=dataset["train"].column_names)
dataset = dataset.filter(lambda x: x is not None)

print(f"✅ Dataset processed: {len(dataset['train'])} train, {len(dataset['validation'])} val samples")

# Debug: Check data types
print("🔍 Checking data types:")
for i, sample in enumerate(dataset["train"]):
    print(f"Sample {i}:")
    print(f"  pixel_values type: {type(sample['pixel_values'])}, shape: {sample['pixel_values'].shape if hasattr(sample['pixel_values'], 'shape') else 'N/A'}")
    print(f"  labels type: {type(sample['labels'])}, shape: {sample['labels'].shape if hasattr(sample['labels'], 'shape') else 'N/A'}")
    if i >= 1:  # Only check first 2 samples
        break

# 4. Training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir="./donut-finetuned-invoices-v2",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    num_train_epochs=3,  # Reduced for testing
    learning_rate=1e-5,
    logging_steps=1,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    predict_with_generate=True,
    generation_max_length=512,
    fp16=False,
    dataloader_pin_memory=False,
    remove_unused_columns=False,
)

# 5. Initialize custom data collator
data_collator = DonutDataCollator()

# 6. Initialize trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    data_collator=data_collator,
    tokenizer=processor.tokenizer,
)

print("🚀 Starting training with fixed data collator...")
trainer.train()

Map: 100%|██████████| 2/2 [00:00<00:00,  5.24 examples/s]
Map: 100%|██████████| 1/1 [00:00<00:00,  7.88 examples/s]
Filter: 100%|██████████| 2/2 [00:02<00:00,  1.32s/ examples]
Filter: 100%|██████████| 1/1 [00:01<00:00,  1.31s/ examples]


✅ Dataset processed: 2 train, 1 val samples
🔍 Checking data types:
Sample 0:
  pixel_values type: <class 'list'>, shape: N/A
  labels type: <class 'list'>, shape: N/A
Sample 1:
  pixel_values type: <class 'list'>, shape: N/A
  labels type: <class 'list'>, shape: N/A
🚀 Starting training with fixed data collator...


  trainer = Seq2SeqTrainer(


Epoch,Training Loss,Validation Loss
1,0.0,
2,0.0,
3,0.0,


TrainOutput(global_step=3, training_loss=0.0, metrics={'train_runtime': 76.07, 'train_samples_per_second': 0.079, 'train_steps_per_second': 0.039, 'total_flos': 3.75229398122496e+16, 'train_loss': 0.0, 'epoch': 3.0})

In [None]:
# Load the fine-tuned model for inference - CORRECTED VERSION
import torch
from transformers import VisionEncoderDecoderModel, DonutProcessor
from PIL import Image
import json
import re

# Use the original model name for the processor
original_model_name = "naver-clova-ix/donut-base-finetuned-docvqa"
model_path = "./donut-finetuned-invoices-v2/checkpoint-3"  # Your fine-tuned model
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

print(f"🔄 Loading fine-tuned model from {model_path}...")
print(f"🔄 Loading processor from original model: {original_model_name}...")

# Load fine-tuned model but original processor
finetuned_model = VisionEncoderDecoderModel.from_pretrained(model_path)
finetuned_processor = DonutProcessor.from_pretrained(original_model_name)  # Use original processor

finetuned_model.to(device)
finetuned_model.eval()
print("✅ Fine-tuned model and processor loaded successfully!")

# Function to extract invoice data using the fine-tuned model
def predict_invoice_data(image_path, model, processor):
    """Extract invoice data using the fine-tuned Donut model"""
    
    # Load and preprocess image
    image = Image.open(image_path).convert("RGB")
    pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
    
    # Create task prompt matching your training format
    task_prompt = "<s_InvoiceNumber><s_InvoiceDate><s_Currency><s_AmountWithTax><s_AmountWithoutTax><s_Tax>"
    decoder_input_ids = processor.tokenizer(task_prompt, 
                                          add_special_tokens=False, 
                                          return_tensors="pt").input_ids.to(device)
    
    # Generate prediction
    with torch.no_grad():
        outputs = model.generate(
            pixel_values,
            decoder_input_ids=decoder_input_ids,
            max_length=512,
            early_stopping=True,
            pad_token_id=processor.tokenizer.pad_token_id,
            eos_token_id=processor.tokenizer.eos_token_id,
            use_cache=True,
            num_beams=1,
            bad_words_ids=[[processor.tokenizer.unk_token_id]],
            return_dict_in_generate=True,
        )
    
    # Decode the output
    sequence = processor.batch_decode(outputs.sequences)[0]
    sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
    
    return sequence

# Function to parse the model output into structured data
def parse_donut_output(output_text):
    """Parse Donut model output into structured JSON"""
    
    result = {}
    
    # Define patterns for each field based on your training format
    patterns = {
        'InvoiceNo': r'<s_InvoiceNumber>(.*?)</s_InvoiceNumber>',
        'InvoiceDate': r'<s_InvoiceDate>(.*?)</s_InvoiceDate>',
        'Currency': r'<s_Currency>(.*?)</s_Currency>',
        'Amount with Tax': r'<s_AmountWithTax>(.*?)</s_AmountWithTax>',
        'Amount without Tax': r'<s_AmountWithoutTax>(.*?)</s_AmountWithoutTax>',
        'Tax': r'<s_Tax>(.*?)</s_Tax>'
    }
    
    # Extract each field
    for field, pattern in patterns.items():
        match = re.search(pattern, output_text)
        if match:
            result[field] = match.group(1).strip()
        else:
            result[field] = ""
    
    return result

# ADD THE MISSING FUNCTION HERE
def calculate_field_accuracy(predicted, ground_truth):
    """Calculate field-level accuracy"""
    correct = 0
    total = len(ground_truth)
    
    for field in ground_truth:
        pred_value = predicted.get(field, "").strip()
        gt_value = str(ground_truth[field]).strip()
        
        if pred_value == gt_value:
            correct += 1
            print(f"✅ {field}: MATCH")
        else:
            print(f"❌ {field}: MISMATCH - Predicted: '{pred_value}', Expected: '{gt_value}'")
    
    accuracy = correct / total
    return accuracy, correct, total

# Test the model on your training data
test_image_path = "data/converted_images/invoice_page_1.jpg"

print(f"🧪 Testing fine-tuned model on: {test_image_path}")
print("="*50)

# Get prediction from fine-tuned model
raw_prediction = predict_invoice_data(test_image_path, finetuned_model, finetuned_processor)
print(f"📄 Raw model output: {raw_prediction}")
print("="*50)

# Parse the prediction
parsed_prediction = parse_donut_output(raw_prediction)
print("🎯 Parsed prediction:")
for field, value in parsed_prediction.items():
    print(f"  {field}: {value}")

print("="*50)

# Load ground truth for comparison
with open('data/training/val.json', 'r', encoding='utf-8') as f:
    val_data = json.load(f)
    ground_truth = val_data[0]['ground_truth']  # First validation sample

print("🎯 Ground truth:")
for field, value in ground_truth.items():
    print(f"  {field}: {value}")

print("="*50)

# Calculate accuracy using the function
accuracy, correct, total = calculate_field_accuracy(parsed_prediction, ground_truth)
print(f"\n📊 **ACTUAL MODEL PERFORMANCE:**")
print(f"   Correct fields: {correct}/{total}")
print(f"   Accuracy: {accuracy*100:.2f}%")

🔄 Loading fine-tuned model from ./donut-finetuned-invoices-v2/checkpoint-3...
🔄 Loading processor from original model: naver-clova-ix/donut-base-finetuned-docvqa...
✅ Fine-tuned model and processor loaded successfully!
🧪 Testing fine-tuned model on: data/converted_images/invoice_page_1.jpg


The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


📄 Raw model output: <s_InvoiceNumber><s_InvoiceDate><s_Currency><s_AmountWithTax><s_AmountWithoutTax><s_Tax><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><

In [None]:
# Load the fine-tuned model for inference - CORRECTED VERSION
import torch
from transformers import VisionEncoderDecoderModel, DonutProcessor
from PIL import Image
import json
import re

# Use the original model name for the processor
original_model_name = "naver-clova-ix/donut-base-finetuned-docvqa"
model_path = "./donut-finetuned-invoices-v2/checkpoint-3"  # Your fine-tuned model
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

print(f"🔄 Loading fine-tuned model from {model_path}...")
print(f"🔄 Loading processor from original model: {original_model_name}...")

# Load fine-tuned model but original processor
finetuned_model = VisionEncoderDecoderModel.from_pretrained(model_path)
finetuned_processor = DonutProcessor.from_pretrained(original_model_name)  # Use original processor

finetuned_model.to(device)
finetuned_model.eval()
print("✅ Fine-tuned model and processor loaded successfully!")

# Function to extract invoice data using the fine-tuned model
def predict_invoice_data(image_path, model, processor):
    """Extract invoice data using the fine-tuned Donut model"""
    
    # Load and preprocess image
    image = Image.open(image_path).convert("RGB")
    pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
    
    # Create task prompt matching your training format
    question = "What is PAYMENT TERM?"
    task_prompt = f"<s_docvqa><s_question>{question}</s_question><s_answer>"
    decoder_input_ids = processor.tokenizer(task_prompt, 
                                          add_special_tokens=False, 
                                          return_tensors="pt").input_ids.to(device)
    
    # Generate prediction
    with torch.no_grad():
        outputs = model.generate(
            pixel_values,
            decoder_input_ids=decoder_input_ids,
            max_length=512,
            early_stopping=True,
            pad_token_id=processor.tokenizer.pad_token_id,
            eos_token_id=processor.tokenizer.eos_token_id,
            use_cache=True,
            num_beams=1,
            bad_words_ids=[[processor.tokenizer.unk_token_id]],
            return_dict_in_generate=True,
        )
    
    # Decode the output
    # Decode the output
    sequence = processor.batch_decode(outputs.sequences)[0]
    sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
    answer = sequence.split("<s_answer>")[1]
    print('answer = ', answer)
    return answer

# Function to parse the model output into structured data
def parse_donut_output(output_text):
    """Parse Donut model output into structured JSON"""
    
    result = {}
    
    # Define patterns for each field based on your training format
    patterns = {
        'InvoiceNo': r'<s_InvoiceNumber>(.*?)</s_InvoiceNumber>',
        'InvoiceDate': r'<s_InvoiceDate>(.*?)</s_InvoiceDate>',
        'Currency': r'<s_Currency>(.*?)</s_Currency>',
        'Amount with Tax': r'<s_AmountWithTax>(.*?)</s_AmountWithTax>',
        'Amount without Tax': r'<s_AmountWithoutTax>(.*?)</s_AmountWithoutTax>',
        'Tax': r'<s_Tax>(.*?)</s_Tax>'
    }
    
    # Extract each field
    for field, pattern in patterns.items():
        match = re.search(pattern, output_text)
        if match:
            result[field] = match.group(1).strip()
        else:
            result[field] = ""
    
    return result

# ADD THE MISSING FUNCTION HERE
def calculate_field_accuracy(predicted, ground_truth):
    """Calculate field-level accuracy"""
    correct = 0
    total = len(ground_truth)
    
    for field in ground_truth:
        pred_value = predicted.get(field, "").strip()
        gt_value = str(ground_truth[field]).strip()
        
        if pred_value == gt_value:
            correct += 1
            print(f"✅ {field}: MATCH")
        else:
            print(f"❌ {field}: MISMATCH - Predicted: '{pred_value}', Expected: '{gt_value}'")
    
    accuracy = correct / total
    return accuracy, correct, total

# Test the model on your training data
test_image_path = "data/converted_images/invoice_page_1.jpg"

print(f"🧪 Testing fine-tuned model on: {test_image_path}")
print("="*50)

# Get prediction from fine-tuned model
raw_prediction = predict_invoice_data(test_image_path, finetuned_model, finetuned_processor)
print(f"📄 Raw model output: {raw_prediction}")
print("="*50)

# Parse the prediction
parsed_prediction = parse_donut_output(raw_prediction)
print("🎯 Parsed prediction:")
for field, value in parsed_prediction.items():
    print(f"  {field}: {value}")

print("="*50)

# Load ground truth for comparison
with open('data/training/val.json', 'r', encoding='utf-8') as f:
    val_data = json.load(f)
    ground_truth = val_data[0]['ground_truth']  # First validation sample

print("🎯 Ground truth:")
for field, value in ground_truth.items():
    print(f"  {field}: {value}")

print("="*50)

# Calculate accuracy using the function
accuracy, correct, total = calculate_field_accuracy(parsed_prediction, ground_truth)
print(f"\n📊 **ACTUAL MODEL PERFORMANCE:**")
print(f"   Correct fields: {correct}/{total}")
print(f"   Accuracy: {accuracy*100:.2f}%")

🔄 Loading fine-tuned model from ./donut-finetuned-invoices-v2/checkpoint-3...
🔄 Loading processor from original model: naver-clova-ix/donut-base-finetuned-docvqa...
✅ Fine-tuned model and processor loaded successfully!
🧪 Testing fine-tuned model on: data/converted_images/invoice_page_1.jpg


The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


answer =  <s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s>