In [1]:
import os
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from transformers import DonutProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments
from warnings import filterwarnings
from transformers import VisionEncoderDecoderModel, AutoTokenizer
filterwarnings('ignore')
from tqdm import tqdm
import re
import pandas as pd
from sklearn.model_selection import train_test_split

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa").to(device)

In [3]:
def parse_csv_to_dicts(csv_file):
    df = pd.read_csv(csv_file, header=None, skiprows=1)
    
    tqdm.pandas()
    df[0] = df[0].progress_apply(lambda x: x.split('/')[-1])
    pattern = re.compile(r'^\d*\.?\d+\s\w+$')
    df = df[df[3].apply(lambda x: bool(pattern.match(x)))]

    x_dict = pd.Series(df[2].values, index=df[0]).to_dict()
    y_dict = pd.Series(df[3].values, index=df[0]).to_dict()

    return x_dict, y_dict

In [4]:
x_dict, y_dict = parse_csv_to_dicts('/home/arjun/Desktop/Github/AmazonML-Hackathon/dataset/train.csv')

100%|██████████| 263859/263859 [00:00<00:00, 1462966.45it/s]


In [5]:
x_train, x_eval, y_train, y_eval = train_test_split(
    list(x_dict.items()), list(y_dict.items()), test_size=0.0001, random_state=42
)

x_train = dict(x_train)
x_eval = dict(x_eval)
y_train = dict(y_train)
y_eval = dict(y_eval)

In [6]:
training_args = Seq2SeqTrainingArguments(
    # predict_with_generate=True,
    evaluation_strategy="steps",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    logging_steps=100,
    save_steps=500,
    eval_steps=500,
    save_total_limit=3,
    num_train_epochs=3,
    output_dir="./donut-finetuned-docvqa",
    fp16=True,
)

In [7]:
entity_unit_map = {
    'width': {'centimetre', 'foot', 'inch', 'metre', 'millimetre', 'yard'},
    'depth': {'centimetre', 'foot', 'inch', 'metre', 'millimetre', 'yard'},
    'height': {'centimetre', 'foot', 'inch', 'metre', 'millimetre', 'yard'},
    'item_weight': {'gram',
        'kilogram',
        'microgram',
        'milligram',
        'ounce',
        'pound',
        'ton'},
    'maximum_weight_recommendation': {'gram',
        'kilogram',
        'microgram',
        'milligram',
        'ounce',
        'pound',
        'ton'},
    'voltage': {'kilovolt', 'millivolt', 'volt'},
    'wattage': {'kilowatt', 'watt'},
    'item_volume': {'centilitre',
        'cubic foot',
        'cubic inch',
        'cup',
        'decilitre',
        'fluid ounce',
        'gallon',
        'imperial gallon',
        'litre',
        'microlitre',
        'millilitre',
        'pint',
        'quart'}
}

In [8]:
print(entity_unit_map)

{'width': {'foot', 'yard', 'inch', 'metre', 'millimetre', 'centimetre'}, 'depth': {'foot', 'yard', 'inch', 'metre', 'millimetre', 'centimetre'}, 'height': {'foot', 'yard', 'inch', 'metre', 'millimetre', 'centimetre'}, 'item_weight': {'gram', 'ounce', 'milligram', 'kilogram', 'pound', 'microgram', 'ton'}, 'maximum_weight_recommendation': {'gram', 'ounce', 'milligram', 'kilogram', 'pound', 'microgram', 'ton'}, 'voltage': {'millivolt', 'kilovolt', 'volt'}, 'wattage': {'kilowatt', 'watt'}, 'item_volume': {'imperial gallon', 'cubic inch', 'cubic foot', 'gallon', 'fluid ounce', 'quart', 'litre', 'centilitre', 'cup', 'decilitre', 'millilitre', 'pint', 'microlitre'}}


In [9]:
import torch
from PIL import Image
from torch.utils.data import Dataset
import os

class ImageDataset(Dataset):
    def __init__(self, image_dir, x_dict, y_dict, processor):
        self.image_dir = image_dir
        self.processor = processor
        self.images = list(x_dict.keys())
        self.questions = list(x_dict.values())
        self.answers = list(y_dict.values())
        self.pre_finetune_text = 'Given the image, what is the'
        self.image_files = os.listdir(image_dir)
        # print(len(self.images), len(self.questions), len(self.answers), len(self.image_files))
        assert type(self.answers) == list, "Answer should be a list of strings"
        
    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(img_name).convert("RGB")

        if self.questions[idx] in entity_unit_map:
            potential_units = entity_unit_map[self.questions[idx]]
            question = f"{self.pre_finetune_text} {self.questions[idx]} of the item in {potential_units}: "
            # print(question)
        else:
        
            question = f"{self.pre_finetune_text} {self.questions[idx]}: "
        answer = self.answers[idx]
        
        encoding = self.processor(images=image, text=question, return_tensors="pt")
        encoding["labels"] = self.processor.tokenizer(answer, return_tensors="pt").input_ids
        
        for k,v in encoding.items():
            encoding[k] = v.squeeze()
        
        return encoding

def collate_fn(batch):
    batched_data = {
        'pixel_values': [],
        'labels': []
    }

    max_label_length = max(item['labels'].size(0) for item in batch)

    for item in batch:
        batched_data['pixel_values'].append(item['pixel_values'])
        labels = item['labels']
        padded_labels = torch.full((max_label_length,), -100, dtype=torch.long)
        padded_labels[:labels.size(0)] = labels
        batched_data['labels'].append(padded_labels)

    batched_data['pixel_values'] = torch.stack(batched_data['pixel_values'])
    batched_data['labels'] = torch.stack(batched_data['labels'])
    return batched_data

train_dataset = ImageDataset("/home/arjun/Desktop/Github/AmazonML-Hackathon/images/train", x_train, y_train, processor)
eval_dataset = ImageDataset("/home/arjun/Desktop/Github/AmazonML-Hackathon/images/train", x_eval, y_eval, processor)

In [10]:
tokenizer = AutoTokenizer.from_pretrained('naver-clova-ix/donut-base-finetuned-docvqa')
model.config.decoder_start_token_id = tokenizer.cls_token_id 
model.config.pad_token_id = 0  

In [11]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=collate_fn,
    tokenizer=tokenizer,
)

In [12]:
trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33marjun_g_ravi[0m ([33meurekabotics[0m). Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 100/750561 [01:14<149:18:02,  1.40it/s]

{'loss': 3.3601, 'grad_norm': 11.826188087463379, 'learning_rate': 4.999360478362186e-05, 'epoch': 0.0}


  0%|          | 200/750561 [02:27<154:10:51,  1.35it/s]

{'loss': 1.9963, 'grad_norm': 13.74999713897705, 'learning_rate': 4.998694309989461e-05, 'epoch': 0.0}


  0%|          | 300/750561 [03:41<159:07:33,  1.31it/s]

{'loss': 1.92, 'grad_norm': 6.132299423217773, 'learning_rate': 4.998034803300465e-05, 'epoch': 0.0}


  0%|          | 400/750561 [04:55<151:47:40,  1.37it/s]

{'loss': 1.8564, 'grad_norm': 9.67308235168457, 'learning_rate': 4.997375296611468e-05, 'epoch': 0.0}


  0%|          | 500/750561 [06:08<153:26:58,  1.36it/s]

{'loss': 1.8887, 'grad_norm': 17.32257080078125, 'learning_rate': 4.996709128238745e-05, 'epoch': 0.0}


                                                        
  0%|          | 500/750561 [06:19<153:26:58,  1.36it/s]

{'eval_loss': 1.6655714511871338, 'eval_runtime': 10.7758, 'eval_samples_per_second': 2.413, 'eval_steps_per_second': 2.413, 'epoch': 0.0}


  0%|          | 600/750561 [07:37<153:44:52,  1.35it/s] 

{'loss': 1.6202, 'grad_norm': 6.625514030456543, 'learning_rate': 4.99604295986602e-05, 'epoch': 0.0}


  0%|          | 700/750561 [08:52<155:35:20,  1.34it/s]

{'loss': 1.6686, 'grad_norm': 9.270157814025879, 'learning_rate': 4.9953767914932964e-05, 'epoch': 0.0}


  0%|          | 800/750561 [10:06<160:45:40,  1.30it/s]

{'loss': 3.445, 'grad_norm': 131.4827117919922, 'learning_rate': 4.994737269855482e-05, 'epoch': 0.0}


  0%|          | 900/750561 [11:20<145:15:02,  1.43it/s]

{'loss': 1.77, 'grad_norm': 7.6247334480285645, 'learning_rate': 4.994071101482758e-05, 'epoch': 0.0}


  0%|          | 1000/750561 [12:35<148:05:50,  1.41it/s]

{'loss': 1.6498, 'grad_norm': 8.36882209777832, 'learning_rate': 4.993404933110034e-05, 'epoch': 0.0}


                                                         
  0%|          | 1000/750561 [12:46<148:05:50,  1.41it/s]

{'eval_loss': 1.6843738555908203, 'eval_runtime': 11.0567, 'eval_samples_per_second': 2.352, 'eval_steps_per_second': 2.352, 'epoch': 0.0}


  0%|          | 1100/750561 [14:06<159:34:45,  1.30it/s] 

{'loss': 1.7375, 'grad_norm': 8.247044563293457, 'learning_rate': 4.99273876473731e-05, 'epoch': 0.0}


  0%|          | 1200/750561 [15:20<155:31:44,  1.34it/s]

{'loss': 1.5927, 'grad_norm': 11.792450904846191, 'learning_rate': 4.992072596364586e-05, 'epoch': 0.0}


  0%|          | 1300/750561 [16:34<144:20:27,  1.44it/s]

{'loss': 1.5697, 'grad_norm': 8.97677230834961, 'learning_rate': 4.9914064279918624e-05, 'epoch': 0.01}


  0%|          | 1400/750561 [17:46<148:27:51,  1.40it/s]

{'loss': 1.6713, 'grad_norm': 6.667993068695068, 'learning_rate': 4.9907402596191385e-05, 'epoch': 0.01}


  0%|          | 1500/750561 [18:56<140:53:56,  1.48it/s]

{'loss': 1.5305, 'grad_norm': 7.80579137802124, 'learning_rate': 4.9900740912464146e-05, 'epoch': 0.01}


                                                         
  0%|          | 1500/750561 [19:07<140:53:56,  1.48it/s]

{'eval_loss': 1.6055642366409302, 'eval_runtime': 10.4183, 'eval_samples_per_second': 2.496, 'eval_steps_per_second': 2.496, 'epoch': 0.01}


  0%|          | 1600/750561 [20:22<146:51:25,  1.42it/s] 

{'loss': 1.649, 'grad_norm': 9.100310325622559, 'learning_rate': 4.989407922873691e-05, 'epoch': 0.01}


  0%|          | 1700/750561 [21:37<163:41:38,  1.27it/s]

{'loss': 1.619, 'grad_norm': 9.109334945678711, 'learning_rate': 4.988741754500967e-05, 'epoch': 0.01}


  0%|          | 1800/750561 [22:51<153:32:38,  1.35it/s]

{'loss': 1.575, 'grad_norm': 9.66927433013916, 'learning_rate': 4.988075586128243e-05, 'epoch': 0.01}


  0%|          | 1900/750561 [24:05<148:14:23,  1.40it/s]

{'loss': 1.7305, 'grad_norm': 84.30386352539062, 'learning_rate': 4.987409417755519e-05, 'epoch': 0.01}


  0%|          | 2000/750561 [25:17<163:23:26,  1.27it/s]

{'loss': 1.5882, 'grad_norm': 6.087809085845947, 'learning_rate': 4.986743249382795e-05, 'epoch': 0.01}


                                                         
  0%|          | 2000/750561 [25:29<163:23:26,  1.27it/s]

{'eval_loss': 1.5176868438720703, 'eval_runtime': 11.9988, 'eval_samples_per_second': 2.167, 'eval_steps_per_second': 2.167, 'epoch': 0.01}


  0%|          | 2100/750561 [26:45<158:17:12,  1.31it/s] 

{'loss': 1.5136, 'grad_norm': 7.309168338775635, 'learning_rate': 4.986077081010072e-05, 'epoch': 0.01}


  0%|          | 2200/750561 [27:59<167:14:30,  1.24it/s]

{'loss': 1.4749, 'grad_norm': 8.09124755859375, 'learning_rate': 4.985410912637348e-05, 'epoch': 0.01}


  0%|          | 2300/750561 [29:11<153:07:59,  1.36it/s]

{'loss': 1.5207, 'grad_norm': 6.988049030303955, 'learning_rate': 4.984744744264624e-05, 'epoch': 0.01}


  0%|          | 2400/750561 [30:22<150:16:53,  1.38it/s]

{'loss': 1.5557, 'grad_norm': 5.203564643859863, 'learning_rate': 4.9840785758918994e-05, 'epoch': 0.01}


  0%|          | 2500/750561 [31:33<174:11:06,  1.19it/s]

{'loss': 1.5396, 'grad_norm': 6.701388359069824, 'learning_rate': 4.9834124075191755e-05, 'epoch': 0.01}


                                                         
  0%|          | 2500/750561 [31:45<174:11:06,  1.19it/s]

{'eval_loss': 1.4935129880905151, 'eval_runtime': 12.2533, 'eval_samples_per_second': 2.122, 'eval_steps_per_second': 2.122, 'epoch': 0.01}


  0%|          | 2600/750561 [33:01<151:14:21,  1.37it/s] 

{'loss': 1.6067, 'grad_norm': 8.155777931213379, 'learning_rate': 4.9827462391464516e-05, 'epoch': 0.01}


  0%|          | 2700/750561 [34:16<158:45:39,  1.31it/s]

{'loss': 1.5721, 'grad_norm': 13.794951438903809, 'learning_rate': 4.9820800707737284e-05, 'epoch': 0.01}


  0%|          | 2800/750561 [35:29<142:56:32,  1.45it/s]

{'loss': 1.5099, 'grad_norm': 7.515554904937744, 'learning_rate': 4.9814139024010045e-05, 'epoch': 0.01}


  0%|          | 2900/750561 [36:44<160:23:35,  1.29it/s]

{'loss': 1.4786, 'grad_norm': 5.730985164642334, 'learning_rate': 4.9807477340282806e-05, 'epoch': 0.01}


  0%|          | 3000/750561 [37:57<160:54:56,  1.29it/s]

{'loss': 1.5178, 'grad_norm': 8.818404197692871, 'learning_rate': 4.9800815656555567e-05, 'epoch': 0.01}


                                                         
  0%|          | 3000/750561 [38:07<160:54:56,  1.29it/s]

{'eval_loss': 1.5559213161468506, 'eval_runtime': 10.0426, 'eval_samples_per_second': 2.589, 'eval_steps_per_second': 2.589, 'epoch': 0.01}


  0%|          | 3100/750561 [39:18<142:02:19,  1.46it/s]

{'loss': 1.6272, 'grad_norm': 8.947408676147461, 'learning_rate': 4.979415397282833e-05, 'epoch': 0.01}


  0%|          | 3200/750561 [40:24<141:48:47,  1.46it/s]

{'loss': 1.6136, 'grad_norm': 8.180472373962402, 'learning_rate': 4.978749228910109e-05, 'epoch': 0.01}


  0%|          | 3300/750561 [41:32<132:47:19,  1.56it/s]

{'loss': 1.5483, 'grad_norm': 9.026142120361328, 'learning_rate': 4.978083060537385e-05, 'epoch': 0.01}


  0%|          | 3400/750561 [42:39<144:26:21,  1.44it/s]

{'loss': 1.5186, 'grad_norm': 7.984751224517822, 'learning_rate': 4.977416892164661e-05, 'epoch': 0.01}


  0%|          | 3500/750561 [43:47<134:43:52,  1.54it/s]

{'loss': 1.5433, 'grad_norm': 8.25605583190918, 'learning_rate': 4.976750723791937e-05, 'epoch': 0.01}


                                                         
  0%|          | 3500/750561 [43:57<134:43:52,  1.54it/s]

{'eval_loss': 1.5462640523910522, 'eval_runtime': 10.0551, 'eval_samples_per_second': 2.586, 'eval_steps_per_second': 2.586, 'epoch': 0.01}


  0%|          | 3600/750561 [45:09<150:57:31,  1.37it/s]

{'loss': 1.5541, 'grad_norm': 14.87681770324707, 'learning_rate': 4.976084555419213e-05, 'epoch': 0.01}


  0%|          | 3700/750561 [46:16<148:18:29,  1.40it/s]

{'loss': 1.4725, 'grad_norm': 5.854771137237549, 'learning_rate': 4.975418387046489e-05, 'epoch': 0.01}


  1%|          | 3800/750561 [47:25<148:05:08,  1.40it/s]

{'loss': 1.4853, 'grad_norm': 9.463830947875977, 'learning_rate': 4.9747522186737654e-05, 'epoch': 0.02}


  1%|          | 3900/750561 [48:31<134:39:41,  1.54it/s]

{'loss': 1.5216, 'grad_norm': 8.069664001464844, 'learning_rate': 4.9740860503010415e-05, 'epoch': 0.02}


  1%|          | 4000/750561 [49:38<137:03:55,  1.51it/s]

{'loss': 1.6455, 'grad_norm': 8.018741607666016, 'learning_rate': 4.9734198819283176e-05, 'epoch': 0.02}


                                                         
  1%|          | 4000/750561 [49:48<137:03:55,  1.51it/s]

{'eval_loss': 1.4911293983459473, 'eval_runtime': 10.0216, 'eval_samples_per_second': 2.594, 'eval_steps_per_second': 2.594, 'epoch': 0.02}


  1%|          | 4100/750561 [50:58<141:51:54,  1.46it/s]

{'loss': 1.5053, 'grad_norm': 7.489960193634033, 'learning_rate': 4.972753713555594e-05, 'epoch': 0.02}


  1%|          | 4200/750561 [52:06<144:21:52,  1.44it/s]

{'loss': 1.4422, 'grad_norm': 4.454456806182861, 'learning_rate': 4.9720875451828704e-05, 'epoch': 0.02}


  1%|          | 4300/750561 [53:17<147:20:56,  1.41it/s]

{'loss': 1.5236, 'grad_norm': 5.500746726989746, 'learning_rate': 4.9714213768101465e-05, 'epoch': 0.02}


  1%|          | 4400/750561 [54:26<138:51:48,  1.49it/s]

{'loss': 1.5925, 'grad_norm': 6.354903697967529, 'learning_rate': 4.9707552084374226e-05, 'epoch': 0.02}


  1%|          | 4500/750561 [55:36<147:48:43,  1.40it/s]

{'loss': 1.4992, 'grad_norm': 12.739602088928223, 'learning_rate': 4.970089040064699e-05, 'epoch': 0.02}


                                                         
  1%|          | 4500/750561 [55:46<147:48:43,  1.40it/s]

{'eval_loss': 1.5079067945480347, 'eval_runtime': 10.0055, 'eval_samples_per_second': 2.599, 'eval_steps_per_second': 2.599, 'epoch': 0.02}


  1%|          | 4600/750561 [57:01<148:27:39,  1.40it/s]

{'loss': 1.4891, 'grad_norm': 8.998235702514648, 'learning_rate': 4.969422871691974e-05, 'epoch': 0.02}


  1%|          | 4700/750561 [58:13<152:43:26,  1.36it/s]

{'loss': 1.5837, 'grad_norm': 6.95258092880249, 'learning_rate': 4.968756703319251e-05, 'epoch': 0.02}


  1%|          | 4800/750561 [59:21<135:36:44,  1.53it/s]

{'loss': 1.5311, 'grad_norm': 6.912909507751465, 'learning_rate': 4.968090534946527e-05, 'epoch': 0.02}


  1%|          | 4900/750561 [1:00:28<133:26:47,  1.55it/s]

{'loss': 1.4372, 'grad_norm': 6.582512378692627, 'learning_rate': 4.967424366573803e-05, 'epoch': 0.02}


  1%|          | 5000/750561 [1:01:34<137:25:06,  1.51it/s]

{'loss': 1.4711, 'grad_norm': 13.263248443603516, 'learning_rate': 4.966758198201079e-05, 'epoch': 0.02}


                                                           
  1%|          | 5000/750561 [1:01:44<137:25:06,  1.51it/s]

{'eval_loss': 1.420148253440857, 'eval_runtime': 10.0126, 'eval_samples_per_second': 2.597, 'eval_steps_per_second': 2.597, 'epoch': 0.02}


  1%|          | 5100/750561 [1:02:52<133:25:08,  1.55it/s]

{'loss': 1.4526, 'grad_norm': 4.4212446212768555, 'learning_rate': 4.966092029828355e-05, 'epoch': 0.02}


  1%|          | 5200/750561 [1:03:57<135:26:57,  1.53it/s]

{'loss': 1.5478, 'grad_norm': 14.285238265991211, 'learning_rate': 4.9654258614556313e-05, 'epoch': 0.02}


  1%|          | 5300/750561 [1:05:03<137:08:32,  1.51it/s]

{'loss': 1.7197, 'grad_norm': 5.066572666168213, 'learning_rate': 4.9647596930829074e-05, 'epoch': 0.02}


  1%|          | 5400/750561 [1:06:13<149:50:16,  1.38it/s]

{'loss': 1.6542, 'grad_norm': 10.727250099182129, 'learning_rate': 4.9640935247101835e-05, 'epoch': 0.02}


  1%|          | 5500/750561 [1:07:23<154:58:55,  1.34it/s]

{'loss': 1.4911, 'grad_norm': 4.645290374755859, 'learning_rate': 4.96342735633746e-05, 'epoch': 0.02}


                                                           
  1%|          | 5500/750561 [1:07:35<154:58:55,  1.34it/s]

{'eval_loss': 1.395321011543274, 'eval_runtime': 12.2684, 'eval_samples_per_second': 2.119, 'eval_steps_per_second': 2.119, 'epoch': 0.02}


  1%|          | 5600/750561 [1:08:50<148:59:07,  1.39it/s] 

{'loss': 1.5605, 'grad_norm': 3.880685567855835, 'learning_rate': 4.9627611879647364e-05, 'epoch': 0.02}


  1%|          | 5617/750561 [1:09:03<151:49:03,  1.36it/s]

OSError: image file is truncated (1 bytes not processed)

In [None]:
model.save_pretrained("./donut-finetuned-docvqa")
tokenizer.save_pretrained("./donut-finetuned-docvqa")