In [1]:
from transformers import VisionEncoderDecoderConfig

image_size = [1280, 960]
max_length = 768

# update image_size of the encoder
# during pre-training, a larger image size was used
config = VisionEncoderDecoderConfig.from_pretrained("naver-clova-ix/donut-base")
config.encoder.image_size = image_size # (height, width)
# update max_length of the decoder (for generation)
config.decoder.max_length = max_length
# TODO we should actually update max_position_embeddings and interpolate the pre-trained ones:
# https://github.com/clovaai/donut/blob/0acc65a85d140852b8d9928565f0f6b2d98dc088/donut/model.py#L602

In [2]:
from transformers import DonutProcessor, VisionEncoderDecoderModel

processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base", config=config)



In [3]:
import json
import random
from typing import Any, List, Tuple

from pathlib import Path
from PIL import Image

import torch
from torch.utils.data import Dataset

added_tokens = []

class DocDataset(Dataset):
    def __init__(
        self,
        dataset_path: str,
        max_length: int,
        split: str = "train",
        ignore_id: int = -100,
        task_start_token: str = "",
        prompt_end_token: str = None,
        sort_json_key: bool = True,
    ):
        super().__init__()

        self.max_length = max_length
        self.split = split
        self.ignore_id = ignore_id
        self.task_start_token = task_start_token
        self.prompt_end_token = prompt_end_token if prompt_end_token else task_start_token
        self.sort_json_key = sort_json_key

        self.img_root = f"{dataset_path}/images/"
        self.parse_root = f"{dataset_path}/parse_seq_data/"
        
        self.image_files = [f for f in Path(self.img_root).iterdir() if f.is_file() and not f.name.startswith('.')]
        self.parse_files = [f for f in Path(self.parse_root).iterdir() if f.is_file() and not f.name.startswith('.')]

        self.gt_token_sequences = []

        for sample in self.parse_files:
            if sample.is_file() and sample.suffix == ".json":
                with open(sample, 'r') as file:
                    gt_json = json.load(file)
            
            self.gt_token_sequences.append([
                self.json2token(
                    gt_json,
                    update_special_tokens_for_json_key=self.split == "train",
                    sort_json_key=self.sort_json_key
                )
                + processor.tokenizer.eos_token
            ])

        self.add_tokens([self.task_start_token, self.prompt_end_token])
        self.prompt_end_token_id = processor.tokenizer.convert_tokens_to_ids(self.prompt_end_token)

    def json2token(self, obj: Any, update_special_tokens_for_json_key: bool = True, sort_json_key: bool = True):
        if type(obj) == dict:
            if len(obj) == 1 and "text_sequence" in obj:
                return obj["text_sequence"]
            else:
                output = ""
                if sort_json_key:
                    keys = sorted(obj.keys(), reverse=True)
                else:
                    keys = obj.keys()
                for k in keys:
                    if update_special_tokens_for_json_key:
                        self.add_tokens([fr"<{k}>", fr"<{k}/>"])
                    output += (
                        fr"<{k}>"
                        + self.json2token(obj[k], update_special_tokens_for_json_key, sort_json_key)
                        + fr"<{k}/>"
                    )
                return output
        elif type(obj) == list:
            return r"".join(
                [self.json2token(item, update_special_tokens_for_json_key, sort_json_key) for item in obj]
            )
        else:
            obj = str(obj)
            if f"<{obj}/>" in added_tokens:
                obj = f"<{obj}/>"  # for categorical special tokens
            return obj
        
    def add_tokens(self, list_of_tokens: List[str]):
        """
        Add special tokens to tokenizer and resize the token embeddings of the decoder
        """
        newly_added_num = processor.tokenizer.add_tokens(list_of_tokens)
        if newly_added_num > 0:
            model.decoder.resize_token_embeddings(len(processor.tokenizer))
            added_tokens.extend(list_of_tokens)
            
    def __len__(self) -> int:
        return len(self.parse_files)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        image_path = self.image_files[idx]
        image = Image.open(image_path)

        pixel_values = processor(image, random_padding=self.split == "train", return_tensors="pt").pixel_values
        pixel_values = pixel_values.squeeze()

        target_sequence = random.choice(self.gt_token_sequences[idx])

        input_ids = processor.tokenizer(
            target_sequence,
            add_special_tokens=False,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )["input_ids"].squeeze(0)

        labels = input_ids.clone()

        labels[labels == processor.tokenizer.pad_token_id] = self.ignore_id

        return pixel_values, labels, target_sequence

In [4]:
processor.image_processor.size = image_size[::-1] # should be (width, height)
processor.image_processor.do_align_long_axis = False

train_dataset = DocDataset(dataset_path="dataset/train", max_length=max_length, split="train", task_start_token="<s_parse>", prompt_end_token="",
                             sort_json_key=False)
val_dataset = DocDataset(dataset_path="dataset/val", max_length=max_length, split="validation", task_start_token="<s_parse>", prompt_end_token="",
                             sort_json_key=False)

In [5]:
len(added_tokens)

8

In [6]:
print(added_tokens)

['<license_num>', '<license_num/>', '<name>', '<name/>', '<dob>', '<dob/>', '<s_parse>', '<s_parse>']


In [7]:
print("Original number of tokens:", processor.tokenizer.vocab_size)
print("Number of tokens after adding special tokens:", len(processor.tokenizer))

Original number of tokens: 57522
Number of tokens after adding special tokens: 57532


In [8]:
processor.decode([57528])

'<name/>'

In [9]:
pixel_values, labels, target_sequence = train_dataset[0]
print(pixel_values.shape)
print(labels.shape)
print(target_sequence)

torch.Size([3, 1280, 960])
torch.Size([768])
<license_num>01-08-00244454<license_num/><name>ANIL TAMANG<name/><dob>16-07-1980<dob/></s>


In [10]:
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(['<s_parse>'])[0]

In [11]:
model.config.decoder_start_token_id

57531

In [12]:
print("Pad token ID:", processor.decode([model.config.pad_token_id]))
print("Decoder start token ID:", processor.decode([model.config.decoder_start_token_id]))

Pad token ID: <pad>
Decoder start token ID: <s_parse>


In [13]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [14]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)

In [15]:
batch = next(iter(train_dataloader))
pixel_values, labels, target_sequences = batch
print(pixel_values.shape)

torch.Size([1, 3, 1280, 960])


In [16]:
for id in labels.squeeze().tolist()[:30]:
  if id != -100:
    print(processor.decode([id]))
  else:
    print(id)

<license_num>
03
-0
6-0
03
542
34
<license_num/>
<name>
K
IRAN
L
AMA
<name/>
<dob>
10
-11
-
1993
<dob/>
</s>
-100
-100
-100
-100
-100
-100
-100
-100
-100


In [17]:
print(len(train_dataset))
print(len(val_dataset))

6
1


In [18]:
batch = next(iter(val_dataloader))
pixel_values, labels, target_sequences = batch
print(pixel_values.shape)

torch.Size([1, 3, 1280, 960])


In [19]:
print(target_sequences[0])

<license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/></s>


In [20]:
!pip install -q pytorch-lightning wandb

[0m

In [21]:
from pathlib import Path
import re
from nltk import edit_distance
import numpy as np
import math

from torch.nn.utils.rnn import pad_sequence
from torch.optim.lr_scheduler import LambdaLR

import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only

class DonutModelPLModule(pl.LightningModule):
    def __init__(self, config, processor, model):
        super().__init__()
        self.config = config
        self.processor = processor
        self.model = model

    def training_step(self, batch, batch_idx):
        pixel_values, labels, _ = batch
        
        outputs = self.model(pixel_values, labels=labels)
        loss = outputs.loss
        self.log("train_loss", loss)
        return loss
    
    def validation_step(self, batch, batch_idx, dataset_idx=0):
        pixel_values, labels, answers = batch
        batch_size = pixel_values.shape[0]
        # we feed the prompt to the model
        decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device)
        
        outputs = self.model.generate(pixel_values,
                                   decoder_input_ids=decoder_input_ids,
                                   max_length=max_length,
                                   early_stopping=True,
                                   pad_token_id=self.processor.tokenizer.pad_token_id,
                                   eos_token_id=self.processor.tokenizer.eos_token_id,
                                   use_cache=True,
                                   num_beams=1,
                                   bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
                                   return_dict_in_generate=True,)
    
        predictions = []
        for seq in self.processor.tokenizer.batch_decode(outputs.sequences):
            seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
            seq = re.sub(r"<.*?>", "", seq, count=1).strip()  # remove first task start token
            predictions.append(seq)

        scores = []
        for seq in self.processor.tokenizer.batch_decode(outputs.sequences):
            seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
            seq = re.sub(r"<.*?>", "", seq, count=1).strip()  # remove first task start token
            predictions.append(seq)

        scores = []
        for pred, answer in zip(predictions, answers):
            pred = re.sub(r"(?<=\>) | (?=\=)", "", answer, count=1).replace(self.processor.tokenizer.eos_token, "")
            answer = answer.replace(self.processor.tokenizer.eos_token, "")
            scores.append(edit_distance(pred, answer) / max(len(pred), len(answer)))

            if self.config.get("verbose", False) and len(scores) == 1:
                print(f"Prediction: {pred}")
                print(f"    Answer: {answer}")
                print(f" Normed ED: {scores[0]}")

        self.log("val_edit_distance", np.mean(scores))
        
        return scores
    
    def configure_optimizers(self):
        # you could also add a learning rate scheduler if you want
        optimizer = torch.optim.Adam(self.parameters(), lr=self.config.get("lr"))
    
        return optimizer

    def train_dataloader(self):
        return train_dataloader

    def val_dataloader(self):
        return val_dataloader

In [22]:
config = {"max_epochs":3,
          "val_check_interval":1, # how many times we want to validate during an epoch
          "check_val_every_n_epoch":1,
          "gradient_clip_val":1.0,
          "num_training_samples_per_epoch": 6,
          "lr":3e-5,
          "train_batch_sizes": [1],
          "val_batch_sizes": [1],
          # "seed":2022,
          "num_nodes": 1,
          "warmup_steps": 300, # 800/8*30/10, 10%
          "result_path": "./result",
          "verbose": True,
          }

model_module = DonutModelPLModule(config, processor, model)

In [23]:
from pytorch_lightning.callbacks import EarlyStopping

early_stop_callback = EarlyStopping(monitor="val_edit_distance", patience=3, verbose=False, mode="min")

trainer = pl.Trainer(
        accelerator="gpu",
        devices=1,
        max_epochs=config.get("max_epochs"),
        val_check_interval=config.get("val_check_interval"),
        check_val_every_n_epoch=config.get("check_val_every_n_epoch"),
        gradient_clip_val=config.get("gradient_clip_val"),
        precision=16, # we'll use mixed precision
        num_sanity_val_steps=0,
        enable_checkpointing=True
)

trainer.fit(model_module)

/usr/local/lib/python3.9/dist-packages/lightning_fabric/connector.py:558: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(val_check_interval=1)` was configured so validation will run after every batch.
You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                      | Params
----------------------------------------------------
0 | model | V

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]



Prediction: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
    Answer: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
 Normed ED: 0.0


/usr/local/lib/python3.9/dist-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 1. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
    Answer: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
 Normed ED: 0.0


Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
    Answer: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
 Normed ED: 0.0


Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
    Answer: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
 Normed ED: 0.0


Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
    Answer: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
 Normed ED: 0.0


Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
    Answer: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
 Normed ED: 0.0


Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
    Answer: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
 Normed ED: 0.0


Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
    Answer: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
 Normed ED: 0.0


Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
    Answer: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
 Normed ED: 0.0


Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
    Answer: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
 Normed ED: 0.0


Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
    Answer: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
 Normed ED: 0.0


Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
    Answer: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
 Normed ED: 0.0


Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
    Answer: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
 Normed ED: 0.0


Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
    Answer: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
 Normed ED: 0.0


Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
    Answer: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
 Normed ED: 0.0


Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
    Answer: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
 Normed ED: 0.0


Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
    Answer: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
 Normed ED: 0.0


Validation: |          | 0/? [00:00<?, ?it/s]

Prediction: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
    Answer: <license_num>01-06-00037702<license_num/><name>PRAJIN GHIMIRE<name/><dob>25-06-1992<dob/>
 Normed ED: 0.0


`Trainer.fit` stopped: `max_epochs=3` reached.
