diff --git a/visual-tree-search-backend/README.md b/visual-tree-search-backend/README.md index 63961cb..9ddb92d 100644 --- a/visual-tree-search-backend/README.md +++ b/visual-tree-search-backend/README.md @@ -82,4 +82,12 @@ python run_demo_treesearch_async.py \ --goal "search running shoes, click on the first result" \ --iterations 3 \ --max_depth 3 +``` + +## 7. Add LATS agent +* test run_demo_treesearch_async.py +* test web socket +``` +uvicorn app.main:app --host 0.0.0.0 --port 3000 +python test/test-tree-search-ws-lats.py ``` \ No newline at end of file diff --git a/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/__init__.py b/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/lats_agent.py b/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/lats_agent.py new file mode 100644 index 0000000..8fe68db --- /dev/null +++ b/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/lats_agent.py @@ -0,0 +1,776 @@ +"""Language-based Action Tree Search (LATS) Agent implementation.""" + +import time +from typing import Any, Optional, Tuple, List +import os +from openai import OpenAI +from datetime import datetime +import aiohttp +from dotenv import load_dotenv +load_dotenv() + +from .lats_node import LATSNode, Observation +from ...core_async.config import AgentConfig + +from ...webagent_utils_async.action.highlevel import HighLevelActionSet +from ...webagent_utils_async.utils.playwright_manager import AsyncPlaywrightManager, setup_playwright +from .tree_vis import RED, better_print, print_trajectory, collect_all_nodes, GREEN, RESET, print_entire_tree +from .trajectory_score import create_llm_prompt, score_trajectory_with_openai +from ...replay_async import generate_feedback, playwright_step_execution, locate_element_from_action +from ...webagent_utils_async.browser_env.observation import extract_page_info, observe_features +from ...webagent_utils_async.action.prompt_functions import generate_actions_with_observation +from ...webagent_utils_async.evaluation.feedback import generate_feedback_with_screenshot +from ...webagent_utils_async.utils.utils import urls_to_images + + +from ...webagent_utils_async.utils.utils import parse_function_args, locate_element +from ...evaluation_async.evaluators import goal_finished_evaluator +from ...webagent_utils_async.action.prompt_functions import extract_top_actions +from ...webagent_utils_async.browser_env.observation import extract_page_info +from .lats_node import LATSNode +from .tree_vis import better_print, print_trajectory, collect_all_nodes, GREEN, RESET, print_entire_tree +from .trajectory_score import create_llm_prompt, score_trajectory_with_openai +from ...webagent_utils_async.action.utils import execute_action +from ...webagent_utils_async.action.prompt_functions import extract_top_actions, is_goal_finished +from ...webagent_utils_async.browser_env.observation import extract_page_info +from ...webagent_utils_async.evaluation.feedback import capture_post_action_feedback + +openai_client = OpenAI() + +class LATSAgent: + """ + Language-based Action Tree Search Agent implementation. + + This agent uses MCTS-like tree search to find optimal action sequences for web navigation tasks. + + Attributes: + starting_url (str): The initial URL to start from + model_name (str): Name of the language model to use + goal (str): The goal state to achieve + playwright_manager (PlaywrightManager): Manager for browser automation + num_simulations (int): Number of simulations to run + exploration_weight (float): Exploration vs exploitation trade-off parameter + """ + + def __init__( + self, + starting_url: str, + messages: list[dict[str, Any]], + goal: str, + images: list, + playwright_manager: AsyncPlaywrightManager, + config: AgentConfig, + ): + """Initialize the LATS Agent.""" + # no action grounding model, just one step to geneate both action natural language description and action at the same time + self.starting_url = starting_url + self.goal = goal + self.image_urls = images + self.images = urls_to_images(self.image_urls) + + self.messages = messages + if len(images) == 0: + self.messages.append({"role": "user", "content": f"The goal is: {self.goal}"}) + else: + self.messages.append({"role": "user", "content": f"The goal is: {self.goal}"}) + + self.playwright_manager = playwright_manager + + self.config = config + + # set bid, only click, fill, hoover, drag and draw + self.agent_type = ["bid"] + self.action_set = HighLevelActionSet( + subsets=self.agent_type, strict=False, multiaction=False, demo_mode="default" + ) + self.root_node = LATSNode( + natural_language_description=None, + action=None, + prob=None, + element=None, + goal=self.goal, + parent=None + ) + self.goal_finished = False + self.result_node = None + self.reset_url = os.environ["ACCOUNT_RESET_URL"] + + async def run(self, websocket=None) -> list[LATSNode]: + """ + Run the LATS search and return the best path found. + + Args: + websocket: Optional WebSocket connection for sending updates + + Returns: + list[LATSNode]: Best path from root to terminal node + """ + if websocket: + await websocket.send_json({ + "type": "search_status", + "status": "started", + "message": "Starting LATS search", + "timestamp": datetime.utcnow().isoformat() + }) + + best_node = await self.lats_search(websocket) + print_trajectory(best_node) + + if websocket: + await websocket.send_json({ + "type": "search_complete", + "status": "success" if best_node.reward == 1 else "partial_success", + "score": best_node.reward, + "path": best_node.get_trajectory(), + "timestamp": datetime.utcnow().isoformat() + }) + + return best_node.get_trajectory() + + async def lats_search(self, websocket=None) -> LATSNode: + """ + Perform the main LATS search algorithm. + + Args: + websocket: Optional WebSocket connection for sending updates + + Returns: + LATSNode: Best terminal node found + """ + print(f"") + print(f"{GREEN}START SEARCH{RESET}") + + terminal_nodes = [] + + for i in range(self.config.iterations): + if websocket: + await websocket.send_json({ + "type": "iteration_start", + "iteration": i + 1, + "timestamp": datetime.utcnow().isoformat() + }) + + print(f"") + print(f"") + print(f"Iteration {i + 1}...") + + # Step 1: Selection with websocket update + if websocket: + await websocket.send_json({ + "type": "step_start", + "step": "selection", + "iteration": i + 1, + "timestamp": datetime.utcnow().isoformat() + }) + + node = self.select_node(self.root_node) + + if node is None: + print("All paths lead to terminal nodes with reward 0. Ending search.") + break + + print(f"{GREEN}Tree:{RESET}") + better_print(node=self.root_node, selected_node=node) + print(f"") + + # Step 2: Expansion with websocket update + if websocket: + await websocket.send_json({ + "type": "step_start", + "step": "expansion", + "iteration": i + 1, + "timestamp": datetime.utcnow().isoformat() + }) + + await self.expand_node(node, websocket) + + while node is not None and node.is_terminal and not self.goal_finished: + print(f"Depth limit node found at iteration {i + 1}, reselecting...") + node = self.select_node(self.root_node) + if node is not None: + await self.expand_node(node, websocket) + + if node is None: + # all the nodes are terminal, stop the search + print(f"{RED}All nodes are terminal, stopping search{RESET}") + break + + if self.goal_finished: + print(f"{RED}Goal finished, stopping search{RESET}") + break + + print(f"{GREEN}Tree:{RESET}") + better_print(self.root_node) + print(f"") + + # Step 3: Evaluation + print(f"") + print(f"{GREEN}Step 3: evaluation{RESET}") + await self.evaluate_node(node) + + print(f"{GREEN}Tree:{RESET}") + better_print(self.root_node) + print(f"") + + # Step 4: Simulation + print(f"{GREEN}Step 4: simulation{RESET}") + # # Find the child with the highest value + ## always = 1 + reward, terminal_node = await self.simulate(max(node.children, key=lambda child: child.value), max_depth=self.config.max_depth, num_simulations=1) + terminal_nodes.append(terminal_node) + + if reward == 1: + return terminal_node + + # Step 5: Backpropagation + print(f"{GREEN}Step 5: backpropagation{RESET}") + self.backpropagate(terminal_node, reward) + print(f"{GREEN}Tree:{RESET}") + better_print(self.root_node) + print(f"") + + # Send tree update after each iteration + if websocket: + tree_data = self._get_tree_data() + await websocket.send_json({ + "type": "tree_update", + "tree": tree_data, + "timestamp": datetime.utcnow().isoformat() + }) + + # Find best node + all_nodes_list = collect_all_nodes(self.root_node) + all_nodes_list.extend(terminal_nodes) + + ## temp change: if reward is the same, choose the deeper node + best_child = max(all_nodes_list, key=lambda x: (x.reward, x.depth)) + + if best_child.reward == 1: + print("Successful trajectory found") + else: + print("Unsuccessful trajectory found") + await self.playwright_manager.close() + + return best_child if best_child is not None else self.root_node + + def select_node(self, node: LATSNode) -> Optional[LATSNode]: + """ + Select a node for expansion using UCT. + + Args: + node: Root node to start selection from + + Returns: + Optional[LATSNode]: Selected node or None if all paths exhausted + """ + if node.is_terminal: + return None + return node.get_best_leaf() + + async def expand_node(self, node: LATSNode, websocket=None) -> None: + """ + Expand a node by generating its children. + + Args: + node: Node to expand + websocket: Optional WebSocket connection for sending updates + """ + if websocket: + await websocket.send_json({ + "type": "node_expanding", + "node_id": id(node), + "timestamp": datetime.utcnow().isoformat() + }) + + children = await self.generate_children(node, websocket) + + for child in children: + node.add_child(child) + if websocket: + await websocket.send_json({ + "type": "node_created", + "node_id": id(child), + "parent_id": id(node), + "action": child.action, + "description": child.natural_language_description, + "timestamp": datetime.utcnow().isoformat() + }) + + if children and children[0].goal_finish_feedback.is_done: + self.set_goal_finished(children[0]) + if websocket: + await websocket.send_json({ + "type": "goal_finished", + "node_id": id(children[0]), + "timestamp": datetime.utcnow().isoformat() + }) + return + + node.check_terminal() + + async def evaluate_node(self, node: LATSNode) -> None: + """ + Evaluate a node using LLM scoring. + + Args: + node: Node to evaluate + + Returns: + float: Evaluation score + """ + scores = [] + print(f"{GREEN}-- total {len(node.children)} children to evaluate:{RESET}") + for i, child in enumerate(node.children): + print(f"{GREEN}--- evaluating child {i+1}...{RESET}") + if child.is_terminal: + score = 0 + else: + trajectory = child.get_trajectory() + prompt = create_llm_prompt(trajectory, self.goal) + result = score_trajectory_with_openai(prompt, openai_client, self.config.evaluation_model, child.observation.image) + score = result["overall_score"] + scores.append(score) + + for child, score in zip(node.children, scores): + child.value = score + child.reward = score + + async def simulate(self, node: LATSNode, max_depth: int = 2, num_simulations=1) -> tuple[float, LATSNode]: + """ + Perform a rollout simulation from a node. + + Args: + node: Starting node for rollout + max_depth: Maximum depth to simulate to + + Returns: + tuple[float, LATSNode]: (Score of the rollout, Terminal node reached) + """ + depth = node.depth + print("print the trajectory") + print_trajectory(node) + print("print the entire tree") + print_entire_tree(self.root_node) + return await self.rollout(node, max_depth=max_depth) + + async def send_completion_request(self, plan, depth, node, trajectory=[]): + print("print the trajectory") + print_trajectory(node) + print("print the entire tree") + print_entire_tree(self.root_node) + + if depth >= self.config.max_depth: + return trajectory, node + + context = await self.playwright_manager.get_context() + page = await self.playwright_manager.get_page() + # Extract page information + time.sleep(3) + page_info = await extract_page_info(page, fullpage=True, log_folder=self.config.log_folder) + updated_actions = await extract_top_actions( + trajectory, self.goal, self.images, page_info, self.action_set, openai_client, + features=["axtree"], elements_filter="som", branching_factor=self.config.branching_factor, + log_folder=self.config.log_folder, fullpage=True, + action_generation_model=self.config.action_generation_model, + action_grounding_model=self.config.action_grounding_model + ) + next_action = updated_actions[0] + retry_count = self.config.retry_count if hasattr(self.config, 'retry_count') else 1 # Default retries if not set + + for attempt in range(retry_count): + try: + # Convert action to Python code + code, function_calls = self.action_set.to_python_code(next_action["action"]) + + # Locate element + if len(function_calls) == 1: + for function_name, function_args in function_calls: + extracted_number = parse_function_args(function_args) + element = await locate_element(page, extracted_number) + next_action["element"] = element + + # Execute action + await execute_action(next_action, self.action_set, page, context, self.goal, page_info['interactive_elements'], + self.config.log_folder) + feedback = await capture_post_action_feedback(page, next_action, self.goal, self.config.log_folder) + trajectory.append({'action': next_action['action'], 'feedback': feedback}) + action_str = next_action["action"] + + print(f"The action is: {action_str} - The action result is: {feedback}") + + # Check if goal is finished + messages = [{"role": "system", "content": "The goal is {}, Is the overall goal finished?".format(self.goal)}] + for item in trajectory: + action = item['action'] + feedback = item['feedback'] + messages.append({"role": "user", "content": 'action is: {}'.format(action)}) + messages.append({"role": "user", "content": 'action feedback is: {}'.format(feedback)}) + + goal_finished = await is_goal_finished(messages, openai_client) + + new_node = LATSNode( + natural_language_description=next_action["natural_language_description"], + action=next_action["action"], + prob=next_action["prob"], + element=next_action["element"], + goal=node.goal, + parent=node + ) + + if goal_finished: + return trajectory, new_node + + return await self.send_completion_request(plan, depth + 1, new_node, trajectory) + + except Exception as e: + print(f"Attempt {attempt + 1} failed with error: {e}") + if attempt + 1 == retry_count: + print("Max retries reached. Skipping this step and retrying the whole request.") + # Retry the entire request from the same state + return await self.send_completion_request(plan, depth, node, trajectory) + + # If all retries and retries of retries fail, return the current trajectory and node + return trajectory, node + + + async def rollout(self, node: LATSNode, max_depth: int = 2)-> tuple[float, LATSNode]: + # Reset browser state + await self._reset_browser() + path = self.get_path_to_root(node) + + print("execute path") + # Execute path + + messages = [] + trajectory = [] + + for n in path[1:]: # Skip root node + success = await playwright_step_execution( + n, + self.goal, + self.playwright_manager, + is_replay=False, + log_folder=self.config.log_folder + ) + if not success: + return 0, n + if not n.feedback: + n.feedback = await generate_feedback( + self.goal, + n.natural_language_description, + self.playwright_manager, + ) + trajectory.append({ + "action": n.action, + "feedback": n.feedback + }) + ## call the prompt agent + print("current depth: ", len(path) - 1) + print("max depth: ", self.config.max_depth) + trajectory, node = await self.send_completion_request(self.goal, len(path) - 1, node=n, trajectory=trajectory) + print("print the trajectory") + print_trajectory(node) + print("print the entire tree") + print_entire_tree(self.root_node) + + page = await self.playwright_manager.get_page() + page_info = await extract_page_info(page, self.config.fullpage, self.config.log_folder) + + messages = [{"role": "user", "content": f"Action is: {n.action}"} for n in path[1:]] + goal_finished, confidence_score = goal_finished_evaluator( + messages, + openai_client, + self.goal, + page_info['screenshot'] + ) + print("evaluating") + + score = confidence_score if goal_finished else 0 + + return score, node + + def backpropagate(self, node: LATSNode, value: float) -> None: + """ + Backpropagate values through the tree. + + Args: + node: Current node to start backpropagation from + value: Value to propagate upwards + """ + while node: + node.visits += 1 + node.value = (node.value * (node.visits - 1) + value) / node.visits + node = node.parent + + async def _reset_browser(self, websocket=None) -> Optional[str]: + """Reset the browser to initial state and return the live browser URL if available.""" + await self.playwright_manager.close() + + ## reset account using api-based account reset + if self.config.account_reset: + if websocket: + await websocket.send_json({ + "type": "account_reset", + "status": "started", + "timestamp": datetime.utcnow().isoformat() + }) + + try: + # Use aiohttp instead of curl + async with aiohttp.ClientSession() as session: + headers = {'Connection': 'close'} # Similar to curl -N + async with session.get(self.reset_url, headers=headers) as response: + if response.status == 200: + data = await response.json() + print(f"Account reset successful: {data}") + if websocket: + await websocket.send_json({ + "type": "account_reset", + "status": "success", + "data": data, + "timestamp": datetime.utcnow().isoformat() + }) + else: + error_msg = f"Account reset failed with status {response.status}" + print(error_msg) + if websocket: + await websocket.send_json({ + "type": "account_reset", + "status": "failed", + "reason": error_msg, + "timestamp": datetime.utcnow().isoformat() + }) + + except Exception as e: + print(f"Error during account reset: {e}") + if websocket: + await websocket.send_json({ + "type": "account_reset", + "status": "failed", + "reason": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + + try: + # Create new playwright manager + self.playwright_manager = await setup_playwright( + storage_state=self.config.storage_state, + headless=self.config.headless, + mode=self.config.browser_mode + ) + page = await self.playwright_manager.get_page() + live_browser_url = None + if self.config.browser_mode == "browserbase": + live_browser_url = await self.playwright_manager.get_live_browser_url() + session_id = await self.playwright_manager.get_session_id() + else: + session_id = None + live_browser_url = None + await page.goto(self.starting_url, wait_until="networkidle") + + # Send success message if websocket is provided + if websocket: + if self.config.storage_state: + await websocket.send_json({ + "type": "browser_setup", + "status": "success", + "message": f"Browser successfully initialized with storage state file: {self.config.storage_state}", + "live_browser_url": live_browser_url, + "session_id": session_id, + "timestamp": datetime.utcnow().isoformat() + }) + else: + await websocket.send_json({ + "type": "browser_setup", + "status": "success", + "message": "Browser successfully initialized", + "live_browser_url": live_browser_url, + "session_id": session_id, + "timestamp": datetime.utcnow().isoformat() + }) + + return live_browser_url, session_id + except Exception as e: + print(f"Error setting up browser: {e}") + if websocket: + await websocket.send_json({ + "type": "browser_setup", + "status": "failed", + "reason": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + return None, None + + async def observe(self) -> None: + page = await self.playwright_manager.get_page() + page_info = await extract_page_info(page, self.config.fullpage, self.config.log_folder) + feature_text = await observe_features( + page_info, + features=self.config.features, + elements_filter=self.config.elements_filter, + log_folder=self.config.log_folder, + fullpage=self.config.fullpage + ) + screenshot = page_info['screenshot_som'] + observation = Observation( + text=feature_text, + image=screenshot, + ) + return observation + + async def execute_action_trajectory(self, action_trajectory: list[dict]) -> None: + if not action_trajectory: + return True + + await self._reset_browser() + print("taking action trajectory") + for action_data in action_trajectory: + print("action_data") + print(action_data) + + # Convert action_data dict to LATSNode + temp_node = LATSNode( + natural_language_description=action_data["natural_language_description"], + action=action_data["action"], + prob=0, + element=action_data["element"], + goal=self.goal, + parent=None # No parent needed for temporary node + ) + + success = await playwright_step_execution( + temp_node, # Pass the node instead of raw action_data + self.goal, + self.playwright_manager, + is_replay=False, + log_folder=self.config.log_folder + ) + + if not success: + return False + return True + + async def generate_candidate_actions(self, node: LATSNode) -> list[dict]: + trajectory = node.get_trajectory() + action_trajectory = node.get_action_trajectory() + await self.execute_action_trajectory(action_trajectory) + observation = await self.observe() + # only root node has no observation at this point + if node.observation is None: + node.observation = observation + actions = await generate_actions_with_observation( + trajectory, + self.goal, + self.images, + openai_client=openai_client, + action_set=self.action_set, + feature_text=observation.text, + screenshot=observation.image, + branching_factor=self.config.branching_factor, + log_folder=self.config.log_folder, + action_generation_model=self.config.action_generation_model, + ) + + page = await self.playwright_manager.get_page() + valid_actions = [] + for action_data in actions: + if action_data["action"] == "FINISH": + continue + + is_bid_action, element_data = await locate_element_from_action(page, action_data["action"]) + if is_bid_action and not element_data: + continue + + action_data['element'] = element_data + valid_actions.append(action_data) + return valid_actions + + async def generate_children(self, node: LATSNode, websocket=None) -> list[LATSNode]: + print(f"{GREEN}-- generating candidate actions...{RESET}") + + children = [] + + action_trajectory = node.get_action_trajectory() + candidate_actions = await self.generate_candidate_actions(node) + print(f"{GREEN}-- generated {len(candidate_actions)} actions{RESET}") + for action_data in candidate_actions: + print(f"{GREEN}--- {action_data['action']}{RESET}") + print(f"{GREEN}--- {action_data['natural_language_description']}{RESET}") + + print(f"") + print(f"{GREEN}-- executing candidate trajectories{RESET}") + for i, action_data in enumerate(candidate_actions): + + candidate_action_trajectory = action_trajectory + [action_data] + print(f"{GREEN}--- trajectory {i+1}:{RESET}") + for action in candidate_action_trajectory: + print(f"{GREEN}---- {action['action']}{RESET}") + print(f"{GREEN}---- {action['natural_language_description']}{RESET}") + executed_successfully = await self.execute_action_trajectory(candidate_action_trajectory) + if not executed_successfully: + # not executed successfully, give up this candidate + print(f"{RED}--- failed to execute action trajectory{RESET}") + continue + + observation = await self.observe() + print(f"{GREEN}--- generate feedback...{RESET}") + feedback = await generate_feedback_with_screenshot( + self.goal, + action_data["natural_language_description"], + observation.image, + model=self.config.feedback_model, + ) + print(f"feedback: is_done: {feedback.is_done}, explanation: {feedback.explanation}") + + child = LATSNode( + natural_language_description=action_data["natural_language_description"], + action=action_data["action"], + prob=action_data["prob"], + element=action_data["element"], + goal=node.goal, + ) + child.observation = observation + child.goal_finish_feedback = feedback + if feedback.is_done: + # the goal is finished, stop the search + return [child] + + children.append(child) + + if node.depth + 1 >= self.config.max_depth: + child.is_terminal = True + + return children + + def set_goal_finished(self, node: LATSNode) -> None: + self.goal_finished = True + self.result_node = node + + def get_path_to_root(self, node: LATSNode) -> List[LATSNode]: + path = [] + current = node + while current: + path.append(current) + current = current.parent + return list(reversed(path)) + + def _get_tree_data(self): + """Get tree data in a format suitable for visualization""" + nodes = collect_all_nodes(self.root_node) + tree_data = [] + + for node in nodes: + node_data = { + "id": id(node), + "parent_id": id(node.parent) if node.parent else None, + "action": node.action if node.action else "ROOT", + "description": node.natural_language_description, + "depth": node.depth, + "is_terminal": node.is_terminal, + "value": node.value, + "visits": node.visits, + "reward": node.reward + } + tree_data.append(node_data) + + return tree_data diff --git a/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/lats_node.py b/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/lats_node.py new file mode 100644 index 0000000..911255f --- /dev/null +++ b/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/lats_node.py @@ -0,0 +1,207 @@ +import numpy as np +from dataclasses import dataclass +from typing import Optional +from pydantic import BaseModel +import base64 +from ...webagent_utils_async.evaluation.feedback import Feedback + +@dataclass +class Element: + """Represents a DOM element with its properties.""" + text: str + tag: str + id: str + title: str + ariaLabel: str + name: str + value: str + placeholder: str + class_name: str # Changed from 'class' as it's a reserved keyword + role: str + unique_selector: str + selector_uniqueness_validated: bool + +class Observation(BaseModel): + text: str + image: Optional[bytes] = None + image_base64: Optional[str] = None + + def get_base64_image(self): + if self.image_base64 is None: + self.image_base64 = base64.b64encode(self.image).decode('utf-8') + return self.image_base64 + +class LATSNode: + """ + A node class for Language-based Action Tree Search (LATS). + + This class implements a tree structure for MCTS-like search algorithms, + specifically designed for language-based action planning in UI interactions. + + Attributes: + natural_language_description (str): Human-readable description of the action + action (str): The actual action to be executed + prob (float): Probability or confidence score for this action + element (Element): DOM element associated with this action + goal (str): The target goal state + parent (Optional[LATSNode]): Parent node in the tree + children (list[LATSNode]): Child nodes in the tree + visits (int): Number of times this node has been visited + value (float): Accumulated value/score of this node + depth (int): Depth of this node in the tree + is_terminal (bool): Whether this node is a terminal state + reward (float): Reward received at this node + exhausted (bool): Whether all children have been explored + em (float): Exact match score for evaluation + """ + + def __init__( + self, + natural_language_description: str, + action: str, + prob: float, + element: dict, # Using dict instead of Element for backward compatibility + goal: str, + parent: Optional['LATSNode'] = None + ) -> None: + """ + Initialize a new LATSNode. + + Args: + natural_language_description: Human-readable description of the action + action: The actual action to be executed + prob: Probability or confidence score for this action + element: DOM element associated with this action + goal: The target goal state + parent: Parent node in the tree, if any + """ + self.natural_language_description = natural_language_description + self.action = action + self.prob = prob + self.element = element + self.feedback = '' + self.goal_finish_feedback: Optional[Feedback] = None + self.parent = parent + self.goal = goal + self.children: list[LATSNode] = [] + self.visits = 0 + self.value = 0.0 + self.depth = 0 if parent is None else parent.depth + 1 + self.is_terminal = False + self.reward = 0.0 + self.exhausted = False # If all children are terminal + self.em = 0.0 # Exact match, evaluation metric + self.observation: Optional[Observation] = None + + def uct(self) -> float: + """ + Calculate the UCT (Upper Confidence Bound for Trees) value for this node. + + Returns: + float: The UCT value for this node. If the node has never been visited, + returns the node's current value. + """ + if self.visits == 0: + return self.value + return self.value / self.visits + np.sqrt(2 * np.log(self.parent.visits) / self.visits) + + def get_best_leaf(self) -> 'LATSNode': + unfinished_children = [c for c in self.children if not c.is_terminal] + if not unfinished_children: + return self + + best_child = max(unfinished_children, key=lambda x: x.uct()) + return best_child.get_best_leaf() + + def get_action_trajectory(self) -> list[dict]: + trajectory = [] + node = self + # exclude the root node + while node.parent is not None: + trajectory.append({ + "action": node.action, + "natural_language_description": node.natural_language_description, + "element": node.element + }) + node = node.parent + return trajectory[::-1] + + def get_trajectory(self) -> list[dict]: + trajectory = [] + node = self + # exclude the root node + while node.parent is not None: + trajectory.append({ + "natural_language_description": node.natural_language_description, + "action": node.action + }) + node = node.parent + return trajectory[::-1] + + def add_child(self, child: 'LATSNode') -> None: + self.children.append(child) + child.parent = self + child.depth = self.depth + 1 + + def check_terminal(self) -> bool: + if not self.children or all(child.is_terminal for child in self.children): + self.is_terminal = True + if self.parent: + self.parent.check_terminal() + + def __str__(self) -> str: + """ + Get a string representation of the node. + + Returns: + str: A string describing the node's key attributes + """ + return (f"Node(depth={self.depth}, value={self.value:.2f}, " + f"visits={self.visits}, action={self.action}, " + f"feedback={self.feedback})") + + def to_dict(self) -> dict: + """ + Convert the node and its subtree to a dictionary representation. + + Returns: + dict: A dictionary containing all node attributes and recursive + representations of parent and children nodes + """ + return { + 'state': self.state, + 'question': self.question, + 'parent': self.parent.to_dict() if self.parent else None, + 'children': [child.to_dict() for child in self.children], + 'visits': self.visits, + 'value': self.value, + 'depth': self.depth, + 'is_terminal': self.is_terminal, + 'reward': self.reward, + 'em': self.em, + } + + @property + def state(self) -> dict: + """ + Get the current state representation of the node. + + Returns: + dict: A dictionary containing the node's state information + """ + return { + 'natural_language_description': self.natural_language_description, + 'action': self.action, + 'prob': self.prob, + 'element': self.element + } + + @property + def question(self) -> str: + """ + Get the goal/question associated with this node. + + Returns: + str: The goal or question string + """ + return self.goal \ No newline at end of file diff --git a/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/mcts_agent.py b/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/mcts_agent.py new file mode 100644 index 0000000..9f1734c --- /dev/null +++ b/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/mcts_agent.py @@ -0,0 +1,68 @@ +import logging +import time +from typing import Any, Dict, List, Optional +from collections import deque +from datetime import datetime +import os +import json +import subprocess + +from openai import OpenAI +from dotenv import load_dotenv +load_dotenv() +import aiohttp + +from ...core_async.config import AgentConfig + +from ...webagent_utils_async.action.highlevel import HighLevelActionSet +from ...webagent_utils_async.utils.playwright_manager import AsyncPlaywrightManager, setup_playwright +from ...webagent_utils_async.utils.utils import parse_function_args, locate_element +from ...evaluation_async.evaluators import goal_finished_evaluator +from ...replay_async import generate_feedback, playwright_step_execution +from ...webagent_utils_async.action.prompt_functions import extract_top_actions +from ...webagent_utils_async.browser_env.observation import extract_page_info +from .lats_node import LATSNode +from .tree_vis import better_print, print_trajectory, collect_all_nodes, GREEN, RESET, print_entire_tree +from .trajectory_score import create_llm_prompt, score_trajectory_with_openai +from ...webagent_utils_async.utils.utils import urls_to_images + +logger = logging.getLogger(__name__) +openai_client = OpenAI() + +class MCTSAgent: + def __init__( + self, + starting_url: str, + messages: list[dict[str, Any]], + goal: str, + images: list, + playwright_manager: AsyncPlaywrightManager, + config: AgentConfig, + ): + self.starting_url = starting_url + self.goal = goal + self.image_urls = images + self.images = urls_to_images(self.image_urls) + self.messages = messages + self.messages.append({"role": "user", "content": f"The goal is: {self.goal}"}) + + self.playwright_manager = playwright_manager + + self.config = config + + self.agent_type = ["bid", "nav", "file", "select_option"] + self.action_set = HighLevelActionSet( + subsets=self.agent_type, strict=False, multiaction=True, demo_mode="default" + ) + self.root_node = LATSNode( + natural_language_description=None, + action=None, + prob=None, + element=None, + goal=self.goal, + parent=None + ) + self.reset_url = os.environ["ACCOUNT_RESET_URL"] + + async def run(self, websocket=None) -> List[Dict[str, Any]]: + pass \ No newline at end of file diff --git a/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/simple_search_agent.py b/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/simple_search_agent.py new file mode 100644 index 0000000..fc41795 --- /dev/null +++ b/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/simple_search_agent.py @@ -0,0 +1,1149 @@ +import logging +import time +from typing import Any, Dict, List, Optional +from collections import deque +from datetime import datetime +import os +import json +import subprocess + +from openai import OpenAI +from dotenv import load_dotenv +load_dotenv() +import aiohttp + +from ...core_async.config import AgentConfig + +from ...webagent_utils_async.action.highlevel import HighLevelActionSet +from ...webagent_utils_async.utils.playwright_manager import AsyncPlaywrightManager, setup_playwright +from ...webagent_utils_async.utils.utils import parse_function_args, locate_element +from ...evaluation_async.evaluators import goal_finished_evaluator +from ...replay_async import generate_feedback, playwright_step_execution +from ...webagent_utils_async.action.prompt_functions import extract_top_actions +from ...webagent_utils_async.browser_env.observation import extract_page_info +from .lats_node import LATSNode +from .tree_vis import better_print, print_trajectory, collect_all_nodes, GREEN, RESET, print_entire_tree +from .trajectory_score import create_llm_prompt, score_trajectory_with_openai +from ...webagent_utils_async.utils.utils import urls_to_images + +logger = logging.getLogger(__name__) +openai_client = OpenAI() + +class SimpleSearchAgent: + def __init__( + self, + starting_url: str, + messages: list[dict[str, Any]], + goal: str, + images: list, + playwright_manager: AsyncPlaywrightManager, + config: AgentConfig, + ): + self.starting_url = starting_url + self.goal = goal + self.image_urls = images + self.images = urls_to_images(self.image_urls) + self.messages = messages + self.messages.append({"role": "user", "content": f"The goal is: {self.goal}"}) + + self.playwright_manager = playwright_manager + + self.config = config + + self.agent_type = ["bid", "nav", "file", "select_option"] + self.action_set = HighLevelActionSet( + subsets=self.agent_type, strict=False, multiaction=True, demo_mode="default" + ) + self.root_node = LATSNode( + natural_language_description=None, + action=None, + prob=None, + element=None, + goal=self.goal, + parent=None + ) + self.reset_url = os.environ["ACCOUNT_RESET_URL"] + + + async def run(self, websocket=None) -> List[Dict[str, Any]]: + """ + Run the search algorithm based on configuration. + + Args: + websocket: Optional WebSocket connection to send updates to + + Returns: + List[Dict[str, Any]]: List of actions in the best path found + + Raises: + ValueError: If the search algorithm is not supported + """ + algorithm = self.config.search_algorithm.lower() + + if algorithm == "bfs": + logger.info("Starting BFS algorithm") + if websocket: + return await self.bfs_with_websocket(websocket) + else: + return await self.bfs() + elif algorithm == "dfs": + logger.info("Starting DFS algorithm") + if websocket: + return await self.dfs_with_websocket(websocket) + else: + return await self.dfs() + else: + error_msg = f"Unsupported algorithm: {algorithm}" + logger.error(error_msg) + if websocket: + await websocket.send_json({ + "type": "error", + "message": error_msg, + "timestamp": datetime.utcnow().isoformat() + }) + raise ValueError(error_msg) + + async def _reset_browser(self, websocket=None) -> Optional[str]: + """Reset the browser to initial state and return the live browser URL if available.""" + await self.playwright_manager.close() + + ## reset account using api-based account reset + if self.config.account_reset: + if websocket: + await websocket.send_json({ + "type": "account_reset", + "status": "started", + "timestamp": datetime.utcnow().isoformat() + }) + + try: + # Use aiohttp instead of curl + async with aiohttp.ClientSession() as session: + headers = {'Connection': 'close'} # Similar to curl -N + async with session.get(self.reset_url, headers=headers) as response: + if response.status == 200: + data = await response.json() + print(f"Account reset successful: {data}") + if websocket: + await websocket.send_json({ + "type": "account_reset", + "status": "success", + "data": data, + "timestamp": datetime.utcnow().isoformat() + }) + else: + error_msg = f"Account reset failed with status {response.status}" + print(error_msg) + if websocket: + await websocket.send_json({ + "type": "account_reset", + "status": "failed", + "reason": error_msg, + "timestamp": datetime.utcnow().isoformat() + }) + + except Exception as e: + print(f"Error during account reset: {e}") + if websocket: + await websocket.send_json({ + "type": "account_reset", + "status": "failed", + "reason": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + + try: + # Create new playwright manager + self.playwright_manager = await setup_playwright( + storage_state=self.config.storage_state, + headless=self.config.headless, + mode=self.config.browser_mode + ) + page = await self.playwright_manager.get_page() + live_browser_url = None + if self.config.browser_mode == "browserbase": + live_browser_url = await self.playwright_manager.get_live_browser_url() + session_id = await self.playwright_manager.get_session_id() + else: + session_id = None + live_browser_url = None + await page.goto(self.starting_url, wait_until="networkidle") + + # Send success message if websocket is provided + if websocket: + if self.config.storage_state: + await websocket.send_json({ + "type": "browser_setup", + "status": "success", + "message": f"Browser successfully initialized with storage state file: {self.config.storage_state}", + "live_browser_url": live_browser_url, + "session_id": session_id, + "timestamp": datetime.utcnow().isoformat() + }) + else: + await websocket.send_json({ + "type": "browser_setup", + "status": "success", + "message": "Browser successfully initialized", + "live_browser_url": live_browser_url, + "session_id": session_id, + "timestamp": datetime.utcnow().isoformat() + }) + + return live_browser_url, session_id + except Exception as e: + print(f"Error setting up browser: {e}") + if websocket: + await websocket.send_json({ + "type": "browser_setup", + "status": "failed", + "reason": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + return None, None + + async def expand(self, node: LATSNode, websocket=None) -> None: + """ + Expand a node by generating its children. + + Args: + node: Node to expand + websocket: Optional WebSocket connection to send updates to + """ + children_state = await self.generate_children(node, websocket) + for child_state in children_state: + child = LATSNode( + natural_language_description=child_state["natural_language_description"], + action=child_state["action"], + prob=child_state["prob"], + element=child_state["element"], + goal=node.goal, + parent=node + ) + node.children.append(child) + + # Send child creation update if websocket is provided + if websocket: + await websocket.send_json({ + "type": "node_created", + "node_id": id(child), + "parent_id": id(node), + "action": child.action, + "description": child.natural_language_description, + "timestamp": datetime.utcnow().isoformat() + }) + + async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]: + """ + Generate child nodes for a given node. + + Args: + node: Parent node to generate children for + websocket: Optional WebSocket connection to send updates to + + Returns: + list[dict]: List of child state dictionaries + """ + # Reset browser and get live URL + live_browser_url, session_id = await self._reset_browser(websocket) + path = self.get_path_to_root(node) + + # Execute path + for n in path[1:]: # Skip root node + if websocket: + await websocket.send_json({ + "type": "replaying_action", + "node_id": id(n), + "action": n.action, + "timestamp": datetime.utcnow().isoformat() + }) + + success = await playwright_step_execution( + n, + self.goal, + self.playwright_manager, + is_replay=False, + log_folder=self.config.log_folder + ) + if not success: + n.is_terminal = True + if websocket: + await websocket.send_json({ + "type": "replay_failed", + "node_id": id(n), + "timestamp": datetime.utcnow().isoformat() + }) + return [] + + if not n.feedback: + n.feedback = await generate_feedback( + self.goal, + n.natural_language_description, + self.playwright_manager, + ) + if websocket: + await websocket.send_json({ + "type": "feedback_generated", + "node_id": id(n), + "feedback": n.feedback, + "timestamp": datetime.utcnow().isoformat() + }) + + time.sleep(3) + page = await self.playwright_manager.get_page() + page_info = await extract_page_info(page, self.config.fullpage, self.config.log_folder) + + messages = [{"role": "user", "content": f"Action is: {n.action}"} for n in path[1:]] + + if websocket: + await websocket.send_json({ + "type": "generating_actions", + "node_id": id(node), + "timestamp": datetime.utcnow().isoformat() + }) + + next_actions = await extract_top_actions( + [{"natural_language_description": n.natural_language_description, "action": n.action, "feedback": n.feedback} for n in path[1:]], + self.goal, + self.images, + page_info, + self.action_set, + openai_client, + features=self.config.features, + elements_filter=self.config.elements_filter, + branching_factor=self.config.branching_factor, + log_folder=self.config.log_folder, + fullpage=self.config.fullpage, + action_generation_model=self.config.action_generation_model, + action_grounding_model=self.config.action_grounding_model + ) + + children = [] + for action in next_actions: + if action["action"] == "FINISH": + if action["prob"] > 0.2: + node.is_terminal = True + if websocket: + await websocket.send_json({ + "type": "node_terminal", + "node_id": id(node), + "reason": "finish_action", + "timestamp": datetime.utcnow().isoformat() + }) + return [] + continue + + page = await self.playwright_manager.get_page() + code, function_calls = self.action_set.to_python_code(action["action"]) + + if len(function_calls) == 1: + try: + for function_name, function_args in function_calls: + extracted_number = parse_function_args(function_args) + element = await locate_element(page, extracted_number) + action["element"] = element + except Exception as e: + action["element"] = None + if websocket: + await websocket.send_json({ + "type": "element_location_failed", + "action": action["action"], + "error": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + children.append(action) + + if not children: + node.is_terminal = True + if websocket: + await websocket.send_json({ + "type": "node_terminal", + "node_id": id(node), + "reason": "no_valid_actions", + "timestamp": datetime.utcnow().isoformat() + }) + + return children + + def get_path_to_root(self, node: LATSNode) -> List[LATSNode]: + path = [] + current = node + while current: + path.append(current) + current = current.parent + return list(reversed(path)) + + async def bfs(self) -> List[Dict[str, Any]]: + """ + Performs breadth-first search starting from the root node. + Skips nodes that are marked as terminal. + + Returns: + List[Dict[str, Any]]: List of actions in the best path found + """ + queue = deque([self.root_node]) + queue_set = {self.root_node} # Track nodes in queue + best_score = float('-inf') + best_path = None + visited = set() # Track visited nodes to avoid cycles + current_level = 0 # Track current level for BFS + + try: + while queue: + # Process all nodes at current level + level_size = len(queue) + current_level += 1 + level_nodes = [] # Store nodes at current level for later processing + + # First, expand all nodes at current level + for _ in range(level_size): + current_node = queue.popleft() + queue_set.remove(current_node) # Remove from queue tracking + + # Skip if we've already visited this node + if current_node in visited: + continue + + visited.add(current_node) + + # Skip terminal nodes + if current_node.is_terminal: + logger.info(f"Node {id(current_node)} is terminal") + continue + + # Expand current node if it hasn't been expanded yet and hasn't reached max_depth + if not current_node.children and current_node.depth < self.config.max_depth: + try: + await self.expand(current_node) + except Exception as e: + error_msg = f"Error expanding node {id(current_node)}: {str(e)}" + logger.error(error_msg) + current_node.is_terminal = True + continue + + # Store node for later processing + level_nodes.append(current_node) + + # Add non-terminal children to queue for next level if they haven't reached max_depth + for child in current_node.children: + if not child.is_terminal and child not in visited and child not in queue_set and child.depth < self.config.max_depth: + queue.append(child) + queue_set.add(child) # Add to queue tracking + + # Now process all nodes at current level + for current_node in level_nodes: + print("print the trajectory") + print_trajectory(current_node) + print("print the entire tree") + print_entire_tree(self.root_node) + + # Get the path from root to this node + path = self.get_path_to_root(current_node) + + # Create trajectory for scoring + trajectory = [] + for node in path[1:]: # Skip root node + trajectory.append({ + "natural_language_description": node.natural_language_description, + "action": node.action, + "feedback": node.feedback + }) + + try: + # Score the trajectory + prompt = create_llm_prompt(trajectory, self.goal) + result = score_trajectory_with_openai(prompt, openai_client, model=self.config.evaluation_model) + score = result["overall_score"] + except Exception as e: + error_msg = f"Error scoring node {id(current_node)}: {str(e)}" + logger.error(error_msg) + score = float('-inf') + + # Update best path if this score is better + if score > best_score: + best_score = score + best_path = path + + logger.info(f"Node {id(current_node)} score: {score}") + + # If we've found a satisfactory solution, return it + if score >= 0.75: + logger.info(f"Found satisfactory solution with score {score}") + return [{"action": node.action} for node in path[1:]] + + # If we've exhausted all nodes and haven't found a perfect solution, + # return the best path we found + if best_path: + logger.info(f"Returning best path found with score {best_score}") + return [{"action": node.action} for node in best_path[1:]] + + # If no path was found at all + logger.warning("No valid path found") + return [] + + except Exception as e: + error_msg = f"Error in BFS search: {str(e)}" + logger.error(error_msg) + if best_path: + logger.info(f"Returning best path found before error with score {best_score}") + return [{"action": node.action} for node in best_path[1:]] + return [] + + async def dfs(self) -> List[Dict[str, Any]]: + """ + Performs depth-first search starting from the root node. + Skips nodes that are marked as terminal. + + Returns: + List[Dict[str, Any]]: List of actions in the best path found + """ + stack = [self.root_node] + stack_set = {self.root_node} # Track nodes in stack + best_score = float('-inf') + best_path = None + visited = set() # Track visited nodes to avoid cycles + current_path = [] # Track current path for DFS + + try: + while stack: + current_node = stack[-1] # Peek at the top node without removing it + + # Skip if we've already visited this node + if current_node in visited: + stack.pop() + stack_set.remove(current_node) + if current_path: + current_path.pop() # Remove from current path + continue + + visited.add(current_node) + current_path.append(current_node) # Add to current path + + # Skip terminal nodes + if current_node.is_terminal: + logger.info(f"Node {id(current_node)} is terminal") + stack.pop() + stack_set.remove(current_node) + current_path.pop() # Remove from current path + continue + + # Expand current node if it hasn't been expanded yet and hasn't reached max_depth + if not current_node.children and current_node.depth < self.config.max_depth: + try: + await self.expand(current_node) + except Exception as e: + error_msg = f"Error expanding node {id(current_node)}: {str(e)}" + logger.error(error_msg) + current_node.is_terminal = True + stack.pop() + stack_set.remove(current_node) + current_path.pop() # Remove from current path + continue + + print("print the trajectory") + print_trajectory(current_node) + print("print the entire tree") + print_entire_tree(self.root_node) + + # Get the path from root to this node + path = self.get_path_to_root(current_node) + + # Create trajectory for scoring + trajectory = [] + for node in path[1:]: # Skip root node + trajectory.append({ + "natural_language_description": node.natural_language_description, + "action": node.action, + "feedback": node.feedback + }) + + try: + # Score the trajectory + prompt = create_llm_prompt(trajectory, self.goal) + result = score_trajectory_with_openai(prompt, openai_client, model=self.config.evaluation_model) + score = result["overall_score"] + except Exception as e: + error_msg = f"Error scoring node {id(current_node)}: {str(e)}" + logger.error(error_msg) + score = float('-inf') + + # Update best path if this score is better + if score > best_score: + best_score = score + best_path = path + + logger.info(f"Node {id(current_node)} score: {score}") + + # If we've found a satisfactory solution, return it + if score >= 0.75: + logger.info(f"Found satisfactory solution with score {score}") + return [{"action": node.action} for node in path[1:]] + + # Add non-terminal children to stack in reverse order if they haven't reached max_depth + has_unvisited_children = False + for child in reversed(current_node.children): + if not child.is_terminal and child not in visited and child not in stack_set and child.depth < self.config.max_depth: + stack.append(child) + stack_set.add(child) # Add to stack tracking + has_unvisited_children = True + break # Only add one child at a time for DFS + + # If no unvisited children, remove current node from stack + if not has_unvisited_children: + stack.pop() + stack_set.remove(current_node) + current_path.pop() # Remove from current path + + # If we've exhausted all nodes and haven't found a perfect solution, + # return the best path we found + if best_path: + logger.info(f"Returning best path found with score {best_score}") + return [{"action": node.action} for node in best_path[1:]] + + # If no path was found at all + logger.warning("No valid path found") + return [] + + except Exception as e: + error_msg = f"Error in DFS search: {str(e)}" + logger.error(error_msg) + if best_path: + logger.info(f"Returning best path found before error with score {best_score}") + return [{"action": node.action} for node in best_path[1:]] + return [] + + async def bfs_with_websocket(self, websocket=None) -> List[Dict[str, Any]]: + """ + Performs breadth-first search starting from the root node with WebSocket updates. + Skips nodes that are marked as terminal. + + Args: + websocket: Optional WebSocket connection to send updates to + + Returns: + List[Dict[str, Any]]: List of actions in the best path found + """ + queue = deque([self.root_node]) + queue_set = {self.root_node} # Track nodes in queue + best_score = float('-inf') + best_path = None + visited = set() # Track visited nodes to avoid cycles + current_level = 0 # Track current level for BFS + + try: + # Get the live browser URL during initial setup + live_browser_url, session_id = await self._reset_browser(websocket) + + # Send initial status if websocket is provided + if websocket: + await websocket.send_json({ + "type": "search_status", + "status": "started", + "message": "BFS search started", + "timestamp": datetime.utcnow().isoformat(), + "live_browser_url": live_browser_url, + "session_id": session_id + }) + + while queue: + # Process all nodes at current level + level_size = len(queue) + current_level += 1 + level_nodes = [] # Store nodes at current level for later processing + + if websocket: + await websocket.send_json({ + "type": "level_start", + "level": current_level, + "nodes_in_level": level_size, + "timestamp": datetime.utcnow().isoformat() + }) + + # First, expand all nodes at current level + for _ in range(level_size): + current_node = queue.popleft() + queue_set.remove(current_node) # Remove from queue tracking + + # Skip if we've already visited this node + if current_node in visited: + if websocket: + await websocket.send_json({ + "type": "node_skipped", + "node_id": id(current_node), + "reason": "already_visited", + "timestamp": datetime.utcnow().isoformat() + }) + continue + + visited.add(current_node) + + # Skip terminal nodes + if current_node.is_terminal: + if websocket: + await websocket.send_json({ + "type": "node_terminal", + "node_id": id(current_node), + "reason": "terminal_node", + "timestamp": datetime.utcnow().isoformat() + }) + continue + + # Expand current node if it hasn't been expanded yet and hasn't reached max_depth + if not current_node.children and current_node.depth < self.config.max_depth: + if websocket: + await websocket.send_json({ + "type": "node_expanding", + "node_id": id(current_node), + "timestamp": datetime.utcnow().isoformat() + }) + + try: + await self.expand(current_node, websocket) + except Exception as e: + error_msg = f"Error expanding node {id(current_node)}: {str(e)}" + logger.error(error_msg) + current_node.is_terminal = True + if websocket: + await websocket.send_json({ + "type": "node_error", + "node_id": id(current_node), + "error": error_msg, + "timestamp": datetime.utcnow().isoformat() + }) + continue + + # Send tree update after expansion + if websocket: + tree_data = self._get_tree_data() + await websocket.send_json({ + "type": "tree_update", + "tree": tree_data, + "timestamp": datetime.utcnow().isoformat() + }) + + # Store node for later processing + level_nodes.append(current_node) + + # Add non-terminal children to queue for next level if they haven't reached max_depth + for child in current_node.children: + if not child.is_terminal and child not in visited and child not in queue_set and child.depth < self.config.max_depth: + queue.append(child) + queue_set.add(child) # Add to queue tracking + + # Send queue update if websocket is provided + if websocket: + await websocket.send_json({ + "type": "node_queued", + "node_id": id(child), + "parent_id": id(current_node), + "timestamp": datetime.utcnow().isoformat() + }) + + # Now process all nodes at current level + for current_node in level_nodes: + # Send node processing update if websocket is provided + if websocket: + await websocket.send_json({ + "type": "node_processing", + "node_id": id(current_node), + "depth": current_node.depth, + "timestamp": datetime.utcnow().isoformat() + }) + + print("print the trajectory") + print_trajectory(current_node) + print("print the entire tree") + print_entire_tree(self.root_node) + + # Get the path from root to this node + path = self.get_path_to_root(current_node) + + # Create trajectory for scoring + trajectory = [] + for node in path[1:]: # Skip root node + trajectory.append({ + "natural_language_description": node.natural_language_description, + "action": node.action, + "feedback": node.feedback + }) + + try: + # Score the trajectory + prompt = create_llm_prompt(trajectory, self.goal) + result = score_trajectory_with_openai(prompt, openai_client, model=self.config.evaluation_model) + score = result["overall_score"] + except Exception as e: + error_msg = f"Error scoring node {id(current_node)}: {str(e)}" + logger.error(error_msg) + score = float('-inf') + if websocket: + await websocket.send_json({ + "type": "node_error", + "node_id": id(current_node), + "error": error_msg, + "timestamp": datetime.utcnow().isoformat() + }) + + # Send score update if websocket is provided + if websocket: + await websocket.send_json({ + "type": "node_scored", + "node_id": id(current_node), + "score": score, + "timestamp": datetime.utcnow().isoformat() + }) + + # Update best path if this score is better + if score > best_score: + best_score = score + best_path = path + + # Send best path update if websocket is provided + if websocket: + await websocket.send_json({ + "type": "best_path_update", + "score": best_score, + "path": [{"id": id(node), "action": node.action} for node in best_path[1:]], + "timestamp": datetime.utcnow().isoformat() + }) + + logger.info(f"Node {id(current_node)} score: {score}") + + # If we've found a satisfactory solution, return it + if score >= 0.75: + logger.info(f"Found satisfactory solution with score {score}") + + # Send completion update if websocket is provided + if websocket: + await websocket.send_json({ + "type": "search_complete", + "status": "success", + "score": score, + "path": [{"id": id(node), "action": node.action} for node in path[1:]], + "timestamp": datetime.utcnow().isoformat() + }) + + return [{"action": node.action} for node in path[1:]] + + if websocket: + await websocket.send_json({ + "type": "level_complete", + "level": current_level, + "timestamp": datetime.utcnow().isoformat() + }) + + # If we've exhausted all nodes and haven't found a perfect solution, + # return the best path we found + if best_path: + logger.info(f"Returning best path found with score {best_score}") + + # Send completion update if websocket is provided + if websocket: + await websocket.send_json({ + "type": "search_complete", + "status": "partial_success", + "score": best_score, + "path": [{"id": id(node), "action": node.action} for node in best_path[1:]], + "timestamp": datetime.utcnow().isoformat() + }) + + return [{"action": node.action} for node in best_path[1:]] + + # If no path was found at all + logger.warning("No valid path found") + + # Send failure update if websocket is provided + if websocket: + await websocket.send_json({ + "type": "search_complete", + "status": "failure", + "message": "No valid path found", + "timestamp": datetime.utcnow().isoformat() + }) + + return [] + + except Exception as e: + error_msg = f"Error in BFS search: {str(e)}" + logger.error(error_msg) + if websocket: + await websocket.send_json({ + "type": "search_error", + "error": error_msg, + "timestamp": datetime.utcnow().isoformat() + }) + if best_path: + logger.info(f"Returning best path found before error with score {best_score}") + return [{"action": node.action} for node in best_path[1:]] + return [] + + async def dfs_with_websocket(self, websocket=None) -> List[Dict[str, Any]]: + """ + Performs depth-first search starting from the root node with WebSocket updates. + Skips nodes that are marked as terminal. + + Args: + websocket: Optional WebSocket connection to send updates to + + Returns: + List[Dict[str, Any]]: List of actions in the best path found + """ + stack = [self.root_node] + stack_set = {self.root_node} # Track nodes in stack + best_score = float('-inf') + best_path = None + visited = set() # Track visited nodes to avoid cycles + current_path = [] # Track current path for DFS + + try: + # Get the live browser URL during initial setup + live_browser_url, session_id = await self._reset_browser(websocket) + + # Send initial status if websocket is provided + if websocket: + await websocket.send_json({ + "type": "search_status", + "status": "started", + "message": "DFS search started", + "timestamp": datetime.utcnow().isoformat(), + "live_browser_url": live_browser_url, + "session_id": session_id + }) + + while stack: + current_node = stack[-1] # Peek at the top node without removing it + + # Skip if we've already visited this node + if current_node in visited: + stack.pop() + stack_set.remove(current_node) + if current_path: + current_path.pop() # Remove from current path + if websocket: + await websocket.send_json({ + "type": "node_backtrack", + "node_id": id(current_node), + "reason": "already_visited", + "timestamp": datetime.utcnow().isoformat() + }) + continue + + visited.add(current_node) + current_path.append(current_node) # Add to current path + + # Skip terminal nodes + if current_node.is_terminal: + logger.info(f"Node {id(current_node)} is terminal") + stack.pop() + stack_set.remove(current_node) + current_path.pop() # Remove from current path + if websocket: + await websocket.send_json({ + "type": "node_backtrack", + "node_id": id(current_node), + "reason": "terminal_node", + "timestamp": datetime.utcnow().isoformat() + }) + continue + + # Expand current node if it hasn't been expanded yet and hasn't reached max_depth + if not current_node.children and current_node.depth < self.config.max_depth: + if websocket: + await websocket.send_json({ + "type": "node_expanding", + "node_id": id(current_node), + "timestamp": datetime.utcnow().isoformat() + }) + + try: + await self.expand(current_node, websocket) + except Exception as e: + error_msg = f"Error expanding node {id(current_node)}: {str(e)}" + logger.error(error_msg) + current_node.is_terminal = True + stack.pop() + stack_set.remove(current_node) + current_path.pop() # Remove from current path + if websocket: + await websocket.send_json({ + "type": "node_backtrack", + "node_id": id(current_node), + "reason": "expansion_error", + "error": error_msg, + "timestamp": datetime.utcnow().isoformat() + }) + continue + + # Send tree update after expansion + if websocket: + tree_data = self._get_tree_data() + await websocket.send_json({ + "type": "tree_update", + "tree": tree_data, + "timestamp": datetime.utcnow().isoformat() + }) + + # Get the path from root to this node + path = self.get_path_to_root(current_node) + + # Create trajectory for scoring + trajectory = [] + for node in path[1:]: # Skip root node + trajectory.append({ + "natural_language_description": node.natural_language_description, + "action": node.action, + "feedback": node.feedback + }) + + try: + # Score the trajectory + prompt = create_llm_prompt(trajectory, self.goal) + result = score_trajectory_with_openai(prompt, openai_client, model=self.config.evaluation_model) + score = result["overall_score"] + except Exception as e: + error_msg = f"Error scoring node {id(current_node)}: {str(e)}" + logger.error(error_msg) + score = float('-inf') + if websocket: + await websocket.send_json({ + "type": "node_error", + "node_id": id(current_node), + "error": error_msg, + "timestamp": datetime.utcnow().isoformat() + }) + + # Send score update if websocket is provided + if websocket: + await websocket.send_json({ + "type": "node_scored", + "node_id": id(current_node), + "score": score, + "timestamp": datetime.utcnow().isoformat() + }) + + # Update best path if this score is better + if score > best_score: + best_score = score + best_path = path + + # Send best path update if websocket is provided + if websocket: + await websocket.send_json({ + "type": "best_path_update", + "score": best_score, + "path": [{"id": id(node), "action": node.action} for node in best_path[1:]], + "timestamp": datetime.utcnow().isoformat() + }) + + logger.info(f"Node {id(current_node)} score: {score}") + + # If we've found a satisfactory solution, return it + if score >= 0.75: + logger.info(f"Found satisfactory solution with score {score}") + + # Send completion update if websocket is provided + if websocket: + await websocket.send_json({ + "type": "search_complete", + "status": "success", + "score": score, + "path": [{"id": id(node), "action": node.action} for node in path[1:]], + "timestamp": datetime.utcnow().isoformat() + }) + + return [{"action": node.action} for node in path[1:]] + + # Add non-terminal children to stack in reverse order + has_unvisited_children = False + for child in reversed(current_node.children): + if not child.is_terminal and child not in visited and child not in stack_set: + stack.append(child) + stack_set.add(child) # Add to stack tracking + has_unvisited_children = True + + # Send stack update if websocket is provided + if websocket: + await websocket.send_json({ + "type": "node_stacked", + "node_id": id(child), + "parent_id": id(current_node), + "timestamp": datetime.utcnow().isoformat() + }) + break # Only add one child at a time for DFS + + # If no unvisited children, remove current node from stack + if not has_unvisited_children: + stack.pop() + stack_set.remove(current_node) + current_path.pop() # Remove from current path + if websocket: + await websocket.send_json({ + "type": "node_backtrack", + "node_id": id(current_node), + "reason": "no_unvisited_children", + "timestamp": datetime.utcnow().isoformat() + }) + + # If we've exhausted all nodes and haven't found a perfect solution, + # return the best path we found + if best_path: + logger.info(f"Returning best path found with score {best_score}") + + # Send completion update if websocket is provided + if websocket: + await websocket.send_json({ + "type": "search_complete", + "status": "partial_success", + "score": best_score, + "path": [{"id": id(node), "action": node.action} for node in best_path[1:]], + "timestamp": datetime.utcnow().isoformat() + }) + + return [{"action": node.action} for node in best_path[1:]] + + # If no path was found at all + logger.warning("No valid path found") + + # Send failure update if websocket is provided + if websocket: + await websocket.send_json({ + "type": "search_complete", + "status": "failure", + "message": "No valid path found", + "timestamp": datetime.utcnow().isoformat() + }) + + return [] + + except Exception as e: + error_msg = f"Error in DFS search: {str(e)}" + logger.error(error_msg) + if websocket: + await websocket.send_json({ + "type": "search_error", + "error": error_msg, + "timestamp": datetime.utcnow().isoformat() + }) + if best_path: + logger.info(f"Returning best path found before error with score {best_score}") + return [{"action": node.action} for node in best_path[1:]] + return [] + + def _get_tree_data(self): + """Get tree data in a format suitable for visualization""" + nodes = collect_all_nodes(self.root_node) + tree_data = [] + + for node in nodes: + node_data = { + "id": id(node), + "parent_id": id(node.parent) if node.parent else None, + "action": node.action if node.action else "ROOT", + "description": node.natural_language_description, + "depth": node.depth, + "is_terminal": node.is_terminal + } + tree_data.append(node_data) + + return tree_data + diff --git a/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/trajectory_score.py b/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/trajectory_score.py new file mode 100644 index 0000000..1bbe9af --- /dev/null +++ b/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/trajectory_score.py @@ -0,0 +1,209 @@ +"""Module for scoring and evaluating action trajectories using LLMs.""" + +import base64 +import json +import datetime +from typing import Any, Optional, List, Dict, TypedDict +from openai import OpenAI + +class TrajectoryMetrics(TypedDict): + """Structured metrics for trajectory evaluation.""" + overall_score: float + efficiency_score: float + accuracy_score: float + robustness_score: float + detailed_explanation: str + improvement_suggestions: List[str] + key_achievements: List[str] + potential_issues: List[str] + metadata: Dict[str, Any] + +SYSTEM_PROMPT = \ +"""You are an expert web task completion evaluator. Your task is to provide a comprehensive evaluation of web task completion +by analyzing the trajectory against the desired goal. Consider multiple aspects of the task execution and provide detailed feedback. + +Analyze the provided trajectory and screenshot of the web page, return a JSON response with: +1. overall_score (float 0-10): Overall task completion score +2. efficiency_score (float 0-10): How well the task was completed (minimal steps, optimal path) +3. accuracy_score (float 0-10): How precisely the actions were executed +4. robustness_score (float 0-10): How well the solution handles edge cases +5. detailed_explanation (string): Comprehensive analysis of the execution +6. improvement_suggestions (list of strings): Specific ways to improve the solution +7. key_achievements (list of strings): Important milestones reached +8. potential_issues (list of strings): Areas that could be problematic + +Example format: +{ + "overall_score": 8.5, + "efficiency_score": 9.0, + "accuracy_score": 8.0, + "robustness_score": 7.5, + "detailed_explanation": "The trajectory effectively achieves the goal with minimal steps...", + "improvement_suggestions": ["Could have used more efficient selectors", "Consider adding error handling"], + "key_achievements": ["Successfully logged in", "Found target element"], + "potential_issues": ["No timeout handling", "Assumes specific page layout"] +} +""" + +USER_PROMPT_TEMPLATE = \ +"""Goal: {goal} + +Trajectory: +{trajectory_str} + +Current Page State: +{page_state} + +Please provide a comprehensive evaluation of the task completion.""" + +def format_trajectory_step(step: Dict[str, Any], index: int) -> str: + """Format a single trajectory step with detailed information.""" + return f"""Step {index}: + Action: {step['action']} + Description: {step['natural_language_description']} + Target: {step.get('target', 'N/A')} + Status: {step.get('status', 'completed')} + Output: {step.get('output', 'N/A')}""" + +def create_llm_prompt( + trajectory: List[Dict[str, Any]], + goal: str, + page_state: Optional[Dict[str, Any]] = None +) -> str: + """ + Creates a prompt for LLM scoring and processes trajectory information. + + Args: + trajectory: List of dictionaries containing action and description + goal: The goal of the trajectory + page_state: Optional dictionary containing current page state information + + Returns: + str: Formatted prompt string + """ + # Format trajectory steps with more detail + trajectory_str = "\n\n".join( + format_trajectory_step(step, i+1) + for i, step in enumerate(trajectory) + ) + + # Format page state if available + page_state_str = "No page state information available" + if page_state: + page_state_str = json.dumps(page_state, indent=2) + + prompt = USER_PROMPT_TEMPLATE.format( + goal=goal, + trajectory_str=trajectory_str, + page_state=page_state_str + ) + return prompt + +def validate_evaluation(evaluation: Dict[str, Any]) -> bool: + """Validate the evaluation output has all required fields and correct types.""" + required_fields = { + 'overall_score': (int, float), + 'efficiency_score': (int, float), + 'accuracy_score': (int, float), + 'robustness_score': (int, float), + 'detailed_explanation': str, + 'improvement_suggestions': list, + 'key_achievements': list, + 'potential_issues': list + } + + for field, expected_type in required_fields.items(): + if field not in evaluation: + return False + if not isinstance(evaluation[field], expected_type): + return False + if isinstance(evaluation[field], (int, float)): + if not 0 <= evaluation[field] <= 10: + return False + + return True + +def normalize_scores(evaluation: Dict[str, Any]) -> Dict[str, Any]: + """Normalize all scores to be between 0 and 1.""" + score_fields = ['overall_score', 'efficiency_score', 'accuracy_score', 'robustness_score'] + for field in score_fields: + if field in evaluation: + evaluation[field] = evaluation[field] / 10.0 + return evaluation + +def score_trajectory_with_openai( + prompt: str, + openai_client: OpenAI, + model: str = "gpt-4o", + screenshot: Optional[bytes] = None +) -> Dict[str, Any]: + """ + Uses OpenAI to score the trajectory based on the provided prompt. + + Args: + prompt: The prompt to send to OpenAI + openai_client: OpenAI client instance + model: OpenAI model to use + screenshot: Screenshot of the current page + + Returns: + dict: Parsed response containing comprehensive evaluation + """ + system_message = SYSTEM_PROMPT + + try: + content = [ + {"type": "text", "text": prompt}, + ] + if screenshot is not None: + base64_image = base64.b64encode(screenshot).decode('utf-8') + content.append({ + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}", + "detail": "high" + } + }) + + response = openai_client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": system_message}, + {"role": "user", "content": content} + ], + response_format={"type": "json_object"} + ) + + evaluation = json.loads(response.choices[0].message.content) + + # Validate evaluation + if not validate_evaluation(evaluation): + raise ValueError("Invalid evaluation format") + + # Normalize scores + evaluation = normalize_scores(evaluation) + + # Add metadata + evaluation["metadata"] = { + "model_used": model, + "timestamp": datetime.datetime.now().isoformat(), + "has_screenshot": screenshot is not None + } + + return evaluation + + except Exception as e: + return { + "overall_score": 0.0, + "efficiency_score": 0.0, + "accuracy_score": 0.0, + "robustness_score": 0.0, + "detailed_explanation": f"Error occurred during evaluation: {str(e)}", + "improvement_suggestions": ["Check API connection and try again"], + "key_achievements": [], + "potential_issues": ["Evaluation failed"], + "metadata": { + "error": str(e), + "timestamp": datetime.datetime.now().isoformat() + } + } \ No newline at end of file diff --git a/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/tree_vis.py b/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/tree_vis.py new file mode 100644 index 0000000..48f667d --- /dev/null +++ b/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/tree_vis.py @@ -0,0 +1,128 @@ +"""Utilities for visualizing LATS tree structures.""" + +from typing import Optional +from .lats_node import LATSNode + +# ANSI color codes +GREEN = '\033[92m' +RED = '\033[91m' +RESET = '\033[0m' + +def collect_all_nodes(node: LATSNode) -> list[LATSNode]: + """ + Recursively collect all nodes starting from the given node. + + Args: + node: The root node to start collection from + + Returns: + list[LATSNode]: List of all nodes in the tree + """ + nodes = [node] + for child in node.children: + nodes.extend(collect_all_nodes(child)) + return nodes + +def better_print(node: LATSNode, level: int = 0, selected_node: Optional[LATSNode] = None) -> None: + """ + Print tree structure recursively with indentation, showing node statistics. + + Args: + node: The node to print + level: Current indentation level (default=0) + selected_node: The currently selected node to highlight + """ + indent = " " * level + + action = node.action if node.action is not None else 'None' + if isinstance(action, str): + action = action.replace('\n', '') + + visits = f"visits: {node.visits}" + value = f"value: {node.value:.3f}" if hasattr(node, 'value') else "value: N/A" + reward = f"reward: {node.reward:.3f}" if hasattr(node, 'reward') else "reward: N/A" + stats = f"[{visits}, {value}, {reward}]" + + if node == selected_node: + print(f"{indent}├── Level {level}: {GREEN}{action}{RESET} {stats} ← Selected") + else: + print(f"{indent}├── Level {level}: {action} {stats}") + + for child in node.children: + better_print(child, level + 1, selected_node) + +def print_trajectory(terminal_node: LATSNode) -> None: + """ + Print the single path from a terminal node to the root. + + Args: + terminal_node: The leaf node to start the trajectory from + """ + path = [] + current = terminal_node + while current is not None: + path.append(current) + current = current.parent + + for level, node in enumerate(reversed(path)): + indent = " " * level + action = node.action + + visits = f"visits: {node.visits}" + value = f"value: {node.value:.3f}" if hasattr(node, 'value') else "value: N/A" + reward = f"reward: {node.reward:.3f}" if hasattr(node, 'reward') else "reward: N/A" + is_terminal = f"terminal: {node.is_terminal}" + feedback = f"feedback: {node.feedback if node.feedback else 'N/A'}" + stats = f"[{visits}, {value}, {reward}, {is_terminal}, {feedback}]" + + indicator = "" + if node == terminal_node: + indicator = "← Terminal" + elif not hasattr(node, 'parent') or node.parent is None: + indicator = "(Root)" + + print(f"{indent}├── Level {level}: {GREEN}{action}{RESET} {stats} {indicator}") + +def print_entire_tree(root: LATSNode) -> None: + """ + Print the entire tree structure starting from the root node. + + Args: + root: The root node of the tree to print + """ + def _print_subtree(node: LATSNode, level: int, prefix: str, is_last: bool) -> None: + # Prepare the current line's prefix + current_prefix = prefix + ("└── " if is_last else "├── ") + + # Prepare node statistics + action = node.action + visits = f"visits: {node.visits}" + value = f"value: {node.value:.3f}" if hasattr(node, 'value') else "value: N/A" + reward = f"reward: {node.reward:.3f}" if hasattr(node, 'reward') else "reward: N/A" + is_terminal = f"terminal: {node.is_terminal}" + feedback = f"feedback: {node.feedback if node.feedback else 'N/A'}" + stats = f"[{visits}, {value}, {reward}, {is_terminal}, {feedback}]" + + # Add indicator for root or terminal nodes + indicator = "" + if not node.children: + indicator = "← Terminal" + elif level == 0: + indicator = "(Root)" + + # Print the current node + print(f"{current_prefix}Level {level}: {GREEN}{action}{RESET} {stats} {indicator}") + + # Prepare the prefix for children + child_prefix = prefix + (" " if is_last else "│ ") + + # Sort children by some criteria (e.g., visits) if desired + children = sorted(node.children, key=lambda x: x.visits, reverse=True) if node.children else [] + + # Recursively print all children + for i, child in enumerate(children): + is_last_child = (i == len(children) - 1) + _print_subtree(child, level + 1, child_prefix, is_last_child) + + # Start the recursive printing from the root + _print_subtree(root, 0, "", True) \ No newline at end of file diff --git a/visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/lats_agent.py b/visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/lats_agent.py index cdf6298..8fe68db 100644 --- a/visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/lats_agent.py +++ b/visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/lats_agent.py @@ -1,35 +1,57 @@ -import logging +"""Language-based Action Tree Search (LATS) Agent implementation.""" + import time -from typing import Any, Dict, List, Optional -from collections import deque -from datetime import datetime +from typing import Any, Optional, Tuple, List import os -import json -import subprocess - from openai import OpenAI +from datetime import datetime +import aiohttp from dotenv import load_dotenv load_dotenv() -import aiohttp +from .lats_node import LATSNode, Observation from ...core_async.config import AgentConfig from ...webagent_utils_async.action.highlevel import HighLevelActionSet from ...webagent_utils_async.utils.playwright_manager import AsyncPlaywrightManager, setup_playwright +from .tree_vis import RED, better_print, print_trajectory, collect_all_nodes, GREEN, RESET, print_entire_tree +from .trajectory_score import create_llm_prompt, score_trajectory_with_openai +from ...replay_async import generate_feedback, playwright_step_execution, locate_element_from_action +from ...webagent_utils_async.browser_env.observation import extract_page_info, observe_features +from ...webagent_utils_async.action.prompt_functions import generate_actions_with_observation +from ...webagent_utils_async.evaluation.feedback import generate_feedback_with_screenshot +from ...webagent_utils_async.utils.utils import urls_to_images + + from ...webagent_utils_async.utils.utils import parse_function_args, locate_element from ...evaluation_async.evaluators import goal_finished_evaluator -from ...replay_async import generate_feedback, playwright_step_execution from ...webagent_utils_async.action.prompt_functions import extract_top_actions from ...webagent_utils_async.browser_env.observation import extract_page_info from .lats_node import LATSNode from .tree_vis import better_print, print_trajectory, collect_all_nodes, GREEN, RESET, print_entire_tree from .trajectory_score import create_llm_prompt, score_trajectory_with_openai -from ...webagent_utils_async.utils.utils import urls_to_images +from ...webagent_utils_async.action.utils import execute_action +from ...webagent_utils_async.action.prompt_functions import extract_top_actions, is_goal_finished +from ...webagent_utils_async.browser_env.observation import extract_page_info +from ...webagent_utils_async.evaluation.feedback import capture_post_action_feedback -logger = logging.getLogger(__name__) openai_client = OpenAI() class LATSAgent: + """ + Language-based Action Tree Search Agent implementation. + + This agent uses MCTS-like tree search to find optimal action sequences for web navigation tasks. + + Attributes: + starting_url (str): The initial URL to start from + model_name (str): Name of the language model to use + goal (str): The goal state to achieve + playwright_manager (PlaywrightManager): Manager for browser automation + num_simulations (int): Number of simulations to run + exploration_weight (float): Exploration vs exploitation trade-off parameter + """ + def __init__( self, starting_url: str, @@ -39,20 +61,27 @@ def __init__( playwright_manager: AsyncPlaywrightManager, config: AgentConfig, ): + """Initialize the LATS Agent.""" + # no action grounding model, just one step to geneate both action natural language description and action at the same time self.starting_url = starting_url self.goal = goal self.image_urls = images self.images = urls_to_images(self.image_urls) + self.messages = messages - self.messages.append({"role": "user", "content": f"The goal is: {self.goal}"}) + if len(images) == 0: + self.messages.append({"role": "user", "content": f"The goal is: {self.goal}"}) + else: + self.messages.append({"role": "user", "content": f"The goal is: {self.goal}"}) self.playwright_manager = playwright_manager self.config = config - self.agent_type = ["bid", "nav", "file", "select_option"] + # set bid, only click, fill, hoover, drag and draw + self.agent_type = ["bid"] self.action_set = HighLevelActionSet( - subsets=self.agent_type, strict=False, multiaction=True, demo_mode="default" + subsets=self.agent_type, strict=False, multiaction=False, demo_mode="default" ) self.root_node = LATSNode( natural_language_description=None, @@ -62,7 +91,686 @@ def __init__( goal=self.goal, parent=None ) + self.goal_finished = False + self.result_node = None self.reset_url = os.environ["ACCOUNT_RESET_URL"] - async def run(self, websocket=None) -> List[Dict[str, Any]]: - pass + async def run(self, websocket=None) -> list[LATSNode]: + """ + Run the LATS search and return the best path found. + + Args: + websocket: Optional WebSocket connection for sending updates + + Returns: + list[LATSNode]: Best path from root to terminal node + """ + if websocket: + await websocket.send_json({ + "type": "search_status", + "status": "started", + "message": "Starting LATS search", + "timestamp": datetime.utcnow().isoformat() + }) + + best_node = await self.lats_search(websocket) + print_trajectory(best_node) + + if websocket: + await websocket.send_json({ + "type": "search_complete", + "status": "success" if best_node.reward == 1 else "partial_success", + "score": best_node.reward, + "path": best_node.get_trajectory(), + "timestamp": datetime.utcnow().isoformat() + }) + + return best_node.get_trajectory() + + async def lats_search(self, websocket=None) -> LATSNode: + """ + Perform the main LATS search algorithm. + + Args: + websocket: Optional WebSocket connection for sending updates + + Returns: + LATSNode: Best terminal node found + """ + print(f"") + print(f"{GREEN}START SEARCH{RESET}") + + terminal_nodes = [] + + for i in range(self.config.iterations): + if websocket: + await websocket.send_json({ + "type": "iteration_start", + "iteration": i + 1, + "timestamp": datetime.utcnow().isoformat() + }) + + print(f"") + print(f"") + print(f"Iteration {i + 1}...") + + # Step 1: Selection with websocket update + if websocket: + await websocket.send_json({ + "type": "step_start", + "step": "selection", + "iteration": i + 1, + "timestamp": datetime.utcnow().isoformat() + }) + + node = self.select_node(self.root_node) + + if node is None: + print("All paths lead to terminal nodes with reward 0. Ending search.") + break + + print(f"{GREEN}Tree:{RESET}") + better_print(node=self.root_node, selected_node=node) + print(f"") + + # Step 2: Expansion with websocket update + if websocket: + await websocket.send_json({ + "type": "step_start", + "step": "expansion", + "iteration": i + 1, + "timestamp": datetime.utcnow().isoformat() + }) + + await self.expand_node(node, websocket) + + while node is not None and node.is_terminal and not self.goal_finished: + print(f"Depth limit node found at iteration {i + 1}, reselecting...") + node = self.select_node(self.root_node) + if node is not None: + await self.expand_node(node, websocket) + + if node is None: + # all the nodes are terminal, stop the search + print(f"{RED}All nodes are terminal, stopping search{RESET}") + break + + if self.goal_finished: + print(f"{RED}Goal finished, stopping search{RESET}") + break + + print(f"{GREEN}Tree:{RESET}") + better_print(self.root_node) + print(f"") + + # Step 3: Evaluation + print(f"") + print(f"{GREEN}Step 3: evaluation{RESET}") + await self.evaluate_node(node) + + print(f"{GREEN}Tree:{RESET}") + better_print(self.root_node) + print(f"") + + # Step 4: Simulation + print(f"{GREEN}Step 4: simulation{RESET}") + # # Find the child with the highest value + ## always = 1 + reward, terminal_node = await self.simulate(max(node.children, key=lambda child: child.value), max_depth=self.config.max_depth, num_simulations=1) + terminal_nodes.append(terminal_node) + + if reward == 1: + return terminal_node + + # Step 5: Backpropagation + print(f"{GREEN}Step 5: backpropagation{RESET}") + self.backpropagate(terminal_node, reward) + print(f"{GREEN}Tree:{RESET}") + better_print(self.root_node) + print(f"") + + # Send tree update after each iteration + if websocket: + tree_data = self._get_tree_data() + await websocket.send_json({ + "type": "tree_update", + "tree": tree_data, + "timestamp": datetime.utcnow().isoformat() + }) + + # Find best node + all_nodes_list = collect_all_nodes(self.root_node) + all_nodes_list.extend(terminal_nodes) + + ## temp change: if reward is the same, choose the deeper node + best_child = max(all_nodes_list, key=lambda x: (x.reward, x.depth)) + + if best_child.reward == 1: + print("Successful trajectory found") + else: + print("Unsuccessful trajectory found") + await self.playwright_manager.close() + + return best_child if best_child is not None else self.root_node + + def select_node(self, node: LATSNode) -> Optional[LATSNode]: + """ + Select a node for expansion using UCT. + + Args: + node: Root node to start selection from + + Returns: + Optional[LATSNode]: Selected node or None if all paths exhausted + """ + if node.is_terminal: + return None + return node.get_best_leaf() + + async def expand_node(self, node: LATSNode, websocket=None) -> None: + """ + Expand a node by generating its children. + + Args: + node: Node to expand + websocket: Optional WebSocket connection for sending updates + """ + if websocket: + await websocket.send_json({ + "type": "node_expanding", + "node_id": id(node), + "timestamp": datetime.utcnow().isoformat() + }) + + children = await self.generate_children(node, websocket) + + for child in children: + node.add_child(child) + if websocket: + await websocket.send_json({ + "type": "node_created", + "node_id": id(child), + "parent_id": id(node), + "action": child.action, + "description": child.natural_language_description, + "timestamp": datetime.utcnow().isoformat() + }) + + if children and children[0].goal_finish_feedback.is_done: + self.set_goal_finished(children[0]) + if websocket: + await websocket.send_json({ + "type": "goal_finished", + "node_id": id(children[0]), + "timestamp": datetime.utcnow().isoformat() + }) + return + + node.check_terminal() + + async def evaluate_node(self, node: LATSNode) -> None: + """ + Evaluate a node using LLM scoring. + + Args: + node: Node to evaluate + + Returns: + float: Evaluation score + """ + scores = [] + print(f"{GREEN}-- total {len(node.children)} children to evaluate:{RESET}") + for i, child in enumerate(node.children): + print(f"{GREEN}--- evaluating child {i+1}...{RESET}") + if child.is_terminal: + score = 0 + else: + trajectory = child.get_trajectory() + prompt = create_llm_prompt(trajectory, self.goal) + result = score_trajectory_with_openai(prompt, openai_client, self.config.evaluation_model, child.observation.image) + score = result["overall_score"] + scores.append(score) + + for child, score in zip(node.children, scores): + child.value = score + child.reward = score + + async def simulate(self, node: LATSNode, max_depth: int = 2, num_simulations=1) -> tuple[float, LATSNode]: + """ + Perform a rollout simulation from a node. + + Args: + node: Starting node for rollout + max_depth: Maximum depth to simulate to + + Returns: + tuple[float, LATSNode]: (Score of the rollout, Terminal node reached) + """ + depth = node.depth + print("print the trajectory") + print_trajectory(node) + print("print the entire tree") + print_entire_tree(self.root_node) + return await self.rollout(node, max_depth=max_depth) + + async def send_completion_request(self, plan, depth, node, trajectory=[]): + print("print the trajectory") + print_trajectory(node) + print("print the entire tree") + print_entire_tree(self.root_node) + + if depth >= self.config.max_depth: + return trajectory, node + + context = await self.playwright_manager.get_context() + page = await self.playwright_manager.get_page() + # Extract page information + time.sleep(3) + page_info = await extract_page_info(page, fullpage=True, log_folder=self.config.log_folder) + updated_actions = await extract_top_actions( + trajectory, self.goal, self.images, page_info, self.action_set, openai_client, + features=["axtree"], elements_filter="som", branching_factor=self.config.branching_factor, + log_folder=self.config.log_folder, fullpage=True, + action_generation_model=self.config.action_generation_model, + action_grounding_model=self.config.action_grounding_model + ) + next_action = updated_actions[0] + retry_count = self.config.retry_count if hasattr(self.config, 'retry_count') else 1 # Default retries if not set + + for attempt in range(retry_count): + try: + # Convert action to Python code + code, function_calls = self.action_set.to_python_code(next_action["action"]) + + # Locate element + if len(function_calls) == 1: + for function_name, function_args in function_calls: + extracted_number = parse_function_args(function_args) + element = await locate_element(page, extracted_number) + next_action["element"] = element + + # Execute action + await execute_action(next_action, self.action_set, page, context, self.goal, page_info['interactive_elements'], + self.config.log_folder) + feedback = await capture_post_action_feedback(page, next_action, self.goal, self.config.log_folder) + trajectory.append({'action': next_action['action'], 'feedback': feedback}) + action_str = next_action["action"] + + print(f"The action is: {action_str} - The action result is: {feedback}") + + # Check if goal is finished + messages = [{"role": "system", "content": "The goal is {}, Is the overall goal finished?".format(self.goal)}] + for item in trajectory: + action = item['action'] + feedback = item['feedback'] + messages.append({"role": "user", "content": 'action is: {}'.format(action)}) + messages.append({"role": "user", "content": 'action feedback is: {}'.format(feedback)}) + + goal_finished = await is_goal_finished(messages, openai_client) + + new_node = LATSNode( + natural_language_description=next_action["natural_language_description"], + action=next_action["action"], + prob=next_action["prob"], + element=next_action["element"], + goal=node.goal, + parent=node + ) + + if goal_finished: + return trajectory, new_node + + return await self.send_completion_request(plan, depth + 1, new_node, trajectory) + + except Exception as e: + print(f"Attempt {attempt + 1} failed with error: {e}") + if attempt + 1 == retry_count: + print("Max retries reached. Skipping this step and retrying the whole request.") + # Retry the entire request from the same state + return await self.send_completion_request(plan, depth, node, trajectory) + + # If all retries and retries of retries fail, return the current trajectory and node + return trajectory, node + + + async def rollout(self, node: LATSNode, max_depth: int = 2)-> tuple[float, LATSNode]: + # Reset browser state + await self._reset_browser() + path = self.get_path_to_root(node) + + print("execute path") + # Execute path + + messages = [] + trajectory = [] + + for n in path[1:]: # Skip root node + success = await playwright_step_execution( + n, + self.goal, + self.playwright_manager, + is_replay=False, + log_folder=self.config.log_folder + ) + if not success: + return 0, n + if not n.feedback: + n.feedback = await generate_feedback( + self.goal, + n.natural_language_description, + self.playwright_manager, + ) + trajectory.append({ + "action": n.action, + "feedback": n.feedback + }) + ## call the prompt agent + print("current depth: ", len(path) - 1) + print("max depth: ", self.config.max_depth) + trajectory, node = await self.send_completion_request(self.goal, len(path) - 1, node=n, trajectory=trajectory) + print("print the trajectory") + print_trajectory(node) + print("print the entire tree") + print_entire_tree(self.root_node) + + page = await self.playwright_manager.get_page() + page_info = await extract_page_info(page, self.config.fullpage, self.config.log_folder) + + messages = [{"role": "user", "content": f"Action is: {n.action}"} for n in path[1:]] + goal_finished, confidence_score = goal_finished_evaluator( + messages, + openai_client, + self.goal, + page_info['screenshot'] + ) + print("evaluating") + + score = confidence_score if goal_finished else 0 + + return score, node + + def backpropagate(self, node: LATSNode, value: float) -> None: + """ + Backpropagate values through the tree. + + Args: + node: Current node to start backpropagation from + value: Value to propagate upwards + """ + while node: + node.visits += 1 + node.value = (node.value * (node.visits - 1) + value) / node.visits + node = node.parent + + async def _reset_browser(self, websocket=None) -> Optional[str]: + """Reset the browser to initial state and return the live browser URL if available.""" + await self.playwright_manager.close() + + ## reset account using api-based account reset + if self.config.account_reset: + if websocket: + await websocket.send_json({ + "type": "account_reset", + "status": "started", + "timestamp": datetime.utcnow().isoformat() + }) + + try: + # Use aiohttp instead of curl + async with aiohttp.ClientSession() as session: + headers = {'Connection': 'close'} # Similar to curl -N + async with session.get(self.reset_url, headers=headers) as response: + if response.status == 200: + data = await response.json() + print(f"Account reset successful: {data}") + if websocket: + await websocket.send_json({ + "type": "account_reset", + "status": "success", + "data": data, + "timestamp": datetime.utcnow().isoformat() + }) + else: + error_msg = f"Account reset failed with status {response.status}" + print(error_msg) + if websocket: + await websocket.send_json({ + "type": "account_reset", + "status": "failed", + "reason": error_msg, + "timestamp": datetime.utcnow().isoformat() + }) + + except Exception as e: + print(f"Error during account reset: {e}") + if websocket: + await websocket.send_json({ + "type": "account_reset", + "status": "failed", + "reason": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + + try: + # Create new playwright manager + self.playwright_manager = await setup_playwright( + storage_state=self.config.storage_state, + headless=self.config.headless, + mode=self.config.browser_mode + ) + page = await self.playwright_manager.get_page() + live_browser_url = None + if self.config.browser_mode == "browserbase": + live_browser_url = await self.playwright_manager.get_live_browser_url() + session_id = await self.playwright_manager.get_session_id() + else: + session_id = None + live_browser_url = None + await page.goto(self.starting_url, wait_until="networkidle") + + # Send success message if websocket is provided + if websocket: + if self.config.storage_state: + await websocket.send_json({ + "type": "browser_setup", + "status": "success", + "message": f"Browser successfully initialized with storage state file: {self.config.storage_state}", + "live_browser_url": live_browser_url, + "session_id": session_id, + "timestamp": datetime.utcnow().isoformat() + }) + else: + await websocket.send_json({ + "type": "browser_setup", + "status": "success", + "message": "Browser successfully initialized", + "live_browser_url": live_browser_url, + "session_id": session_id, + "timestamp": datetime.utcnow().isoformat() + }) + + return live_browser_url, session_id + except Exception as e: + print(f"Error setting up browser: {e}") + if websocket: + await websocket.send_json({ + "type": "browser_setup", + "status": "failed", + "reason": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + return None, None + + async def observe(self) -> None: + page = await self.playwright_manager.get_page() + page_info = await extract_page_info(page, self.config.fullpage, self.config.log_folder) + feature_text = await observe_features( + page_info, + features=self.config.features, + elements_filter=self.config.elements_filter, + log_folder=self.config.log_folder, + fullpage=self.config.fullpage + ) + screenshot = page_info['screenshot_som'] + observation = Observation( + text=feature_text, + image=screenshot, + ) + return observation + + async def execute_action_trajectory(self, action_trajectory: list[dict]) -> None: + if not action_trajectory: + return True + + await self._reset_browser() + print("taking action trajectory") + for action_data in action_trajectory: + print("action_data") + print(action_data) + + # Convert action_data dict to LATSNode + temp_node = LATSNode( + natural_language_description=action_data["natural_language_description"], + action=action_data["action"], + prob=0, + element=action_data["element"], + goal=self.goal, + parent=None # No parent needed for temporary node + ) + + success = await playwright_step_execution( + temp_node, # Pass the node instead of raw action_data + self.goal, + self.playwright_manager, + is_replay=False, + log_folder=self.config.log_folder + ) + + if not success: + return False + return True + + async def generate_candidate_actions(self, node: LATSNode) -> list[dict]: + trajectory = node.get_trajectory() + action_trajectory = node.get_action_trajectory() + await self.execute_action_trajectory(action_trajectory) + observation = await self.observe() + # only root node has no observation at this point + if node.observation is None: + node.observation = observation + actions = await generate_actions_with_observation( + trajectory, + self.goal, + self.images, + openai_client=openai_client, + action_set=self.action_set, + feature_text=observation.text, + screenshot=observation.image, + branching_factor=self.config.branching_factor, + log_folder=self.config.log_folder, + action_generation_model=self.config.action_generation_model, + ) + + page = await self.playwright_manager.get_page() + valid_actions = [] + for action_data in actions: + if action_data["action"] == "FINISH": + continue + + is_bid_action, element_data = await locate_element_from_action(page, action_data["action"]) + if is_bid_action and not element_data: + continue + + action_data['element'] = element_data + valid_actions.append(action_data) + return valid_actions + + async def generate_children(self, node: LATSNode, websocket=None) -> list[LATSNode]: + print(f"{GREEN}-- generating candidate actions...{RESET}") + + children = [] + + action_trajectory = node.get_action_trajectory() + candidate_actions = await self.generate_candidate_actions(node) + print(f"{GREEN}-- generated {len(candidate_actions)} actions{RESET}") + for action_data in candidate_actions: + print(f"{GREEN}--- {action_data['action']}{RESET}") + print(f"{GREEN}--- {action_data['natural_language_description']}{RESET}") + + print(f"") + print(f"{GREEN}-- executing candidate trajectories{RESET}") + for i, action_data in enumerate(candidate_actions): + + candidate_action_trajectory = action_trajectory + [action_data] + print(f"{GREEN}--- trajectory {i+1}:{RESET}") + for action in candidate_action_trajectory: + print(f"{GREEN}---- {action['action']}{RESET}") + print(f"{GREEN}---- {action['natural_language_description']}{RESET}") + executed_successfully = await self.execute_action_trajectory(candidate_action_trajectory) + if not executed_successfully: + # not executed successfully, give up this candidate + print(f"{RED}--- failed to execute action trajectory{RESET}") + continue + + observation = await self.observe() + print(f"{GREEN}--- generate feedback...{RESET}") + feedback = await generate_feedback_with_screenshot( + self.goal, + action_data["natural_language_description"], + observation.image, + model=self.config.feedback_model, + ) + print(f"feedback: is_done: {feedback.is_done}, explanation: {feedback.explanation}") + + child = LATSNode( + natural_language_description=action_data["natural_language_description"], + action=action_data["action"], + prob=action_data["prob"], + element=action_data["element"], + goal=node.goal, + ) + child.observation = observation + child.goal_finish_feedback = feedback + if feedback.is_done: + # the goal is finished, stop the search + return [child] + + children.append(child) + + if node.depth + 1 >= self.config.max_depth: + child.is_terminal = True + + return children + + def set_goal_finished(self, node: LATSNode) -> None: + self.goal_finished = True + self.result_node = node + + def get_path_to_root(self, node: LATSNode) -> List[LATSNode]: + path = [] + current = node + while current: + path.append(current) + current = current.parent + return list(reversed(path)) + + def _get_tree_data(self): + """Get tree data in a format suitable for visualization""" + nodes = collect_all_nodes(self.root_node) + tree_data = [] + + for node in nodes: + node_data = { + "id": id(node), + "parent_id": id(node.parent) if node.parent else None, + "action": node.action if node.action else "ROOT", + "description": node.natural_language_description, + "depth": node.depth, + "is_terminal": node.is_terminal, + "value": node.value, + "visits": node.visits, + "reward": node.reward + } + tree_data.append(node_data) + + return tree_data diff --git a/visual-tree-search-backend/app/api/lwats/core_async/agent_factory.py b/visual-tree-search-backend/app/api/lwats/core_async/agent_factory.py index 2136196..2c3fa12 100644 --- a/visual-tree-search-backend/app/api/lwats/core_async/agent_factory.py +++ b/visual-tree-search-backend/app/api/lwats/core_async/agent_factory.py @@ -8,6 +8,9 @@ from ..agents_async.SimpleSearchAgents.simple_search_agent import SimpleSearchAgent from ..agents_async.SimpleSearchAgents.lats_agent import LATSAgent from ..agents_async.SimpleSearchAgents.mcts_agent import MCTSAgent +from ..agents_async.SearchAgents.simple_search_agent import SimpleSearchAgent as NewSimpleSearchAgent +from ..agents_async.SearchAgents.lats_agent import LATSAgent as NewLATSAgent +from ..agents_async.SearchAgents.mcts_agent import MCTSAgent as NewMCTSAgent from ..webagent_utils_async.utils.utils import setup_logger from ..webagent_utils_async.utils.playwright_manager import setup_playwright @@ -106,4 +109,73 @@ async def setup_search_agent( error_message = f"Unsupported agent type: {agent_type}. Please use 'FunctionCallingAgent', 'HighLevelPlanningAgent', 'ContextAwarePlanningAgent', 'PromptAgent' or 'PromptSearchAgent' ." logger.error(error_message) return {"error": error_message} + return agent, playwright_manager + + +async def new_setup_search_agent( + agent_type, + starting_url, + goal, + images, + agent_config: AgentConfig +): + logger = setup_logger() + + file_path = os.path.join(agent_config.log_folder, 'flow', 'steps.json') + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, 'w') as file: + file.write(goal + '\n') + file.write(starting_url + '\n') + + playwright_manager = await setup_playwright( + headless=agent_config.headless, + mode=agent_config.browser_mode, + storage_state=agent_config.storage_state + ) + # storage_state='state.json', headless=False, mode="chromium" + + page = await playwright_manager.get_page() + await page.goto(starting_url) + # Maximize the window on macOS + # await page.set_viewport_size({"width": 1440, "height": 900}) + + messages = [{ + "role": "system", + "content": SEARCH_AGENT_SYSTEM_PROMPT, + }] + + if agent_type == "SimpleSearchAgent": + print("SimpleSearchAgent") + agent = NewSimpleSearchAgent( + starting_url=starting_url, + messages=messages, + goal=goal, + images = images, + playwright_manager=playwright_manager, + config=agent_config, + ) + elif agent_type == "LATSAgent": + print("LATSAgent") + agent = NewLATSAgent( + starting_url=starting_url, + messages=messages, + goal=goal, + images = images, + playwright_manager=playwright_manager, + config=agent_config, + ) + elif agent_type == "MCTSAgent": + print("MCTSAgent") + agent = NewMCTSAgent( + starting_url=starting_url, + messages=messages, + goal=goal, + images = images, + playwright_manager=playwright_manager, + config=agent_config, + ) + else: + error_message = f"Unsupported agent type: {agent_type}. Please use 'FunctionCallingAgent', 'HighLevelPlanningAgent', 'ContextAwarePlanningAgent', 'PromptAgent' or 'PromptSearchAgent' ." + logger.error(error_message) + return {"error": error_message} return agent, playwright_manager \ No newline at end of file diff --git a/visual-tree-search-backend/app/api/lwats/replay.py b/visual-tree-search-backend/app/api/lwats/replay.py deleted file mode 100644 index 36ed441..0000000 --- a/visual-tree-search-backend/app/api/lwats/replay.py +++ /dev/null @@ -1,507 +0,0 @@ -import sys -import os -import base64 - -from .webagent_utils_sync.action.action_parser import parse_action - -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from .webagent_utils_sync.browser_env.observation import ( - _pre_extract, - _post_extract, - extract_dom_snapshot, - extract_dom_extra_properties, - extract_merged_axtree, - extract_focused_element_bid, -) -from .webagent_utils_sync.browser_env.extract_elements import extract_interactive_elements -from openai import OpenAI -import os -import re -import json -from .webagent_utils_sync.utils.utils import encode_image, locate_element -from dotenv import load_dotenv -_ = load_dotenv() -from elevenlabs.client import ElevenLabs -from elevenlabs import play - -# Initialize the Eleven Labs client -elevenlabs_client = ElevenLabs(api_key=os.getenv("ELEVEN_API_KEY")) -openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) -import argparse -from .webagent_utils_sync.action.highlevel import HighLevelActionSet -from .webagent_utils_sync.utils.playwright_manager import PlaywrightManager -from .webagent_utils_sync.action.base import execute_python_code_safely -import time -import logging -from .webagent_utils_sync.browser_env.obs import flatten_axtree_to_str, flatten_dom_to_str - -logger = logging.getLogger(__name__) -from .webagent_utils_sync.utils.utils import setup_logger -import os -import time -import re -import logging - - -def read_steps_json(file_path): - goal = None - starting_url = None - steps = [] - - # Ensure the file exists - if not os.path.exists(file_path): - print(f"File not found: {file_path}") - return goal, starting_url, steps - - with open(file_path, 'r') as file: - for i, line in enumerate(file): - if i == 0: - # First line is the starting_url (plain string) - goal = line.strip() - if i == 1: - # First line is the starting_url (plain string) - starting_url = line.strip() - else: - try: - # Subsequent lines are JSON objects - step = json.loads(line.strip()) - steps.append(step) - except json.JSONDecodeError as e: - print(f"Error decoding JSON on line {i + 1}: {line}") - print(f"Error message: {str(e)}") - - return goal, starting_url, steps - - -# def find_matching_element(interactive_elements, target): -# for element in interactive_elements: -# if (element.get('text', '').lower() == target.get('text', '').lower() and -# element.get('tag') == target.get('tag') and -# target.get('id') == element.get('id')): -# return element -# return None - - -# def find_match(interactive_elements, key, value): -# for element in interactive_elements: -# if element.get(key, '') == value: -# return element -# return None - - -# def replace_number(text, new_number): -# # Find the first number in the string and replace it -# return re.sub(r'\d+', str(new_number), text) - - -# Node(depth=1, value=0.00, visits=0, action=fill('274', 'running shoes'), feedback=) search running shoes, click on the first result False log -# page: -# node: Node(depth=1, value=0.00, visits=0, action=fill('274', 'running shoes'), feedback=) -# node.element: {'text': '', 'type': 'text', 'tag': 'input', 'id': 'search', 'name': 'q', 'value': '', 'placeholder': 'Search entire store here...', 'class': 'input-text', 'role': 'combobox', 'unique_selector': '#search', 'selector_uniqueness_validated': True} -# selector: #search -# element: selector='#search'> -# {'text': '', 'type': 'text', 'tag': 'input', 'id': 'search', 'name': 'q', 'value': '', 'placeholder': 'Search entire store here...', 'class': 'input-text', 'role': 'combobox', 'unique_selector': '#search', 'selector_uniqueness_validated': True} -# element count before: 1 -# element count after: 1 - - -def playwright_step_execution(node, goal, playwright_manager, is_replay, log_folder): - logger = logging.getLogger(__name__) - context = playwright_manager.get_context() - page = playwright_manager.get_page() - url = page.url - - print(node, goal, playwright_manager, is_replay, log_folder) - print(f"page: {page}") - print(f"node: {node}") - print(f"node.element: {node.element}") - step_data = node.element - selector = step_data.get("unique_selector", "body") - element = page.locator(selector) - print(f"selector: {selector}") - print(f"element: {element}") - print(step_data) - - logger.info(f"==> Attempting action with selector: {selector}") - logger.info(f"Node data: {node}") - logger.info(f"Current page URL: {url}") - - action = node.action - - print(f"element count before: {element.count()}") - - # - # 1) Wait until attached & visible - # - try: - logger.info(f"Waiting for element to be attached: {selector}") - element.wait_for(state="attached", timeout=10_000) - logger.info(f"Waiting for element to be visible: {selector}") - element.wait_for(state="visible", timeout=10_000) - except Exception as e: - logger.error(f"Error while waiting for element to become visible: {e}") - debug_screenshot_path = os.path.join(log_folder, 'screenshots', 'error_wait.png') - page.screenshot(path=debug_screenshot_path) - logger.error(f"Saved debug screenshot: {debug_screenshot_path}") - return False - - print(f"element count after: {element.count()}") - - # - # 2) Execute action - # - try: - action_name, args, kwargs = parse_action(action) - execute_action(page, element, action_name, args, kwargs) - return True - except Exception as e: - logger.error(f"Error occurred during execution: {e}") - debug_screenshot_path = os.path.join(log_folder, 'screenshots', 'error_action.png') - page.screenshot(path=debug_screenshot_path) - logger.error(f"Saved debug screenshot: {debug_screenshot_path}") - return False - -BID_ACTIONS = [ - "fill", - "check", - "uncheck", - "select_option", - "click", - "dblclick", - "hover", - "press", - "focus", - "clear", - # "drag_and_drop", - "upload_file", -] - - - -def locate_element_from_action(page, action): - action_name, args, kwargs = parse_action(action) - is_bid_action = action_name in BID_ACTIONS - if is_bid_action: - element_data = locate_element(page, args[0]) - else: - element_data = None - return is_bid_action, element_data - -def step_execution(action_data, playwright_manager, log_folder): - logger = logging.getLogger(__name__) - page = playwright_manager.get_page() - url = page.url - - action = action_data["action"] - action_name, args, kwargs = parse_action(action) - - if action_data['element'] is not None: - selector = action_data['element'].get("unique_selector", "body") - element = page.locator(selector) - else: - selector = "no element" - element = None - - logger.info(f"==> Exectute Action: {action}") - logger.info(f"Current page URL: {url}") - logger.info(f"Selector: {selector}") - - # - # 1) Wait until attached & visible - # - if element is not None: - try: - logger.info(f"Waiting for element to be attached: {selector}") - element.wait_for(state="attached", timeout=10_000) - logger.info(f"Waiting for element to be visible: {selector}") - element.wait_for(state="visible", timeout=10_000) - except Exception as e: - logger.error(f"Error while waiting for element to become visible: {e}") - debug_screenshot_path = os.path.join(log_folder, 'screenshots', 'error_wait.png') - page.screenshot(path=debug_screenshot_path) - logger.error(f"Saved debug screenshot: {debug_screenshot_path}") - return False - - # - # 2) Execute action - # - try: - execute_action(page, element, action_name, args, kwargs) - return True - except Exception as e: - logger.error(f"Error occurred during execution: {e}") - debug_screenshot_path = os.path.join(log_folder, 'screenshots', 'error_action.png') - page.screenshot(path=debug_screenshot_path) - logger.error(f"Saved debug screenshot: {debug_screenshot_path}") - return False - - -def execute_action(page, element, action_name, args, kwargs): - # TODO: add timeout - match action_name: - case "noop": - page.wait_for_timeout(args[0]) - case "fill": - element.fill(args[1]) - case "check": - element.check() - case "uncheck": - element.uncheck() - case "select_option": - element.select_option(args[1]) - case "click": - element.click(**kwargs) - case "dblclick": - element.dblclick(**kwargs) - case "hover": - element.hover() - case "press": - element.press(args[1]) - case "focus": - element.focus() - case "clear": - element.clear() - # case "drag_and_drop": - # drag_and_drop(args[0], args[1]) - case "scroll": - page.mouse.wheel(args[0], args[1]) - case "mouse_move": - page.mouse.move(args[0], args[1]) - case "mouse_up": - page.mouse.up(args[0], args[1], **kwargs) - case "mouse_down": - page.mouse.down(args[0], args[1], **kwargs) - case "mouse_click": - page.mouse.click(args[0], args[1], **kwargs) - case "mouse_dblclick": - page.mouse.dblclick(args[0], args[1], **kwargs) - case "mouse_drag_and_drop": - page.mouse.move(args[0], args[1]) - page.mouse.down() - page.mouse.move(args[2], args[3]) - page.mouse.up() - case "keyboard_press": - page.keyboard.press(args[0]) - case "keyboard_up": - page.keyboard.up(args[0]) - case "keyboard_down": - page.keyboard.down(args[0]) - case "keyboard_type": - page.keyboard.type(args[0]) - case "keyboard_insert_text": - page.keyboard.insert_text(args[0]) - case "goto": - page.goto(args[0]) - case "go_back": - page.go_back() - case "go_forward": - page.go_forward() - case "new_tab": - page = page.context.new_page() - # trigger the callback that sets this page as active in browsergym - page.locate("html").dispatch_event("pageshow") - case "tab_close": - context = page.context - page.close() - if context.pages: - page = context.pages[-1] - else: - page = context.new_page() - page.locate("html").dispatch_event("pageshow") - case "tab_focus": - page = page.context.pages[args[0]] - page.locate("html").dispatch_event("pageshow") - case "upload_file": - with page.expect_file_chooser() as fc_info: - element.click() - - file_chooser = fc_info.value - file_chooser.set_files(args[1]) - case "mouse_upload_file": - with page.expect_file_chooser() as fc_info: - page.mouse.click(args[0], args[1]) - - file_chooser = fc_info.value - file_chooser.set_files(args[2]) - case _: - raise ValueError(f"Unknown action: {action_name}") - - - -def generate_feedback(goal, action_description, playwright_manager, model="gpt-4o"): - page = playwright_manager.get_page() - - page.wait_for_timeout(3000) - - screenshot_bytes = page.screenshot() - base64_image = base64.b64encode(screenshot_bytes).decode('utf-8') - - system_prompt = f""" - You are a helpful assitant. Given a goal, a screenshot of the current web page and a description of the action taken, provide a natural language description of current page state related to the goal. - """ - - # Build and send prompt to OpenAI - prompt = f""" - # Goal: - {goal} - - # Action description: - {action_description} - - Please provide a natural language description of current page state. It must be related to the goal. - """ - response = openai_client.chat.completions.create( - model=model, - messages=[ - {"role": "system", "content": system_prompt}, - { - "role": "user", - "content": [ - {"type": "text", "text": prompt}, - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{base64_image}", - "detail": "high" - } - } - ], - }, - ], - ) - - feedback = response.choices[0].message.content - logger.info(f"Feedback from OpenAI: {feedback}") - - return feedback - - -# def browsergym_step_execution(step, goal, playwright_manager, is_replay, log_folder): -# time.sleep(5) -# context = playwright_manager.get_context() -# page = playwright_manager.get_page() -# action_set = HighLevelActionSet( -# subsets=["bid", "nav"], -# strict=False, -# multiaction=True, -# demo_mode="off" -# ) - -# _pre_extract(page) -# dom = extract_dom_snapshot(page) -# axtree = extract_merged_axtree(page) -# focused_element_bid = extract_focused_element_bid(page) -# extra_properties = extract_dom_extra_properties(dom) -# interactive_elements = extract_interactive_elements(page) -# _post_extract(page) -# url = page.url -# try: -# import pdb; pdb.set_trace() -# if step['element'] != None: -# # debug finding matching element -# element = find_matching_element(interactive_elements, step['element']) -# if element: -# action = replace_number(step["action"], element['bid']) -# else: -# action = step["action"] -# else: -# action = step["action"] -# except: -# print(step) -# import pdb; pdb.set_trace() -# code, function_calls = action_set.to_python_code(action) -# logger.info("Executing action script") -# # Execute code in the main thread -# code_file_path = execute_python_code_safely( -# code, -# page, -# context, -# log_folder, -# send_message_to_user=None, -# report_infeasible_instructions=None -# ) - -# # if is_replay: -# task_description = goal -# page = playwright_manager.get_page() -# # screenshot_path_post = os.path.join(log_folder, 'screenshots', 'screenshot_post.png') -# time.sleep(3) -# # page.screenshot(path=screenshot_path_post) -# # base64_image = encode_image(screenshot_path_post) - -# screenshot_bytes = page.screenshot() -# # Encode the bytes to base64 -# base64_image = base64.b64encode(screenshot_bytes).decode('utf-8') -# prompt = f""" -# After we take action {action}, a screenshot was captured. - -# # Screenshot description: -# The image provided is a screenshot of the application state after the action was performed. - -# # The original goal: -# {task_description} - -# Based on the screenshot and the updated Accessibility Tree, is the goal finished now? Provide an answer and explanation, referring to visual elements from the screenshot if relevant. -# """ - -# # Query OpenAI model -# response = openai_client.chat.completions.create( -# model="gpt-4o", -# messages=[ -# {"role": "user", -# "content": [ -# {"type": "text", "text": prompt}, -# {"type": "image_url", -# "image_url": { -# "url": f"data:image/jpeg;base64,{base64_image}", -# "detail": "high" -# } -# } -# ] -# }, -# ], -# ) - -# feedback = response.choices[0].message.content -# return feedback - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--log_folder', type=str, default='log', help='Path to the log folder') - args = parser.parse_args() - - log_folder = args.log_folder - logger = setup_logger(log_folder, log_file="replay_log.txt") - # Example usage - playwright_manager = PlaywrightManager(storage_state=None, video_dir=os.path.join(args.log_folder, 'videos')) - browser = playwright_manager.get_browser() - context = playwright_manager.get_context() - page = playwright_manager.get_page() - playwright_manager.playwright.selectors.set_test_id_attribute('data-unique-test-id') - - file_path = os.path.join(log_folder, 'flow', 'steps.json') - goal, starting_url, steps = read_steps_json(file_path) - page.goto(starting_url) - page.set_viewport_size({"width": 1440, "height": 900}) - messages = [{"role": "system", - "content": "You are a smart web search agent to perform search and click task, upload files for customers"}] - for i, step in enumerate(steps, 1): - print(f"Step {i}:") - print(json.dumps(step)) - task_description = step["task_description"] - action, feedback = take_action(step, playwright_manager, True, log_folder) - content = "The task_description is: {}, the action is: {} and the feedback is: {}".format(task_description, - action, feedback) - messages.append({"role": "assistant", "content": content}) - messages.append({"role": "user", "content": "summarize the status of the task, be concise"}) - response = openai_client.chat.completions.create(model="gpt-4o", messages=messages) - summary = response.choices[0].message.content - playwright_manager.close() - print(summary) - audio = elevenlabs_client.generate( - text=summary, - voice="Rachel", - model="eleven_multilingual_v2" - ) - play(audio) diff --git a/visual-tree-search-backend/app/api/lwats/webagent_utils_async/browser_env/observation.py b/visual-tree-search-backend/app/api/lwats/webagent_utils_async/browser_env/observation.py index 0d0e519..9855590 100644 --- a/visual-tree-search-backend/app/api/lwats/webagent_utils_async/browser_env/observation.py +++ b/visual-tree-search-backend/app/api/lwats/webagent_utils_async/browser_env/observation.py @@ -17,6 +17,8 @@ import base64 import asyncio from datetime import datetime +from .obs import flatten_axtree_to_str, flatten_dom_to_str +from .extract_elements import flatten_interactive_elements_to_str MARK_FRAMES_MAX_TRIES = 3 @@ -480,4 +482,48 @@ async def extract_focused_element_bid(page: playwright.async_api.Page): else: frame = None - return focused_bid \ No newline at end of file + return focused_bid + + +ACCESSIBILITY_FEATURE_TEMPLATE = \ +"""# Current Accessibility Tree: +{axtree_str} +""" + +INTERACTIVE_ELEMENTS_TEMPLATE = \ +"""# Interactive elements: +{interactive_elements_str} +""" + +DOM_FEATURE_TEMPLATE = \ +"""# Current DOM: +{dom_str} +""" + +async def observe_features(page_info, features, elements_filter, log_folder, fullpage=True): + filter_som_only = False if fullpage else elements_filter == "som" + filter_visible_only = elements_filter == "visibility" + + feature_texts = [] + if "axtree" in features: + axtree_str = flatten_axtree_to_str(page_info.get('axtree', ''), extra_properties=page_info['extra_properties'], filter_som_only=filter_som_only, filter_visible_only=filter_visible_only) + feature_texts.append(ACCESSIBILITY_FEATURE_TEMPLATE.format(axtree_str=axtree_str)) + + if "interactive_elements" in features: + interactive_elements_str = flatten_interactive_elements_to_str(page_info.get('interactive_elements', '')) + feature_texts.append(INTERACTIVE_ELEMENTS_TEMPLATE.format(interactive_elements_str=interactive_elements_str)) + + if "dom" in features: + dom_str = flatten_dom_to_str(page_info.get('dom', ''), extra_properties=page_info['extra_properties'], filter_som_only=filter_som_only, filter_visible_only=filter_visible_only) + feature_texts.append(DOM_FEATURE_TEMPLATE.format(dom_str=dom_str)) + + feature_text = "\n".join(feature_texts) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"feature_{timestamp}.txt" + file_path = os.path.join(log_folder, 'prompt', filename) + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, 'w', encoding='utf8') as file: + file.write(feature_text) + + return feature_text \ No newline at end of file diff --git a/visual-tree-search-backend/app/api/routes/new_tree_search.py b/visual-tree-search-backend/app/api/routes/new_tree_search.py new file mode 100644 index 0000000..fda4bef --- /dev/null +++ b/visual-tree-search-backend/app/api/routes/new_tree_search.py @@ -0,0 +1,234 @@ +import asyncio +from typing import List, Optional +from fastapi import APIRouter, BackgroundTasks, HTTPException +import json +import os +import threading +import multiprocessing +from datetime import datetime +import logging + +import argparse +from dotenv import load_dotenv +import json +import logging + +from ..lwats.core_async.config import AgentConfig, add_agent_config_arguments, filter_valid_config_args +load_dotenv() +from ..lwats.core_async.agent_factory import new_setup_search_agent + +def run_tree_search(args): + # Log the arguments to help debug + logging.info(f"Running tree search with args: {args.__dict__}") + + # Ensure starting_url is set correctly + if not hasattr(args, 'starting_url') or not args.starting_url: + logging.error("starting_url is not set or is empty") + return {"error": "starting_url is required"} + + logging.info(f"Using starting URL: {args.starting_url}") + + agent_config = AgentConfig(**filter_valid_config_args(args.__dict__)) + agent, playwright_manager = new_setup_search_agent( + agent_type=args.agent_type, + starting_url=args.starting_url, + goal=args.goal, + images=args.images, + agent_config=agent_config + ) + print(agent_config) + + # Run the search + results = agent.run() + + # Close the playwright_manager when done + playwright_manager.close() + + return results +from ..lwats.core_async.config import AgentConfig, filter_valid_config_args + +router = APIRouter() + +# Store results of tree search runs +search_results = {} +# Store process objects +search_processes = {} + +def run_search_in_process(search_id: str, args_dict): + """Run the tree search in a separate process""" + try: + # Create an args object similar to what argparse would create + class Args: + pass + + args = Args() + for key, value in args_dict.items(): + setattr(args, key, value) + + # Update status to running + search_results[search_id]["status"] = "running" + + # Debug: Print current working directory and storage_state path + logging.info(f"Current working directory: {os.getcwd()}") + logging.info(f"Storage state path: {args.storage_state}") + logging.info(f"Storage state exists: {os.path.exists(args.storage_state)}") + logging.info(f"Starting URL: {args.starting_url}") # Log the starting URL + + # Run the search + results = run_tree_search(args) + + # Update results + search_results[search_id]["results"] = results + search_results[search_id]["status"] = "completed" + search_results[search_id]["completed_at"] = datetime.utcnow().isoformat() + + except Exception as e: + logging.error(f"Error in search process: {str(e)}") + search_results[search_id]["status"] = "failed" + search_results[search_id]["error"] = str(e) + +@router.post("/run") +async def start_tree_search( + background_tasks: BackgroundTasks, + agent_type: str = "SimpleSearchAgent", + starting_url: str = "http://128.105.145.205:7770/", + goal: str = "search running shoes, click on the first result", + images: Optional[str] = None, + search_algorithm: str = "bfs", + headless: bool = True, + browser_mode: str = "chromium", + storage_state: str = "shopping.json", + action_generation_model: str = "gpt-4o-mini", + evaluation_model: str = "gpt-4o", + branching_factor: int = 5, + max_depth: int = 3, + iterations: int = 3 +): + """Start a tree search with the given parameters""" + # Create a unique ID for this search + search_id = f"search_{datetime.utcnow().strftime('%Y%m%d_%H%M%S_%f')}" + + # Parse images + image_list = [img.strip() for img in images.split(',')] if images else [] + + # Debug: Print all possible locations for the file + logging.info(f"Current working directory: {os.getcwd()}") + possible_locations = [ + os.path.join(os.getcwd(), storage_state), + os.path.join(os.path.dirname(os.getcwd()), storage_state), + os.path.join(os.path.dirname(os.path.dirname(os.getcwd())), storage_state), + os.path.abspath(storage_state) + ] + + for loc in possible_locations: + logging.info(f"Checking location: {loc}, exists: {os.path.exists(loc)}") + + # Try to find the file in various locations + storage_state_path = None + for loc in possible_locations: + if os.path.exists(loc): + storage_state_path = loc + break + + if storage_state_path: + logging.info(f"Found storage_state at: {storage_state_path}") + else: + logging.warning(f"Could not find storage_state file '{storage_state}' in any expected location") + # Create an empty storage state file as a fallback + storage_state_path = os.path.join(os.getcwd(), "empty_storage.json") + with open(storage_state_path, 'w') as f: + f.write("{}") + logging.info(f"Created empty storage state file at {storage_state_path}") + + # Log the starting URL to verify it's being set correctly + logging.info(f"Setting starting URL to: {starting_url}") + + # Create args dictionary + args_dict = { + "agent_type": agent_type, + "starting_url": starting_url, + "goal": goal, + "images": image_list, + "search_algorithm": search_algorithm, + "headless": headless, + "browser_mode": browser_mode, + "storage_state": storage_state_path, # Use the found path + "action_generation_model": action_generation_model, + "evaluation_model": evaluation_model, + "branching_factor": branching_factor, + "max_depth": max_depth, + "iterations": iterations + } + + # Initialize the results entry + search_results[search_id] = { + "id": search_id, + "status": "pending", + "created_at": datetime.utcnow().isoformat(), + "config": args_dict + } + + # Start the search in a separate process + process = threading.Thread( + target=run_search_in_process, + args=(search_id, args_dict), + daemon=True + ) + search_processes[search_id] = process + process.start() + + return { + "search_id": search_id, + "status": "pending", + "message": "Tree search started in the background" + } + +@router.get("/status/{search_id}") +async def get_search_status(search_id: str): + """Get the status of a tree search""" + if search_id not in search_results: + raise HTTPException(status_code=404, detail="Search ID not found") + + # Check if process is still alive + if search_id in search_processes: + process = search_processes[search_id] + if process.is_alive(): + search_results[search_id]["status"] = "running" + elif search_results[search_id]["status"] == "pending": + # Process ended but status wasn't updated + search_results[search_id]["status"] = "failed" + search_results[search_id]["error"] = "Process terminated unexpectedly" + + return search_results[search_id] + +@router.get("/list") +async def list_searches(): + """List all tree searches""" + return { + "searches": [ + { + "id": search_id, + "status": search_results[search_id]["status"], + "created_at": search_results[search_id]["created_at"], + "completed_at": search_results[search_id].get("completed_at") + } + for search_id in search_results + ] + } + +@router.post("/cancel/{search_id}") +async def cancel_search(search_id: str): + """Cancel a running search""" + if search_id not in search_results: + raise HTTPException(status_code=404, detail="Search ID not found") + + if search_id in search_processes: + process = search_processes[search_id] + if process.is_alive(): + # We can't directly terminate a thread, but we can mark it as cancelled + search_results[search_id]["status"] = "cancelled" + return {"message": f"Search {search_id} has been marked for cancellation"} + else: + return {"message": f"Search {search_id} is not running"} + + return {"message": f"Search {search_id} process not found"} \ No newline at end of file diff --git a/visual-tree-search-backend/app/api/routes/new_tree_search_websocket.py b/visual-tree-search-backend/app/api/routes/new_tree_search_websocket.py new file mode 100644 index 0000000..5277676 --- /dev/null +++ b/visual-tree-search-backend/app/api/routes/new_tree_search_websocket.py @@ -0,0 +1,218 @@ +import asyncio +import json +from datetime import datetime +from typing import Dict, Any, List, Set +import logging +from collections import deque + +# Configure basic logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) + +from fastapi import APIRouter, WebSocket, WebSocketDisconnect + +# Import necessary components for the search agent +from ..lwats.webagent_utils_async.utils.playwright_manager import setup_playwright +from ..lwats.core_async.config import AgentConfig +from ..lwats.core_async.agent_factory import new_setup_search_agent +from ..lwats.agents_async.SimpleSearchAgents.tree_vis import collect_all_nodes +from ..lwats.agents_async.SimpleSearchAgents.trajectory_score import create_llm_prompt, score_trajectory_with_openai + +router = APIRouter() + +# Track active WebSocket connections +active_connections: Dict[str, WebSocket] = {} + +# This is the function that will be called from main.py +async def new_tree_search_websocket_endpoint(websocket: WebSocket): + """WebSocket endpoint for tree search visualization and control""" + await websocket.accept() + connection_id = str(id(websocket)) + active_connections[connection_id] = websocket + + logging.info(f"WebSocket connection established with ID: {connection_id}") + + try: + # Send initial connection confirmation + await websocket.send_json({ + "type": "connection_established", + "connection_id": connection_id, + "timestamp": datetime.utcnow().isoformat() + }) + + # Listen for messages from the client + while True: + data = await websocket.receive_text() + message = json.loads(data) + + # Handle different message types + if message["type"] == "ping": + await websocket.send_json({ + "type": "pong", + "timestamp": datetime.utcnow().isoformat() + }) + + elif message["type"] == "start_search": + # Start the search process + await handle_search_request(websocket, message) + + except WebSocketDisconnect: + logging.info(f"WebSocket disconnected with ID: {connection_id}") + except Exception as e: + logging.error(f"Error in WebSocket connection: {e}") + finally: + # Clean up connection + if connection_id in active_connections: + del active_connections[connection_id] + +async def handle_search_request(websocket: WebSocket, message: Dict[str, Any]): + """Handle a search request from the client""" + try: + # Extract parameters from the message + agent_type = message.get("agent_type", "SimpleSearchAgent") + starting_url = message.get("starting_url", "http://128.105.145.205:7770/") + goal = message.get("goal", "search running shoes, click on the first result") + search_algorithm = message.get("search_algorithm", "bfs") + max_depth = message.get("max_depth", 3) + storage_state = message.get("storage_state", "app/api/shopping.json") + + # Send status update + await websocket.send_json({ + "type": "status_update", + "status": "initializing", + "message": "Initializing search agent", + "timestamp": datetime.utcnow().isoformat() + }) + + # Create agent configuration + config = AgentConfig( + search_algorithm=search_algorithm, + max_depth=max_depth, + storage_state=storage_state, + headless=False + ) + + # Send status update + await websocket.send_json({ + "type": "status_update", + "status": "setting_up", + "message": "Setting up playwright browser", + "timestamp": datetime.utcnow().isoformat() + }) + + # Setup playwright and agent + agent, playwright_manager = await new_setup_search_agent( + agent_type=agent_type, + starting_url=starting_url, + goal=goal, + images=[], # No initial images + agent_config=config + ) + + # Send status update + await websocket.send_json({ + "type": "status_update", + "status": "running", + "message": "Search agent initialized, starting search", + "timestamp": datetime.utcnow().isoformat() + }) + + # Run search with WebSocket updates + if search_algorithm.lower() == "bfs": + # Use the agent's built-in WebSocket-enabled BFS method + await agent.bfs_with_websocket(websocket) + elif search_algorithm.lower() == "dfs": + # Use the agent's built-in WebSocket-enabled DFS method + await agent.dfs_with_websocket(websocket) + elif search_algorithm.lower() == "lats": + await agent.run(websocket) + else: + await websocket.send_json({ + "type": "error", + "message": f"Unsupported algorithm: {search_algorithm}", + "timestamp": datetime.utcnow().isoformat() + }) + + # Clean up + await playwright_manager.close() + + except Exception as e: + logging.error(f"Error handling search request: {e}") + await websocket.send_json({ + "type": "error", + "message": f"Error during search: {str(e)}", + "timestamp": datetime.utcnow().isoformat() + }) + +async def send_tree_update(websocket: WebSocket, root_node): + """Send a tree update to the client""" + try: + # Collect all nodes in the tree + nodes = collect_all_nodes(root_node) + + # Convert nodes to a format suitable for visualization + tree_data = [] + for node in nodes: + node_data = { + "id": id(node), + "parent_id": id(node.parent) if node.parent else None, + "action": node.action if node.action else "ROOT", + "description": node.natural_language_description, + "depth": node.depth, + "is_terminal": node.is_terminal, + "score": getattr(node, "value", 0) / getattr(node, "visits", 1) if hasattr(node, "visits") and node.visits > 0 else 0 + } + tree_data.append(node_data) + + await websocket.send_json({ + "type": "tree_update", + "nodes": tree_data, + "timestamp": datetime.utcnow().isoformat() + }) + except Exception as e: + logging.error(f"Error sending tree update: {e}") + +async def send_trajectory_update(websocket: WebSocket, node, status: str): + """Send a trajectory update to the client""" + try: + # Get path from root to this node + path = [] + current = node + while current: + path.append(current) + current = current.parent + path = list(reversed(path)) + + # Convert path to a format suitable for visualization + trajectory_data = [] + for i, node in enumerate(path): + if i == 0: # Skip root node in display + continue + + node_data = { + "id": id(node), + "action": node.action, + "description": node.natural_language_description, + "feedback": node.feedback if hasattr(node, "feedback") else None, + "depth": node.depth + } + trajectory_data.append(node_data) + + await websocket.send_json({ + "type": f"trajectory_{status}", # trajectory_start or trajectory_complete + "trajectory": trajectory_data, + "timestamp": datetime.utcnow().isoformat() + }) + except Exception as e: + logging.error(f"Error sending trajectory update: {e}") + +# Add a route for testing WebSocket status via HTTP +@router.get("/status") +async def tree_search_websocket_status(): + """Get Tree Search WebSocket connection status""" + return { + "active_connections": len(active_connections), + "status": "running" + } \ No newline at end of file diff --git a/visual-tree-search-backend/app/api/routes/tree_search_websocket.py b/visual-tree-search-backend/app/api/routes/tree_search_websocket.py index bd154d8..144950f 100644 --- a/visual-tree-search-backend/app/api/routes/tree_search_websocket.py +++ b/visual-tree-search-backend/app/api/routes/tree_search_websocket.py @@ -126,6 +126,8 @@ async def handle_search_request(websocket: WebSocket, message: Dict[str, Any]): elif search_algorithm.lower() == "dfs": # Use the agent's built-in WebSocket-enabled DFS method await agent.dfs_with_websocket(websocket) + elif search_algorithm.lower() == "lats": + await agent.run(websocket) else: await websocket.send_json({ "type": "error", diff --git a/visual-tree-search-backend/app/api/shopping.json b/visual-tree-search-backend/app/api/shopping.json index c652f2d..4de8ead 100644 --- a/visual-tree-search-backend/app/api/shopping.json +++ b/visual-tree-search-backend/app/api/shopping.json @@ -24,7 +24,7 @@ "value": "", "domain": "128.105.145.205", "path": "/", - "expires": 1775110735, + "expires": 1775272527, "httpOnly": false, "secure": false, "sameSite": "Strict" @@ -81,20 +81,20 @@ }, { "name": "private_content_version", - "value": "843052e37c1703f047c79850de356543", + "value": "c514ad01bc9816ee30ab1f02510aa34a", "domain": "128.105.145.205", "path": "/", - "expires": 1778134733.861287, + "expires": 1778296522.505323, "httpOnly": false, "secure": false, "sameSite": "Lax" }, { "name": "PHPSESSID", - "value": "e60e354f4484dbe490e863a6e3c5b4b9", + "value": "007c737fca0eb4173ab5362c2d9c8b09", "domain": "128.105.145.205", "path": "/", - "expires": 1775110737.116277, + "expires": 1775272526.468528, "httpOnly": true, "secure": false, "sameSite": "Lax" @@ -104,17 +104,17 @@ "value": "9bf9a599123e6402b85cde67144717a08b817412", "domain": "128.105.145.205", "path": "/", - "expires": 1775110737.116485, + "expires": 1775272526.468754, "httpOnly": true, "secure": false, "sameSite": "Lax" }, { "name": "form_key", - "value": "M0iOyk8VTBOyKC1q", + "value": "Hsr3n5ycGkOPfr1K", "domain": "128.105.145.205", "path": "/", - "expires": 1775110737.116428, + "expires": 1775272526.468692, "httpOnly": false, "secure": false, "sameSite": "Lax" @@ -124,17 +124,17 @@ "value": "true", "domain": "128.105.145.205", "path": "/", - "expires": 1775110735, + "expires": 1775272527, "httpOnly": false, "secure": false, "sameSite": "Lax" }, { "name": "section_data_ids", - "value": "{%22messages%22:1743574735%2C%22customer%22:1743574735%2C%22compare-products%22:1743574735%2C%22last-ordered-items%22:1743574735%2C%22cart%22:1743574735%2C%22directory-data%22:1743574735%2C%22captcha%22:1743574735%2C%22instant-purchase%22:1743574735%2C%22loggedAsCustomer%22:1743574735%2C%22persistent%22:1743574735%2C%22review%22:1743574735%2C%22wishlist%22:1743574735%2C%22recently_viewed_product%22:1743574735%2C%22recently_compared_product%22:1743574735%2C%22product_data_storage%22:1743574735%2C%22paypal-billing-agreement%22:1743574735}", + "value": "{%22messages%22:1743736525%2C%22customer%22:1743736525%2C%22compare-products%22:1743736525%2C%22last-ordered-items%22:1743736525%2C%22cart%22:1743736525%2C%22directory-data%22:1743736525%2C%22captcha%22:1743736525%2C%22instant-purchase%22:1743736525%2C%22loggedAsCustomer%22:1743736525%2C%22persistent%22:1743736525%2C%22review%22:1743736525%2C%22wishlist%22:1743736525%2C%22recently_viewed_product%22:1743736525%2C%22recently_compared_product%22:1743736525%2C%22product_data_storage%22:1743736525%2C%22paypal-billing-agreement%22:1743736525}", "domain": "128.105.145.205", "path": "/", - "expires": 1775110735, + "expires": 1775272524, "httpOnly": false, "secure": false, "sameSite": "Lax" diff --git a/visual-tree-search-backend/app/main.py b/visual-tree-search-backend/app/main.py index 40eaad5..4917d33 100644 --- a/visual-tree-search-backend/app/main.py +++ b/visual-tree-search-backend/app/main.py @@ -54,6 +54,7 @@ async def root(): from app.api.routes.websocket import websocket_endpoint from app.api.routes.tree_websocket import tree_websocket_endpoint from app.api.routes.tree_search_websocket import tree_search_websocket_endpoint +from app.api.routes.new_tree_search_websocket import new_tree_search_websocket_endpoint # Register the WebSocket endpoints @app.websocket("/ws") async def websocket_route(websocket: WebSocket): @@ -67,6 +68,10 @@ async def tree_websocket_route(websocket: WebSocket): async def tree_search_websocket_route(websocket: WebSocket): await tree_search_websocket_endpoint(websocket) +@app.websocket("/new-tree-search-ws") +async def new_tree_search_websocket_route(websocket: WebSocket): + await new_tree_search_websocket_endpoint(websocket) + if __name__ == "__main__": port = int(os.getenv("PORT", 3000)) uvicorn.run("app.main:app", host="0.0.0.0", port=port, reload=True) \ No newline at end of file diff --git a/visual-tree-search-backend/test/test-tree-search-ws-lats.py b/visual-tree-search-backend/test/test-tree-search-ws-lats.py new file mode 100644 index 0000000..ea0291d --- /dev/null +++ b/visual-tree-search-backend/test/test-tree-search-ws-lats.py @@ -0,0 +1,160 @@ +import asyncio +import json +import websockets +import argparse +import logging +from datetime import datetime + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +# Default values +DEFAULT_WS_URL = "ws://localhost:3000/new-tree-search-ws" +DEFAULT_STARTING_URL = "http://128.105.145.205:7770/" +DEFAULT_GOAL = "search running shoes, click on the first result" + +async def connect_and_test_search( + ws_url: str, + starting_url: str, + goal: str, + search_algorithm: str = "bfs", + max_depth: int = 3 +): + """ + Connect to the WebSocket endpoint and test the tree search functionality. + + Args: + ws_url: WebSocket URL to connect to + starting_url: URL to start the search from + goal: Goal to achieve + search_algorithm: Search algorithm to use (bfs or dfs) + max_depth: Maximum depth for the search tree + """ + logger.info(f"Connecting to WebSocket at {ws_url}") + + async with websockets.connect(ws_url) as websocket: + logger.info("Connected to WebSocket") + + # Wait for connection established message + response = await websocket.recv() + data = json.loads(response) + if data.get("type") == "connection_established": + logger.info(f"Connection established with ID: {data.get('connection_id')}") + + # Send search request + request = { + "type": "start_search", + "agent_type": "LATSAgent", + "starting_url": starting_url, + "goal": goal, + "search_algorithm": search_algorithm, + "max_depth": max_depth + } + + logger.info(f"Sending search request: {request}") + await websocket.send(json.dumps(request)) + + # Process responses + while True: + try: + response = await websocket.recv() + data = json.loads(response) + + # Log the message type and some key information + msg_type = data.get("type", "unknown") + + if msg_type == "status_update": + logger.info(f"Status update: {data.get('status')} - {data.get('message')}") + + elif msg_type == "node_update": + node_id = data.get("node_id") + status = data.get("status") + logger.info(f"Node update: {node_id} - {status}") + + # If node was scored, log the score + if status == "scored": + logger.info(f"Node score: {data.get('score')}") + + elif msg_type == "tree_update": + logger.info(f"Tree update received with {len(data.get('nodes', []))} nodes") + + elif msg_type == "best_path_update": + logger.info(f"Best path update: score={data.get('score')}, path length={len(data.get('path', []))}") + + elif msg_type == "search_complete": + status = data.get("status") + score = data.get("score", "N/A") + path_length = len(data.get("path", [])) + + logger.info(f"Search complete: {status}, score={score}, path length={path_length}") + logger.info("Path actions:") + + for i, node in enumerate(data.get("path", [])): + logger.info(f" {i+1}. {node.get('action')}") + + # Exit the loop when search is complete + break + + elif msg_type == "error": + logger.error(f"Error: {data.get('message')}") + break + + else: + logger.info(f"Received message of type {msg_type}") + logger.info(f"Message: {data}") + + except websockets.exceptions.ConnectionClosed: + logger.warning("WebSocket connection closed") + break + except Exception as e: + logger.error(f"Error processing message: {e}") + break + + logger.info("Test completed") + +def parse_arguments(): + """Parse command line arguments""" + parser = argparse.ArgumentParser(description="Test the tree search WebSocket functionality") + + parser.add_argument("--ws-url", type=str, default=DEFAULT_WS_URL, + help=f"WebSocket URL (default: {DEFAULT_WS_URL})") + + parser.add_argument("--starting-url", type=str, default=DEFAULT_STARTING_URL, + help=f"Starting URL for the search (default: {DEFAULT_STARTING_URL})") + + parser.add_argument("--goal", type=str, default=DEFAULT_GOAL, + help=f"Goal to achieve (default: {DEFAULT_GOAL})") + + parser.add_argument("--algorithm", type=str, choices=["bfs", "dfs", "lats"], default="lats", + help="Search algorithm to use (default: lats)") + + parser.add_argument("--max-depth", type=int, default=3, + help="Maximum depth for the search tree (default: 3)") + + return parser.parse_args() + +async def main(): + """Main entry point""" + args = parse_arguments() + + logger.info("Starting tree search WebSocket test") + logger.info(f"WebSocket URL: {args.ws_url}") + logger.info(f"Starting URL: {args.starting_url}") + logger.info(f"Goal: {args.goal}") + logger.info(f"Algorithm: {args.algorithm}") + logger.info(f"Max depth: {args.max_depth}") + + await connect_and_test_search( + ws_url=args.ws_url, + starting_url=args.starting_url, + goal=args.goal, + search_algorithm=args.algorithm, + max_depth=args.max_depth + ) + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/visual-tree-search-backend/test/test-tree-search-ws-simple.py b/visual-tree-search-backend/test/test-tree-search-ws-simple.py new file mode 100644 index 0000000..a4fbf28 --- /dev/null +++ b/visual-tree-search-backend/test/test-tree-search-ws-simple.py @@ -0,0 +1,159 @@ +import asyncio +import json +import websockets +import argparse +import logging +from datetime import datetime + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +# Default values +DEFAULT_WS_URL = "ws://localhost:3000/new-tree-search-ws" +DEFAULT_STARTING_URL = "http://128.105.145.205:7770/" +DEFAULT_GOAL = "search running shoes, click on the first result" + +async def connect_and_test_search( + ws_url: str, + starting_url: str, + goal: str, + search_algorithm: str = "bfs", + max_depth: int = 3 +): + """ + Connect to the WebSocket endpoint and test the tree search functionality. + + Args: + ws_url: WebSocket URL to connect to + starting_url: URL to start the search from + goal: Goal to achieve + search_algorithm: Search algorithm to use (bfs or dfs) + max_depth: Maximum depth for the search tree + """ + logger.info(f"Connecting to WebSocket at {ws_url}") + + async with websockets.connect(ws_url) as websocket: + logger.info("Connected to WebSocket") + + # Wait for connection established message + response = await websocket.recv() + data = json.loads(response) + if data.get("type") == "connection_established": + logger.info(f"Connection established with ID: {data.get('connection_id')}") + + # Send search request + request = { + "type": "start_search", + "agent_type": "SimpleSearchAgent", + "starting_url": starting_url, + "goal": goal, + "search_algorithm": search_algorithm, + "max_depth": max_depth + } + + logger.info(f"Sending search request: {request}") + await websocket.send(json.dumps(request)) + + # Process responses + while True: + try: + response = await websocket.recv() + data = json.loads(response) + + # Log the message type and some key information + msg_type = data.get("type", "unknown") + + if msg_type == "status_update": + logger.info(f"Status update: {data.get('status')} - {data.get('message')}") + + elif msg_type == "node_update": + node_id = data.get("node_id") + status = data.get("status") + logger.info(f"Node update: {node_id} - {status}") + + # If node was scored, log the score + if status == "scored": + logger.info(f"Node score: {data.get('score')}") + + elif msg_type == "tree_update": + logger.info(f"Tree update received with {len(data.get('nodes', []))} nodes") + + elif msg_type == "best_path_update": + logger.info(f"Best path update: score={data.get('score')}, path length={len(data.get('path', []))}") + + elif msg_type == "search_complete": + status = data.get("status") + score = data.get("score", "N/A") + path_length = len(data.get("path", [])) + + logger.info(f"Search complete: {status}, score={score}, path length={path_length}") + logger.info("Path actions:") + + for i, node in enumerate(data.get("path", [])): + logger.info(f" {i+1}. {node.get('action')}") + + # Exit the loop when search is complete + break + + elif msg_type == "error": + logger.error(f"Error: {data.get('message')}") + break + + else: + logger.info(f"Received message of type {msg_type}") + + except websockets.exceptions.ConnectionClosed: + logger.warning("WebSocket connection closed") + break + except Exception as e: + logger.error(f"Error processing message: {e}") + break + + logger.info("Test completed") + +def parse_arguments(): + """Parse command line arguments""" + parser = argparse.ArgumentParser(description="Test the tree search WebSocket functionality") + + parser.add_argument("--ws-url", type=str, default=DEFAULT_WS_URL, + help=f"WebSocket URL (default: {DEFAULT_WS_URL})") + + parser.add_argument("--starting-url", type=str, default=DEFAULT_STARTING_URL, + help=f"Starting URL for the search (default: {DEFAULT_STARTING_URL})") + + parser.add_argument("--goal", type=str, default=DEFAULT_GOAL, + help=f"Goal to achieve (default: {DEFAULT_GOAL})") + + parser.add_argument("--algorithm", type=str, choices=["bfs", "dfs"], default="bfs", + help="Search algorithm to use (default: bfs)") + + parser.add_argument("--max-depth", type=int, default=3, + help="Maximum depth for the search tree (default: 3)") + + return parser.parse_args() + +async def main(): + """Main entry point""" + args = parse_arguments() + + logger.info("Starting tree search WebSocket test") + logger.info(f"WebSocket URL: {args.ws_url}") + logger.info(f"Starting URL: {args.starting_url}") + logger.info(f"Goal: {args.goal}") + logger.info(f"Algorithm: {args.algorithm}") + logger.info(f"Max depth: {args.max_depth}") + + await connect_and_test_search( + ws_url=args.ws_url, + starting_url=args.starting_url, + goal=args.goal, + search_algorithm=args.algorithm, + max_depth=args.max_depth + ) + +if __name__ == "__main__": + asyncio.run(main())