In [11]:
%env CUDA_VISIBLE_DEVICES

'2'

In [12]:
import torch
from transformers import VisionEncoderDecoderModel, MBartForCausalLM, ViTModel, AutoTokenizer, ViTConfig, PretrainedConfig
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from torchvision.transforms import ToTensor
from torch.nn.utils.rnn import pad_sequence
# Let's test donut's performance if we change the encoder to ViT, while still using old decoder

In [13]:
donut_base = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base")
decoder = donut_base.decoder
encoder_config = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
encoder_config.config.num_hidden_layers = 8
encoder_config.config.num_attention_heads = 16
encoder_config.config.hidden_size = decoder.config.hidden_size
encoder = ViTModel(encoder_config.config, add_pooling_layer=False)
model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder)

In [14]:
lab_dir = "./"

In [15]:
tokenizer = AutoTokenizer.from_pretrained(lab_dir + "cord/cord_tokenizer/")
model.decoder.resize_token_embeddings(len(tokenizer))
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.eos_token_id = tokenizer.eos_token_id

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [16]:
dataset = load_dataset("zyxleo/cord_donut_multitask")

In [28]:
def collate(examples):
    pixel_values = [example[0] for example in examples]
    labels = [torch.tensor(example[1]) for example in examples]
    input_ids = [torch.tensor(example[2]) for example in examples]
    ground_truth = [example[3] for example in examples]
    pixel_values = torch.stack(pixel_values)
    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    labels = pad_sequence(labels, batch_first=True, padding_value=-100)
    return pixel_values, labels, input_ids, ground_truth

In [29]:
class FtDataset(Dataset):
    def __init__(self, dataset, split, image_size):
        self.dataset = dataset[split]
        self.image_size = image_size[::-1]
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        image = Image.open(lab_dir + "cord/" + self.dataset[idx]["image_path"])
        image = image.resize(self.image_size).convert("RGB")
        image = ToTensor()(image)
        labels = self.dataset[idx]["labels"]
        input_ids = self.dataset[idx]["input_ids"]
        ground_truth = self.dataset[idx]["ground_truth"]
        return image, labels, input_ids, ground_truth

In [30]:
train_dataset = FtDataset(dataset, "train", (1256, 800))
val_dataset = FtDataset(dataset, "validation", (1256, 800))
test_dataset = FtDataset(dataset, "test", (1256, 800))

In [31]:
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=collate)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, collate_fn=collate)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collate)

In [32]:
next(iter(val_loader))[0].shape

torch.Size([1, 3, 1256, 800])

In [33]:
import re
def token2json(tokens, is_inner_value=False):
    """
    Convert a (generated) token sequence into an ordered JSON format.
    """
    added_vocab = tokenizer.get_added_vocab()
    output = {}
    while tokens:
        start_token = re.search(r"<s_(.*?)>", tokens, re.IGNORECASE)
        if start_token is None:
            break
        key = start_token.group(1)
        key_escaped = re.escape(key)
        end_token = re.search(rf"</s_{key_escaped}>", tokens, re.IGNORECASE)
        start_token = start_token.group()
        if end_token is None:
            tokens = tokens.replace(start_token, "")
        else:
            end_token = end_token.group()
            start_token_escaped = re.escape(start_token)
            end_token_escaped = re.escape(end_token)
            content = re.search(f"{start_token_escaped}(.*?){end_token_escaped}", tokens, re.IGNORECASE)
            if content is not None:
                content = content.group(1).strip()
                if r"<s_" in content and r"</s_" in content:  # non-leaf node
                    value = token2json(content, is_inner_value=True)
                    if value:
                        if len(value) == 1:
                            value = value[0]
                        output[key] = value
                else:  # leaf nodes
                    output[key] = []
                    for leaf in content.split(r"<sep/>"):
                        leaf = leaf.strip()
                        if leaf in added_vocab and leaf[0] == "<" and leaf[-2:] == "/>":
                            leaf = leaf[1:-2]  # for categorical special tokens
                        output[key].append(leaf)
                    if len(output[key]) == 1:
                        output[key] = output[key][0]
            tokens = tokens[tokens.find(end_token) + len(end_token) :].strip()
            if tokens[:6] == r"<sep/>":  # non-leaf nodes
                return [output] + token2json(tokens[6:], is_inner_value=True)
    if len(output):
        return [output] if is_inner_value else output
    else:
        return [] if is_inner_value else {"text_sequence": tokens}

In [23]:
import pytorch_lightning as pl
import pdb
from torch.optim import AdamW
import json
import numpy as np
import gc
from transformers import DonutProcessor
from donut import JSONParseEvaluator



class PlModule(pl.LightningModule):
    def __init__(self, model, config):
        super().__init__()
        self.model = model
        self.config = config
        self.processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")
        self.evaluator = JSONParseEvaluator()
    
    def training_step(self, batch, batch_idx):
        pixel_values, labels, input_ids, ground_truth = batch
        outputs = self.model(pixel_values=pixel_values, labels=labels, interpolate_pos_encoding=True)
        loss = outputs.loss
        pixel_values = labels = input_ids = ground_truth = None
        gc.collect()
        torch.cuda.empty_cache()
        self.log("train_loss", loss.item())
        return loss
    
    def validation_step(self, batch, batch_idx):
        # pdb.set_trace()
        pixel_values, labels, input_ids, ground_truth = batch
        outputs = self.model.generate(pixel_values=pixel_values,
                                      decoder_input_ids=input_ids,
                                      early_stopping=True,
                                      max_length=768,
                                      use_cache=True,
                                      num_beams=1,
                                      bad_words_ids=[[tokenizer.unk_token_id]],
                                      pad_token_id=tokenizer.pad_token_id,
                                      eos_token_id=tokenizer.eos_token_id,
                                      return_dict_in_generate=True,
                                      output_hidden_states=True,
                                      interpolate_pos_encoding=True)
        predictions = []
        for seq in tokenizer.batch_decode(outputs.sequences):
            seq = seq.replace(tokenizer.eos_token, "").replace(tokenizer.pad_token, "")
            seq = re.sub(r"<.*?>", "", seq, count=1).strip()  # remove first task start token
            seq = re.sub(r"(?:(?<=>) | (?=</s_))", "", seq)
            seq = self.processor.token2json(seq)
            predictions.append(seq)

        accs = []
        f1s = []
        for pred, answer in zip(predictions, ground_truth):
            # NOT NEEDED ANYMORE
            # answer = re.sub(r"<.*?>", "", answer, count=1)
            # scores.append(self.evaluator.(pred, answer) / max(len(pred), len(answer)))
            answer = json.loads(answer)
            accs.append(self.evaluator.cal_acc(pred, answer))
            f1s.append(self.evaluator.cal_f1([pred], [answer]))
            if len(accs) == 1:
                print(f"Prediction: {pred}")
                print(f"    Answer: {answer}")
                print(f"       Acc: {accs[-1]:.2f}")
                print(f"        F1: {f1s[-1]:.2f}")
                

        # self.log("val_edit_distance", np.mean(scores))
        self.log("val_acc", np.mean(accs))
        self.log("val_f1", np.mean(f1s))
        
        # return scores
        return accs, f1s
    
    def test_step(self, batch, batch_idx):
        pixel_values, labels, input_ids, ground_truth = batch
        outputs = self.model.generate(pixel_values=pixel_values,
                              decoder_input_ids=input_ids,
                              early_stopping=True,
                              max_length=768,
                              use_cache=True,
                              num_beams=1,
                              bad_words_ids=[[tokenizer.unk_token_id]],
                              pad_token_id=tokenizer.pad_token_id,
                              eos_token_id=tokenizer.eos_token_id,
                              return_dict_in_generate=True,
                              interpolate_pos_encoding=True)
        predictions = []
        for seq in tokenizer.batch_decode(outputs.sequences):
            seq = seq.replace(tokenizer.eos_token, "").replace(tokenizer.pad_token, "")
            seq = re.sub(r"<.*?>", "", seq, count=1).strip()  # remove first task start token
            seq = re.sub(r"(?:(?<=>) | (?=</s_))", "", seq)
            seq = self.processor.token2json(seq)
            predictions.append(seq)

        accs = []
        f1s = []
        for pred, answer in zip(predictions, ground_truth):
            # NOT NEEDED ANYMORE
            # answer = re.sub(r"<.*?>", "", answer, count=1)
            # scores.append(self.evaluator.(pred, answer) / max(len(pred), len(answer)))
            answer = json.loads(answer)
            accs.append(self.evaluator.cal_acc(pred, answer))
            f1s.append(self.evaluator.cal_f1([pred], [answer]))
            if len(accs) == 1:
                print(f"Prediction: {pred}")
                print(f"    Answer: {answer}")
                print(f"       Acc: {accs[-1]:.2f}")
                print(f"        F1: {f1s[-1]:.2f}")
                

        # self.log("val_edit_distance", np.mean(scores))
        self.log("val_acc", np.mean(accs))
        self.log("val_f1", np.mean(f1s))
        
        # return scores
        return accs, f1s
    
    def configure_optimizers(self):
        return AdamW(self.model.parameters(), lr=self.config.get("lr"))
    
    def train_dataloader(self):
        return train_loader
    
    def val_dataloader(self):
        return val_loader
    
    def test_dataloader(self):
        return test_loader

In [24]:
config = {"max_epochs":10,
          "val_check_interval":0.2, 
          "check_val_every_n_epoch":2,
          "gradient_clip_val":1.0,
          "num_training_samples_per_epoch": 800,
          "lr":3e-5,
          "train_batch_sizes": [1],
          "val_batch_sizes": [1],
          # "seed":2022,
          "num_nodes": 1,
          "warmup_steps": 300, # 800/8*30/10, 10%
          "result_path": "./result",
          "verbose": True,
          }

model_module = PlModule(model, config)

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


In [25]:
from pytorch_lightning.callbacks import EarlyStopping
early_stop_callback = EarlyStopping(monitor="val_acc", patience=2, verbose=False, mode="min")
trainer = pl.Trainer(
        # fast_dev_run=100,
        accelerator="gpu",
        devices=1,
        max_epochs=config.get("max_epochs"),
        val_check_interval=config.get("val_check_interval"),
        check_val_every_n_epoch=config.get("check_val_every_n_epoch"),
        gradient_clip_val=config.get("gradient_clip_val"),
        num_sanity_val_steps=0,
        precision=16, # we'll use mixed precision
        # callbacks=[PushToHubCallback(), early_stop_callback],
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [9]:
trainer.fit(model_module)

NameError: name 'model_module' is not defined

In [34]:
trainer.fit(model_module, ckpt_path=lab_dir+"lightning_logs/version_2/checkpoints/epoch=3-step=3040.ckpt")


Restoring states from the checkpoint path at ./lightning_logs/version_2/checkpoints/epoch=3-step=3040.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [2]

  | Name  | Type                      | Params
----------------------------------------------------
0 | model | VisionEncoderDecoderModel | 212 M 
----------------------------------------------------
212 M     Trainable params
0         Non-trainable params
212 M     Total params
850.846   Total estimated model params size (MB)
Restored all states from the checkpoint at ./lightning_logs/version_2/checkpoints/epoch=3-step=3040.ckpt
/root/miniforge-pypy3/envs/envv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.
/root/miniforge-pypy3/envs/envv/lib/python3.9/site-packages/pytorch_lightning/trainer/connect

Epoch 3:   0%|          | 0/800 [00:00<?, ?it/s] 

/root/miniforge-pypy3/envs/envv/lib/python3.9/site-packages/pytorch_lightning/loops/training_epoch_loop.py:154: You're resuming from a checkpoint that ended before the epoch ended. This can cause unreliable results if further training is done. Consider using an end-of-epoch checkpoint



Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/100 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/100 [00:00<?, ?it/s][A



Prediction: {'total': {'total_price': '22,000', 'changeprice': '0', 'cashprice': '22,000'}, 'sub_total': {'tax_price': '2,000', 'subtotal_price': '20,000'}, 'menu': {'price': '20,000', 'nm': 'CHOCO CHOCO', 'cnt': '1'}}
    Answer: {'menu': [{'nm': 'REAL GANACHE', 'cnt': '1', 'price': '16,500'}, {'nm': 'EGG TART', 'cnt': '1', 'price': '13,000'}, {'nm': 'PIZZA TOAST', 'cnt': '1', 'price': '16,000'}], 'total': {'total_price': '45,500', 'cashprice': '50,000', 'changeprice': '4,500'}}
       Acc: 0.25
        F1: 0.10

Validation DataLoader 0:   1%|          | 1/100 [00:00<01:18,  1.26it/s][A
Epoch 3:   0%|          | 0/800 [00:00<?, ?it/s, v_num=3]               [A

/root/miniforge-pypy3/envs/envv/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 1. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Epoch 3: 100%|██████████| 800/800 [01:36<00:00,  8.31it/s, v_num=3]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/100 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/100 [00:00<?, ?it/s][APrediction: {'total': {'total_price': '45,000', 'changeprice': '0', 'cashprice': '45,000'}, 'menu': {'price': '45,000', 'nm': 'CHOCO CROSSIANT', 'cnt': '1'}}
    Answer: {'menu': [{'nm': 'REAL GANACHE', 'cnt': '1', 'price': '16,500'}, {'nm': 'EGG TART', 'cnt': '1', 'price': '13,000'}, {'nm': 'PIZZA TOAST', 'cnt': '1', 'price': '16,000'}], 'total': {'total_price': '45,500', 'cashprice': '50,000', 'changeprice': '4,500'}}
       Acc: 0.29
        F1: 0.11

Validation DataLoader 0:   1%|          | 1/100 [00:00<00:37,  2.64it/s][APrediction: {'total': {'total_price': '45,000', 'changeprice': '5,000', 'cashprice': '50,000'}, 'menu': {'price': '45,000', 'nm': 'CHOCO CROSSIANT', 'cnt': '1'}}
    Answer: {'menu': {'nm': 'Kopi Susu Kolonel', 'cnt': '1', 'pri

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 800/800 [19:49<00:00,  0.67it/s, v_num=3]


In [16]:
ckpt = torch.load("lab/lightning_logs/version_0/checkpoints/epoch=29-step=24000.ckpt")
state_dict = ckpt["state_dict"]
# remove leading "model." in every key's name
new_state_dict = {}
for key in state_dict:
    new_state_dict[key[6:]] = state_dict[key]
model.load_state_dict(new_state_dict)

<All keys matched successfully>

In [17]:
iterator = iter(train_loader)
test0 = next(iterator)
test1 = next(iterator)
pixel_values, labels, input_ids, ground_truth = test0
outputs0 = model.generate(pixel_values=pixel_values,
                              decoder_input_ids=input_ids,
                              early_stopping=True,
                              max_length=768,
                              use_cache=True,
                              num_beams=1,
                              bad_words_ids=[[tokenizer.unk_token_id]],
                              pad_token_id=tokenizer.pad_token_id,
                              eos_token_id=tokenizer.eos_token_id,
                              return_dict_in_generate=True,
                              interpolate_pos_encoding=True)
pixel_values, labels, input_ids, ground_truth = test1
outputs1 = model.generate(pixel_values=pixel_values,
                              decoder_input_ids=input_ids,
                              early_stopping=True,
                              max_length=768,
                              use_cache=True,
                              num_beams=1,
                              bad_words_ids=[[tokenizer.unk_token_id]],
                              pad_token_id=tokenizer.pad_token_id,
                              eos_token_id=tokenizer.eos_token_id,
                              return_dict_in_generate=True,
                              interpolate_pos_encoding=True)



In [18]:
print(token2json(tokenizer.decode(outputs0[0][0])))
print(token2json(tokenizer.decode(outputs1[0][0])))

{'total': {'total_price': '25,000', 'total_etc': '25,000', 'changeprice': '0'}, 'sub_total': {'subtotal_price': '25.000'}, 'menu': [{'unitprice': '12,500', 'price': '12,500', 'nm': 'Silky Green Tea', 'cnt': '1x'}, {'unitprice': '12,500', 'price': '12.500', 'nm': 'Silky Hazelnut', 'cnt': '1x'}]}
{'total': {'total_price': '45,000', 'menuqty_cnt': '1', 'changeprice': '5,000', 'cashprice': '50,000'}, 'sub_total': {'subtotal_price': '45,000'}, 'menu': {'price': '45,000', 'nm': 'Salted Egg Yolk Chicken', 'cnt': '1'}}


In [19]:
ground_truth

['{"menu": {"nm": "Salted Egg Yolk Chicken", "cnt": "1", "price": "45,000"}, "sub_total": {"subtotal_price": "45,000"}, "total": {"total_price": "45,000", "cashprice": "50,000", "changeprice": "5,000", "menuqty_cnt": "1"}}']