-
Notifications
You must be signed in to change notification settings - Fork 2
/
plot_gr_trace.py
90 lines (77 loc) · 3.02 KB
/
plot_gr_trace.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import copy
import os
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from config import CONFIG
from evaluator.task_trace import DatasetHelper, TaskTrace
from evaluator.utils.visualization import plot_episode
helper = DatasetHelper(CONFIG.EPI_METADATA_PATH, CONFIG.GR_DATASET_PATH)
def plot_single_trace(episode: str, output_file: str):
task_description: str = helper.get_task_description_by_episode(episode)
trace: TaskTrace = helper.load_groundtruth_trace_by_episode(episode)
ui_infos_for_plot = []
step_id = 0
for ui_state in trace:
current_ui_state = {
"image": None,
"episode_id": None,
"category": None,
"step_id": None,
"goal": None,
"result_action": [None, None],
"result_touch_yx": None,
"result_lift_yx": None,
"image_height": 1140,
"image_width": 540,
"image_channels": 3,
"ui_positions": None,
"ui_text": None,
"ui_type": None,
"ui_state": None,
"essential_states": None,
}
img = Image.open(ui_state.screenshot_path).convert("RGB")
img_arr = np.array(img)
current_ui_state["image_height"] = img.height
current_ui_state["image_width"] = img.width
current_ui_state["category"] = helper.get_category_by_episode(
episode
).value.upper()
current_ui_state["image"] = img_arr
current_ui_state["episode_id"] = episode
current_ui_state["step_id"] = step_id
step_id += 1
current_ui_state["goal"] = task_description
current_ui_state["result_action"][0] = ui_state.action.action_type
current_ui_state["result_action"][1] = ui_state.action.typed_text
current_ui_state["result_touch_yx"] = ui_state.action.touch_point_yx
current_ui_state["result_lift_yx"] = ui_state.action.lift_point_yx
if ui_state.state_type == "groundtruth":
ess = ui_state.essential_state
current_ui_state["essential_states"] = ess
# passing the whole ui_state for a quick implementation
current_ui_state["ui_state"] = ui_state
ui_infos_for_plot.append(copy.deepcopy(current_ui_state))
plot_episode(
ui_infos_for_plot,
show_essential_states=True,
show_annotations=False,
show_actions=True,
)
plt.savefig(output_file, pad_inches=0, bbox_inches="tight")
def plot_all():
output_path = "gr_trace_plots/"
if not os.path.exists(output_path):
os.mkdir(output_path)
from concurrent.futures import ProcessPoolExecutor
with ProcessPoolExecutor(max_workers=60) as e:
for epi in helper.get_all_episodes():
category = helper.get_category_by_episode(epi).value
e.submit(
plot_single_trace,
episode=epi,
output_file=os.path.join(output_path, f"{category}_{epi}.png"),
)
if __name__ == "__main__":
plot_all()