In [1]:
import json
import os
import numpy as np
import torch
from PIL import Image, ImageDraw
import ipywidgets as widgets
from IPython.display import display, clear_output
from model import Qwen2VLProcessor
from qwen_vl_utils import process_vision_info  # adjust this import as needed


In [2]:

# --- Configuration ---
device      = 'cuda'
model_path  = "/home/syc/intern/wanshan/Qwen2VL-Resampler-Finetune/output/resampler_7b_retain_ratio_1"
min_pixels  = 1344 * 28 * 28
max_pixels  = 1680 * 28 * 28
processor   = Qwen2VLProcessor.from_pretrained(model_path, min_pixels=min_pixels, max_pixels=max_pixels)
vis_base    = "./visualize_imgs"
os.makedirs(vis_base, exist_ok=True)

# Load your precomputed masks JSON
with open("./data_with_masks-retain_ratio_0.8.json", "r") as f:
    data = json.load(f)

# State
index = 0
out = widgets.Output()


In [3]:
data[0]

{'img_path': '/data/data1/syc/intern/wanshan/datasets/ScreenSpot/screenspot_imgs/mobile_1ca5b944-293a-46a1-af95-eb35bc8a0b2a.png',
 'text': 'check the weather',
 'bbox': [0.09449152542372881,
  0.0475609756097561,
  0.34915254237288135,
  0.4091463414634146],
 'pred': [0.22, 0.1],
 'matched': True,
 'response': '{"action_type": 4, "click_point": (0.22,0.10)}\n',
 'type': 'icon',
 'source': 'ios',
 'select_mask': [True,
  True,
  True,
  True,
  True,
  True,
  False,
  False,
  False,
  False,
  False,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  False,
  False,
  False,
  False,
  True,
  True,
  False,
  False,
  False,
  False,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  False,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  False,
  True,
  True,
  True,
  True,
  True,
  True,
  True,
  False,
  True,
  False,
  True,
  True,
  True,
  True,


In [5]:
def show_instance(idx):
    global last_annotated_img, last_save_path
    inst = data[idx]
    # Prepare messages
    messages = [{
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": inst["img_path"],
            },
            {"type": "text", "text": inst["text"]},
        ],
    }]
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)
    
    # Retrieve and format the mask
    mask_np = np.array(inst["select_mask"], dtype=bool)
    
    # Create per-instance directory
    vis_dir = os.path.join(vis_base, f"inst_{idx:03d}")
    os.makedirs(vis_dir, exist_ok=True)
    
    # Generate demo.png with select_mask
    processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
        select_mask=mask_np,
        vis_dir=vis_dir
    ).to(device)
    
    # Load and annotate the generated image
    img_path = os.path.join(vis_dir, "demo.png")
    with out:
        clear_output(wait=True)
        if os.path.exists(img_path):
            img = Image.open(img_path)
            draw = ImageDraw.Draw(img)
            w, h = img.size
            
            # Draw bounding box in green
            x1, y1, x2, y2 = inst['bbox']
            draw.rectangle([x1*w, y1*h, x2*w, y2*h], outline="green", width=3)
            
            # Draw predicted point as a red circle
            px, py = inst['pred']
            r = 5
            draw.ellipse([(px*w - r, py*h - r), (px*w + r, py*h + r)], fill="red")
            
            # Update state for saving
            last_annotated_img = img
            last_save_path = os.path.join(vis_dir, "demo_annotated.png")
            
            # Display annotated image and info
            display(img)
            display(widgets.HTML(f"<b>Instance {idx+1}/{len(data)}:</b> {inst['text']}"))
        else:
            display(widgets.HTML(f"<b>Error:</b> No output found in {vis_dir}"))

def on_next(_):
    global index
    if index < len(data) - 1:
        index += 1
    show_instance(index)

def on_prev(_):
    global index
    if index > 0:
        index -= 1
    show_instance(index)

def on_save(_):
    if last_annotated_img is not None and last_save_path:
        last_annotated_img.save(last_save_path)
        with out:
            display(widgets.HTML(f"<span style='color:green;'><b>Saved:</b> {last_save_path}</span>"))

# Create navigation and save buttons
prev_btn = widgets.Button(description='Previous')
next_btn = widgets.Button(description='Next')
save_btn = widgets.Button(description='Save')

prev_btn.on_click(on_prev)
next_btn.on_click(on_next)
save_btn.on_click(on_save)

# Display UI
display(widgets.HBox([prev_btn, next_btn, save_btn]))
display(out)

# Show the first instance
show_instance(index)

HBox(children=(Button(description='Previous', style=ButtonStyle()), Button(description='Next', style=ButtonSty…

Output()