# Donut Model Training for Invoice Extraction (Colab Version)

This notebook is optimized for Google Colab with T4/P100 GPU (16GB VRAM).

## Setup Steps:
1. Upload your invoice dataset to Google Drive
2. Mount your Drive
3. Install dependencies
4. Train the model

Expected folder structure in Drive:
```
MyDrive/
  data/
    invoices-donut/
      train/
      valid/
      donut_json/
        train/
        valid/
```

In [None]:
# Check GPU type and enable memory efficient settings
!nvidia-smi

# Install required packages
!pip install transformers==4.31.0 datasets torch torchvision seqeval accelerate

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Imports and GPU Setup
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import json
from datasets import Dataset, DatasetDict
from transformers import (
    DonutProcessor, 
    VisionEncoderDecoderModel,
    TrainingArguments,
    default_data_collator,
    Trainer
)
import torch
from PIL import Image
import numpy as np
import random
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import gc  # For memory management

# Verify GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if torch.cuda.is_available():
    torch.cuda.empty_cache()
print(f"Using device: {device}")
if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

In [None]:
# Data Loading Function with memory optimization
def load_split(split='train', chunk_size=100):
    # Update these paths to match your Drive structure
    image_dir = f'/content/drive/MyDrive/data/invoices-donut/{split}'
    json_dir = f'/content/drive/MyDrive/data/invoices-donut/donut_json/{split}'
    
    # Create directories if they don't exist
    os.makedirs(image_dir, exist_ok=True)
    os.makedirs(json_dir, exist_ok=True)
    
    # Get all image files
    image_files = sorted([f for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))])
    total_files = len(image_files)
    data = []
    
    # Process in chunks to save memory
    for i in range(0, total_files, chunk_size):
        chunk = image_files[i:i + chunk_size]
        print(f'Processing chunk {i//chunk_size + 1}/{(total_files + chunk_size - 1)//chunk_size}')
    
    for img_file in tqdm(chunk, desc=f'Loading chunk for {split}'):
            img_path = os.path.join(image_dir, img_file)
            json_path = os.path.join(json_dir, os.path.splitext(img_file)[0] + '.json')
            if os.path.exists(json_path):
                try:
                    with open(json_path, 'r', encoding='utf-8') as f:
                        label = json.load(f)
                        if 'raw_response' in label:
                            label = json.loads(label['raw_response'])
                except Exception as e:
                    print(f"Error processing {img_file}: {str(e)}")
                    continue
                
                # Free up memory by explicitly deleting unused variables
                flat_label = {}
                for k, v in label.items():
                    if isinstance(v, (dict, list)):
                        flat_label[k] = json.dumps(v, ensure_ascii=False)
                    else:
                        flat_label[k] = str(v)
                data.append({'image_path': img_path, **flat_label})
                del label, flat_label
                
                # Free memory after each chunk
                if len(data) > chunk_size * 2:
                    gc.collect()
            
    return data

# Load datasets with chunking to save memory
print("Loading training data...")
train_data = load_split('train', chunk_size=50)  # Smaller chunks to manage memory
gc.collect()  # Clear memory after training data

print("\nLoading validation data...")
val_data = load_split('valid', chunk_size=50)
gc.collect()  # Clear memory after validation data

print(f'\nLoaded {len(train_data)} training and {len(val_data)} validation image-label pairs.')

In [None]:
# Initialize model and processor
model_name = "naver-clova-ix/donut-base"
processor = DonutProcessor.from_pretrained(model_name)
model = VisionEncoderDecoderModel.from_pretrained(model_name)

# Move model to GPU
model.to(device)

In [None]:
# Dataset Preprocessing
def preprocess(example, processor):
    import json
    from PIL import Image
    
    pixel_values = []
    labels_list = []
    target_sequences = []
    
    batch_size = len(example['image_path'])
    label_keys = [k for k in example.keys() if k != 'image_path']
    
    for i in range(batch_size):
        label_dict = {k: example[k][i] for k in label_keys}
        # Resize image to 384x384 for memory efficiency
        image = Image.open(example['image_path'][i]).convert('RGB').resize((384, 384))
        pixel_values.append(processor(image, return_tensors='pt').pixel_values[0])
        
        task_prompt = "<s_invoice>"
        label_str = json.dumps(label_dict, ensure_ascii=False)
        target_sequence = f"{task_prompt}{label_str}</s_invoice>"
        target_sequences.append(target_sequence)
        
        labels = processor.tokenizer(
            target_sequence,
            padding='max_length',
            max_length=512,
            truncation=True,
            return_tensors='pt'
        ).input_ids[0]
        labels_list.append(labels)
    
    return {
        'pixel_values': pixel_values,
        'labels': labels_list,
        'target_sequence': target_sequences
    }

# Create datasets
train_dataset = Dataset.from_list(train_data)
val_dataset = Dataset.from_list(val_data)

# Process datasets
train_dataset = train_dataset.map(
    preprocess,
    remove_columns=train_dataset.column_names,
    num_proc=1,  # Single process for stability
    fn_kwargs={'processor': processor},
    batched=True,
    batch_size=4,
    keep_in_memory=False
)

val_dataset = val_dataset.map(
    preprocess,
    remove_columns=val_dataset.column_names,
    num_proc=1,
    fn_kwargs={'processor': processor},
    batched=True,
    batch_size=4,
    keep_in_memory=False
)

dataset = DatasetDict({
    'train': train_dataset,
    'validation': val_dataset
})

# Set format for PyTorch
dataset.set_format(type='torch', columns=['pixel_values', 'labels'])

In [None]:
# Define metrics computation
def compute_metrics(eval_pred):
    # Extract predictions and references from eval_pred
    predictions = eval_pred.predictions
    labels = eval_pred.label_ids
    
    # Process predictions
    decoded_preds = processor.tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, processor.tokenizer.pad_token_id)
    decoded_labels = processor.tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Count exact matches
    exact_matches = sum(1 for pred, label in zip(decoded_preds, decoded_labels) if pred.strip() == label.strip())
    exact_match_accuracy = exact_matches / len(decoded_preds)
    
    return {"exact_match_accuracy": exact_match_accuracy}

In [None]:
# Training setup
output_dir = '/content/drive/MyDrive/models/donut-finetuned-invoice'
os.makedirs(output_dir, exist_ok=True)

training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=2,  # Increased for 16GB GPU
    per_device_eval_batch_size=2,   # Increased for 16GB GPU
    num_train_epochs=3,
    save_steps=100,
    save_total_limit=2,
    logging_steps=10,
    learning_rate=5e-5,
    fp16=True,                      # Mixed precision
    gradient_checkpointing=True,    # Memory efficiency
    gradient_accumulation_steps=4,  # Effective batch size of 8
    evaluation_strategy='epoch',
    save_strategy='epoch',
    load_best_model_at_end=True,
    metric_for_best_model='exact_match_accuracy',
    warmup_ratio=0.1,
    report_to='none',
    gradient_checkpointing_kwargs={"use_reentrant": False}
)

# Initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['validation'],
    tokenizer=processor.tokenizer,
    data_collator=default_data_collator,
    compute_metrics=compute_metrics
)

In [None]:
# Start training
trainer.train()

# Save final model
trainer.save_model()
processor.save_pretrained(output_dir)

## Training Complete!

The trained model and processor are saved in your Google Drive at: `/content/drive/MyDrive/models/donut-finetuned-invoice`

Key improvements over local training:
- Uses 16GB GPU efficiently
- Larger batch sizes
- Mixed precision training
- Gradient checkpointing for memory efficiency
- Auto-saves to Google Drive

Expected training time: ~2-3 hours