In [1]:
import os
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from transformers import DonutProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments
from warnings import filterwarnings
from transformers import VisionEncoderDecoderModel, AutoTokenizer
filterwarnings('ignore')
from tqdm import tqdm
import re
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa").to(device)

In [3]:
def parse_csv_to_dicts(csv_file):
    df = pd.read_csv(csv_file, header=None, skiprows=1)
    
    tqdm.pandas()
    df[0] = df[0].progress_apply(lambda x: x.split('/')[-1])
    pattern = re.compile(r'^\d*\.?\d+\s\w+$')
    df = df[df[3].apply(lambda x: bool(pattern.match(x)))]

    x_dict = pd.Series(df[2].values, index=df[0]).to_dict()
    y_dict = pd.Series(df[3].values, index=df[0]).to_dict()

    return x_dict, y_dict



In [4]:
x_dict, y_dict = parse_csv_to_dicts('/home/arjun/Desktop/Github/AmazonML-Hackathon/dataset/train.csv')

100%|██████████| 263859/263859 [00:00<00:00, 1414574.20it/s]


In [5]:
training_args = Seq2SeqTrainingArguments(
    # predict_with_generate=True,
    evaluation_strategy="steps",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    logging_steps=100,
    save_steps=500,
    eval_steps=500,
    save_total_limit=3,
    num_train_epochs=3,
    output_dir="./donut-finetuned-docvqa",
    fp16=True,
)

In [6]:
entity_unit_map = {
    'width': {'centimetre', 'foot', 'inch', 'metre', 'millimetre', 'yard'},
    'depth': {'centimetre', 'foot', 'inch', 'metre', 'millimetre', 'yard'},
    'height': {'centimetre', 'foot', 'inch', 'metre', 'millimetre', 'yard'},
    'item_weight': {'gram',
        'kilogram',
        'microgram',
        'milligram',
        'ounce',
        'pound',
        'ton'},
    'maximum_weight_recommendation': {'gram',
        'kilogram',
        'microgram',
        'milligram',
        'ounce',
        'pound',
        'ton'},
    'voltage': {'kilovolt', 'millivolt', 'volt'},
    'wattage': {'kilowatt', 'watt'},
    'item_volume': {'centilitre',
        'cubic foot',
        'cubic inch',
        'cup',
        'decilitre',
        'fluid ounce',
        'gallon',
        'imperial gallon',
        'litre',
        'microlitre',
        'millilitre',
        'pint',
        'quart'}
}

In [7]:
print(entity_unit_map)

{'width': {'inch', 'foot', 'yard', 'metre', 'millimetre', 'centimetre'}, 'depth': {'inch', 'foot', 'yard', 'metre', 'millimetre', 'centimetre'}, 'height': {'inch', 'foot', 'yard', 'metre', 'millimetre', 'centimetre'}, 'item_weight': {'kilogram', 'gram', 'ounce', 'pound', 'microgram', 'milligram', 'ton'}, 'maximum_weight_recommendation': {'kilogram', 'gram', 'ounce', 'pound', 'microgram', 'milligram', 'ton'}, 'voltage': {'millivolt', 'kilovolt', 'volt'}, 'wattage': {'kilowatt', 'watt'}, 'item_volume': {'imperial gallon', 'cup', 'litre', 'microlitre', 'cubic foot', 'centilitre', 'millilitre', 'quart', 'decilitre', 'fluid ounce', 'pint', 'cubic inch', 'gallon'}}


In [8]:
import torch
from PIL import Image
from torch.utils.data import Dataset
import os

class ImageDataset(Dataset):
    def __init__(self, image_dir, x_dict, y_dict, processor):
        self.image_dir = image_dir
        self.processor = processor
        self.images = list(x_dict.keys())
        self.questions = list(x_dict.values())
        self.answers = list(y_dict.values())
        self.pre_finetune_text = 'Given the image, what is the'
        self.image_files = os.listdir(image_dir)
        # print(len(self.images), len(self.questions), len(self.answers), len(self.image_files))
        assert type(self.answers) == list, "Answer should be a list of strings"
        
    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(img_name).convert("RGB")

        if self.questions[idx] in entity_unit_map:
            potential_units = entity_unit_map[self.questions[idx]]
            question = f"{self.pre_finetune_text} {self.questions[idx]} of the item in {potential_units}: "
            # print(question)
        else:
        
            question = f"{self.pre_finetune_text} {self.questions[idx]}: "
        answer = self.answers[idx]
        
        encoding = self.processor(images=image, text=question, return_tensors="pt")
        encoding["labels"] = self.processor.tokenizer(answer, return_tensors="pt").input_ids
        
        for k,v in encoding.items():
            encoding[k] = v.squeeze()
        
        return encoding

def collate_fn(batch):
    batched_data = {
        'pixel_values': [],
        'labels': []
    }

    max_label_length = max(item['labels'].size(0) for item in batch)

    for item in batch:
        batched_data['pixel_values'].append(item['pixel_values'])
        labels = item['labels']
        padded_labels = torch.full((max_label_length,), -100, dtype=torch.long)
        padded_labels[:labels.size(0)] = labels
        batched_data['labels'].append(padded_labels)

    batched_data['pixel_values'] = torch.stack(batched_data['pixel_values'])
    batched_data['labels'] = torch.stack(batched_data['labels'])
    return batched_data

dataset = ImageDataset("/home/arjun/Desktop/Github/AmazonML-Hackathon/images/train", x_dict, y_dict, processor)

In [9]:
tokenizer = AutoTokenizer.from_pretrained('naver-clova-ix/donut-base-finetuned-docvqa')
model.config.decoder_start_token_id = tokenizer.cls_token_id 
model.config.pad_token_id = 0  

In [10]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=collate_fn,
    tokenizer=tokenizer,
)

In [11]:
trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33marjun_g_ravi[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 100/750639 [01:12<145:22:21,  1.43it/s]

{'loss': 3.1091, 'grad_norm': 18.409988403320312, 'learning_rate': 4.999360544815817e-05, 'epoch': 0.0}


  0%|          | 200/750639 [02:22<145:56:49,  1.43it/s]

{'loss': 2.0914, 'grad_norm': 16.87972640991211, 'learning_rate': 4.9986944456656266e-05, 'epoch': 0.0}


  0%|          | 300/750639 [03:32<146:34:02,  1.42it/s]

{'loss': 1.8067, 'grad_norm': 15.8847074508667, 'learning_rate': 4.9980350075069376e-05, 'epoch': 0.0}


  0%|          | 400/750639 [04:44<146:52:22,  1.42it/s]

{'loss': 1.7033, 'grad_norm': 15.84870433807373, 'learning_rate': 4.9973755693482485e-05, 'epoch': 0.0}


  0%|          | 500/750639 [05:54<146:26:29,  1.42it/s]

{'loss': 1.7444, 'grad_norm': 9.888818740844727, 'learning_rate': 4.996709470198058e-05, 'epoch': 0.0}


ValueError: Trainer: evaluation requires an eval_dataset.

In [None]:
model.save_pretrained("./donut-finetuned-docvqa")
tokenizer.save_pretrained("./donut-finetuned-docvqa")