In [None]:
# For handling images and data
import os
import json
from PIL import Image
import numpy as np


# For working with Hugging Face transformers
from transformers import DonutProcessor, VisionEncoderDecoderModel

# For managing and splitting data
from sklearn.model_selection import train_test_split

# For PyTorch (if needed for model fine-tuning)
import torch
from torch.utils.data import DataLoader, Dataset

# Additional utilities
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
!pip install datasets

from datasets import load_dataset

from datasets import load_from_disk

# Load the dataset
dataset = load_dataset('naver-clova-ix/cord-v2')

# Save the dataset locally
# dataset.save_to_disk('./cord_v2')

# dataset = load_from_disk('./cord_v2')

In [None]:
# Check the dataset structure
print(dataset)

# Access the training split
train_dataset = dataset['train']

# Display a sample
print(train_dataset[0])

# Further you can use anyne of the either cells below for the model training

In [None]:
from PIL import Image
import torch
from torch.utils.data import DataLoader
from transformers import DonutProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments

# Load the Donut processor (which includes the tokenizer and image processor)
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")

# Define the maximum sequence length for padding and truncation
max_length = 512  # You can adjust this value based on your dataset and model

# Define the preprocessing function
def preprocess_data(example):
    try:
        # Access the image directly from the dataset (assuming it's already a PIL image)
        image = example['image']

        # Process the image using the DonutProcessor
        pixel_values = processor(image, return_tensors="pt").pixel_values

        # Process the annotations (text) using the DonutProcessor's tokenizer
        target_text = processor.tokenizer(
            example['ground_truth'],
            add_special_tokens=True,
            return_tensors="pt",
            padding='max_length',
            truncation=True,
            max_length=512
        ).input_ids

        # Ensure that both pixel_values and input_ids are valid
        if target_text.size(1) == 0 or pixel_values is None:
            return None

        return {"pixel_values": pixel_values.squeeze(), "input_ids": target_text.squeeze()}
    except Exception as e:
        print(f"Error processing sample: {example}, error: {e}")
        return None

# Apply the preprocessing function to the dataset using batched operations
train_dataset = dataset['train'].map(preprocess_data, batched=True, batch_size=8, remove_columns=dataset['train'].column_names)
validation_dataset = dataset['validation'].map(preprocess_data, batched=True, batch_size=8, remove_columns=dataset['validation'].column_names)

# Convert the processed datasets to PyTorch tensors
train_dataset.set_format(type='torch', columns=['pixel_values', 'input_ids'])
validation_dataset.set_format(type='torch', columns=['pixel_values', 'input_ids'])

# Set up DataLoader for batch processing
batch_size = 2  # Use a smaller batch size to avoid out-of-memory errors
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=batch_size, pin_memory=True)

# Load the Donut model using VisionEncoderDecoderModel
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")

# Define a custom data collator
class CustomDataCollator:
    def __call__(self, batch):
        # Filter out any entries that might be None or missing required keys
        batch = [item for item in batch if item and 'pixel_values' in item and 'input_ids' in item]

        # If the batch is empty, raise an error or handle it
        if not batch:
            raise ValueError("Batch is empty after filtering")

        # Separate pixel values and input IDs
        pixel_values = torch.stack([item["pixel_values"] for item in batch])
        input_ids = torch.stack([item["input_ids"] for item in batch])

        # Return a dictionary of batched data
        return {"pixel_values": pixel_values, "labels": input_ids}


data_collator = CustomDataCollator()

# Define the training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir="./donut_model",
    num_train_epochs=5,
    per_device_train_batch_size=batch_size,  # Use the same batch size as defined earlier
    per_device_eval_batch_size=batch_size,
    gradient_accumulation_steps=4,  # Accumulate gradients for more effective batch size
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_dir="./logs",
    logging_steps=10,
    save_total_limit=3,  # Keep only the last 3 checkpoints
    fp16=True,  # Mixed precision training
    dataloader_pin_memory=True,
)

# Clear CUDA cache before training
torch.cuda.empty_cache()

# Initialize the Trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=validation_dataset,
    data_collator=data_collator,
    tokenizer=processor.tokenizer,
)

# Start the training process
trainer.train()

# Save the final model
trainer.save_model("./donut_model_final")
processor.tokenizer.save_pretrained("./donut_model_final")

In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, DonutProcessor
import pandas as pd
from PIL import Image

# Custom Dataset Class
class CustomDataset(Dataset):
    def __init__(self, df, processor):
        self.df = df
        self.processor = processor

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # Get the image path and text from the dataframe
        image_path = self.df.iloc[idx, 0]
        text = self.df.iloc[idx, 1]

        # Load the image
        image = Image.open(image_path).convert("RGB")

        # Process the image
        pixel_values = self.processor(image, return_tensors="pt").pixel_values.squeeze()

        # Encode the text
        input_ids = self.processor.tokenizer(text, add_special_tokens=False, return_tensors="pt").input_ids.squeeze()

        # Return the processed image and encoded text
        return {
            "pixel_values": pixel_values,
            "input_ids": input_ids,
        }

# Custom Data Collator to filter empty batches
class CustomDataCollator:
    def __call__(self, batch):
        # Filter out any entries that might be None or missing required keys
        batch = [item for item in batch if item and 'pixel_values' in item and 'input_ids' in item]

        # If the batch is empty, skip this batch
        if not batch:
            return None

        # Separate pixel values and input IDs
        pixel_values = torch.stack([item["pixel_values"] for item in batch])
        input_ids = torch.stack([item["input_ids"] for item in batch])

        # Return a dictionary of batched data
        return {"pixel_values": pixel_values, "labels": input_ids}

# Load your dataset
df = pd.DataFrame({dataset})

# Load the Donut processor
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")

# Initialize the custom dataset
dataset = CustomDataset(df, processor)

# Create a DataLoader with the custom data collator
data_collator = CustomDataCollator()
data_loader = DataLoader(dataset, batch_size=2, collate_fn=data_collator)

# Load the Donut model
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base")

# Define training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir="./donut_model",
    per_device_train_batch_size=2,
    num_train_epochs=3,
    logging_dir="./logs",
    logging_steps=10,
)

# Define a custom Trainer class that skips None batches
class CustomTrainer(Seq2SeqTrainer):
    def get_train_dataloader(self):
        return data_loader

    def training_step(self, model, inputs):
        if inputs is None:
            return None  # Skip empty batches

        outputs = model(**inputs)
        loss = outputs.loss
        loss.backward()

        return loss

# Initialize the trainer
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=data_collator,
    tokenizer=processor.tokenizer,
)

# Start the training process
trainer.train()

# Save the final model
trainer.save_model("./donut_model_final")