In [1]:
import json
import os
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont
import ipywidgets as widgets
from IPython.display import display, clear_output
from model import Qwen2VLProcessor
from qwen_vl_utils import process_vision_info  # adjust this import as needed


In [2]:

# --- Configuration ---
device      = 'cuda'
model_path  = "/home/syc/intern/wanshan/Qwen2VL-Resampler-Finetune/output/resampler_7b_retain_ratio_1"
# min_pixels  = 1344 * 28 * 28
max_pixels  = 1680 * 28 * 28
processor   = Qwen2VLProcessor.from_pretrained(model_path, max_pixels=max_pixels)
vis_base    = "./visualize_imgs"
os.makedirs(vis_base, exist_ok=True)

json_path = "/home/syc/intern/wanshan/llm/Qwen2VL_sim/screenspot_sim_qwen2vl-7b_max_pixels_1680-prune_layer-0-retain_ratio-0.9418-web.json"
# Load your precomputed masks JSON
with open(json_path, "r") as f:
    data = json.load(f)

# State
index = 0
out = widgets.Output()


In [3]:
data[0].keys()

dict_keys(['img_path', 'text', 'bbox', 'pred', 'matched', 'response', 'type', 'source', 'select_mask'])

In [4]:
naive_pred_json_path = "/home/syc/intern/wanshan/Qwen2-VL/agent_tasks/ScreenSpot/sim_prunelayer_0-04-25/screenspot_sim_qwen2vl-7b_max_pixels_1680-prune_layer-0-retain_ratio-1.0-web.json"

with open(naive_pred_json_path, "r") as f:
    naive_data = json.load(f)
naive_data[0]

{'img_path': '/data/data1/syc/intern/wanshan/datasets/ScreenSpot/screenspot_imgs/web_213f816e-8e80-4d13-970d-1347bbc7a2a8.png',
 'text': 'create a new project',
 'bbox': [0.906640625, 0.08958333333333333, 0.987890625, 0.13819444444444445],
 'pred': [0.95, 0.12],
 'matched': True,
 'response': '{"action_type": 4, "click_point": (0.95,0.12)}\n',
 'type': 'text',
 'source': 'gitlab'}

In [5]:
def show_instance(idx):
    global last_annotated_img, last_save_path
    inst = data[idx]
    inst_no_selectmask = inst.copy()
    inst_no_selectmask.pop("select_mask", None)

    naive_inst = naive_data[idx]

    sentence_label.value = f"**Naive:** {naive_inst} <br> **SIM:**{inst_no_selectmask}"

    # ——— 1) generate demo.png as you already have ———
    messages = [{
        "role": "user",
        "content": [
            {"type": "image", "image": inst["img_path"]},
            {"type": "text",  "text": inst["text"]},
        ],
    }]
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)
    mask_np = np.array(inst["select_mask"], dtype=bool)
    vis_dir = os.path.join(vis_base, f"inst_{idx:03d}")
    os.makedirs(vis_dir, exist_ok=True)
    # this writes vis_dir/demo.png
    processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
        select_mask=mask_np,
        vis_dir=vis_dir
    ).to(device)

    # ——— 2) load demo.png and build top banner + annotation ———
    img_path = os.path.join(vis_dir, "demo.png")
    with out:
        clear_output(wait=True)
        if not os.path.exists(img_path):
            display(widgets.HTML(f"<b>Error:</b> no output in {vis_dir}"))
            return

        # load original
        img = Image.open(img_path).convert("RGB")
        w, h = img.size

        # load font
        try:
            font = ImageFont.truetype(
                "/usr/share/fonts/truetype/noto/NotoSansSC-Regular.otf", size=25
            )
        except IOError:
            font = ImageFont.load_default(size=25)

        # measure text
        instruction = inst["text"]
        dummy = ImageDraw.Draw(img)
        x0,y0,x1,y1 = dummy.textbbox((0,0), instruction, font=font)
        text_w, text_h = x1-x0, y1-y0

        # compute banner height
        top_margin, bottom_margin = 5, 15
        padding_top = top_margin + text_h + bottom_margin

        # new canvas
        new_img = Image.new("RGB", (w, h + padding_top), "white")
        new_img.paste(img, (0, padding_top))
        draw = ImageDraw.Draw(new_img)

        # draw instruction text centered
        x_text = (w - text_w) / 2
        y_text = top_margin

        draw.text((x_text, y_text), instruction, font=font, fill="black")

        # helper to go from relative [0,1] to abs coords on new_img
        def to_abs(rx, ry):
            return rx * w, padding_top + ry * h

        # draw ground-truth bbox
        bx0, by0, bx1, by1 = inst["bbox"]
        ax0, ay0 = to_abs(bx0, by0)
        ax1, ay1 = to_abs(bx1, by1)
        draw.rectangle([ax0, ay0, ax1, ay1], outline="green", width=3)

        # draw naive click (blue)
        nx, ny = naive_inst["pred"]
        cx, cy = to_abs(nx, ny)
        r = 8
        draw.ellipse([cx-r, cy-r, cx+r, cy+r], fill="blue")

        # draw SIM click (red)
        px, py = inst["pred"]
        sx, sy = to_abs(px, py)
        draw.ellipse([sx-r, sy-r, sx+r, sy+r], fill="red")

        # update for saving
        last_annotated_img = new_img
        last_save_path = os.path.join(vis_dir, "demo_annotated.png")

        # display
        display(new_img)
        display(widgets.HTML(
            f"<b>Instance {idx+1}/{len(data)}:</b> {instruction}"
        ))

def on_next(_):
    global index
    if index < len(data) - 1:
        index += 1
    show_instance(index)

def on_prev(_):
    global index
    if index > 0:
        index -= 1
    show_instance(index)

def on_save(_):
    if last_annotated_img is not None and last_save_path:
        last_annotated_img.save(last_save_path)
        with out:
            display(widgets.HTML(f"<span style='color:green;'><b>Saved:</b> {last_save_path}</span>"))

# sentence_label 
sentence_label = widgets.HTML()

# Create navigation and save buttons
prev_btn = widgets.Button(description='Previous')
next_btn = widgets.Button(description='Next')
save_btn = widgets.Button(description='Save')

prev_btn.on_click(on_prev)
next_btn.on_click(on_next)
save_btn.on_click(on_save)

# Display UI

display(sentence_label, widgets.HBox([prev_btn, next_btn, save_btn]))
display(out)

# Show the first instance
show_instance(index)

HTML(value='')

HBox(children=(Button(description='Previous', style=ButtonStyle()), Button(description='Next', style=ButtonSty…

Output()

In [None]:
16 22 34 35  