In [None]:
import pdb;
import os
import re
import time
import torch
from transformers import BitsAndBytesConfig, Qwen2VLForConditionalGeneration

import PIL.Image as Image

torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)

In [None]:
from qwen_vl_utils import process_vision_info
from model.showui.processing_showui import ShowUIProcessor
from model.showui.modeling_showui import ShowUIForConditionalGeneration

num_forward = 1

min_pixels = 256*28*28
max_pixels = 1344*28*28

# skip_ratio indicates how much ratio of visual markers you want to skip
# [1,28,1] means we apply UI guide token selection from 1-th to 28-th layer (28 is the last layer of Qwen2-VL)
skip_rand = False
lm_skip_ratio = 0.5
lm_skip_layer = "[1,28,1]"

# ShowUI preprocessor options
# ui_mask_pre - Prebuild patch selection indices in the preprocessor (this is more efficient, not in model layers)
# ui_mask_ratio - Specify the percentage of patch tokens to select
# ui_mask_rand - Enable random selection instead of uniform selection patches
processor = ShowUIProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", 
    min_pixels=min_pixels, max_pixels=max_pixels,
    ui_mask_pre=True, ui_mask_rand=False, ui_mask_ratio=lm_skip_ratio,
)

In [None]:
qwen_layer_num = 28

def parse_layer_type(str_ranges, L=qwen_layer_num, default=0):
    # 0 is without layer token selection, 1 is with layer token selection
    result = [default] * L
    matches = re.findall(r'\[\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\]', str_ranges)
    for start, end, value in matches:
        start, end, value = int(start) - 1, int(end) - 1, int(value)
        if end >= L:
            end = L - 1
        result[start:end + 1] = [value] * (end - start + 1)
    return result

if isinstance(lm_skip_layer, list):
    lm_skip_layer = str(lm_skip_layer)  # 转换为字符串

lm_skip_layer = parse_layer_type(lm_skip_layer, 28)
# print(lm_skip_layer)

quantization = BitsAndBytesConfig(load_in_8bit=True)
nf4_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=True,
   bnb_4bit_compute_dtype=torch.bfloat16
)

device = torch.device("mps")
model = ShowUIForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-2B-Instruct",
    torch_dtype=torch.bfloat16,
    device_map={"": device},
    lm_skip_ratio=lm_skip_ratio,
    lm_skip_layer=lm_skip_layer,
)

In [None]:
#img_url = 'examples/chrome.png'
img_url = 'examples/kim.png'
vis_dir = 'examples'

messages = [
    {
        "role": "user",
        "content": [
            {"type": "text", "text": "Please describe this image."},
            {
                "type": "image",
                "image": img_url,
                "min_pixels": min_pixels,
                "max_pixels": max_pixels,
            },
        ],
    }
]

# Preparation for inference
text = processor.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True,
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
    text=[text],
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
    ui_graph=True,
    ui_graph_threshold=1,
    ui_graph_vis_dir=vis_dir, # if provided, will be used to save the visualized img
)

#inputs = inputs.to("cuda:0")
#
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

inputs = inputs.to(device)

generated_ids = model.generate(**inputs, max_new_tokens=4096, do_sample=True)
generated_ids_trimmed = [
    out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
    generated_ids_trimmed, 
)[0]

print(f"(a) Screenshot patch number: {inputs['patch_assign'].shape[0]}")
display(Image.open(img_url))

print(f"(b) By applying UI-graph, UI Component number: {inputs['patch_assign_len'][0]}")
display(Image.open(f'{vis_dir}/demo.png'))

print(f"(c) Model output with skip ratio: {lm_skip_ratio}")
print(output_text)