# Load Model

In [None]:
import ast
import torch
from PIL import Image, ImageDraw
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
from qwen_vl_utils import process_vision_info

class ShowUI:
    def __init__(self, model_path: str):
        self.model_path = model_path
        self.model = None
        self.processor = None
        self.nf4_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_dtype=torch.bfloat16
        )
        self.min_pixels = 256*28*28
        self.max_pixels = 1344*28*28

    def load_model(self):
        
        print(f"Loading model from {self.model_path}...")
        
        self.model = Qwen2VLForConditionalGeneration.from_pretrained(
            self.model_path,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            # quantization_config=self.nf4_config
        )
        
        print("Model loaded successfully.")
        
        print("Loading processor...")
        
        self.processor = AutoProcessor.from_pretrained(
            self.model_path,
            size={"shortest_edge": self.min_pixels, "longest_edge": self.max_pixels},
            use_fast=True
        )
        
        print("Processor loaded successfully.")

    def invoke(self, img_url: str, query: str):
        image = Image.open(img_url)

        print(f"Image loaded from {img_url}, size: {image.size}")
        print(f"Query: {query}")
        

        print("Processing messages for model input...")

        _SYSTEM = (
            "Based on the screenshot of the page, I give a text description and you give its corresponding location. "
            "The coordinate represents a clickable location [x, y] for an element, which is a relative coordinate on the screenshot, scaled from 0 to 1."
        )
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": _SYSTEM},
                    {"type": "image", "image": img_url, "min_pixels": self.min_pixels, "max_pixels": self.max_pixels},
                    {"type": "text", "text": query}
                ],
            }
        ]
        
        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True,
        )
        image_inputs, video_inputs = process_vision_info(messages)
        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to("cuda")
        print("Inputs prepared for model generation.")
        generated_ids = self.model.generate(**inputs, max_new_tokens=128)
        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        print("Model generation completed.")
        print("Decoding generated IDs to text...")
        output_text = self.processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]
        print(f"Output text: {output_text}")
        print("Decoding completed.")
        
        click_xy = ast.literal_eval(output_text)
        x, y = click_xy[0] * image.width, click_xy[1] * image.height

        
        return x, y, image

    def draw_point(self, image, x, y, radius=2):
        print(f"Drawing point at ({x}, {y}) with radius {radius} on the image.")
        draw = ImageDraw.Draw(image)
        draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill='red', outline='red')
        image.show()


In [None]:
import torch
print(torch.cuda.is_available())

In [None]:
model_path = "D:/Project/showui-2b"
    
showui = ShowUI(model_path)
showui.load_model()

# GUI Navigation
## Set up system prompt.

In [None]:
img_url = "D:/Project/my_dataset/unlabel_images/image0.png"
query = "重置"

x, y, image = showui.invoke(img_url, query)
print(f"Click coordinates: ({x}, {y})")
showui.draw_point(image, x, y)