In [None]:
from PIL import Image, ImageColor, ImageDraw
from absl import logging
from copy import deepcopy

import json
import math
import matplotlib.pyplot as plt

from server.utils import base64_to_image

In [None]:
plt.rcParams["font.family"] = "sans-serif"
plt.rcParams["font.sans-serif"] = ["WenQuanYi Zen Hei", "Noto Sans CJK SC"]
plt.rcParams["axes.unicode_minus"] = False

In [None]:
logging.set_verbosity("info")

In [None]:
def mark_action(image, action, color="green"):
    """
    action (android_world/android_world/env/actuation.py):

    # scroll direction
    x_min, y_min, x_max, y_max = (0, 0, screen_width, screen_height)
    start_x, start_y = (x_min + x_max) // 2, (y_min + y_max) // 2
    if direction == 'down':
        end_x, end_y = (x_min + x_max) // 2, y_min
    elif direction == 'up':
        end_x, end_y = (x_min + x_max) // 2, y_max
    elif direction == 'right':
        end_x, end_y = x_min, (y_min + y_max) // 2
    elif direction == 'left':
        end_x, end_y = x_max, (y_min + y_max) // 2

    # swipe direction
    mid_x, mid_y = 0.5 * screen_width, 0.5 * screen_height
    if direction == 'down':
        start_x, start_y = mid_x, 0
        end_x, end_y = mid_x, screen_height
    elif direction == 'up':
        start_x, start_y = mid_x, screen_height
        end_x, end_y = mid_x, 0
    elif direction == 'left':
        start_x, start_y = 0, mid_y
        end_x, end_y = screen_width, mid_y
    elif direction == 'right':
        start_x, start_y = screen_width, mid_y
        end_x, end_y = 0, mid_y
    """

    def _plot_point(draw, point, radius, fill="green"):
        draw.ellipse(
            [
                (point[0] - radius, point[1] - radius),
                (point[0] + radius, point[1] + radius),
            ],
            fill=fill,
        )

    def _plot_arrow(draw, start, end, fill="green", arrow_size=10, width=2):
        draw.line([start, end], fill=fill, width=width)

        dx = end[0] - start[0]
        dy = end[1] - start[1]
        theta = -math.atan2(dy, dx)
        angle = math.radians(30)
        adj_angle1 = theta + angle
        adj_angle2 = theta - angle

        point1 = (
            end[0] - arrow_size * math.cos(adj_angle1),
            end[1] + arrow_size * math.sin(adj_angle1),
        )
        point2 = (
            end[0] - arrow_size * math.cos(adj_angle2),
            end[1] + arrow_size * math.sin(adj_angle2),
        )
        draw.polygon([end, point1, point2], fill=fill)

    w, h = image.size
    radius = min(image.size) * 0.05
    arrow_size = min(image.size) * 0.05
    line_width = int(w * 0.02)

    if isinstance(color, str):
        try:
            color = ImageColor.getrgb(color)
            color = color + (128,)
        except ValueError:
            color = (0, 255, 0, 128)
    else:
        color = (0, 255, 0, 128)

    overlay = Image.new("RGBA", image.size, (255, 255, 255, 0))
    overlay_draw = ImageDraw.Draw(overlay)

    action = deepcopy(action)

    action_type = action.pop("action_type")
    action_args = action

    if action_type in ["click", "double_tap", "long_press"]:
        x, y = action_args["x"], action_args["y"]
        _plot_point(draw=overlay_draw, point=(x, y), radius=radius, fill=color)
    elif action_type == "input_text":
        x, y = action_args.get("x"), action_args.get("y")
        if x is not None and y is not None:
            _plot_point(draw=overlay_draw, point=(x, y), radius=radius, fill=color)
    elif action_type == "scroll":
        direction = action_args["direction"]
        x_min, y_min, x_max, y_max = (0, 0, w, h)
        start = (x_min + x_max) // 2, (y_min + y_max) // 2
        if direction == "down":
            end = (x_min + x_max) // 2, y_min + arrow_size
        elif direction == "up":
            end = (x_min + x_max) // 2, y_max - arrow_size
        elif direction == "right":
            end = x_min, (y_min + y_max) // 2
        elif direction == "left":
            end = x_max, (y_min + y_max) // 2
        _plot_arrow(
            draw=overlay_draw,
            start=start,
            end=end,
            fill=color,
            arrow_size=arrow_size,
            width=line_width,
        )
    elif action_type == "swipe":
        start = action_args.get("touch_xy")
        end = action_args.get("lift_xy")
        if start is None or end is None:
            direction = action_args["direction"]
            mid_x, mid_y = 0.5 * w, 0.5 * h
            if direction == "down":
                start = mid_x, 0
                end = mid_x, h
            elif direction == "up":
                start = mid_x, h
                end = mid_x, 0
            elif direction == "left":
                start = 0, mid_y
                end = w, mid_y
            elif direction == "right":
                start = w, mid_y
                end = 0, mid_y
        _plot_arrow(
            draw=overlay_draw,
            start=start,
            end=end,
            fill=color,
            arrow_size=arrow_size,
            width=line_width,
        )
    else:
        pass

    image = image.convert("RGBA")
    return Image.alpha_composite(image, overlay)


def process_image(image_base64, action=None, width=240):

    image = base64_to_image(image_base64=image_base64)
    if action is not None:
        image = mark_action(image=image, action=action)
    w, h = image.size
    ratio = width / w
    height = int(ratio * h)
    image = image.resize(size=(width, height), resample=Image.Resampling.LANCZOS)
    return image

In [None]:
def plot_trajectory(trajectory, title=None):
    num_steps = len(trajectory)
    n_cols = 4
    n_rows = (num_steps + n_cols - 1) // n_cols
    _, axs = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 6 * n_rows))
    for step, ax in zip(trajectory, axs.ravel()):
        action = step["action"]
        image = process_image(step["observation"], action=json.loads(action))
        action = action or ""
        ax.imshow(image)
        ax.set_title(action)
        ax.axis("off")

    for ax in axs.flat:
        if not ax.has_data():
            ax.axis("off")

    if title:
        plt.suptitle(title, fontsize=16)
        
    plt.tight_layout()
    plt.show()

In [None]:
record_file = "xxx.json"
with open(f"../records/{record_file}", mode="r") as f:
    record = json.load(f)

In [None]:
task = record["task"]

trajectory = record["trajectory"]
num_steps = len(trajectory)

logging.info(
    f"task: {task}\n# of steps: {num_steps}"
)

In [None]:
plot_trajectory(trajectory=trajectory, title=f"用户指令：{task}")

In [None]:
# for step in trajectory:
#     response = step["response"]
#     action = step["action"]
#     logging.info(f"response: {response}")
#     logging.info(f"action: {action}")
#     screenshot = process_image(step["observation"], action=json.loads(action))
#     screenshot.show()