In [1]:
pip install transformers xmltodict

Note: you may need to restart the kernel to use updated packages.


In [2]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForImageTextToText

tokenizer = AutoTokenizer.from_pretrained("mychen76/invoice-and-receipts_donut_v1")
model = AutoModelForImageTextToText.from_pretrained("mychen76/invoice-and-receipts_donut_v1")

Config of the encoder: <class 'transformers.models.donut.modeling_donut_swin.DonutSwinModel'> is overwritten by shared encoder config: DonutSwinConfig {
  "attention_probs_dropout_prob": 0.0,
  "depths": [
    2,
    2,
    14,
    2
  ],
  "drop_path_rate": 0.1,
  "embed_dim": 128,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 1024,
  "image_size": [
    1280,
    960
  ],
  "initializer_range": 0.02,
  "layer_norm_eps": 1e-05,
  "mlp_ratio": 4.0,
  "model_type": "donut-swin",
  "num_channels": 3,
  "num_heads": [
    4,
    8,
    16,
    32
  ],
  "num_layers": 4,
  "patch_size": 4,
  "path_norm": true,
  "qkv_bias": true,
  "transformers_version": "4.46.2",
  "use_absolute_embeddings": false,
  "window_size": 10
}

Config of the decoder: <class 'transformers.models.mbart.modeling_mbart.MBartForCausalLM'> is overwritten by shared decoder config: MBartConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_cross_attention": true,
  "add_

In [3]:
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel

model_id="mychen76/invoice-and-receipts_donut_v1"
def load_model(model_id=model_id):
    processor = DonutProcessor.from_pretrained(model_id)
    model = VisionEncoderDecoderModel.from_pretrained(model_id)
    return processor, model

processor, model = load_model(model_id)

Config of the encoder: <class 'transformers.models.donut.modeling_donut_swin.DonutSwinModel'> is overwritten by shared encoder config: DonutSwinConfig {
  "attention_probs_dropout_prob": 0.0,
  "depths": [
    2,
    2,
    14,
    2
  ],
  "drop_path_rate": 0.1,
  "embed_dim": 128,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 1024,
  "image_size": [
    1280,
    960
  ],
  "initializer_range": 0.02,
  "layer_norm_eps": 1e-05,
  "mlp_ratio": 4.0,
  "model_type": "donut-swin",
  "num_channels": 3,
  "num_heads": [
    4,
    8,
    16,
    32
  ],
  "num_layers": 4,
  "patch_size": 4,
  "path_norm": true,
  "qkv_bias": true,
  "transformers_version": "4.46.2",
  "use_absolute_embeddings": false,
  "window_size": 10
}

Config of the decoder: <class 'transformers.models.mbart.modeling_mbart.MBartForCausalLM'> is overwritten by shared decoder config: MBartConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_cross_attention": true,
  "add_

In [13]:
from PIL import Image
import json

def generateTextInImage(processor,model,input_image,task_prompt="<s_receipt>"):
    pixel_values = processor(input_image, return_tensors="pt").pixel_values
    print ("input pixel_values: ",pixel_values.shape)
    task_prompt = "<s_receipt>"
    decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt")["input_ids"]
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    outputs = model.generate(pixel_values.to(device),
                               decoder_input_ids=decoder_input_ids.to(device),
                               max_length=model.decoder.config.max_position_embeddings,
                               early_stopping=True,
                               pad_token_id=processor.tokenizer.pad_token_id,
                               eos_token_id=processor.tokenizer.eos_token_id,
                               use_cache=True,
                               num_beams=1,
                               bad_words_ids=[[processor.tokenizer.unk_token_id]],
                               return_dict_in_generate=True,
                               output_scores=True,)
    return outputs

def generateOutputXML(processor,model, input_image, task_start="<s_receipt>",task_end="</s_receipt>"):
    import re
    outputs=generateTextInImage(processor,model,input_image,task_prompt=task_start)
    sequence = processor.batch_decode(outputs.sequences)[0]
    print(sequence)
    sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
    sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()  # remove first task start token
    return sequence

def convertOutputToJson(processor, xml):
    sequence = generateOutputXML()
    result=processor.token2json(sequence)
    print(":vampire:",result)
    return result

def generateOutputJson(processor,model, input_image, task_start="<s_receipt>",task_end="</s_receipt>"):
    xml = generateOutputXML(processor,model, input_image,task_start=task_start,task_end=task_end)
    result=processor.token2json(xml)
    print(":vampire:",result)
    return result

IMAGE_PATH = "./cropped_receipt/cropped_image.png"
OUTPUT_PATH = "./outputs/hf_donut.json"
input_image = Image.open(IMAGE_PATH)

## generate json
invoice1_json=generateOutputJson(processor,model,input_image)
print(invoice1_json)
with open(OUTPUT_PATH, 'w') as file:
    json.dump(invoice1_json, file, indent=4)


# ## generate xml
# xml=generateOutputXML(processor,model,input_image)
# print(xml)

input pixel_values:  torch.Size([1, 3, 1280, 960])




<s_receipt><s_store_name> Walmart</s_store_name><s_store_addr></s_store_addr><s_telephone> (970)259-8755</s_telephone><s_date> 01/15/17 01/15/17</s_date><s_time> 14:26:03</s_time><s_subtotal> 36.35</s_subtotal><s_tax> 1.07</s_tax><s_total> 38.68</s_total><s_ignore></s_ignore><s_tips></s_tips><s_line_items><s_item_key></s_item_key><s_item_name> BEVERAGE</s_item_name><s_item_value> 2.00</s_item_value><s_item_quantity> 1</s_item_quantity><sep/><s_item_key></s_item_key><s_item_name> OSCRAMPCOM</s_item_name><s_item_value> 1.14</s_item_value><s_item_quantity> 1</s_item_quantity><sep/><s_item_key></s_item_key><s_item_name> STRUBYCO</s_item_name><s_item_value> 1.06</s_item_value><s_item_quantity> 1</s_item_quantity><sep/><s_item_key></s_item_key><s_item_name> CAMPRITO</s_item_name><s_item_value> 2.97</s_item_value><s_item_quantity> 1</s_item_quantity><sep/><s_item_key></s_item_key><s_item_name> KFTSINGLES</s_item_name><s_item_value> 1.67</s_item_value><s_item_quantity> 1</s_item_quantity><sep/