In [5]:
from datasets import Dataset
import json
from PIL import Image
import os
def load_jsonl_dataset(jsonl_path):
    with open(jsonl_path, "r", encoding="utf-8") as f:
        data = [json.loads(line) for line in f]
        data = data[:]
    return Dataset.from_list(data)

def load_json_dataset(json_path):
    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)
    return data


from PIL import Image, ImageDraw, ImageFont

def draw_bboxes(
    img,
    bboxes: list,
    labels: list | None = None,
    colors: list | None = None,
    save_path: str | None = None,
    line_width: int = 3,
    font_path: str | None = None,
    font_size: int = 16,
):
    """
    在图片上绘制 bounding boxes。

    参数
    ----
    bboxes : list[tuple[int, int, int, int]]
        每个 bounding box 以 (xmin, ymin, xmax, ymax) 像素坐标表示。
    labels : list[str] | None
        与 bboxes 对应的文本标签（例如类名或置信度）。可省略。
    colors : list[str] | None
        每个框的颜色（任意 Pillow 认识的颜色字符串或 RGB 元组）。长度与 bboxes 相同，缺省则自动循环常用色。
    save_path : str | None
        保存路径；若为 None 则只返回 Image 对象，不落盘。
    line_width : int
        框线宽度。
    font_path : str | None
        字体文件路径；缺省时 Pillow 会用默认位图字体，可能不支持中文。
    font_size : int
        标签字号。
    """
    draw = ImageDraw.Draw(img)

    # 颜色准备
    default_palette = ["red", "lime", "blue", "yellow", "cyan", "magenta", "orange"]
    if colors is None:
        colors = [default_palette[i % len(default_palette)] for i in range(len(bboxes))]

    # 字体准备
    if labels:
        if font_path:
            font = ImageFont.truetype(font_path, font_size)
        else:
            # 某些 Pillow 版本的默认字体不支持中文；如需中文请显式指定 font_path
            font = ImageFont.load_default()

    # 主循环
    for i, (xmin, ymin, xmax, ymax) in enumerate(bboxes):
        color = colors[i]
        draw.rectangle([xmin, ymin, xmax, ymax], outline=color, width=line_width)

        if labels and i < len(labels):
            text = labels[i]
            text_size = draw.textlength(text, font=font)
            text_height = font.getbbox(text)[3] - font.getbbox(text)[1]
            # 文字背景框
            draw.rectangle(
                [xmin, ymin - text_height - 4, xmin + text_size + 6, ymin],
                fill=color,
            )
            # 文字
            draw.text(
                (xmin + 3, ymin - text_height - 2),
                text,
                fill="black",
                font=font,
            )

    if save_path:
        img.save(save_path)
        print(f"结果已保存到: {save_path}")
    return img

def draw_line(
    img,
    pt1: tuple[int, int],
    pt2: tuple[int, int],
    color: str | tuple[int, int, int] = "red",
    width: int = 3,
    save_path: str | None = None,
    canvas_size: tuple[int, int] | None = None,
    bg_color: str | tuple[int, int, int] = "white",
) -> Image.Image:
    """
    给定两点坐标，在图片上画线。

    参数
    ----
    pt1, pt2 : (x, y)
        线段两端点坐标，像素单位。
    color : str | RGB
        线条颜色；可用 Pillow 支持的任何颜色名称或 (R,G,B) 元组。
    width : int
        线宽（像素）。
    save_path : str | None
        若指定则保存到该路径，否则只返回 Image 对象。
    canvas_size : (w, h) | None
        当 img_path 为 None 时必须指定，用于创建空白画布的尺寸。
    bg_color : str | RGB
        新建画布时的背景色。
    """
    draw = ImageDraw.Draw(img)
    draw.line([pt1, pt2], fill=color, width=width)

    if save_path:
        img.save(save_path)
        print(f"已保存到 {save_path}")

    return img


def abstract_visual_token_single_input_images_preprocess_function(sample, dataset_root):
    
    if 'PixelReasoner' in dataset_root:
        conversations = [
            
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image},
                    {"type": "text", "text": sample["question"]},
                ],
            }
        ]
        qid = sample["qid"]
        image = Image.open(os.path.join(dataset_root, sample["image"])).convert("RGB")
        width, height = image.size
        image.save(f"./debug_output/step_image_{qid}.jpg")
        # Format conversations
        print(sample["question"])
        steps = sample["response_steps"]
        for step in steps:
            step_content = [
                        {"type": "text", "text": step["response_str"]},
                    ]
            print(step["response_str"])
            if step["manipulation"]:
                if step["manipulation"]["type"] == "crop":
                    if "<abs_vis_token>" in step["response_str"]:
                        bbox_norm = step["manipulation"]["parameters"]
                        print(bbox_norm)
                        x_min = int(bbox_norm[0] * width)
                        y_min = int(bbox_norm[1] * height)
                        x_max = int(bbox_norm[2] * width)
                        y_max = int(bbox_norm[3] * height)
                        step_content.append({"type": "image", "image": image.crop((x_min, y_min, x_max, y_max))})
                    
                        image.crop((x_min, y_min, x_max, y_max)).save(f"./debug_output/step_image_{qid}_{steps.index(step)}.jpg")
                        conversations.append({
                            "role": "assistant",
                            "content": step_content
                        })
    elif "CoM" in dataset_root:
        
        qid = sample["qid"]
        image = Image.open(os.path.join(dataset_root, sample["image"])).convert("RGB")
        width, height = image.size
        # Format conversations
        conversations = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image},
                    {"type": "text", "text": sample["question"]},
                ],
            }
        ]
        
        steps = sample["response_steps"]
        

        print(qid, "[Question]",sample["question"], "[Answer]",sample["answer"])
        for i, step in enumerate(steps):
            step_content = [
                        {"type": "text", "text": step["response_str"]},
                    ]
            print("[Response]",step["response_str"])
            if step["manipulation"]:
                print(i, step["manipulation"]["type"])
                if step["manipulation"]["type"] == "crop_and_zoomin":
                    bbox = step["manipulation"]["parameters"]
                    #print(bbox)
                    x_min, y_min, x_max, y_max = bbox
                    image = image.crop((x_min, y_min, x_max, y_max))
                    step_content.append({"type": "image", "image": image})

                elif step["manipulation"]["type"] == "grounding":
                    bboxes = step["manipulation"]["parameters"]
                    if bboxes is None:
                        return None
                    print(bboxes)
                    if not isinstance(bboxes[0], list):
                        bboxes = [bboxes]
                    image = draw_bboxes(image, bboxes=bboxes)
    
                elif step["manipulation"]["type"] == "line":
                    pts = step["manipulation"]["parameters"]
                    print(pts)
                    if pts[2] - pts[0] < pts[3] - pts[1]:
                        pts[2] = pts[0]
                    else:
                        pts[3] = pts[1]
                    image = draw_line(image, (pts[0], pts[1]), (pts[2], pts[3]))
                    
                image.save(f"./debug_output/CoM/step_image_{qid}_{steps.index(step)}.jpg")
            conversations.append({
                        "role": "assistant",
                        "content": step_content
                    })
    elif 'CoF' in dataset_root:
        def get_bbox_and_rmv_tool(text: str):
            s = text.find("<tool_call>")
            e = text.find("</tool_call>")
            if s == -1 or e == -1:
                return None, text
            tool_call_str = text[s+len("<tool_call>"):e]
            bbox = json.loads(tool_call_str)["arguments"]["bbox_2d"]
            return bbox, text[:s]
            
        if len(sample['images']) < 2:
            return None
        
        input_image_path = sample["images"][0]
        qid = int(input_image_path.split('/')[0])
        steps = sample['messages']
        input_image = Image.open(os.path.join(dataset_root, 'images', input_image_path)).convert("RGB")
        question = steps[1]["content"]
        input_image.save(f"./debug_output/CoF/step_image_{qid}_0.jpg")
        print(qid, "[Question]",question)
        conversations = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": input_image},
                    {"type": "text", "text": question},
                ],
            }
        ]
        
        img_id = 1
        for step in steps[2:]:
            if step["role"] == "user":
                continue
            
            bbox, resp_wo_tool = get_bbox_and_rmv_tool(step["content"])
            print("[Response]",resp_wo_tool)
            if bbox is not None:
                image = Image.open(os.path.join(dataset_root, 'images', sample["images"][img_id])).convert("RGB")
                image.save(f"./debug_output/CoF/step_image_{qid}_{steps.index(step)}.jpg")
                img_id += 1
            step_content = [
                {"type": "text", "text": resp_wo_tool}
            ]

            
            

    return conversations

data_name = "CoM"
data_mapping = {
    "PixelReasoner": {
        "data_path": "/data1/qxwang/datasets/multimodal/PixelReasoner-SFT-Data/processed_data.json",
        "dataset_root": "/data1/qxwang/datasets/multimodal/PixelReasoner-SFT-Data"
    },
    "CoM": {
        "data_path": "/data1/qxwang/datasets/multimodal/CoMDataset/com_math_processed.jsonl",
        "dataset_root": "/data1/qxwang/datasets/multimodal/CoMDataset"
    },
    "CoF": {
        "data_path": "/data1/qxwang/datasets/multimodal/CoF-SFT-Data-5.4k/cof_sft_data.json",
        "dataset_root": "/data1/qxwang/datasets/multimodal/CoF-SFT-Data-5.4k"
    }
}

data_path = data_mapping[data_name]["data_path"]
dataset_root = data_mapping[data_name]["dataset_root"]
train_dataset = load_json_dataset(data_path)

preprocess_function = abstract_visual_token_single_input_images_preprocess_function
train_dataset = [x for x in [preprocess_function(sample, dataset_root) for sample in train_dataset[0:20]] if x is not None]



0 [Question] What is the lowest accuracy reported in the whole chart? [Answer] 1
[Response] First, find the location of the bush-troop in the image. <abs_vis_token></abs_vis_token>
0 grounding
[[231, 154, 262, 218]]
[Response] Then, draw a line segment from the leftmost side of the bush-troop. <abs_vis_token></abs_vis_token>
1 line
[235, 220, 240, 337]
[Response] Then, draw segments from the rightmost side of the bush-troop. <abs_vis_token></abs_vis_token>
2 line
[258, 220, 264, 340]
[Response] Your output: Then, locate the position of the intersection points of the Line on the image with the horizontal axis. <abs_vis_token></abs_vis_token>
3 grounding
[[226, 297, 248, 322], [250, 297, 270, 322]]
[Response] Output: It can be known that the length of the "bush-troop" is 1.
[Response] So I can understand that the answer is 1.
1 [Question] Move the ruler to measure the length of the nail to the nearest inch. The nail is about (_) inches long. [Answer] 2
[Response] Draw the leftmost vertic