### CPU Only

In [24]:
import torch
from transformers import (
    DonutProcessor, 
    VisionEncoderDecoderModel,
    TrainingArguments,
    Trainer
)
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import json
import os
from typing import List, Dict, Any
import logging
from dataclasses import dataclass
from transformers.file_utils import PaddingStrategy
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class DataCollatorForDonut:
    """
    Data collator that will dynamically pad the inputs received, as well as the labels.
    """
    tokenizer: PreTrainedTokenizerBase
    padding: bool = True
    max_length: int = None
    pad_to_multiple_of: int = None
    
    def __call__(self, features):
        # Split inputs and labels since they have different lengths
        pixel_values = [feature["pixel_values"] for feature in features]
        decoder_input_ids = [feature["decoder_input_ids"] for feature in features]
        labels = [feature["labels"] for feature in features]
        
        # Stack pixel values
        pixel_values = torch.stack(pixel_values)
        
        # Pad decoder input ids and labels to the same length
        max_length = max(len(seq) for seq in decoder_input_ids + labels)
        if self.max_length is not None:
            max_length = min(max_length, self.max_length)
        
        # Pad decoder input ids
        decoder_input_ids_padded = []
        for seq in decoder_input_ids:
            if len(seq) < max_length:
                padding = [self.tokenizer.pad_token_id] * (max_length - len(seq))
                padded_seq = torch.cat([seq, torch.tensor(padding)])
            else:
                padded_seq = seq[:max_length]
            decoder_input_ids_padded.append(padded_seq)
        
        # Pad labels
        labels_padded = []
        for seq in labels:
            if len(seq) < max_length:
                padding = [-100] * (max_length - len(seq))  # Use -100 for padding in labels
                padded_seq = torch.cat([seq, torch.tensor(padding)])
            else:
                padded_seq = seq[:max_length]
            labels_padded.append(padded_seq)
        
        # Stack tensors
        decoder_input_ids = torch.stack(decoder_input_ids_padded)
        labels = torch.stack(labels_padded)
        
        return {
            "pixel_values": pixel_values,
            "decoder_input_ids": decoder_input_ids,
            "labels": labels
        }

class InsuranceClaimDataset(Dataset):
    """Dataset class for insurance claim documents"""
    
    def __init__(self, jsonl_file: str, images_dir: str, processor: DonutProcessor, max_length: int = 128):
        self.jsonl_file = jsonl_file
        self.images_dir = images_dir
        self.processor = processor
        self.max_length = max_length
        
        # Load your dataset annotations
        self.annotations = self.load_annotations()
        
    def load_annotations(self) -> List[Dict[str, Any]]:
        """Load annotations from JSONL file"""
        annotations = []
        
        # Load from JSONL file
        if os.path.exists(self.jsonl_file):
            with open(self.jsonl_file, 'r') as f:
                for line in f:
                    if line.strip():  # Skip empty lines
                        data = json.loads(line)
                        
                        # Build full image path
                        image_filename = data.get("image_path", data.get("image", data.get("file_name")))
                        full_image_path = os.path.join(self.images_dir, image_filename)
                        
                        # Handle different possible formats for ground truth
                        ground_truth = data.get("ground_truth", data.get("annotation", data.get("labels")))
                        
                        # Convert ground truth to string if it's a dict
                        if isinstance(ground_truth, dict):
                            ground_truth = json.dumps(ground_truth)
                        
                        annotations.append({
                            "image_path": full_image_path,
                            "ground_truth": ground_truth
                        })
        else:
            raise FileNotFoundError(f"JSONL file not found at {self.jsonl_file}")
        
        return annotations
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        annotation = self.annotations[idx]
        
        # Load image
        try:
            image = Image.open(annotation["image_path"]).convert("RGB")
        except Exception as e:
            print(f"Error loading image {annotation['image_path']}: {e}")
            raise
        
        # Create task prompt for insurance claims
        task_prompt = "<s_docvqa><s_question>Extract key information from this insurance claim document</s_question><s_answer>"
        
        # Prepare target (ground truth)
        target_text = annotation["ground_truth"] + self.processor.tokenizer.eos_token
        
        # Process image only
        image_encoding = self.processor(
            images=image,
            return_tensors="pt"
        )
        
        # Create decoder input ids (task prompt)
        decoder_input_ids = self.processor.tokenizer(
            task_prompt,
            return_tensors="pt",
            padding=False,
            truncation=True,
            max_length=self.max_length,
            add_special_tokens=False
        )
        
        # Create labels (task prompt + target)
        full_target = task_prompt + target_text
        labels = self.processor.tokenizer(
            full_target,
            return_tensors="pt",
            padding=False,
            truncation=True,
            max_length=self.max_length,
            add_special_tokens=False
        )
        
        return {
            "pixel_values": image_encoding["pixel_values"].squeeze(0),
            "decoder_input_ids": decoder_input_ids["input_ids"].squeeze(0),
            "labels": labels["input_ids"].squeeze(0)
        }

class DonutTrainer(Trainer):
    """Custom trainer for Donut model with proper loss computation"""
    
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        """Custom loss computation with detailed debugging"""
        labels = inputs.pop("labels")
        
        # Forward pass
        outputs = model(**inputs)
        logits = outputs.logits
        
        # Debug prints (remove in production)
        print(f"Logits shape: {logits.shape}")
        print(f"Labels shape: {labels.shape}")
        
        # Ensure we have valid labels
        if labels is None:
            return outputs.loss if return_outputs else None
        
        # For causal language modeling, shift labels and logits
        if labels.dim() == 2:  # [batch_size, seq_len]
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            
            print(f"Shift logits shape: {shift_logits.shape}")
            print(f"Shift labels shape: {shift_labels.shape}")
            
            # Flatten for loss computation
            shift_logits = shift_logits.view(-1, shift_logits.size(-1))
            shift_labels = shift_labels.view(-1)
            
            print(f"Flattened logits shape: {shift_logits.shape}")
            print(f"Flattened labels shape: {shift_labels.shape}")
            
            # Compute loss
            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(shift_logits, shift_labels)
        else:
            # Fallback to standard loss computation
            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
        
        return (loss, outputs) if return_outputs else loss

def prepare_model_and_processor(model_name: str = "naver-clova-ix/donut-base-finetuned-docvqa"):
    """Load pre-trained model and processor"""
    
    processor = DonutProcessor.from_pretrained(model_name)
    model = VisionEncoderDecoderModel.from_pretrained(model_name)
    
    # Force model to CPU
    #model = model.to('cpu')
    model.to('cuda' if torch.cuda.is_available() else "cpu")
    
    # Configure model for fine-tuning
    model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
    model.config.pad_token_id = processor.tokenizer.pad_token_id
    model.config.eos_token_id = processor.tokenizer.eos_token_id
    
    # Enable gradient checkpointing for memory efficiency
    model.gradient_checkpointing_enable()
    
    return model, processor

def fine_tune_donut(
    jsonl_file: str,
    images_dir: str,
    output_dir: str,
    val_split: float = 0.2,
    num_epochs: int = 3,
    batch_size: int = 4,
    learning_rate: float = 5e-5
):
    """Main fine-tuning function"""
    
    # Load model and processor
    model, processor = prepare_model_and_processor()
    
    # Create full dataset first
    full_dataset = InsuranceClaimDataset(jsonl_file, images_dir, processor)
    
    logger.info(f"Total dataset size: {len(full_dataset)}")
    
    # Handle small datasets
    if len(full_dataset) < 5:
        logger.warning(f"Very small dataset ({len(full_dataset)} samples)! This may not be sufficient for meaningful training.")
        logger.warning("Consider:")
        logger.warning("1. Adding more training data")
        logger.warning("2. Using data augmentation")
        logger.warning("3. Setting val_split=0 to use all data for training")
        
        # Force no validation split for very small datasets
        if len(full_dataset) <= 3:
            val_split = 0
            logger.info("Forcing val_split=0 due to extremely small dataset")
    
    # Split into train and validation
    if val_split > 0 and len(full_dataset) > 3:
        from sklearn.model_selection import train_test_split
        
        train_indices, val_indices = train_test_split(
            range(len(full_dataset)),
            test_size=val_split,
            random_state=42
        )
        
        train_dataset = torch.utils.data.Subset(full_dataset, train_indices)
        val_dataset = torch.utils.data.Subset(full_dataset, val_indices)
        
        logger.info(f"Training samples: {len(train_dataset)}")
        logger.info(f"Validation samples: {len(val_dataset)}")
    else:
        train_dataset = full_dataset
        val_dataset = None
        logger.info(f"Training samples: {len(train_dataset)} (no validation split)")
    
    # Adjust batch size for small datasets
    actual_batch_size = min(batch_size, len(train_dataset))
    if actual_batch_size < batch_size:
        logger.warning(f"Reducing batch size from {batch_size} to {actual_batch_size} due to small dataset")
    
    # Training arguments - CPU optimized
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=num_epochs,
        per_device_train_batch_size=actual_batch_size,
        per_device_eval_batch_size=actual_batch_size,
        warmup_steps=min(100, len(train_dataset) // actual_batch_size),  # Adjust warmup for small datasets
        weight_decay=0.01,
        logging_dir=f"{output_dir}/logs",
        logging_steps=max(1, len(train_dataset) // actual_batch_size),  # Log every step for small datasets
        save_strategy="epoch",  # Save every epoch instead of steps for small datasets
        eval_strategy="epoch" if val_dataset is not None else "no",
        load_best_model_at_end=val_dataset is not None,
        dataloader_pin_memory=False,
        gradient_checkpointing=True,
        learning_rate=learning_rate,
        lr_scheduler_type="cosine",
        remove_unused_columns=False,
        push_to_hub=False,
        report_to=None,
        no_cuda=False,  # Force CPU usage
        fp16=False,  # Disable mixed precision for CPU
        dataloader_num_workers=0,  # Avoid multiprocessing issues
        gradient_accumulation_steps=max(1, 4 // actual_batch_size),  # Simulate larger batch sizes
        save_total_limit=2,  # Keep only 2 checkpoints to save space
    )
    
    # Add validation settings if we have validation data
    if val_dataset is not None:
        training_args.metric_for_best_model = "eval_loss"
        training_args.greater_is_better = False
    
    # Create data collator with proper max_length
    data_collator = DataCollatorForDonut(
        tokenizer=processor.tokenizer,
        padding=True,
        max_length=128
    )
    
    # Use the custom trainer with improved loss computation
    trainer = DonutTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=processor.tokenizer,
        data_collator=data_collator,
    )
    
    # Debug: Print a sample from the dataset
    if len(train_dataset) > 0:
        sample = train_dataset[0]
        print(f"Sample pixel_values shape: {sample['pixel_values'].shape}")
        print(f"Sample decoder_input_ids shape: {sample['decoder_input_ids'].shape}")
        print(f"Sample labels shape: {sample['labels'].shape}")
    
    # Start training
    logger.info("Starting fine-tuning...")
    trainer.train()
    
    # Save final model
    trainer.save_model()
    processor.save_pretrained(output_dir)
    
    logger.info(f"Fine-tuning completed. Model saved to {output_dir}")
    
    # Additional recommendations for small datasets
    if len(full_dataset) < 10:
        logger.info("RECOMMENDATIONS for small dataset:")
        logger.info("1. Consider data augmentation (rotation, scaling, noise)")
        logger.info("2. Use a pre-trained model that's already good at similar tasks")
        logger.info("3. Try few-shot learning approaches")
        logger.info("4. Collect more training data if possible")

def inference_example(model_path: str, image_path: str):
    """Example inference function - CPU only"""
    
    # Load fine-tuned model
    processor = DonutProcessor.from_pretrained(model_path)
    model = VisionEncoderDecoderModel.from_pretrained(model_path)
    
    # Force model to CPU and set to evaluation mode
    #model = model.to('cpu')
    device = 'cuda' if torch.cuda.is_available() else "cpu"
    model.to(device)
    model.eval()
    
    # Load and process image
    image = Image.open(image_path).convert("RGB")
    
    # Create task prompt
    task_prompt = "<s_docvqa><s_question>Extract key information from this insurance claim document</s_question><s_answer>"
    
    # Process input - ensure tensors are on CPU
    pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
    decoder_input_ids = processor.tokenizer(
        task_prompt, 
        add_special_tokens=False, 
        return_tensors="pt"
    ).input_ids.to(device)
    
    # Generate output
    with torch.no_grad():
        outputs = model.generate(
            pixel_values,
            decoder_input_ids=decoder_input_ids,
            max_length=model.decoder.config.max_position_embeddings,
            early_stopping=True,
            pad_token_id=processor.tokenizer.pad_token_id,
            eos_token_id=processor.tokenizer.eos_token_id,
            use_cache=True,
            num_beams=1,
            bad_words_ids=[[processor.tokenizer.unk_token_id]],
            return_dict_in_generate=True,
        )
    
    # Decode output
    sequence = processor.batch_decode(outputs.sequences)[0]
    sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
    
    # Extract the answer part
    if "<s_answer>" in sequence:
        sequence = sequence.split("<s_answer>")[1]
    
    return sequence.strip()

    # Example usage - CPU training with small dataset handling
    if __name__ == "__main__":
        # Configuration for CPU training
        JSONL_FILE = "Dataset/metadata.jsonl"
        IMAGES_DIR = "Dataset/images/"
        OUTPUT_DIR = "./donut-insurance-finetuned"

        print("Starting CPU-only training...")
        print("Warning: Training on CPU will be significantly slower than GPU!")
        print("Note: Small datasets (< 10 samples) may not provide meaningful results.")
        
        # Fine-tune the model with CPU-optimized settings for small datasets
        fine_tune_donut(
            jsonl_file=JSONL_FILE,
            images_dir=IMAGES_DIR,
            output_dir=OUTPUT_DIR,
            val_split=0.0,  # No validation split for small datasets
            num_epochs=10,  # More epochs might help with small datasets
            batch_size=1,   # Small batch size for CPU and small dataset
            learning_rate=1e-5  # Lower learning rate for stability
        )
        
        # Example inference
        # result = inference_example(OUTPUT_DIR, "path/to/test/image.jpg")
        # print(f"Extracted information: {result}")

In [20]:
JSONL_FILE = "Dataset/metadata.jsonl"
IMAGES_DIR = "Dataset/images/"
OUTPUT_DIR = "./donut-insurance-finetuned"

print("Starting CPU-only training...")
print("Warning: Training on CPU will be significantly slower than GPU!")
print("Consider using smaller batch sizes and fewer epochs for CPU training.")

# Fine-tune the model with CPU-optimized settings
fine_tune_donut(
    jsonl_file=JSONL_FILE,
    images_dir=IMAGES_DIR,
    output_dir=OUTPUT_DIR,
    val_split=0.2,  # 20% for validation, set to 0 for no validation
    num_epochs=2,   # Reduced epochs for CPU training
    batch_size=1,   # Smaller batch size for CPU
    learning_rate=5e-5
)


Starting CPU-only training...
Consider using smaller batch sizes and fewer epochs for CPU training.


INFO:__main__:Total dataset size: 8
INFO:__main__:Training samples: 6
INFO:__main__:Validation samples: 2
  trainer = DonutTrainer(
INFO:__main__:Starting fine-tuning...


Sample pixel_values shape: torch.Size([3, 2560, 1920])
Sample decoder_input_ids shape: torch.Size([13])
Sample labels shape: torch.Size([128])
Logits shape: torch.Size([1, 128, 57532])
Labels shape: torch.Size([1, 128])
Shift logits shape: torch.Size([1, 127, 57532])
Shift labels shape: torch.Size([1, 127])
Flattened logits shape: torch.Size([127, 57532])
Flattened labels shape: torch.Size([127])
Logits shape: torch.Size([1, 128, 57532])
Labels shape: torch.Size([1, 128])
Shift logits shape: torch.Size([1, 127, 57532])
Shift labels shape: torch.Size([1, 127])
Flattened logits shape: torch.Size([127, 57532])
Flattened labels shape: torch.Size([127])
Logits shape: torch.Size([1, 128, 57532])
Labels shape: torch.Size([1, 128])
Shift logits shape: torch.Size([1, 127, 57532])
Shift labels shape: torch.Size([1, 127])
Flattened logits shape: torch.Size([127, 57532])
Flattened labels shape: torch.Size([127])
Logits shape: torch.Size([1, 128, 57532])
Labels shape: torch.Size([1, 128])
Shift log

Epoch,Training Loss,Validation Loss
1,No log,20.391476
2,No log,20.391476


Logits shape: torch.Size([1, 128, 57532])
Labels shape: torch.Size([1, 128])
Shift logits shape: torch.Size([1, 127, 57532])
Shift labels shape: torch.Size([1, 127])
Flattened logits shape: torch.Size([127, 57532])
Flattened labels shape: torch.Size([127])
Logits shape: torch.Size([1, 128, 57532])
Labels shape: torch.Size([1, 128])
Shift logits shape: torch.Size([1, 127, 57532])
Shift labels shape: torch.Size([1, 127])
Flattened logits shape: torch.Size([127, 57532])
Flattened labels shape: torch.Size([127])
Logits shape: torch.Size([1, 128, 57532])
Labels shape: torch.Size([1, 128])
Shift logits shape: torch.Size([1, 127, 57532])
Shift labels shape: torch.Size([1, 127])
Flattened logits shape: torch.Size([127, 57532])
Flattened labels shape: torch.Size([127])
Logits shape: torch.Size([1, 128, 57532])
Labels shape: torch.Size([1, 128])
Shift logits shape: torch.Size([1, 127, 57532])
Shift labels shape: torch.Size([1, 127])
Flattened logits shape: torch.Size([127, 57532])
Flattened labe

There were missing keys in the checkpoint model loaded: ['decoder.lm_head.weight'].
INFO:__main__:Fine-tuning completed. Model saved to ./donut-insurance-finetuned
INFO:__main__:RECOMMENDATIONS for small dataset:
INFO:__main__:1. Consider data augmentation (rotation, scaling, noise)
INFO:__main__:2. Use a pre-trained model that's already good at similar tasks
INFO:__main__:3. Try few-shot learning approaches
INFO:__main__:4. Collect more training data if possible


In [28]:
result = inference_example(OUTPUT_DIR, "Dataset/images/test1.png")
print(f"Extracted information: {result}")

The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Extracted information: scanning and measurement instructions</s_answer>


In [22]:
OUTPUT_DIR

'./donut-insurance-finetuned'

In [27]:
type(result)

str