In [1]:
SYSTEM_PROMPT="""You are an expert web dynamics annotator. You will receive:

1. Current Observation: Accessibility tree before the action.
2. Current Action: User action.
3. Diff Tree: Actual ADDED, UPDATED, or DELETED changes in the accessibility tree.

Your Goal:
Translate the technical `Key Changes (Diff Tree)` into a natural language [Rationale] and [Next State] prediction.

Guidelines for Response:
1.  [Rationale]:
   Convert each Diff Tree change into natural-language descriptions of what was added, removed, or updated. 
   Do NOT infer user intent or website logic. 
   Do NOT describe anything except what the diff explicitly shows.


2.  State Description ([Next State]):
    Then, generate a description of the next web state based on the changed parts you mentioned.
    * High-Level & Functional: Do not say too much trivial details. Instead, give a high-level and functional description in detail after the action.
    * Focus on Changes: Focus solely on describing the changes that will happen.
    * Abstract & Generic: Avoid mentioning specific dynamic content (like specific prices, exact product titles) unless they are generic UI labels. For example, instead of saying "The price updated to $25.00", just say "The price value is updated".

Keep both sections concise. Avoid long paragraphs or speculation.

Response Format:
[Rationale]
Key changes in the accessibility tree based on [description of the action] would include:
1. [Logic for change 1]
2. [Logic for change 2]
...

[Next State] The expected effect is that:
1. [Detail of the next state 1]
2. [Detail of the next state 2]
3. [Detail of the next state 3]
...

"""

In [2]:
from datasets import load_dataset

ds1 = load_dataset("LangAGI-Lab/world_model_for_wa_tao_dataset", split="train")

In [3]:
import google.generativeai as genai

genai.configure(api_key='AIzaSyCjJak0xp3uW8YnClPedUQBhsTaYZOdTXI')


In [4]:
import json
import time
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

output_file = "./sft_annotations_gemini.jsonl"

def annotate_single_item(i, item):
    """处理单个item的函数"""
    current_obs = item.get('observation', '')
    current_act = item.get('current_action', '')
    
    new_items_list = item.get('new_items', [])
    updated_items_list = item.get('updated_items', []) 
    deleted_items_list = item.get('deleted_items', [])
    
    diff_tree = f"ADDED: {new_items_list}\nUPDATED: {updated_items_list}\nDELETED: {deleted_items_list}"
    
    user_input_content = f"""
    Input:
    Current Observation: {current_obs}
    Current Action: {current_act}
    Key Changes (Diff Tree): {diff_tree}
    """
    
    max_retries = 3
    for attempt in range(max_retries):
        try:
            model = genai.GenerativeModel('models/gemini-2.5-flash')
            response = model.generate_content(
                f"{SYSTEM_PROMPT}\n\n{user_input_content}"
            )
            annotation = response.text
            
            return {
                "original_data_index": i,
                "current_observation": current_obs,
                "current_action": current_act,
                "diff_tree_raw": diff_tree,
                "generated_annotation": annotation 
            }
            
        except Exception as e:
            if attempt == max_retries - 1:
                print(f"\nError on item {i}: {e}")
                return None
            time.sleep(5)
    
    return None

# 使用线程池并行处理
max_workers = 10  # 根据API限制调整，Gemini通常可以设置较高
with open(output_file, 'a', encoding='utf-8') as f:
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # 提交所有任务
        futures = {
            executor.submit(annotate_single_item, i, ds1[i]): i 
            for i in range(7903, len(ds1))
        }
        time.sleep(0.07) 
        # 处理完成的任务
        for future in tqdm(as_completed(futures), total=len(futures), desc="Annotating"):
            result = future.result()
            if result:
                f.write(json.dumps(result, ensure_ascii=False) + "\n")
                f.flush()

Annotating: 100%|██████████| 6841/6841 [1:47:17<00:00,  1.06it/s]  


In [1]:
import pandas as pd
df = pd.read_json("./sft_annotations_gemini.jsonl", lines=True)

df.to_csv("./sft_annotations_gemini.csv", index=False, encoding='utf-8-sig')

In [None]:
import pandas as pd

df = pd.read_parquet('sft_train.parquet')

df['length'] = df['prompt'].str.len() + df['response'].str.len()

print("长度统计:")
print(df['length'].describe())

: 

In [None]:
df['length'] = df['prompt'].str.len() + df['response'].str.len()