In [3]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("autodl-tmp/gemma-2b-it", trust_remote_code=True)
# model = AutoModelForCausalLM.from_pretrained("autodl-tmp/Meta-Llama-3-8B-Instruct", trust_remote_code=True)

In [4]:
# zero-shot
instruction = ( 
    "You are a web navigation intelligence who interacts with webpage environments to achieve human user intent.\n"
    "You always generate the next ACTION based on the user's INTENT, current cleaned webpage HTML and ACTION_HISTORY sequence which recording the actions that have been performed.\n\n"
    "Given HTML and INTENT and ACTION_HISTORY, you should\n"
    "(1) Rely on your HTML code comprehension to analyze and understand what elements are on the current page.\n"
    "(2) Depend on your reasoning skills to parse the user's INTENT and infer the next action that should be taken in conjunction with the historical trajectory ACTION_HISTORY.\n"
    "(3) Select an element carefully from HTML code to interact with, thus bringing the goal closer to completion.\n\n"
    "Your output format should be strictly as follows\n"
    "Operation: ... (should be CLICK or TYPE)\n"
    "Value: ... (optional textual value for the operation TYPE)\n"
    "ID: ... (unique id number for the element to click or type into)\n\n"
    # "Thought: ... (A paragraph explaining why you chose this element to interact with, no more than 50 words)"
    "Now, begin!"
)

user_input_template = "INTENT:\n{intent}\n\nHTML:\n{html}\n\nACTION_HISTORY:\n{action_history}"
output_template = "Operation: {op}\nValue: {value}\nID: {id}"

len(tokenizer(instruction)["input_ids"])

201

In [5]:
import json
import datasets

import random
from tqdm import tqdm

import lxml
from lxml import etree
from dom_utils import prune_tree

train_dataset = []

for ID in range(11):
    with open(f"autodl-tmp/train_dataset/train/train_{ID}.json") as file:
        data = json.load(file)

    pos_candidate_na = 0
    total_dataset_num = 0
    large_token_num = 0
    
    for dat in tqdm(data):
        intent = dat["confirmed_task"] # + f"(domain {dat['subdomain']})"
        action_history_all = dat["action_reprs"]
        annotation_id = dat['annotation_id']
        
        for index, d in enumerate(dat["actions"]):
            cleaned_html = d["cleaned_html"]
            action_uid = d['action_uid']
            sample_id = f"{annotation_id}_{action_uid}"

            if len(d["pos_candidates"]) == 0:
                pos_candidate_na += 1
                continue 
                        
            gt = d["pos_candidates"][0]["backend_node_id"]
            random.shuffle(d["neg_candidates"])
            neg_candidates_pool = d["neg_candidates"][:20] # 10 or 50
            candidate_ids = [gt] + [c["backend_node_id"] for c in neg_candidates_pool]

            dom_tree = lxml.etree.fromstring(cleaned_html)
            dom_tree = prune_tree(dom_tree, candidate_ids)
            html = lxml.etree.tostring(dom_tree, pretty_print=True, method="html", encoding='unicode')
            html = html.replace("backend_node_id", "id")

            action_history = action_history_all[:index]
            token_num = len(tokenizer(html)["input_ids"]) + len(tokenizer(intent)["input_ids"]) + len(tokenizer(str(action_history))["input_ids"]) + 20
            
            if token_num < 4000: # 8000
                op = d["operation"]["op"]
                value = d["operation"]["value"]
                chosen_answer_ = output_template.format(op=op, value=value, id=d["pos_candidates"][0]["backend_node_id"])
                
                random.shuffle(neg_candidates_pool)
                rand_neg_candidates = neg_candidates_pool[:3]
                
                for c in rand_neg_candidates: # 1:3 proportion
                    if op != "CLICK" and random.uniform(0, 1) < 0.33: # 1/3 for type/select -> click
                        rejected_answer_ = output_template.format(op="CLICK", value="", id=c["backend_node_id"])
                    else:
                        rejected_answer_ = output_template.format(op=op, value=value, id=c["backend_node_id"])
                        
                    instruction_ = instruction
                    input_ = user_input_template.format(intent=intent, html=html, action_history=action_history)
                    output_ = [chosen_answer_, rejected_answer_]
                    
                    total_dataset_num += 1
                    train_dataset.append({
                        "instruction": instruction_,
                        "input": input_,
                        "output": output_
                    })
                    
            else:
                large_token_num += 1
                # print("too large token_num:", token_num)
            
    print(ID, pos_candidate_na, total_dataset_num)

print("too large token_num:", large_token_num)

with open("/root/data/mind2web_dpo_train_50_gemma.json", "w") as file:
    json.dump(train_dataset, file, indent=4)

100%|██████████| 100/100 [00:25<00:00,  3.99it/s]


0 31 1989


100%|██████████| 100/100 [00:25<00:00,  3.89it/s]


1 52 2106


100%|██████████| 100/100 [00:22<00:00,  4.44it/s]


2 35 1533


100%|██████████| 100/100 [00:26<00:00,  3.74it/s]


3 56 1929


100%|██████████| 100/100 [00:26<00:00,  3.84it/s]


4 32 1818


100%|██████████| 100/100 [00:26<00:00,  3.80it/s]


5 27 1875


100%|██████████| 100/100 [00:22<00:00,  4.42it/s]


6 49 1854


100%|██████████| 100/100 [00:24<00:00,  4.03it/s]


7 51 2073


100%|██████████| 100/100 [00:23<00:00,  4.24it/s]


8 47 2019


100%|██████████| 100/100 [00:26<00:00,  3.78it/s]


9 30 2007


100%|██████████| 9/9 [00:01<00:00,  7.84it/s]


10 3 129
too large token_num: 3


In [5]:
import pickle
with open("data/scores_all_data.pkl", 'rb') as file:
    scores = pickle.load(file)

candidate_scores = scores["scores"]
candidate_ranks = scores["ranks"]

In [7]:
import json
import datasets

import random
from tqdm import tqdm

import lxml
from lxml import etree
from dom_utils import prune_tree

test_dataset_names = ["test_website", "test_task", "test_domain"]
test_dataset_counts = [2, 3, 10]

for test_dataset_name, test_dataset_count in zip(test_dataset_names, test_dataset_counts):
    print(f"generating {test_dataset_name} dpo dataset")
    test_dataset = []
    
    for ID in range(test_dataset_count):
        with open(f"autodl-tmp/test_dataset/{test_dataset_name}/{test_dataset_name}_{ID}.json") as file:
            data = json.load(file)
    
        pos_candidate_na = 0
        total_dataset_num = 0
        large_token_num = 0
    
        for dat in tqdm(data):
            intent = dat["confirmed_task"] # + f"(domain {dat['subdomain']})"
            action_history_all = dat["action_reprs"]
            annotation_id = dat['annotation_id']
        
            for index, d in enumerate(dat["actions"]):
                cleaned_html = d["cleaned_html"]
                action_uid = d['action_uid']
                sample_id = f"{annotation_id}_{action_uid}"
    
                if len(d["pos_candidates"]) == 0:
                    pos_candidate_na += 1
                    continue

                candidate_ids = []
                for candidates in [d["pos_candidates"], d["neg_candidates"]]:
                    for candidate in candidates:
                        candidate_id = candidate["backend_node_id"]
                        rank = candidate_ranks[sample_id][candidate_id]
                        if rank <= 50: # 10 or 50
                            candidate_ids.append(candidate_id)

                dom_tree = lxml.etree.fromstring(cleaned_html)
                dom_tree = prune_tree(dom_tree, candidate_ids)
                html = lxml.etree.tostring(dom_tree, pretty_print=True, method="html", encoding='unicode')
                html = html.replace("backend_node_id", "id")
    
                action_history = action_history_all[:index]
                token_num = len(tokenizer(html)["input_ids"]) + len(tokenizer(intent)["input_ids"]) + len(tokenizer(str(action_history))["input_ids"]) + 20
                
                if token_num < 5000:
                    op = d["operation"]["op"]
                    value = d["operation"]["value"]
                    chosen_answer_ = output_template.format(op=op, value=value, id=d["pos_candidates"][0]["backend_node_id"])
                    
                    instruction_ = instruction
                    input_ = user_input_template.format(intent=intent, html=html, action_history=action_history)
                    output_ = chosen_answer_
                    
                    total_dataset_num += 1
                    test_dataset.append({
                        "instruction": instruction_,
                        "input": input_,
                        "output": output_
                    })
                    
                else:
                    large_token_num += 1
    
        print(ID, pos_candidate_na, total_dataset_num)

    print("too large token_num:", large_token_num)
        
    with open(f"data/mind2web_dpo_{test_dataset_name}_ranked_50.json", "w") as file:
        json.dump(test_dataset, file, indent=4)

generating test_website dpo dataset


100%|██████████| 100/100 [00:48<00:00,  2.07it/s]


0 38 730


100%|██████████| 77/77 [00:35<00:00,  2.16it/s]


1 21 484
too large token_num: 49
generating test_task dpo dataset


100%|██████████| 100/100 [00:59<00:00,  1.67it/s]


0 69 735


100%|██████████| 100/100 [00:50<00:00,  1.98it/s]


1 34 709


100%|██████████| 52/52 [00:24<00:00,  2.12it/s]


2 16 417
too large token_num: 32
generating test_domain dpo dataset


100%|██████████| 100/100 [00:30<00:00,  3.27it/s]


0 42 598


100%|██████████| 100/100 [00:30<00:00,  3.29it/s]


1 25 632


100%|██████████| 100/100 [00:34<00:00,  2.92it/s]


2 34 528


100%|██████████| 100/100 [00:35<00:00,  2.83it/s]


3 51 671


100%|██████████| 100/100 [00:28<00:00,  3.54it/s]


4 34 556


100%|██████████| 100/100 [00:23<00:00,  4.31it/s]


5 35 524


100%|██████████| 100/100 [00:25<00:00,  3.88it/s]


6 30 570


100%|██████████| 100/100 [00:29<00:00,  3.43it/s]


7 42 586


100%|██████████| 100/100 [00:27<00:00,  3.61it/s]


8 26 645


100%|██████████| 12/12 [00:03<00:00,  3.38it/s]


9 2 64
too large token_num: 1
