In [None]:
# !pip install paddlepaddle-gpu -i https://pypi.tuna.tsinghua.edu.cn/simple
!pip install paddlepaddle -i https://pypi.tuna.tsinghua.edu.cn/simple
!pip install "paddleocr>=2.0.1"
!pip install pybind11
!pip install fastwer
!pip install transformers --quiet
!pip install Levenshtein --quiet

In [3]:
import os
import shutil
from PIL import Image

import torch
from io import BytesIO
from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig, DonutProcessor
import json
from tqdm.notebook import tqdm

In [None]:
!git clone https://github.com/Losyash/SROIE-datasetv2.git

In [4]:
sroie_root_dir = './SROIE-datasetv2/'

In [None]:
# Make dir for test predicted
os.mkdir(os.path.join(sroie_root_dir, 'test_predicted'))
os.mkdir(os.path.join(sroie_root_dir, 'test_predicted/box/'))
os.mkdir(os.path.join(sroie_root_dir, 'test_predicted/entities/'))

In [6]:
sroie_test_img_dir = os.path.join(sroie_root_dir, 'test/img')
sroie_test_pred_ent_dir = os.path.join(sroie_root_dir, 'test_predicted/entities')
sroie_test_pred_box_dir = os.path.join(sroie_root_dir, 'test_predicted/box')

In [7]:
# Boxes
from paddleocr import PaddleOCR

ocr = PaddleOCR(use_angle_cls=True, lang='en')

download https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_det_infer.tar to /root/.paddleocr/whl/det/en/en_PP-OCRv3_det_infer/en_PP-OCRv3_det_infer.tar


100%|██████████| 4.00M/4.00M [00:04<00:00, 827kiB/s] 


download https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_rec_infer.tar to /root/.paddleocr/whl/rec/en/en_PP-OCRv3_rec_infer/en_PP-OCRv3_rec_infer.tar


100%|██████████| 9.96M/9.96M [00:18<00:00, 546kiB/s] 


download https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar to /root/.paddleocr/whl/cls/ch_ppocr_mobile_v2.0_cls_infer/ch_ppocr_mobile_v2.0_cls_infer.tar


100%|██████████| 2.19M/2.19M [00:14<00:00, 154kiB/s]

[2023/01/20 11:33:20] ppocr DEBUG: Namespace(alpha=1.0, benchmark=False, beta=1.0, cls_batch_num=6, cls_image_shape='3, 48, 192', cls_model_dir='/root/.paddleocr/whl/cls/ch_ppocr_mobile_v2.0_cls_infer', cls_thresh=0.9, cpu_threads=10, crop_res_save_dir='./output', det=True, det_algorithm='DB', det_box_type='quad', det_db_box_thresh=0.6, det_db_score_mode='fast', det_db_thresh=0.3, det_db_unclip_ratio=1.5, det_east_cover_thresh=0.1, det_east_nms_thresh=0.2, det_east_score_thresh=0.8, det_limit_side_len=960, det_limit_type='max', det_model_dir='/root/.paddleocr/whl/det/en/en_PP-OCRv3_det_infer', det_pse_box_thresh=0.85, det_pse_min_area=16, det_pse_scale=1, det_pse_thresh=0, det_sast_nms_thresh=0.2, det_sast_score_thresh=0.5, draw_img_save_dir='./inference_results', drop_score=0.5, e2e_algorithm='PGNet', e2e_char_dict_path='./ppocr/utils/ic15_dict.txt', e2e_limit_side_len=768, e2e_limit_type='max', e2e_model_dir=None, e2e_pgnet_mode='fast', e2e_pgnet_score_thresh=0.5, e2e_pgnet_valid_set




In [None]:
processor = DonutProcessor.from_pretrained("unstructuredio/donut-base-sroie")
pretrained_model = VisionEncoderDecoderModel.from_pretrained("unstructuredio/donut-base-sroie")

device = "cuda" if torch.cuda.is_available() else "cpu"
pretrained_model.to(device)

In [20]:
def make_box_file(file_name):

    result = ocr.ocr(f'{sroie_test_img_dir}/{file_name}', cls=True)
    file_name_2 = file_name.split('.')[0] + '-1.txt'

    with open(f'{sroie_test_pred_box_dir}/{file_name_2}', "w") as file:
        for i in range(len(result)):
            res = result[i]

            for line in res:
                data = []

                for j in range(len(line[0])):
                    coords = line[0][j]

                    data.append(str(int(coords[0])))
                    data.append(str(int(coords[1])))

                data.append(str(line[1][0]).upper())
                file.write((',').join(data) + '\n')    

In [22]:
def make_entities_file(file_name):
    global pretrained_model, processor

    path_to_test_img = os.path.join(sroie_test_img_dir, file_name)
    image = Image.open(path_to_test_img).convert("RGB")

    task_prompt = f"<s>"
    pixel_values = processor(image, return_tensors="pt").pixel_values
    decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids

    # run inference
    outputs = pretrained_model.generate(
        pixel_values.to(device),
        decoder_input_ids=decoder_input_ids.to(device),
        max_length=pretrained_model.decoder.config.max_position_embeddings,
        early_stopping=True,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=1,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )

    # process output
    prediction = processor.batch_decode(outputs.sequences)[0]
    prediction = processor.token2json(prediction)
    prediction = dict(sorted(prediction.items()))

    for key in ['address', 'company', 'date', 'total']:
        if key not in prediction:
            prediction[key] = ''

    file_name_2 = file_name.split('.')[0] + '-2.txt'

    path_to_test_pred_file = os.path.join(sroie_test_pred_ent_dir, file_name_2)
    with open(path_to_test_pred_file, "w") as fp:
        json.dump(prediction, fp) 


In [None]:
for file_name in tqdm(os.listdir(sroie_test_img_dir)):
    make_box_file(file_name)
    make_entities_file(file_name)

print('Done')