In [1]:
import os
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from transformers import DonutProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments
from warnings import filterwarnings
filterwarnings('ignore')
from transformers import VisionEncoderDecoderModel, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa").to(device)

In [3]:
x_dict = {
    "219gbqQt+ML.jpg": "height",
    "218vf17tHkL.jpg": "weight",
    "21-VzxP3BDL.jpg": "item_volume",
    "217V+UhIrHL.jpg": "length",
    "11j0F4QOiFL.jpg": "height",
    "211sXYcOHcL.jpg": "height",
    "218zo3iJ2IL.jpg": "length",
    "213VIsNlvzL.jpg": "height",
    "21+quvMwZSL.jpg": "weight",
    "217+y-mckBL.jpg": "weight",
    "211EIgVhPEL.jpg": "voltage",
    "218tBdpDGPS.jpg": "length",
    "21-V2Kx5BVL.jpg": "length"
}

y_dict = {
    "219gbqQt+ML.jpg": "12 cm",
    "218vf17tHkL.jpg": "250 mg",
    "21-VzxP3BDL.jpg": "200 ml",
    "217V+UhIrHL.jpg": "5 cm",
    "11j0F4QOiFL.jpg": "2.75 inches",
    "211sXYcOHcL.jpg": "8 cm",
    "218zo3iJ2IL.jpg": "44.2 cm",
    "213VIsNlvzL.jpg": "11 cm",
    "21+quvMwZSL.jpg": "1.6 lbs",
    "217+y-mckBL.jpg": "400 mg",
    "211EIgVhPEL.jpg": "3.7 V",
    "218tBdpDGPS.jpg": "104.5 inches",
    "21-V2Kx5BVL.jpg": "80 inches"
}


In [4]:
training_args = Seq2SeqTrainingArguments(
    # predict_with_generate=True,
    evaluation_strategy="steps",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=2,
    # logging_steps=100,
    save_steps=500,
    eval_steps=500,
    save_total_limit=3,
    num_train_epochs=3,
    output_dir="./donut-finetuned-docvqa",
    fp16=True,
)

In [5]:
import torch
from PIL import Image
from torch.utils.data import Dataset
import os

class ImageDataset(Dataset):
    def __init__(self, image_dir, x_dict, y_dict, processor):
        self.image_dir = image_dir
        self.processor = processor
        self.images = list(x_dict.keys())
        self.questions = list(x_dict.values())
        self.answers = list(y_dict.values())
        self.pre_finetune_text = 'Given the image, what is the'
        self.image_files = os.listdir(image_dir)
        print(len(self.images), len(self.questions), len(self.answers), len(self.image_files))
        assert type(self.answers) == list, "Answer should be a list of strings"
        
    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(img_name).convert("RGB")
        
        question = f"{self.pre_finetune_text} {self.questions[idx]}?"
        answer = self.answers[idx]
        
        # Prepare the inputs for the Donut model
        encoding = self.processor(images=image, text=question, return_tensors="pt")
        
        # Add the answer as the target text
        encoding["labels"] = self.processor.tokenizer(answer, return_tensors="pt").input_ids
        
        # Remove batch dimension
        for k,v in encoding.items():
            encoding[k] = v.squeeze()
        
        return encoding

def collate_fn(batch):
    print('started collating')
    
    # Initialize dictionaries to store batched data
    batched_data = {
        'pixel_values': [],
        'labels': []
    }

    max_label_length = max(item['labels'].size(0) for item in batch)

    # Collect data from each item in the batch
    for item in batch:
        batched_data['pixel_values'].append(item['pixel_values'])
        
        # Pad labels to max length in batch
        labels = item['labels']
        padded_labels = torch.full((max_label_length,), -100, dtype=torch.long)
        padded_labels[:labels.size(0)] = labels
        batched_data['labels'].append(padded_labels)

    # Stack tensors
    batched_data['pixel_values'] = torch.stack(batched_data['pixel_values'])
    batched_data['labels'] = torch.stack(batched_data['labels'])

    return batched_data

# Create dataset
dataset = ImageDataset("/home/arjun/Desktop/Github/AmazonML-Hackathon/images/test", x_dict, y_dict, processor)

# Define trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    tokenizer=processor.tokenizer,
    data_collator=collate_fn,
)

13 13 13 13


In [6]:
tokenizer = AutoTokenizer.from_pretrained('naver-clova-ix/donut-base-finetuned-docvqa')
model.config.decoder_start_token_id = tokenizer.cls_token_id  # or another appropriate token ID

In [7]:
# Define trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    # tokenizer=processor.tokenizer,
    data_collator=collate_fn,
    tokenizer=tokenizer,
)

In [8]:
trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33marjun_g_ravi[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/39 [00:00<?, ?it/s]

started collating
started collating


ValueError: Make sure to set the pad_token_id attribute of the model's configuration.