In [1]:
from transformers import DonutProcessor
from datasets import load_dataset, Dataset
from PIL import Image
from tqdm import tqdm
import json
import os
import pickle

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")
tokenizer = processor.tokenizer

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


In [3]:
with open("./dataset_trainticket/real_1920_trimmed.pkl", 'rb') as f:
    pkl = pickle.load(f)

In [4]:
    added_tokens = set()
    def json2token(obj, update_special_tokens_for_json_key, sort_json_key: bool = True):
        """
        Convert an ordered JSON object into a token sequence
        """
        if type(obj) == dict:
            if len(obj) == 1 and "text_sequence" in obj:
                return obj["text_sequence"]
            else:
                output = ""
                if sort_json_key:
                    keys = sorted(obj.keys(), reverse=True)
                else:
                    keys = obj.keys()
                for k in keys:
                    if update_special_tokens_for_json_key:
                        tokenizer.add_tokens([fr"<s_{k}>", fr"</s_{k}>"])
                        added_tokens.add(fr"<s_{k}>")
                        added_tokens.add(fr"</s_{k}>")
                    output += (
                        fr"<s_{k}>"
                        + json2token(obj[k], update_special_tokens_for_json_key, sort_json_key)
                        + fr"</s_{k}>"
                    )
                return output
        elif type(obj) == list:
            return r"<sep/>".join(
                [json2token(item, update_special_tokens_for_json_key, sort_json_key) for item in obj]
            )
        else:
            obj = str(obj)
            if obj in added_tokens:
                obj = f"<{obj}/>"  # for categorical special tokens
            return obj

In [13]:
f_test = open("./ticket/ticket_dataset/ticket_test.jsonl", 'w')
f_train = open("./ticket/ticket_dataset/ticket_train.jsonl", 'w')
for key in tqdm(pkl.keys()):
    line = {}
    line["task"] = "ticket"
    line["ground_truth"] = json.dumps(pkl[key])
    gt_tokens = json2token(pkl[key], True)
    line["labels"] = tokenizer(gt_tokens + "</s>", add_special_tokens=False).input_ids
    if not (os.path.exists(f"./ticket/ticket_images/test/{key}.jpg") or os.path.exists(f"./ticket/ticket_images/train/{key}.jpg")):
        print(f"Don't exist: {key}")
    assert os.path.exists(f"./ticket/ticket_images/test/{key}.jpg") or os.path.exists(f"./ticket/ticket_images/train/{key}.jpg")
    if os.path.exists(f"./ticket/ticket_images/test/{key}.jpg"):
        line["input_ids"] = tokenizer("<s>", add_special_tokens=False).input_ids
        line["image_path"] = f"./ticket/ticket_images/test/{key}.jpg"
        f_test.write(json.dumps(line) + "\n")
    if os.path.exists(f"./ticket/ticket_images/train/{key}.jpg"): 
        line["input_ids"] = tokenizer("<s>" + gt_tokens, add_special_tokens=False).input_ids
        line["image_path"] = f"./ticket/ticket_images/train/{key}.jpg"
        f_train.write(json.dumps(line) + "\n")
f_test.close()
f_train.close()

100%|██████████| 1918/1918 [00:08<00:00, 235.17it/s]


In [14]:
tokenizer.save_pretrained("ticket_tokenizer")

('ticket_tokenizer/tokenizer_config.json',
 'ticket_tokenizer/special_tokens_map.json',
 'ticket_tokenizer/sentencepiece.bpe.model',
 'ticket_tokenizer/added_tokens.json',
 'ticket_tokenizer/tokenizer.json')

In [15]:
ds = load_dataset("json", data_files={"train": "./ticket/ticket_dataset/ticket_train.jsonl", "test": "./ticket/ticket_dataset/ticket_test.jsonl"})

Downloading data files: 100%|██████████| 2/2 [00:00<00:00, 2761.23it/s]
Extracting data files: 100%|██████████| 2/2 [00:00<00:00, 688.55it/s]
Generating train split: 1520 examples [00:00, 77036.89 examples/s]
Generating test split: 398 examples [00:00, 45743.92 examples/s]


In [17]:
ds.push_to_hub("ticket_donut_multitask", token="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")

Creating parquet from Arrow format: 100%|██████████| 2/2 [00:00<00:00, 189.09ba/s]
Pushing dataset shards to the dataset hub: 100%|██████████| 1/1 [00:01<00:00,  1.01s/it]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 364.44ba/s]
Pushing dataset shards to the dataset hub: 100%|██████████| 1/1 [00:03<00:00,  3.87s/it]
Downloading metadata: 100%|██████████| 703/703 [00:00<00:00, 3.00MB/s]


In [13]:
for key in tqdm(pkl.keys()):
    line = {}
    line["task"] = "ticket"
    line["input_ids"] = tokenizer("<s>", add_special_tokens=False).input_ids
    if not (os.path.exists(f"./ticket_images/test/{key}.jpg") or os.path.exists(f"./ticket_images/train/{key}.jpg")):
        print(f"Don't exist: {key}")

100%|██████████| 1920/1920 [00:00<00:00, 13066.05it/s]

Don't exist: IMG_20180514_134508
Don't exist: IMG_20180514_132629





In [19]:
ds_new = load_dataset("zyxleo/ticket_donut_multitask")

Downloading readme: 100%|██████████| 703/703 [00:00<00:00, 2.32MB/s]
Downloading data: 100%|██████████| 245k/245k [00:01<00:00, 175kB/s]
Downloading data: 100%|██████████| 52.0k/52.0k [00:00<00:00, 86.3kB/s]
Downloading data files: 100%|██████████| 2/2 [00:02<00:00,  1.01s/it]
Extracting data files: 100%|██████████| 2/2 [00:00<00:00, 797.70it/s]
Generating train split: 100%|██████████| 1520/1520 [00:00<00:00, 154146.43 examples/s]
Generating test split: 100%|██████████| 398/398 [00:00<00:00, 95015.82 examples/s]


In [21]:
from transformers import VisionEncoderDecoderModel
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base")

In [None]:
model.generate()