<a href="https://colab.research.google.com/github/ivelin/donut_ui_refexp/blob/main/Donut_UI_RefExp_Gradio_Inference_Playground.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://huggingface.co/spaces/ivelin/ui-refexp/


Cloning into 'ui-refexp'...
remote: Enumerating objects: 172, done.[K
remote: Counting objects: 100% (172/172), done.[K
remote: Compressing objects: 100% (170/170), done.[K
remote: Total 172 (delta 101), reused 0 (delta 0), pack-reused 0[K
Receiving objects: 100% (172/172), 464.20 KiB | 3.36 MiB/s, done.
Resolving deltas: 100% (101/101), done.


In [2]:
!cd ui-refexp/ && pip3 install -r requirements.txt

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/huggingface/transformers.git (from -r requirements.txt (line 2))
  Cloning https://github.com/huggingface/transformers.git to /tmp/pip-req-build-amoc1c9p
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers.git /tmp/pip-req-build-amoc1c9p
  Resolved https://github.com/huggingface/transformers.git to commit 1eda4a410298d57156d44bfc39a6001a72554412
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting sentencepiece
  Downloading sentencepiece-0.1.97-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m28.0 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloadi

In [3]:
!pip install gradio


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting gradio
  Downloading gradio-3.16.2-py3-none-any.whl (14.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.2/14.2 MB[0m [31m83.7 MB/s[0m eta [36m0:00:00[0m
Collecting fastapi
  Downloading fastapi-0.89.1-py3-none-any.whl (55 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.8/55.8 KB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
Collecting aiofiles
  Downloading aiofiles-22.1.0-py3-none-any.whl (14 kB)
Collecting httpx
  Downloading httpx-0.23.3-py3-none-any.whl (71 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m71.5/71.5 KB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pycryptodome
  Downloading pycryptodome-3.16.0-cp35-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (2.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [

In [5]:
import re
import gradio as gr
from PIL import Image, ImageDraw
import math
import torch
import html
from transformers import DonutProcessor, VisionEncoderDecoderModel

pretrained_repo_name = "ivelin/donut-refexp-combined-v1"
print(f"Loading model checkpoint: {pretrained_repo_name}")

processor = DonutProcessor.from_pretrained(pretrained_repo_name)
model = VisionEncoderDecoderModel.from_pretrained(pretrained_repo_name)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)


def process_refexp(image: Image, prompt: str):

    print(f"(image, prompt): {image}, {prompt}")

    # trim prompt to 80 characters and normalize to lowercase
    prompt = prompt[:80].lower()

    # prepare encoder inputs
    pixel_values = processor(image, return_tensors="pt").pixel_values

    # prepare decoder inputs
    task_prompt = "<s_refexp><s_prompt>{user_input}</s_prompt><s_target_bounding_box>"
    prompt = task_prompt.replace("{user_input}", prompt)
    decoder_input_ids = processor.tokenizer(
        prompt, add_special_tokens=False, return_tensors="pt").input_ids

    # generate answer
    outputs = model.generate(
        pixel_values.to(device),
        decoder_input_ids=decoder_input_ids.to(device),
        max_length=model.decoder.config.max_position_embeddings,
        early_stopping=True,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=1,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )

    # postprocess
    sequence = processor.batch_decode(outputs.sequences)[0]
    print(fr"predicted decoder sequence: {html.escape(sequence)}")
    sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(
        processor.tokenizer.pad_token, "")
    # remove first task start token
    sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
    print(
        fr"predicted decoder sequence before token2json: {html.escape(sequence)}")
    seqjson = processor.token2json(sequence)

    # safeguard in case predicted sequence does not include a target_bounding_box token
    bbox = seqjson.get('target_bounding_box')
    if bbox is None:
        print(
            f"token2bbox seq has no predicted target_bounding_box, seq:{seq}")
        bbox = {"xmin": 0, "ymin": 0, "xmax": 0, "ymax": 0}
        return bbox

    print(f"predicted bounding box with text coordinates: {bbox}")
    # safeguard in case text prediction is missing some bounding box coordinates
    # or coordinates are not valid numeric values
    try:
        xmin = float(bbox.get("xmin", 0))
    except ValueError:
        xmin = 0
    try:
        ymin = float(bbox.get("ymin", 0))
    except ValueError:
        ymin = 0
    try:
        xmax = float(bbox.get("xmax", 1))
    except ValueError:
        xmax = 1
    try:
        ymax = float(bbox.get("ymax", 1))
    except ValueError:
        ymax = 1
    # replace str with float coords
    bbox = {"xmin": xmin, "ymin": ymin, "xmax": xmax,
            "ymax": ymax, "decoder output sequence": sequence}
    print(f"predicted bounding box with float coordinates: {bbox}")

    print(f"image object: {image}")
    print(f"image size: {image.size}")
    width, height = image.size
    print(f"image width, height: {width, height}")
    print(f"processed prompt: {prompt}")

    # safeguard in case text prediction is missing some bounding box coordinates
    xmin = math.floor(width*bbox["xmin"])
    ymin = math.floor(height*bbox["ymin"])
    xmax = math.floor(width*bbox["xmax"])
    ymax = math.floor(height*bbox["ymax"])

    print(
        f"to image pixel values: xmin, ymin, xmax, ymax: {xmin, ymin, xmax, ymax}")

    shape = [(xmin, ymin), (xmax, ymax)]

    # deaw bbox rectangle
    img1 = ImageDraw.Draw(image)
    img1.rectangle(shape, outline="green", width=5)
    img1.rectangle(shape, outline="white", width=2)

    return image, bbox


title = "Demo: Donut 🍩 for UI RefExp (by GuardianUI)"
description = "Gradio Demo for Donut RefExp task, an instance of `VisionEncoderDecoderModel` fine-tuned on [UIBert RefExp](https://huggingface.co/datasets/ivelin/ui_refexp_saved) Dataset (UI Referring Expression). To use it, simply upload your image and type a prompt and click 'submit', or click one of the examples to load them. See the model training <a href='https://colab.research.google.com/github/ivelin/donut_ui_refexp/blob/main/Fine_tune_Donut_on_UI_RefExp.ipynb' target='_parent'>Colab Notebook</a> for this space. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.15664' target='_blank'>Donut: OCR-free Document Understanding Transformer</a> | <a href='https://github.com/clovaai/donut' target='_blank'>Github Repo</a></p>"
examples = [["example_1.jpg", "select the setting icon from top right corner"],
            ["example_1.jpg", "click on down arrow beside the entertainment"],
            ["example_1.jpg", "select the down arrow button beside lifestyle"],
            ["example_1.jpg", "click on the image beside the option traffic"],
            ["example_2.jpg", "enter the text field next to the name"],
            ["example_2.jpg", "click on green color button"],
            ["example_2.jpg", "click on text which is beside call now"],
            ["example_2.jpg", "click on more button"],
            ["example_3.jpg", "select the third row first image"],
            ["example_3.jpg", "click the tick mark on the first image"],
            ["example_3.jpg", "select the ninth image"],
            ["example_3.jpg", "select the add icon"],
            ["example_3.jpg", "click the first image"],
            ["val-image-1.jpg", "select calendar option"],
            ["val-image-1.jpg", "select photos&videos option"],
            ["val-image-2.jpg", "click on change store"],
            ["val-image-2.jpg", "click on shop menu at the bottom"],
            ["val-image-3.jpg", "click on image above short meow"],
            ["val-image-3.jpg", "go to cat sounds"],
            ]

demo = gr.Interface(fn=process_refexp,
                    inputs=[gr.Image(type="pil"), "text"],
                    outputs=[gr.Image(type="pil"), "json"],
                    title=title,
                    description=description,
                    article=article,
                    # examples=examples,
                    # caching examples inference takes too long to start space after app change commit
                    cache_examples=False
                    )

demo.launch(share=True)

Loading model checkpoint: ivelin/donut-refexp-combined-v1
Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Running on public URL: https://57ea5661-fa43-409d.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces


