# Donut Model Training Notebook
This notebook demonstrates how to train a Donut (Document Understanding Transformer) model for invoice extraction using your labeled dataset.

In [None]:
# Install required packages (uncomment if needed)
# !pip install transformers datasets torch torchvision seqeval accelerate

In [None]:
import os
import json
from datasets import load_dataset, Dataset, DatasetDict
from transformers import DonutProcessor, VisionEncoderDecoderModel, TrainingArguments, Trainer
import torch
from PIL import Image
import numpy as np  
import random
import matplotlib.pyplot as plt

## 1. Load and Prepare the Dataset
Assume your images are in `../data/invoices-donut/train` and JSONs in `../data/invoices-donut/donut_json/train`.

In [None]:
image_dir = '../data/invoices-donut/train'
json_dir = '../data/invoices-donut/donut_json/train'
image_files = sorted([f for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))])
data = []
for img_file in image_files:
    img_path = os.path.join(image_dir, img_file)
    json_path = os.path.join(json_dir, os.path.splitext(img_file)[0] + '.json')
    if os.path.exists(json_path):
        with open(json_path, 'r', encoding='utf-8') as f:
            label = json.load(f)
        data.append({'image_path': img_path, 'label': label})
print(f'Loaded {len(data)} image-label pairs.')

## 2. Visualize a Sample
Let's visualize a random sample from the dataset.

In [None]:
sample = random.choice(data)
img = Image.open(sample['image_path'])
plt.imshow(img)
plt.axis('off')
plt.title('Sample Invoice Image')
plt.show()
print('Label:', json.dumps(sample['label'], indent=2, ensure_ascii=False))

## 3. Load Donut Processor and Model
You can use a pre-trained Donut model from HuggingFace and fine-tune it.

In [None]:
processor = DonutProcessor.from_pretrained('naver-clova-ix/donut-base-finetuned-docvqa')
model = VisionEncoderDecoderModel.from_pretrained('naver-clova-ix/donut-base-finetuned-docvqa')
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)

## 4. Prepare Dataset for Training
We need to convert images and labels into the format expected by Donut.

In [None]:
def preprocess(example):
    image = Image.open(example['image_path']).convert('RGB')
    pixel_values = processor(image, return_tensors='pt').pixel_values[0]
    label_str = json.dumps(example['label'], ensure_ascii=False)
    labels = processor.tokenizer(label_str, add_special_tokens=True, max_length=512, padding='max_length', truncation=True, return_tensors='pt').input_ids[0]
    return {'pixel_values': pixel_values, 'labels': labels}

dataset = Dataset.from_list(data)
dataset = dataset.map(preprocess)
dataset.set_format(type='torch', columns=['pixel_values', 'labels'])

## 5. Train the Model
Set up training arguments and start training.

In [None]:
training_args = TrainingArguments(
    output_dir='./donut-finetuned-invoice',
    per_device_train_batch_size=2,
    num_train_epochs=5,
    save_steps=100,
    save_total_limit=2,
    logging_steps=10,
    learning_rate=5e-5,
    fp16=True,
    report_to='none'
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    tokenizer=processor.tokenizer,
    data_collator=None
)

trainer.train()

## 6. Save the Fine-tuned Model
Save your model and processor for later inference.

In [None]:
model.save_pretrained('./donut-finetuned-invoice')
processor.save_pretrained('./donut-finetuned-invoice')