In [1]:
import json
import ipywidgets as widgets
from IPython.display import display, Image
import os
from Qwen2VL_uigraph.model.processing_qwen2_vl import Qwen2VLProcessor
from qwen_vl_utils import process_vision_info
import ast  # To safely evaluate JSON-like strings
from PIL import Image, ImageDraw
import io

In [2]:
# Function to draw click point on image
def draw_click_point(img_path, click_x, click_y, bbox, pred=False):
    """
    img_path: str
    click_x, click_y : ralative coordinate (0-1)
    bbox, list : [x_low, y_low, x_high, y_high] (0-1)
    pred : model prediction | ground truth
        if ground truth, no output visualized image
    """
    if os.path.exists(img_path):
        img = Image.open(img_path)
        w, h = img.size  # Get image dimensions
        
        # Convert relative to absolute coordinates
        abs_x = int(click_x * w)
        abs_y = int(click_y * h)
        
        # Draw the dot
        draw = ImageDraw.Draw(img)
        dot_radius = 10  # Adjust dot size if needed
        draw.ellipse((abs_x - dot_radius, abs_y - dot_radius, abs_x + dot_radius, abs_y + dot_radius), fill="red")

        # Draw the bounding box (if exists)
        if bbox: # [0.278, 0.64, 0.528, 0.688]
            bbox_x_top_left = round(bbox[0] * w)
            bbox_y_top_left = round(bbox[1] * h)
            bbox_x_bot_right = round(bbox[2] * w)
            bbox_y_bot_right = round(bbox[3] * h)
            
            bbox_coords = [(bbox_x_top_left, bbox_y_top_left), (bbox_x_bot_right, bbox_y_bot_right)]
            draw.rectangle(bbox_coords, outline="blue", width=3)  # Blue bbox

        # Save the modified image temporarily
        if not pred:
            temp_img_path = "./visualize_imgs/image.png"
        else:
            temp_img_path = "./visualize_imgs/pred_image.png"

        img.save(temp_img_path)
        # Return the new image path
        return temp_img_path
    else:
        return None

In [3]:
# Load JSON file
naive_json_path = "/home/syc/intern/wanshan/Thesis_result/ScreenSpot/UIGraph/all/screenspot_qwen2vl-7b_dropratio-0.0_web.json"

with open(naive_json_path, "r", encoding="utf-8") as f:
    naive_data = json.load(f)



In [32]:
uigraph_json_path = "/home/syc/intern/wanshan/Thesis_result/ScreenSpot/Prune_layer_2/screenspot_qwen2vl-7b_uigraph_dropratio-0.2_web-prune-layer_2.json"
with open(uigraph_json_path, "r", encoding="utf-8") as f:
    uigraph_data = json.load(f)
# Filter instances where Ele_match is False


In [33]:
uigraph_data[0]

{'img_path': '/data/data1/syc/intern/wanshan/datasets/ScreenSpot/screenspot_imgs/mobile_1ca5b944-293a-46a1-af95-eb35bc8a0b2a.png',
 'text': 'check the weather',
 'bbox': [0.09449152542372881,
  0.0475609756097561,
  0.34915254237288135,
  0.4091463414634146],
 'pred': [0.19, 0.1],
 'matched': True,
 'response': '{"action_type": 4, "click_point": (0.19,0.10)}\n',
 'type': 'icon',
 'source': 'ios'}

In [34]:
naive_data[0]['bbox']


[0.09449152542372881,
 0.0475609756097561,
 0.34915254237288135,
 0.4091463414634146]

In [35]:
mismatched_idxs  = []
for sample_idx, sample in enumerate(naive_data):
    if 'matched' in sample and 'matched' in uigraph_data[sample_idx]:
        if sample['matched'] == True and uigraph_data[sample_idx]['matched'] == False and 'web' in sample['img_path']:
            mismatched_idxs.append(sample_idx)
len(mismatched_idxs)

26

In [36]:
# for idx in mismatched_idxs:
#     uigraph_data[idx]['pred'][1] += 0.03

In [37]:
for sample_idx in mismatched_idxs:
    if naive_data[sample_idx]['img_path'] != uigraph_data[sample_idx]['img_path']:
        print('mismatched image_path', naive_data[sample_idx]['img_path'])

In [38]:
model_path = "/data/data1/syc/intern/wanshan/models/Qwen2-VL-2B-Instruct"
# model_path = "/data/data1/syc/intern/wanshan/models/showlab/ShowUI-2B_edited"

# min_pixel = 1344*28*28
# max_pixel = 1680*28*28
# 1. Screenshot -> Graph
uigraph_train = True  # Enable ui graph during training
uigraph_test = True  # Enable ui graph during inference
uigraph_diff = 1  # Pixel difference used for constructing ui graph
uigraph_rand = False  # Enable random graph construction
# 2. Graph -> Mask
uimask_pre = True  # Prebuild patch selection mask in the preprocessor (not in model layers) for efficiency
uimask_ratio = 0.4  # Specify the percentage of patch tokens to skip per component
uimask_rand = False  # Enable random token selection instead of uniform selection


processor = Qwen2VLProcessor.from_pretrained(
    model_path,
    # min_pixels= min_pixel,
    # max_pixels = max_pixel,
    uigraph_train=uigraph_train,
    uigraph_test=uigraph_test,
    uigraph_diff=uigraph_diff,
    uigraph_rand=uigraph_rand,
    uimask_pre=True,
    uimask_ratio=uimask_ratio,
    uimask_rand=uimask_rand,
)

In [39]:
def load_visualize(image_path):
    messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": image_path,
                # "min_pixels": min_pixel,
                # "max_pixels": max_pixel,
            },
            {"type": "text", "text": "Describe this image."},
        ],
    }
    ]
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, _ = process_vision_info(messages)

    inputs = processor(
        text=[text],
        images=image_inputs,
        padding=True,
        return_tensors="pt",
        vis_dir="./visualize_imgs" # this folder to save visualization 
    )
    with open("./visualize_imgs/demo.png", "rb") as f:
        return f.read()

In [40]:
"""
{'img_path': '/data/data1/syc/intern/wanshan/datasets/ScreenSpot/screenspot_imgs/mobile_1ca5b944-293a-46a1-af95-eb35bc8a0b2a.png',
 'text': 'check the weather',
 'bbox': [0.09449152542372881,
  0.0475609756097561,
  0.34915254237288135,
  0.4091463414634146],
 'pred': [0.2, 0.17],
 'matched': True,
 'response': '{"action_type": 4, "click_point": (0.20,0.17)}\n',
 'type': 'icon',
 'source': 'ios'}
"""
# Initialize index
sample_idx = 0

# Function to update display
def update_display():
    global sample_idx
    
    if sample_idx < len(mismatched_idxs):
        idx = mismatched_idxs[sample_idx]
        naive_sample = naive_data[idx]
        uigraph_sample = uigraph_data[idx]
        # episode_ref = ground_truth_bbox[episode_index]

        sentence_label.value = f"**Naive Action(s):** {naive_sample}"
        sentece_uigraph_label.value = f"**UI Graph Action(s):** {uigraph_sample}"
        # Extract click coordinates
        try:
            click_x, click_y = naive_sample.get("pred", (0, 0))
            uigraph_click_x, uigraph_click_y = uigraph_sample.get("pred", (0, 0))
            # uigraph_click_y += 0.03
        except Exception as e:
            click_x, click_y = 0, 0  # Default to top-left corner on error

        # get bbox from groundtruth
        bbox = naive_sample['bbox']
        
        # Load and update image
        img_path = naive_sample["img_path"]
        
        _ = load_visualize(img_path)
        naive_image_path = draw_click_point(img_path, click_x, click_y, bbox)
        uigraph_image_path = draw_click_point("./visualize_imgs/demo.png", uigraph_click_x, uigraph_click_y, bbox, pred=True)
        if os.path.exists(img_path):
            with open(naive_image_path, "rb") as f:
                image_widget.value = f.read()
            with open(uigraph_image_path, "rb") as f:
                image_pred.value = f.read()
            # image_patch.value = load_visualize(img_path)
        else:
            sentence_label.value += f"\n(Error: Image not found at {img_path})"
    else:
        sentence_label.value = ""
        image_widget.value = b""

# Next button function
def next_step(_):
    global sample_idx
    
    # Move to the next item in the navigation sequence
    if sample_idx < len(mismatched_idxs) - 1:
        sample_idx += 1
        update_display()


    update_display()

# Widgets
# instruction_label = widgets.HTML()
sentence_label = widgets.HTML()
sentece_uigraph_label = widgets.HTML()
image_widget = widgets.Image(format='png', width=600)  # Set Image Size
image_pred = widgets.Image(format='png', width=600)
# image_patch = widgets.Image(format='png', width=600)  # Set Image Size

# Layout to show images side by side
image_box = widgets.HBox([image_widget, image_pred])  # Side-by-side


next_button = widgets.Button(description="Next Step")
next_button.on_click(next_step)

# Initial display
update_display()

# Layout
# display(sentence_label, sentece_uigraph_label,  image_box, image_patch, next_button)
display(sentence_label, sentece_uigraph_label,  image_box, next_button)


HTML(value='**Naive Action(s):** {\'img_path\': \'/data/data1/syc/intern/wanshan/datasets/ScreenSpot/screenspo…

HTML(value='**UI Graph Action(s):** {\'img_path\': \'/data/data1/syc/intern/wanshan/datasets/ScreenSpot/screen…

HBox(children=(Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\n\x00\x00\x00\x05\xa0\x08\x06\x00\x00…

Button(description='Next Step', style=ButtonStyle())