In [17]:
from datasets import load_dataset

dataset = load_dataset(r"D:\sparrow-main\sparrow-data\docs\models\donut\data")
dataset
example = dataset['train'][0]
image = example['image']
# let's make the image a bit smaller when visualizing
width, height = image.size
#display(image.resize((int(width*0.3), int(height*0.3))))
# let's load the corresponding JSON dictionary (as string representation)
ground_truth = example['ground_truth']
print(ground_truth)

Resolving data files:   0%|          | 0/24 [00:00<?, ?it/s]

{"gt_parse": {"header": {"invoice_no": "081553517", "invoice_date": "2/2/2023", "salesOrderNumber": "71555", "poNumber": "21129-0101.1"}, "items": [{"unitPrice": "0.63", "description": "ARMORED CABLE BX 12-2 COILS", "quantity": "10,000"}, {"unitPrice": "1.74", "description": "ARMORED CABLE MC 10-2", "quantity": "500"}, {"unitPrice": "2.66488", "quantity": "10", "description": "ALUM.LUG DOUI JBLE 1/0 10/100-PK"}, {"unitPrice": "0.29", "description": "10STRANDED BLACK", "quantity": "1,000"}, {"unitPrice": "0.29", "description": "10STRANDED GREEN", "quantity": "500"}, {"unitPrice": "0.29", "description": "10STRANDED RED", "quantity": "1,000"}, {"unitPrice": "0.29", "description": "10STRANDED WHITE", "quantity": "1,000"}, {"unitPrice": "0.94063", "description": "2\" KNOCKOUT PLUG 10/100PK", "quantity": "6"}, {"unitPrice": "0.125", "quantity": "10", "description": "1\"X3/4\" RED WASH 100/1000-PK"}, {"unitPrice": "0.28288", "quantity": "10", "description": "1-1/2X3/4REDWASH50/500-PK"}, {"uni

In [18]:
from ast import literal_eval

literal_eval(ground_truth)['gt_parse']
from transformers import VisionEncoderDecoderConfig

image_size = [1275, 1650]
max_length = 768

# update image_size of the encoder
# during pre-training, a larger image size was used
config = VisionEncoderDecoderConfig.from_pretrained(r"C:\Users\Deepesh Alwani\Desktop\invoices-donut-model-v1")
config.encoder.image_size = image_size # (height, width)
# update max_length of the decoder (for generation)
config.decoder.max_length = max_length

In [19]:
# TODO we should actually update max_position_embeddings and interpolate the pre-trained ones:
# https://github.com/clovaai/donut/blob/0acc65a85d140852b8d9928565f0f6b2d98dc088/donut/model.py#L602
from transformers import DonutProcessor, VisionEncoderDecoderModel

processor = DonutProcessor.from_pretrained(r"C:\Users\Deepesh Alwani\Desktop\invoices-donut-model-v1")
model = VisionEncoderDecoderModel.from_pretrained(r"C:\Users\Deepesh Alwani\Desktop\invoices-donut-model-v1", config=config)
import json
import random
from typing import Any, List, Tuple

import torch
from torch.utils.data import Dataset

added_tokens = []

In [20]:
class DonutDataset(Dataset):
    """
    PyTorch Dataset for Donut. This class takes a HuggingFace Dataset as input.
    
    Each row, consists of image path(png/jpg/jpeg) and gt data (json/jsonl/txt),
    and it will be converted into pixel_values (vectorized image) and labels (input_ids of the tokenized string).
    
    Args:
        dataset_name_or_path: name of dataset (available at huggingface.co/datasets) or the path containing image files and metadata.jsonl
        max_length: the max number of tokens for the target sequences
        split: whether to load "train", "validation" or "test" split
        ignore_id: ignore_index for torch.nn.CrossEntropyLoss
        task_start_token: the special token to be fed to the decoder to conduct the target task
        prompt_end_token: the special token at the end of the sequences
        sort_json_key: whether or not to sort the JSON keys
    """

    def __init__(
        self,
        dataset_name_or_path: str,
        max_length: int,
        split: str = "train",
        ignore_id: int = -100,
        task_start_token: str = "<s>",
        prompt_end_token: str = None,
        sort_json_key: bool = True,
    ):
        super().__init__()

        self.max_length = max_length
        self.split = split
        self.ignore_id = ignore_id
        self.task_start_token = task_start_token
        self.prompt_end_token = prompt_end_token if prompt_end_token else task_start_token
        self.sort_json_key = sort_json_key

        self.dataset = load_dataset(dataset_name_or_path, split=self.split)
        self.dataset_length = len(self.dataset)

        self.gt_token_sequences = []
        for sample in self.dataset:
            ground_truth = json.loads(sample["ground_truth"])
            if "gt_parses" in ground_truth:  # when multiple ground truths are available, e.g., docvqa
                assert isinstance(ground_truth["gt_parses"], list)
                gt_jsons = ground_truth["gt_parses"]
            else:
                assert "gt_parse" in ground_truth and isinstance(ground_truth["gt_parse"], dict)
                gt_jsons = [ground_truth["gt_parse"]]

            self.gt_token_sequences.append(
                [
                    self.json2token(
                        gt_json,
                        update_special_tokens_for_json_key=self.split == "train",
                        sort_json_key=self.sort_json_key,
                    )
                    + processor.tokenizer.eos_token
                    for gt_json in gt_jsons  # load json from list of json
                ]
            )

        self.add_tokens([self.task_start_token, self.prompt_end_token])
        self.prompt_end_token_id = processor.tokenizer.convert_tokens_to_ids(self.prompt_end_token)

    def json2token(self, obj: Any, update_special_tokens_for_json_key: bool = True, sort_json_key: bool = True):
        """
        Convert an ordered JSON object into a token sequence
        """
        if type(obj) == dict:
            if len(obj) == 1 and "text_sequence" in obj:
                return obj["text_sequence"]
            else:
                output = ""
                if sort_json_key:
                    keys = sorted(obj.keys(), reverse=True)
                else:
                    keys = obj.keys()
                for k in keys:
                    if update_special_tokens_for_json_key:
                        self.add_tokens([fr"<s_{k}>", fr"</s_{k}>"])
                    output += (
                        fr"<s_{k}>"
                        + self.json2token(obj[k], update_special_tokens_for_json_key, sort_json_key)
                        + fr"</s_{k}>"
                    )
                return output
        elif type(obj) == list:
            return r"<sep/>".join(
                [self.json2token(item, update_special_tokens_for_json_key, sort_json_key) for item in obj]
            )
        else:
            obj = str(obj)
            if f"<{obj}/>" in added_tokens:
                obj = f"<{obj}/>"  # for categorical special tokens
            return obj
    
    def add_tokens(self, list_of_tokens: List[str]):
        """
        Add special tokens to tokenizer and resize the token embeddings of the decoder
        """
        newly_added_num = processor.tokenizer.add_tokens(list_of_tokens)
        if newly_added_num > 0:
            model.decoder.resize_token_embeddings(len(processor.tokenizer))
            added_tokens.extend(list_of_tokens)
    
    def __len__(self) -> int:
        return self.dataset_length

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Load image from image_path of given dataset_path and convert into input_tensor and labels
        Convert gt data into input_ids (tokenized string)
        Returns:
            input_tensor : preprocessed image
            input_ids : tokenized gt_data
            labels : masked labels (model doesn't need to predict prompt and pad token)
        """
        sample = self.dataset[idx]

        # inputs
        pixel_values = processor(sample["image"], random_padding=self.split == "train", return_tensors="pt").pixel_values
        pixel_values = pixel_values.squeeze()

        # targets
        target_sequence = random.choice(self.gt_token_sequences[idx])  # can be more than one, e.g., DocVQA Task 1
        input_ids = processor.tokenizer(
            target_sequence,
            add_special_tokens=False,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )["input_ids"].squeeze(0)

        labels = input_ids.clone()
        labels[labels == processor.tokenizer.pad_token_id] = self.ignore_id  # model doesn't need to predict pad token
        # labels[: torch.nonzero(labels == self.prompt_end_token_id).sum() + 1] = self.ignore_id  # model doesn't need to predict prompt (for VQA)
        return pixel_values, labels, target_sequence

In [14]:
# we update some settings which differ from pretraining; namely the size of the images + no rotation required
# source: https://github.com/clovaai/donut/blob/master/config/train_cord.yaml
processor.image_processor.size = image_size[::-1] # should be (width, height)
processor.image_processor.do_align_long_axis = False

train_dataset = DonutDataset(r"D:\sparrow-main\sparrow-data\docs\models\donut\data\img", max_length=max_length,
                             split="train", task_start_token="<s_cord-v2>", prompt_end_token="<s_cord-v2>",
                             sort_json_key=False, # cord dataset is preprocessed, so no need for this
                             )

val_dataset = DonutDataset(r"D:\sparrow-main\sparrow-data\docs\models\donut\data\img", max_length=max_length,
                             split="validation", task_start_token="<s_cord-v2>", prompt_end_token="<s_cord-v2>",
                             sort_json_key=False, # cord dataset is preprocessed, so no need for this
                             )
len(added_tokens)
print(added_tokens)
# the vocab size attribute stays constants (might be a bit unintuitive - but doesn't include special tokens)
print("Original number of tokens:", processor.tokenizer.vocab_size)
print("Number of tokens after adding special tokens:", len(processor.tokenizer))
processor.decode([57521])
pixel_values, labels, target_sequence = train_dataset[0]
print(pixel_values.shape)
# let's print the labels (the first 30 token ID's)
for id in labels.tolist()[:30]:
  if id != -100:
    print(processor.decode([id]))
  else:
    print(id)
# let's check the corresponding target sequence, as a string
print(target_sequence)
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(['<s_cord-v2>'])[0]
# sanity check
print("Pad token ID:", processor.decode([model.config.pad_token_id]))
print("Decoder start token ID:", processor.decode([model.config.decoder_start_token_id]))

Resolving data files:   0%|          | 0/24 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/24 [00:00<?, ?it/s]

['<s_salesOrderNumber>', '</s_salesOrderNumber>', '<s_poNumber>', '</s_poNumber>', '<s_unitPrice>', '</s_unitPrice>', '<s_description>', '</s_description>', '<s_quantity>', '</s_quantity>', '<s_S.O.NO>', '</s_S.O.NO>', '<s_P.O. NO>', '</s_P.O. NO>', '<s_price_one_item>', '</s_price_one_item>']
Original number of tokens: 57522
Number of tokens after adding special tokens: 57582
torch.Size([3, 1275, 1650])
<s_header>
<s_invoice_no>
08
15
5
35
17
</s_invoice_no>
<s_invoice_date>
2
/2
/20
23
</s_invoice_date>
<s_salesOrderNumber>
71
555
</s_salesOrderNumber>
<s_poNumber>
2
11
29
-
010
1.1
</s_poNumber>
</s_header>
<s_items>
<s_unitPrice>
0
<s_header><s_invoice_no>081553517</s_invoice_no><s_invoice_date>2/2/2023</s_invoice_date><s_salesOrderNumber>71555</s_salesOrderNumber><s_poNumber>21129-0101.1</s_poNumber></s_header><s_items><s_unitPrice>0.63</s_unitPrice><s_description>ARMORED CABLE BX 12-2 COILS</s_description><s_quantity>10,000</s_quantity><sep/><s_unitPrice>1.74</s_unitPrice><s_desc

In [15]:
from torch.utils.data import DataLoader

# feel free to increase the batch size if you have a lot of memory
# I'm fine-tuning on Colab and given the large image size, batch size > 1 is not feasible
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=0)
print(train_dataloader)
batch = next(iter(train_dataloader))
pixel_values, labels, target_sequences = batch
print(pixel_values.shape)
for id in labels.squeeze().tolist()[:30]:
  if id != -100:
    print(processor.decode([id]))
  else:
    print(id)
print(len(train_dataset))
print(len(val_dataset))
# let's check the first validation batch
batch = next(iter(val_dataloader))
pixel_values, labels, target_sequences = batch
print(pixel_values.shape)
print(target_sequences[0])
from pathlib import Path
import re
from nltk import edit_distance
import numpy as np
import math

from torch.nn.utils.rnn import pad_sequence
from torch.optim.lr_scheduler import LambdaLR

import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only


class DonutModelPLModule(pl.LightningModule):
    def __init__(self, config, processor, model):
        super().__init__()
        self.config = config
        self.processor = processor
        self.model = model

    def training_step(self, batch, batch_idx):
        pixel_values, labels, _ = batch
        
        outputs = self.model(pixel_values, labels=labels)
        loss = outputs.loss
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx, dataset_idx=0):
        pixel_values, labels, answers = batch
        batch_size = pixel_values.shape[0]
        # we feed the prompt to the model
        decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device)
        
        outputs = self.model.generate(pixel_values,
                                   decoder_input_ids=decoder_input_ids,
                                   max_length=max_length,
                                   early_stopping=True,
                                   pad_token_id=self.processor.tokenizer.pad_token_id,
                                   eos_token_id=self.processor.tokenizer.eos_token_id,
                                   use_cache=True,
                                   num_beams=1,
                                   bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
                                   return_dict_in_generate=True,)
    
        predictions = []
        for seq in self.processor.tokenizer.batch_decode(outputs.sequences):
            seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
            seq = re.sub(r"<.*?>", "", seq, count=1).strip()  # remove first task start token
            predictions.append(seq)

        scores = []
        for pred, answer in zip(predictions, answers):
            pred = re.sub(r"(?:(?<=>) | (?=</s_))", "", pred)
            # NOT NEEDED ANYMORE
            # answer = re.sub(r"<.*?>", "", answer, count=1)
            answer = answer.replace(self.processor.tokenizer.eos_token, "")
            scores.append(edit_distance(pred, answer) / max(len(pred), len(answer)))

            if self.config.get("verbose", False) and len(scores) == 1:
                print(f"Prediction: {pred}")
                print(f"    Answer: {answer}")
                print(f" Normed ED: {scores[0]}")

        self.log("val_edit_distance", np.mean(scores))
        
        return scores

    def configure_optimizers(self):
        # you could also add a learning rate scheduler if you want
        optimizer = torch.optim.Adam(self.parameters(), lr=self.config.get("lr"))
    
        return optimizer

    def train_dataloader(self):
        return train_dataloader

    def val_dataloader(self):
        return val_dataloader
config = {"max_epochs":12,
          "val_check_interval":0.2, # how many times we want to validate during an epoch
          "check_val_every_n_epoch":1,
          "gradient_clip_val":1.0,
          "num_training_samples_per_epoch": 40,
          "lr":3e-5,
          "train_batch_sizes": [1],
          "val_batch_sizes": [1],
          # "seed":2022,
          "num_nodes": 1,
          "warmup_steps": 12, # 800/8*30/10, 10%
          "result_path": r"E:\result",
          "verbose": True,
          }


<torch.utils.data.dataloader.DataLoader object at 0x0000028D3DF0E090>
torch.Size([1, 3, 1275, 1650])
<s_header>
<s_invoice_no>
08
15
59
25
2
</s_invoice_no>
<s_invoice_date>
10
/18
/20
23
</s_invoice_date>
<s_S.O.NO>
77
30
7
</s_S.O.NO>
<s_P.O. NO>
23
133
-01
65
.
3
</s_P.O. NO>
</s_header>
<s_items>
<s_item_qty>
23
5
torch.Size([1, 3, 1275, 1650])
<s_header><s_invoice_no>081558399</s_invoice_no><s_invoice_date>9/11/2023</s_invoice_date><s_salesOrderNumber>76352</s_salesOrderNumber><s_poNumber>22131-0310.1</s_poNumber></s_header><s_items><s_quantity>100</s_quantity><s_description>4SQBX2-1/8DP 1/2-3/4TKO BMP</s_description><sep/><s_quantity>20</s_quantity><s_description>7/8"HOLESAW</s_description><sep/><s_quantity>20</s_quantity><s_description>6SOL-16SOL SPLITBOLT 100/1000PK</s_description><sep/><s_quantity>4</s_quantity><s_description>3"LOCKNUTSTEEL50-PK</s_description><sep/><s_quantity>2</s_quantity><s_description>3"PLASTIC BUSHING25-PK</s_description><sep/><s_quantity>2</s_quantity><

In [16]:
model_module = DonutModelPLModule(config, processor, model)
import os
import torch
torch.cuda.empty_cache()
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import Callback, EarlyStopping
import pytorch_lightning as pl

# Replace WandbLogger with TensorBoardLogger
tensorboard_logger = TensorBoardLogger("logs", name="Donut-demo-run-cord")

class SaveFinalModelCallback(Callback):
    def on_train_end(self, trainer, pl_module):
        print("Training done, saving the final model")
        # Save the final model to a local directory
        local_model_dir = (r"E:\model")
        trainer.save_checkpoint(local_model_dir)

early_stop_callback = EarlyStopping(monitor="val_edit_distance", patience=3, verbose=False, mode="min")

trainer = pl.Trainer(
    accelerator="gpu",
    devices=1,
    max_epochs=config.get("max_epochs"),
    val_check_interval=config.get("val_check_interval"),
    check_val_every_n_epoch=config.get("check_val_every_n_epoch"),
    gradient_clip_val=config.get("gradient_clip_val"),
    precision=16,  # we'll use mixed precision
    num_sanity_val_steps=0,
    logger=tensorboard_logger,  # Use TensorBoardLogger instead of WandbLogger
    callbacks=[SaveFinalModelCallback()],
)

trainer.fit(model_module)


Using 16bit None Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: logs\Donut-demo-run-cord
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                      | Params
----------------------------------------------------
0 | model | VisionEncoderDecoderModel | 201 M 
----------------------------------------------------
201 M     Trainable params
0         Non-trainable params
201 M     Total params
403.821   Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
