In [1]:
import re

def judge_node_same(obs_nodes_info_a, obs_nodes_info_b, threshold=3):
    """
    判断两个状态之间的节点差异是否超过阈值
    obs_nodes_info_a 和 b 都是一个字典，字典的key是节点id，value 是一个字典，例如：
    '11765': {'backend_id': 19269,
    'union_bound': [0.0, 0.0, 10.0, 10.0],
    'text': "[11765] RootWebArea 'Postmill' focused: True"}

    对于 obs_nodes_info_a 的一个元素，如果 obs_nodes_info_b 中存在一个元素，其 union_bound 一致，并且 text 去掉 [xxx] 后一致，则认为两个元素是相同的
    如果不同的元素小于 threshold，则认为两个状态是相同的
    这里所谓的 key 和 backend_id 都不重要，不作为参考依据

    返回值：
    - 如果两个状态差异小于等于 threshold，相同，返回 True
    - 如果两个状态差异超过阈值，返回 False
    """

    diff_count = 0

    b_nodes_set = set()
    for node_id, node_info in obs_nodes_info_b.items():
        union_bound = node_info['union_bound']
        cleaned_text = re.sub(r'\[.*?\]', '', node_info['text'])
        b_nodes_set.add(f"{union_bound}_{cleaned_text}")

    for node_id, node_info in obs_nodes_info_a.items():
        union_bound = node_info['union_bound']
        cleaned_text = re.sub(r'\[.*?\]', '', node_info['text'])
        if f"{union_bound}_{cleaned_text}" not in b_nodes_set:
            diff_count += 1
    # print(diff_count)
    return diff_count <= threshold

In [2]:
def a11y_to_components(a11y):
    """
    将 a11y 转换为组件列表
    """
    # 对于每一行
    # 如果包含 [xxx]，则去掉 [xxx]，将后面的内容视作一个组件

    components = []
    for line in a11y.split("\n"):
        # 去掉前后空格
        line = line.strip()
        if line == "":
            continue
        match = re.match(r'\[(\d+)\]', line)
        if match:
            # component_id = match.group(1)
            component_content = line[match.end():].strip()
            components.append(component_content)
        else:
            components.append(line)
    return components


In [3]:
def judge_node_same_from_a11y(a11y_a, a11y_b, threshold=3):
    """
    判断两个状态之间的节点差异是否超过阈值
    a11y_a 和 b 都是一个字符串，字符串的格式为例如：
    [5207] heading 'relationship_advice — relationship_advice'
        [5209] link 'relationship_advice — relationship_advice'
    [5215] button 'Subscribe No subscribers'
        [5571] generic 'No subscribers'
    [5216] StaticText '5,721 submissions'

    首先去掉前面的 [xxx]，将后面的内容视作一个组件

    对于 a11y_a 的一个组件，如果 a11y_b 中存在一个组件，其组件内容一致，则认为两个组件是相同的
    如果不同的组件小于 threshold，则认为两个状态是相同的

    返回值：
    - 如果两个状态差异小于等于 threshold，相同，返回 True
    - 如果两个状态差异超过阈值，返回 False
    """

    diff_count = 0

    a_components = a11y_to_components(a11y_a)
    b_components = a11y_to_components(a11y_b)

    for a_component in a_components:
        if a_component not in b_components:
            diff_count += 1
    # print(diff_count)
    return diff_count <= threshold

In [4]:
from rank_bm25 import BM25Okapi
def bm25_retrieval(query, corpus, top_n=3):
    tokenized_corpus = [doc.split(" ") for doc in corpus]
    bm25 = BM25Okapi(tokenized_corpus)
    tokenized_query = query.split(" ")
    top_n_docs = bm25.get_top_n(tokenized_query, corpus, n=top_n)
    top_n_docs_index = [corpus.index(doc) for doc in top_n_docs]
    return top_n_docs_index

In [None]:
import os

# # /home/zjusst/qms/webarena/result_stage_1_explore/trajs 底下的所有文件名
# file_names = os.listdir("/home/zjusst/qms/webarena/result_stage_1_explore/trajs")
# file_names = [file_name for file_name in file_names if file_name.endswith(".json")]
# # file_names.sort() 自定义排序，按照 _ 分割，取后面的数字排序
# file_names.sort(key=lambda x: int(x.split("_")[-1].split(".")[0]))
# len(file_names)


104

In [None]:
from tqdm import tqdm
import json

# unique_page_ids = []  # (file_name, traj_idx, 0|1代表前还是后, real_url) 
# unique_page_a11ys = []
# #unique_page_nodes = []
# triplets = []  # (page_id, action, page_id)

# for file_name in tqdm(file_names):
#     with open(os.path.join("/home/zjusst/qms/webarena/result_stage_1_explore/trajs", file_name), "r") as f:
#         trajs = json.load(f)

#     for traj_idx, traj in enumerate(trajs):
#         if len(unique_page_ids) == 0:
#             unique_page_ids.append((file_name, traj_idx, 0, traj['url_real_before']))
#             unique_page_a11ys.append(traj["a11y_before"])
#             #unique_page_nodes.append(traj["state_before"]['text']['obs_nodes_info'])

#         prev_page_id = -1
#         # 从 unique_page_a11ys 检索出最接近的三个
#         top_n_docs_index = bm25_retrieval(traj["a11y_before"], unique_page_a11ys, 3)
#         for doc_idx in top_n_docs_index:
#             # if judge_node_same(traj["state_before"]['text']['obs_nodes_info'], unique_page_nodes[doc_idx], 3):
#             #     prev_page_id = unique_page_ids[doc_idx]
#             #     break
#             if judge_node_same_from_a11y(traj["a11y_before"], unique_page_a11ys[doc_idx], 3):
#                 prev_page_id = unique_page_ids[doc_idx]
#                 break
#         if prev_page_id == -1:
#             unique_page_ids.append((file_name, traj_idx, 0, traj['url_real_before']))
#             unique_page_a11ys.append(traj["a11y_before"])
#             #unique_page_nodes.append(traj["state_before"]['text']['obs_nodes_info'])
#             prev_page_id = unique_page_ids[-1]

#         action_str = traj["action_str"]

#         after_page_id = -1
#         top_n_docs_index = bm25_retrieval(traj["a11y_after"], unique_page_a11ys, 3)
#         for doc_idx in top_n_docs_index:
#             # if judge_node_same(traj["state_after"]['text']['obs_nodes_info'], unique_page_nodes[doc_idx], 3):
#             #     after_page_id = unique_page_ids[doc_idx]
#             #     break
#             if judge_node_same_from_a11y(traj["a11y_after"], unique_page_a11ys[doc_idx], 3):
#                 after_page_id = unique_page_ids[doc_idx]
#                 break
#         if after_page_id == -1:
#             unique_page_ids.append((file_name, traj_idx, 1, traj['url_real_after']))
#             unique_page_a11ys.append(traj["a11y_after"])
#             #unique_page_nodes.append(traj["state_after"]['text']['obs_nodes_info'])
#             after_page_id = unique_page_ids[-1]
            
#         triplets.append((prev_page_id, action_str, after_page_id))

# print(len(triplets))


100%|██████████| 104/104 [00:32<00:00,  3.20it/s]

1025





In [None]:
import csv

# with open("/home/zjusst/qms/webarena/result_stage_1_explore/flitered_triplets.csv", "w") as f:
#     writer = csv.writer(f)
#     writer.writerow(["prev_page_id", "action_str", "after_page_id"])
#     writer.writerows(triplets)


In [12]:
# 读取 /home/zjusst/qms/webarena/result_stage_1_explore/flitered_triplets.csv
with open("/home/zjusst/qms/webarena/result_stage_1_explore/flitered_triplets.csv", "r") as f:
    reader = csv.reader(f)
    # 跳过第一行
    next(reader)
    triplets = [row for row in reader]
len(triplets)

1025

In [18]:
triplets[3][2]

"('812_1.json', 1, 1, 'http://reddit.com/wiki')"

In [None]:
out_degree_dict = {}
for triplet in triplets:
    if triplet[2] not in out_degree_dict:
        out_degree_dict[triplet[2]] = 0
    out_degree_dict[triplet[2]] += 1

# 转为 list，由低到高排序
out_degree_list = sorted(out_degree_dict.items(), key=lambda x: x[1])


("('812_2.json', 8, 1, 'http://reddit.com/forums')", 1)

In [27]:
flitered_out_degree_list = []
for item in out_degree_list:
    real_url = eval(item[0])[3]
    # 如果 real_url 的  "/" 大于 3 个，则加入 flitered_out_degree_list
    if real_url.count("/") > 3 and real_url.startswith("http://reddit.com/"):
        flitered_out_degree_list.append(item)

len(flitered_out_degree_list)

66

In [None]:
import random
import time

def random_select_url_and_a11y(flitered_out_degree_list):
    # 种子为时间
    random.seed(time.time())
    record = flitered_out_degree_list[random.randint(0, len(flitered_out_degree_list) - 1)]
    file_name, traj_index, pos, real_url = eval(record[0])
    with open(f"/home/zjusst/qms/webarena/result_stage_1_explore/trajs/{file_name}", "r") as f:
        trajs = json.load(f)
    traj = trajs[traj_index]
    if pos == 0:
        a11y = traj["a11y_before"]
    else:
        a11y = traj["a11y_after"]
    return real_url, a11y
# random_select_url_and_a11y(flitered_out_degree_list)

In [54]:
personas = []
with open("/home/zjusst/qms/persona-hub/data/elite_personas_10000.jsonl", "r") as f:
    for line in f:
        persona = json.loads(line)
        personas.append(persona['persona'])

In [None]:
import json
import random


def random_persona(personas):
    random.seed(time.time())
    return personas[random.randint(0, len(personas) - 1)]

random_persona(personas)

'A geologist who is passionate about glacial studies and has a deep understanding of the geology of Antarctica. They are knowledgeable about the history and current state of glacial movements in the area, and have extensive experience in mapping and analyzing glacial ice sheets. They are also skilled in using satellite imagery and other geospatial data to study glacial processes and their impact on the environment. Additionally, they have a keen interest in the history and culture of the people who have lived and worked in Antarctica, and have a deep appreciation for the natural beauty and challenges of the region.'

In [None]:
real_url, a11y = random_select_url_and_a11y(flitered_out_degree_list)
real_url, a11y[:100]

('http://reddit.com/featured/hot',
 "Tab 0 (current): Postmill\n\n[9738] RootWebArea 'Postmill' focused: True\n\t[9768] HeaderAsNonLandmark '")