# Simple MCTS
This notebook investigates using MCTS for a single static website where multiple actions have to be done to achieve an answer.

The task is to order a macbook with certain configurations. This website was chosen over others because it does not change upon selecting certain elements, which substantially simplifies testing as a website can simply be cached.

The start domain is [this page](https://www.apple.com/shop/buy-mac/macbook-pro/14-inch-space-gray-apple-m3-chip-with-8-core-cpu-and-10-core-gpu-8gb-memory-512gb) with the prompt "order a macbook pro 14 with 24 gb, 2 tb, fast charging and all available software"

Command to run this:
```
python run_demo.py --task_name openended --model_name openai/gpt-4o-mini --start_url https://www.apple.com/shop/buy-mac/macbook-pro/14-inch-space-gray-apple-m3-chip-with-8-core-cpu-and-10-core-gpu-8gb-memory-512gb
```

## Setup and load cached website

In [477]:
website = "https://www.apple.com/shop/buy-mac/macbook-pro/14-inch-space-gray-apple-m3-chip-with-8-core-cpu-and-10-core-gpu-8gb-memory-512gb"

human_prompt = "order a macbook pro 14 with 24 gb ram, 2 tb, fast charging and all available software"

In [1]:
# ideal actions are:
ideal_actions = [
    "click(1007)",  # Select 24GB unified memory
    "click(1038)",  # Select 2TB SSD storage
    "click(1061)",  # Select 96W USB-C Power Adapter
    "click(1112)",  # Select Final Cut Pro software
    "click(1135)",  # Select Logic Pro software
    "click(1209)",  # Add to bag
]
updated_ideal_actions = [
    'click(1008)',
    'click(1039)',
    'click(1062)',
    'click(1113)',
    'click(1136)',
    'click(1209)']

# ideal_actions_str = "\n".join(ideal_actions)

ideal_actions = updated_ideal_actions

In [2]:
txt_file = "../output_example_2.txt"
with open(txt_file, 'r') as file:
    lines = file.readlines()

system_messages = []
prompts = []
actions = []

current_section = None

for line in lines:
    if line.startswith("System Message:"):
        current_section = "System Message"
    elif line.startswith("Prompt:"):
        current_section = "Prompt"
    elif line.startswith("Action:"):
        current_section = "Action"
    else:
        if current_section == "System Message":
            system_messages.append(line)
        elif current_section == "Prompt":
            prompts.append(line)
        elif current_section == "Action":
            actions.append(line)

system_prompt = system_messages[0].split("content='")[-1].strip()
base_prompt = prompts[0].split("content=\'")[-1].strip()
# ideal_actions = actions[:6]

In [3]:
import sys
sys.path.append("../demo_agent")
from agents.legacy.dynamic_prompting import Think, Memory, ActionSpace, Flags

flags=Flags(
    use_html=True,
    use_ax_tree=True,
    use_thinking=True,  # "Enable the agent with a memory (scratchpad)."
    use_error_logs=True,  # "Prompt the agent with the error logs."
    use_memory=False,  # "Enables the agent with a memory (scratchpad)."
    use_history=True,
    use_diff=False,  # "Prompt the agent with the difference between the current and past observation."
    use_past_error_logs=True,  # "Prompt the agent with the past error logs."
    use_action_history=True,  # "Prompt the agent with the action history."
    multi_actions=True,
    action_space="bid",
    use_abstract_example=True,  # "Prompt the agent with an abstract example."
    use_concrete_example=True,  # "Prompt the agent with a concrete example."
    use_screenshot=False,
    enable_chat=True,
    demo_mode="default",
)

think = Think(visible=lambda: flags.use_thinking)
memory = Memory(visible=lambda: flags.use_memory)
action_space = ActionSpace(flags)

def parser(text_answer):
    ans_dict = {}
    try:
        ans_dict.update(think._parse_answer(text_answer))
        ans_dict.update(memory._parse_answer(text_answer))
        ans_dict.update(action_space._parse_answer(text_answer))
    except Exception as e:
        ans_dict['action'] = None
        ans_dict['think'] = None

    return ans_dict, True, ""

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from langchain_openai import ChatOpenAI
from langchain.schema import HumanMessage, SystemMessage


with open("../openai_key.txt", "r") as file:
    api_key = file.read().strip()

model = ChatOpenAI(
            model_name="gpt-4o-mini",
            temperature=0.1,
            max_tokens=2_000,
            api_key=api_key
        ).bind(logprobs=True)

## Preliminaries: clean html and build prompts

In [5]:
import re
from html.parser import HTMLParser

class HTMLCleaner(HTMLParser):
    def __init__(self):
        super().__init__()
        self.output = []
        self.tag_stack = []
        self.current_content = []
        self.interactive_elements = {
            'a', 'button', 'input', 'select', 'textarea', 'label', 'fieldset',
            'legend', 'datalist', 'output', 'option', 'optgroup'
        }

    def is_interactive(self, tag, attrs):
        if tag.lower() in self.interactive_elements:
            return True
        return any(attr[0] == 'onclick' for attr in attrs)

    def handle_starttag(self, tag, attrs):
        if tag.lower() == 'img':
            return
        
        is_interactive = self.is_interactive(tag, attrs)
        bid_attr = next((attr for attr in attrs if attr[0] == 'bid'), None)
        
        if bid_attr and not is_interactive:
            bid_attr = None

        self.tag_stack.append((tag, bid_attr, len(self.output)))
        self.current_content.append([])

    def handle_endtag(self, tag):
        if tag.lower() == 'img':
            return

        if self.tag_stack and self.tag_stack[-1][0] == tag:
            start_tag, bid_attr, start_index = self.tag_stack.pop()
            content = ''.join(self.current_content.pop()).strip()

            if content:
                if bid_attr:
                    self.output.insert(start_index, f'<{start_tag} bid="{bid_attr[1]}">')
                else:
                    self.output.insert(start_index, f'<{start_tag}>')
                self.output.append(content)
                self.output.append(f'</{tag}>')

            if self.current_content:
                self.current_content[-1].extend(self.output[start_index:])
                del self.output[start_index:]

    def handle_data(self, data):
        normalized_data = re.sub(r'(\\n|\n|\r)+', '', data)
        normalized_data = re.sub(r'\s+', ' ', normalized_data)
        if self.current_content:
            self.current_content[-1].append(normalized_data)
        else:
            self.output.append(normalized_data)

def clean_html(html_content):
    html_content = html_content.replace('\\n', '\n')
    cleaner = HTMLCleaner()
    cleaner.feed(html_content)
    return ''.join(cleaner.output).strip()

In [6]:
html = base_prompt.split("# ")[4]
c_html = clean_html(html)
len(html), len(c_html)

(288727, 32333)

In [7]:
simple_action_space = """Action space:\\n\\1 type of actions are available.\\n\\nclick(bid: int)\\n    Description: Click an element.\\n    Examples:\\n        click(\\\'151\\\')\\n\\n    Multiple actions can be provided at once, but will be executed sequentially without any feedback from the page.\\nExample:\\nfill(\\\'a12\\\', \\\'example with "quotes"\\\')\\nclick(\\\'a51\\\')\\nclick(\\\'48\\\', button=\\\'middle\\\', modifiers=[\\\'Shift\\\'])\\n\\n"""

def build_action_prompt(base_prompt, actions, action_thoughts, thoughts):
    base_splits = base_prompt.split("# ")
    # change html obs
    html = base_splits[4][7:]
    base_splits[4] = "HTML:" + clean_html(html)

    hist_split = base_splits[6]
    hist_instruction = hist_split[:-4]
    hist_end = hist_split[-4:]
    act_thought_list = [a + " #" + t for a, t in zip(actions, action_thoughts)]
    new_hist = hist_instruction + " Actions: [" + ", ".join(act_thought_list) + "]; Thoughts [" + ", ".join(thoughts) + "]" + hist_end
    base_splits[6] = new_hist
    base_splits[7] = simple_action_space
    new_prompt = "# ".join(base_splits)
    new_prompt += " # Final Instruction: Given that the last actions are: " + ", ".join(actions) + ", what would you do next? Do not pick an action that you already tried."
    return new_prompt

In [486]:
# ground truth
actions = []
action_thoughts = []
thoughts = []

new_prompt = build_action_prompt(base_prompt, actions, action_thoughts, thoughts)
chat_messages = [
    SystemMessage(content=system_prompt),
    HumanMessage(content=new_prompt+"Think about what to do and then predict all actions at once to complete the task."),
]
out = model.invoke(chat_messages)
ans_dict = parser(out.content)

In [487]:
print(f"Predicted:\n{ans_dict[0]['action']}")
print(f"\nIdeal:\n{ideal_actions}")

Predicted:
fill('memory_aos_phantom_z1c8_065_cg1l_3', '')  # Select 24GB RAM
fill('hard_drivesolid_state_drive_aos_phantom_z1c8_065_cg1p_3', '')  # Select 2TB SSD Storage
click('sw_final_cut_pro_z1c8_065_cg37_2')  # Select Final Cut Pro
click('sw_logic_pro_z1c8_065_cg39_2')  # Select Logic Pro
click('1208')  # Click add to cart button

Ideal:
click('1007') # Select 24GB unified memory
click('1038') # Select 2TB SSD storage
click('1061') # Select 96W USB-C Power Adapter
click('1112') # Select Final Cut Pro software
click('1135') # Select Logic Pro software
click('1209') # Add to bag


In [8]:
from typing import List
from langchain_core.pydantic_v1 import BaseModel, Field

class Action(BaseModel):
    think: str = Field(description="The goal of the next step to eventually accomplish the task")
    action: str = Field(description="The single action to accomplish the task. Should be click(<int>)")

class Plan(BaseModel):
    think: str = Field(description="The overall goal and thoughts to accomplish the task")
    plan: List[Action] = Field(description="Possible actions to take next to get closer to the goal. Should have length of at least 7")

structured_llm = model.with_structured_output(Plan, include_raw=True)
greedy_llm = model.with_structured_output(Action, include_raw=True)

In [9]:
def expand_predict(action_prompt):
    chat_messages = [
        SystemMessage(content=system_prompt),
        HumanMessage(content=action_prompt),
    ]
    answer = structured_llm.invoke(chat_messages)
    return answer

def greedy_predict(action_prompt):
    chat_messages = [
        SystemMessage(content=system_prompt),
        HumanMessage(content=action_prompt),
    ]
    answer = greedy_llm.invoke(chat_messages)
    return answer

In [468]:
actions = []
action_thoughts = []
thoughts = []

new_prompt = build_action_prompt(base_prompt, actions, action_thoughts, thoughts)

expand_answer = expand_predict(new_prompt)
greedy_answer = greedy_predict(new_prompt)

print([p.action for p in expand_answer.plan])
print(greedy_answer.action)

## Build MCTS

In [964]:
class Node:
    def __init__(self, actions, action_thoughts, thoughts, parent=None, depth=0):
        self.actions = actions
        self.action_thoughts = action_thoughts
        self.thoughts = thoughts
        self.parent = parent
        self.children = []
        self.visits = 1  # add 1 to avoid division by 0 and thus never have inf value
        self.value = 0.01
        self.depth = depth

In [957]:
# define tree helper functions

def print_tree(node, indent=""):
    if len(node.actions) > 0:
        # print(f"{indent}Thoughts: {node.thoughts[-1]}")
        # print(f"{indent}Action Thoughts: {node.action_thoughts[-1]}")
        print(f"{indent}Action: {node.actions[-1]}")
        print(f"{indent}Value: {node.value}")
    print(f"{indent}Children: {len(node.children)}")
    for child in node.children:
        print_tree(child, indent + "  ")

# create all possible trajectories of actions from tree
def get_trajectories(node):
    if not node.children:
        return [[node.actions]]
    
    trajectories = []
    for child in node.children:
        child_trajectories = get_trajectories(child)
        for trajectory in child_trajectories:
            trajectories.append([node.actions] + trajectory)
    
    trajectories = [t[-1] for t in trajectories]
    return trajectories

In [908]:
import copy
import math
import random

def select(node: Node, alpha: float = 1.0) -> Node:
    while node.children:
        if len(node.children) < 1:
            return node
        node = ucb_select(node, alpha)
    return node

def ucb_select(node: Node, alpha: float = 1.0) -> Node:
    scores = [ucb_score(c, node.visits, alpha) for c in node.children]
    max_score = max(scores)
    max_children = [c for c, s in zip(node.children, scores) if s == max_score]
    return random.choice(max_children)

def ucb_score(node: Node, total_visits: int, alpha: float = 1.0) -> float:
    if node.visits == 0:
        return float('inf')
    return (node.value / node.visits) + alpha * math.sqrt(2 * math.log(total_visits) / node.visits)

def expand(node: Node) -> Node:
    # this should be taking an action and getting to a new state
    answer = expand_predict(build_action_prompt(base_prompt, root.actions, root.action_thoughts, root.thoughts))
    parsed_answer = answer['parsed']
    new_thought = parsed_answer.think
    for action in parsed_answer.plan:
        new_actions = node.actions + [action.action]
        new_action_thoughts = node.action_thoughts + [action.think]
        new_thoughts = node.thoughts + [new_thought]
        child = Node(copy.deepcopy(new_actions),
                     copy.deepcopy(new_action_thoughts),
                     copy.deepcopy(new_thoughts),
                     parent=node,
                     depth=node.depth+1)
        node.children.append(child)

def backpropagate(node: Node, reward: float):
    while node:
        node.visits += 1
        node.value += reward
        node = node.parent

def verify_success(actions):
    if actions == ideal_actions:
        return True
    return False

def best_child(node: Node) -> Node:
    max_visits = max(c.visits for c in node.children)
    candidates = [c for c in node.children if c.visits == max_visits]
    return max(candidates, key=lambda c: c.value)

0.01

In [924]:
def get_t_value(trajectory):
    # FIXME this is a hacky way to handle the fact that the bids are not assigned nicely
    for i, t in enumerate(trajectory):
        if t == 'click(1006)':
            trajectory[i] = 'click(1007)'
        if t == 'click(1039)':
            trajectory[i] = 'click(1038)'

    correct_actions = 0.
    for a, b in zip(trajectory, ideal_actions):
        if a == b:
            correct_actions += 1
        else:
            break

    score = correct_actions / len(ideal_actions)

    return score

t = ['click(1007)', 'click(1039)', 'click(1061)', 'click(1112)\nclick(1136)']
get_t_value(ideal_actions), get_t_value(t)

(1.0, 0.5)

In [961]:
# greedy baseline
max_depth = 6

actions = []
action_thoughts = []
thoughts = []

root = Node(actions, action_thoughts, thoughts)

for _ in range(max_depth):
    out = greedy_predict(build_action_prompt(base_prompt, root.actions, root.action_thoughts, root.thoughts))
    answer = out['parsed']
    root.action_thoughts += [answer.think]
    root.actions += [answer.action]
print(root.actions, get_t_value(root.actions))

['click(1007)', 'click(1038)', 'click(1038)', 'click(1038)', 'click(1061)', 'click(1062)'] 0.3333333333333333


In [971]:
from tqdm import tqdm

# get initial options
max_depth = 6
max_iters = 200
alpha = 0.7

actions = []
action_thoughts = []
thoughts = []

root = Node(actions, action_thoughts, thoughts)

In [972]:
for iter_idx in tqdm(range(max_iters)):
    # selection
    # - TODO figure out how to do this proportional to logprobs
    selected_node = select(root, alpha)

    # expansion + simulation
    remaining_depth = max_depth - selected_node.depth

    value = get_t_value(selected_node.actions)

    if remaining_depth > 0:
        for _ in range(remaining_depth):
            expand(selected_node)
            selected_node = select(selected_node)
            value = get_t_value(selected_node.actions)
            if value == 1.0:
                break

    # backpropagation
    backpropagate(selected_node, value)

100%|██████████| 200/200 [43:11<00:00, 12.96s/it]


In [973]:
best_c = copy.deepcopy(root)
while best_c.children:
    print([c.actions[-1] for c in best_c.children])
    print([c.value for c in best_c.children])
    print([c.visits for c in best_c.children])
    best_c = best_child(best_c)
print(f"\nBest trajectory:{best_c.actions}\nScore: {get_t_value(best_c.actions)}")

['click(1007)', 'click(1038)', 'click(1061)', 'click(1112)', 'click(1209)', 'click(1204)', 'click(1204)']
[15.509999999999993, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01]
[72, 23, 22, 23, 22, 23, 22]
['click(1007)', 'click(1038)', 'click(1209)', 'click(1100)', 'click(1118)', 'click(1141)', 'click(1204)']
[1.5100000000000002, 6.676666666666665, 1.5100000000000002, 1.5100000000000002, 1.5100000000000002, 1.3433333333333335, 1.5100000000000002]
[10, 19, 10, 10, 10, 9, 10]
['click(1007)', 'click(1038)', 'click(1061)', 'click(1100)', 'click(1112)', 'click(1135)', 'click(1209)']
[0.6766666666666666, 0.6766666666666666, 2.01, 0.6766666666666666, 0.6766666666666666, 1.01, 1.01]
[3, 3, 5, 3, 3, 4, 4]
['click(1007)', 'click(1038)', 'click(1061)', 'click(1112)', 'click(1209)', 'click(1204)', 'click(1204)']
[0.51, 1.01, 0.51, 0.01, 0.01, 0.01, 0.01]
[2, 3, 2, 1, 1, 1, 1]
['click(1006)', 'click(1039)', 'click(1061)', 'click(1072)', 'click(1112)', 'click(1135)', 'click(1209)']
[0.01, 0.01, 0.01, 0.01, 0.01,