diff --git a/visual-tree-search-backend/README.md b/visual-tree-search-backend/README.md index 9ddb92d..3cb4402 100644 --- a/visual-tree-search-backend/README.md +++ b/visual-tree-search-backend/README.md @@ -90,4 +90,12 @@ python run_demo_treesearch_async.py \ ``` uvicorn app.main:app --host 0.0.0.0 --port 3000 python test/test-tree-search-ws-lats.py +``` + +## 7. Add MCTS agent +* test run_demo_treesearch_async.py +* test web socket +``` +uvicorn app.main:app --host 0.0.0.0 --port 3000 +python test/test-tree-search-ws-mcts.py ``` \ No newline at end of file diff --git a/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/mcts_agent.py b/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/mcts_agent.py index 9f1734c..f1cb9ea 100644 --- a/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/mcts_agent.py +++ b/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/mcts_agent.py @@ -27,6 +27,7 @@ from ...webagent_utils_async.utils.utils import urls_to_images logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) openai_client = OpenAI() class MCTSAgent: @@ -65,4 +66,943 @@ def __init__( self.reset_url = os.environ["ACCOUNT_RESET_URL"] async def run(self, websocket=None) -> List[Dict[str, Any]]: - pass \ No newline at end of file + """ + Run the MCTS algorithm based on configuration. + + Args: + websocket: Optional WebSocket connection to send updates to + + Returns: + List[Dict[str, Any]]: List of actions in the best path found + """ + logger.info("Starting Reflective MCTS algorithm") + if websocket: + return await self.rmcts_with_websocket(websocket) + else: + return await self.rmcts() + + async def rmcts(self) -> List[Dict[str, Any]]: + """ + Performs Monte Carlo Tree Search starting from the root node. + Uses GPT-4 for node selection and reflection-based backpropagation. + + Returns: + List[Dict[str, Any]]: List of actions in the best path found + """ + best_score = float('-inf') + best_path = None + visited = set() # Track visited nodes to avoid cycles + max_iterations = self.config.iterations # Use configured number of iterations + + try: + # Initial browser setup + live_browser_url, session_id = await self._reset_browser() + + for iteration in range(max_iterations): + logger.info(f"\n{'='*50}") + logger.info(f"RMCTS Iteration {iteration + 1}/{max_iterations}") + logger.info(f"{'='*50}\n") + + # Selection: Use GPT-4 to select a promising path + current_node = self.root_node + path = [current_node] + selection_depth = 0 + + while current_node.children and not current_node.is_terminal: + logger.info(f"\nSelection Step {selection_depth + 1}:") + logger.info(f"Current node action: {current_node.action}") + logger.info(f"Number of children: {len(current_node.children)}") + + # Get trajectory for GPT-4 to evaluate + trajectory = [] + for node in path[1:]: # Skip root node + trajectory.append({ + "natural_language_description": node.natural_language_description, + "action": node.action, + "feedback": node.feedback + }) + + # Create prompt for GPT-4 to select next node + prompt = f"""Given the current trajectory and goal, select the most promising child node to explore next. + Consider the overall progress, efficiency, and likelihood of success. + + Goal: {self.goal} + + Current Trajectory: + {json.dumps(trajectory, indent=2)} + + Available Children: + {json.dumps([{ + 'action': child.action, + 'description': child.natural_language_description, + 'visits': child.visits, + 'value': child.value + } for child in current_node.children], indent=2)} + + Return a JSON response with: + {{ + "selected_child_index": int, # Index of the selected child + "explanation": str # Brief explanation of the selection + }}""" + + try: + response = openai_client.chat.completions.create( + model=self.config.evaluation_model, + messages=[ + {"role": "system", "content": "You are an expert at selecting promising paths in a search tree."}, + {"role": "user", "content": prompt} + ], + response_format={"type": "json_object"} + ) + + selection = json.loads(response.choices[0].message.content) + selected_index = selection["selected_child_index"] + + if 0 <= selected_index < len(current_node.children): + current_node = current_node.children[selected_index] + path.append(current_node) + logger.info(f"Selected child {selected_index + 1}: {current_node.action}") + logger.info(f"Selection explanation: {selection['explanation']}") + else: + logger.warning(f"Invalid child index {selected_index}, breaking selection") + break + + except Exception as e: + logger.error(f"Error in node selection: {str(e)}") + break + + selection_depth += 1 + + # Expansion: Expand the selected node if possible + if not current_node.is_terminal and current_node.depth < self.config.max_depth: + logger.info(f"\nExpansion Step:") + logger.info(f"Expanding node: {current_node.action}") + + try: + await self.expand(current_node) + logger.info(f"Successfully expanded node with {len(current_node.children)} children") + except Exception as e: + logger.error(f"Error expanding node: {str(e)}") + current_node.is_terminal = True + # Expansion Step: Expand the selected node if possible + if not current_node.is_terminal and current_node.depth < self.config.max_depth: + logger.info(f"\nExpansion Step:") + logger.info(f"Expanding node: {current_node.action}") + + expansion_success = await self.expand(current_node, None) + if not expansion_success: + # No children were generated; backtrack if possible. + if len(path) > 1: + logger.info("Backtracking due to expansion failure (no children generated).") + path.pop() # Remove the current dead-end node. + current_node = path[-1] # Set current_node to its parent. + else: + logger.warning("Expansion failed at root; no further backtracking possible.") + break + else: + logger.info(f"Successfully expanded node with {len(current_node.children)} children") + + # Simulation: Evaluate the current path + logger.info(f"\nSimulation Step:") + logger.info(f"Evaluating path of length {len(path) - 1}") + + try: + trajectory = [] + for node in path[1:]: # Skip root node + trajectory.append({ + "natural_language_description": node.natural_language_description, + "action": node.action, + "feedback": node.feedback + }) + + # Score the trajectory + prompt = create_llm_prompt(trajectory, self.goal) + result = score_trajectory_with_openai(prompt, openai_client, model=self.config.evaluation_model) + score = result["overall_score"] + + logger.info(f"Simulation Results:") + logger.info(f"Overall Score: {score:.3f}") + logger.info(f"Efficiency Score: {result['efficiency_score']:.3f}") + logger.info(f"Accuracy Score: {result['accuracy_score']:.3f}") + logger.info(f"Robustness Score: {result['robustness_score']:.3f}") + + # Update best path if this score is better + if score > best_score: + best_score = score + best_path = path + logger.info(f"\nNew best path found!") + logger.info(f"Previous best score: {best_score:.3f}") + logger.info(f"New best score: {score:.3f}") + + # Reflection-based backpropagation + if score < 0.75: # If the path is not satisfactory + logger.info(f"\nReflection Step (Score {score:.3f} < 0.75):") + + # Generate reflection prompt + reflection_prompt = f"""Analyze the current trajectory and suggest improvements. + + Goal: {self.goal} + + Current Trajectory: + {json.dumps(trajectory, indent=2)} + + Score: {score} + + Return a JSON response with: + {{ + "backtrack_to_step": int, # Which step to backtrack to (0-based index) + "reason": str, # Why backtrack to this step + "suggested_improvements": [str] # List of suggested improvements + }}""" + + try: + reflection = openai_client.chat.completions.create( + model=self.config.evaluation_model, + messages=[ + {"role": "system", "content": "You are an expert at analyzing and improving search trajectories."}, + {"role": "user", "content": reflection_prompt} + ], + response_format={"type": "json_object"} + ) + + reflection_result = json.loads(reflection.choices[0].message.content) + backtrack_step = reflection_result["backtrack_to_step"] + + # Backtrack to the suggested step + if 0 <= backtrack_step < len(path): + current_node = path[backtrack_step] + # Remove nodes after the backtrack point + while len(path) > backtrack_step + 1: + path.pop() + logger.info(f"Backtracking to step {backtrack_step}") + logger.info(f"Reason: {reflection_result['reason']}") + logger.info("Suggested improvements:") + for improvement in reflection_result["suggested_improvements"]: + logger.info(f"- {improvement}") + + except Exception as e: + logger.error(f"Error in reflection: {str(e)}") + + # If we've found a satisfactory solution, return it + if score >= 0.75: + logger.info(f"\nFound satisfactory solution with score {score:.3f}") + return [{"action": node.action} for node in path[1:]] + + except Exception as e: + logger.error(f"Error in simulation: {str(e)}") + continue + + # Update node statistics + logger.info(f"\nBackpropagation Step:") + for node in path: + old_value = node.value + node.visits += 1 + node.value = (node.value * (node.visits - 1) + score) / node.visits + logger.info(f"Node {node.action}:") + logger.info(f" Visits: {node.visits}") + logger.info(f" Value: {old_value:.3f} -> {node.value:.3f}") + + # If we've exhausted all iterations and haven't found a perfect solution, + # return the best path we found + if best_path and len(best_path) > 1: + logger.info(f"\nSearch complete. Returning best path found with score {best_score:.3f}") + return [{"action": node.action} for node in best_path[1:]] + + # If no valid path was found or path was just the root, return a default action + logger.warning("\nNo valid path found, returning fallback action") + return [{"action": "refresh()", "description": "Fallback action - no valid path found"}] + + except Exception as e: + error_msg = f"Error in RMCTS search: {str(e)}" + logger.error(error_msg) + + if best_path: + logger.info(f"\nReturning best path found before error with score {best_score:.3f}") + return [{"action": node.action} for node in best_path[1:]] + return [] + + async def rmcts_with_websocket(self, websocket) -> List[Dict[str, Any]]: + """ + Performs Monte Carlo Tree Search starting from the root node with WebSocket updates. + Uses GPT-4 for node selection and reflection-based backpropagation. + + Args: + websocket: WebSocket connection to send updates to + + Returns: + List[Dict[str, Any]]: List of actions in the best path found + """ + best_score = float('-inf') + best_path = None + visited = set() # Track visited nodes to avoid cycles + max_iterations = self.config.iterations # Use configured number of iterations + + try: + # Initial browser setup + live_browser_url, session_id = await self._reset_browser(websocket) + + for iteration in range(max_iterations): + logger.info(f"\n{'='*50}") + logger.info(f"RMCTS Iteration {iteration + 1}/{max_iterations}") + logger.info(f"{'='*50}\n") + + # Send iteration update if websocket is provided + await websocket.send_json({ + "type": "rmcts_iteration", + "iteration": iteration + 1, + "max_iterations": max_iterations, + "timestamp": datetime.utcnow().isoformat() + }) + + # Selection: Use GPT-4 to select a promising path + current_node = self.root_node + path = [current_node] + selection_depth = 0 + + while current_node.children and not current_node.is_terminal: + logger.info(f"\nSelection Step {selection_depth + 1}:") + logger.info(f"Current node action: {current_node.action}") + logger.info(f"Number of children: {len(current_node.children)}") + + # Get trajectory for GPT-4 to evaluate + trajectory = [] + for node in path[1:]: # Skip root node + trajectory.append({ + "natural_language_description": node.natural_language_description, + "action": node.action, + "feedback": node.feedback + }) + + # Create prompt for GPT-4 to select next node + prompt = f"""Given the current trajectory and goal, select the most promising child node to explore next. + Consider the overall progress, efficiency, and likelihood of success. + + Goal: {self.goal} + + Current Trajectory: + {json.dumps(trajectory, indent=2)} + + Available Children: + {json.dumps([{ + 'action': child.action, + 'description': child.natural_language_description, + 'visits': child.visits, + 'value': child.value + } for child in current_node.children], indent=2)} + + Return a JSON response with: + {{ + "selected_child_index": int, # Index of the selected child + "explanation": str # Brief explanation of the selection + }}""" + + try: + response = openai_client.chat.completions.create( + model=self.config.evaluation_model, + messages=[ + {"role": "system", "content": "You are an expert at selecting promising paths in a search tree."}, + {"role": "user", "content": prompt} + ], + response_format={"type": "json_object"} + ) + + selection = json.loads(response.choices[0].message.content) + selected_index = selection["selected_child_index"] + + if 0 <= selected_index < len(current_node.children): + current_node = current_node.children[selected_index] + path.append(current_node) + logger.info(f"Selected child {selected_index + 1}: {current_node.action}") + logger.info(f"Selection explanation: {selection['explanation']}") + + # Send selection update if websocket is provided + await websocket.send_json({ + "type": "node_selected", + "node_id": id(current_node), + "explanation": selection["explanation"], + "timestamp": datetime.utcnow().isoformat() + }) + else: + logger.warning(f"Invalid child index {selected_index}, breaking selection") + break + + except Exception as e: + logger.error(f"Error in node selection: {str(e)}") + await websocket.send_json({ + "type": "selection_error", + "error": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + break + + selection_depth += 1 + + # Expansion: Expand the selected node if possible + if not current_node.is_terminal and current_node.depth < self.config.max_depth: + logger.info(f"\nExpansion Step:") + logger.info(f"Expanding node: {current_node.action}") + + await websocket.send_json({ + "type": "node_expanding", + "node_id": id(current_node), + "timestamp": datetime.utcnow().isoformat() + }) + + try: + await self.expand(current_node, websocket) + logger.info(f"Successfully expanded node with {len(current_node.children)} children") + except Exception as e: + logger.error(f"Error expanding node: {str(e)}") + current_node.is_terminal = True + await websocket.send_json({ + "type": "expansion_error", + "node_id": id(current_node), + "error": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + + # Simulation: Evaluate the current path + logger.info(f"\nSimulation Step:") + logger.info(f"Evaluating path of length {len(path) - 1}") + + await websocket.send_json({ + "type": "simulation_start", + "path_length": len(path) - 1, + "timestamp": datetime.utcnow().isoformat() + }) + + try: + trajectory = [] + for node in path[1:]: # Skip root node + trajectory.append({ + "natural_language_description": node.natural_language_description, + "action": node.action, + "feedback": node.feedback + }) + + # Score the trajectory + prompt = create_llm_prompt(trajectory, self.goal) + result = score_trajectory_with_openai(prompt, openai_client, model=self.config.evaluation_model) + score = result["overall_score"] + + logger.info(f"Simulation Results:") + logger.info(f"Overall Score: {score:.3f}") + logger.info(f"Efficiency Score: {result['efficiency_score']:.3f}") + logger.info(f"Accuracy Score: {result['accuracy_score']:.3f}") + logger.info(f"Robustness Score: {result['robustness_score']:.3f}") + + # Send simulation results if websocket is provided + await websocket.send_json({ + "type": "simulation_results", + "score": score, + "efficiency_score": result["efficiency_score"], + "accuracy_score": result["accuracy_score"], + "robustness_score": result["robustness_score"], + "timestamp": datetime.utcnow().isoformat() + }) + + # Update best path if this score is better + if score > best_score: + best_score = score + best_path = path + logger.info(f"\nNew best path found!") + logger.info(f"Previous best score: {best_score:.3f}") + logger.info(f"New best score: {score:.3f}") + + # Send best path update if websocket is provided + await websocket.send_json({ + "type": "best_path_update", + "score": best_score, + "path": [{"id": id(node), "action": node.action} for node in best_path[1:]], + "timestamp": datetime.utcnow().isoformat() + }) + + # Reflection-based backpropagation + if score < 0.75: # If the path is not satisfactory + logger.info(f"\nReflection Step (Score {score:.3f} < 0.75):") + + await websocket.send_json({ + "type": "reflection_start", + "score": score, + "timestamp": datetime.utcnow().isoformat() + }) + + # Generate reflection prompt + reflection_prompt = f"""Analyze the current trajectory and suggest improvements. + + Goal: {self.goal} + + Current Trajectory: + {json.dumps(trajectory, indent=2)} + + Score: {score} + + Return a JSON response with: + {{ + "backtrack_to_step": int, # Which step to backtrack to (0-based index) + "reason": str, # Why backtrack to this step + "suggested_improvements": [str] # List of suggested improvements + }}""" + + try: + reflection = openai_client.chat.completions.create( + model=self.config.evaluation_model, + messages=[ + {"role": "system", "content": "You are an expert at analyzing and improving search trajectories."}, + {"role": "user", "content": reflection_prompt} + ], + response_format={"type": "json_object"} + ) + + reflection_result = json.loads(reflection.choices[0].message.content) + backtrack_step = reflection_result["backtrack_to_step"] + + # Backtrack to the suggested step + if 0 <= backtrack_step < len(path): + current_node = path[backtrack_step] + # Remove nodes after the backtrack point + while len(path) > backtrack_step + 1: + path.pop() + logger.info(f"Backtracking to step {backtrack_step}") + logger.info(f"Reason: {reflection_result['reason']}") + logger.info("Suggested improvements:") + for improvement in reflection_result["suggested_improvements"]: + logger.info(f"- {improvement}") + + # Send backtracking update if websocket is provided + await websocket.send_json({ + "type": "backtracking", + "step": backtrack_step, + "reason": reflection_result["reason"], + "suggested_improvements": reflection_result["suggested_improvements"], + "timestamp": datetime.utcnow().isoformat() + }) + + except Exception as e: + logger.error(f"Error in reflection: {str(e)}") + await websocket.send_json({ + "type": "reflection_error", + "error": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + + # If we've found a satisfactory solution, return it + if score >= 0.75: + logger.info(f"\nFound satisfactory solution with score {score:.3f}") + + # Send completion update if websocket is provided + await websocket.send_json({ + "type": "search_complete", + "status": "success", + "score": score, + "path": [{"id": id(node), "action": node.action} for node in path[1:]], + "timestamp": datetime.utcnow().isoformat() + }) + + return [{"action": node.action} for node in path[1:]] + + except Exception as e: + logger.error(f"Error in simulation: {str(e)}") + await websocket.send_json({ + "type": "simulation_error", + "error": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + continue + + # Update node statistics + logger.info(f"\nBackpropagation Step:") + for node in path: + old_value = node.value + node.visits += 1 + node.value = (node.value * (node.visits - 1) + score) / node.visits + logger.info(f"Node {node.action}:") + logger.info(f" Visits: {node.visits}") + logger.info(f" Value: {old_value:.3f} -> {node.value:.3f}") + + # Send backpropagation update if websocket is provided + await websocket.send_json({ + "type": "backpropagation_complete", + "updated_nodes": [{"id": id(node), "visits": node.visits, "value": node.value} for node in path], + "timestamp": datetime.utcnow().isoformat() + }) + + # If we've exhausted all iterations and haven't found a perfect solution, + # return the best path we found + if best_path and len(best_path) > 1: + logger.info(f"\nSearch complete. Returning best path found with score {best_score:.3f}") + + # Send completion update if websocket is provided + await websocket.send_json({ + "type": "search_complete", + "status": "partial_success", + "score": best_score, + "path": [{"id": id(node), "action": node.action} for node in best_path[1:]], + "timestamp": datetime.utcnow().isoformat() + }) + + return [{"action": node.action} for node in best_path[1:]] + + # If no path was found at all + logger.warning("\nNo valid path found") + + # Send failure update if websocket is provided + await websocket.send_json({ + "type": "search_complete", + "status": "failure", + "message": "No valid path found", + "timestamp": datetime.utcnow().isoformat() + }) + + # If no valid path was found or path was just the root, return a default action + logger.warning("\nNo valid path found, returning fallback action") + return [{"action": "refresh()", "description": "Fallback action - no valid path found"}] + + except Exception as e: + error_msg = f"Error in RMCTS search: {str(e)}" + logger.error(error_msg) + + # Send error update if websocket is provided + await websocket.send_json({ + "type": "search_error", + "error": error_msg, + "timestamp": datetime.utcnow().isoformat() + }) + + if best_path: + logger.info(f"\nReturning best path found before error with score {best_score:.3f}") + return [{"action": node.action} for node in best_path[1:]] + return [] + + async def _reset_browser(self, websocket=None) -> Optional[tuple]: + """Reset the browser to initial state and return the live browser URL if available.""" + await self.playwright_manager.close() + + ## reset account using api-based account reset + if self.config.account_reset: + if websocket: + await websocket.send_json({ + "type": "account_reset", + "status": "started", + "timestamp": datetime.utcnow().isoformat() + }) + + try: + # Use aiohttp instead of curl + async with aiohttp.ClientSession() as session: + headers = {'Connection': 'close'} # Similar to curl -N + async with session.get(self.reset_url, headers=headers) as response: + if response.status == 200: + data = await response.json() + print(f"Account reset successful: {data}") + if websocket: + await websocket.send_json({ + "type": "account_reset", + "status": "success", + "data": data, + "timestamp": datetime.utcnow().isoformat() + }) + else: + error_msg = f"Account reset failed with status {response.status}" + print(error_msg) + if websocket: + await websocket.send_json({ + "type": "account_reset", + "status": "failed", + "reason": error_msg, + "timestamp": datetime.utcnow().isoformat() + }) + + except Exception as e: + print(f"Error during account reset: {e}") + if websocket: + await websocket.send_json({ + "type": "account_reset", + "status": "failed", + "reason": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + + try: + # Create new playwright manager + self.playwright_manager = await setup_playwright( + storage_state=self.config.storage_state, + headless=self.config.headless, + mode=self.config.browser_mode + ) + page = await self.playwright_manager.get_page() + live_browser_url = None + if self.config.browser_mode == "browserbase": + live_browser_url = await self.playwright_manager.get_live_browser_url() + session_id = await self.playwright_manager.get_session_id() + else: + session_id = None + live_browser_url = None + await page.goto(self.starting_url, wait_until="networkidle") + + # Send success message if websocket is provided + if websocket: + if self.config.storage_state: + await websocket.send_json({ + "type": "browser_setup", + "status": "success", + "message": f"Browser successfully initialized with storage state file: {self.config.storage_state}", + "live_browser_url": live_browser_url, + "session_id": session_id, + "timestamp": datetime.utcnow().isoformat() + }) + else: + await websocket.send_json({ + "type": "browser_setup", + "status": "success", + "message": "Browser successfully initialized", + "live_browser_url": live_browser_url, + "session_id": session_id, + "timestamp": datetime.utcnow().isoformat() + }) + + return live_browser_url, session_id + except Exception as e: + print(f"Error setting up browser: {e}") + if websocket: + await websocket.send_json({ + "type": "browser_setup", + "status": "failed", + "reason": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + return None, None + + async def expand(self, node: LATSNode, websocket=None) -> bool: + """ + Expand a node by generating its children. If no children are generated, + mark the node as terminal and return False to trigger backtracking. + + Args: + node: Node to expand. + websocket: Optional WebSocket connection to send updates. + + Returns: + bool: True if expansion succeeded (children generated), False otherwise. + """ + try: + children_state = await self.generate_children(node, websocket) + except Exception as e: + logger.error(f"Exception during generation of children for node {node.action}: {e}") + children_state = [] + + if not children_state: + logger.warning("No children generated. Marking node as terminal and triggering backtracking.") + node.is_terminal = True + return False # Indicate that expansion did not generate children. + + for child_state in children_state: + try: + child = LATSNode( + natural_language_description=child_state.get("natural_language_description", ""), + action=child_state.get("action", ""), + prob=child_state.get("prob", 0.0), + element=child_state.get("element", None), + goal=node.goal, + parent=node + ) + node.children.append(child) + + if websocket: + await websocket.send_json({ + "type": "node_created", + "node_id": id(child), + "parent_id": id(node), + "action": child.action, + "description": child.natural_language_description, + "timestamp": datetime.utcnow().isoformat() + }) + except Exception as e: + logger.error(f"Error creating child node from state {child_state}: {e}") + return True # Expansion succeeded (children were generated). + + async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]: + """ + Generate child nodes for a given node. + + Args: + node: Parent node to generate children for + websocket: Optional WebSocket connection to send updates to + + Returns: + list[dict]: List of child state dictionaries + """ + # Reset browser and get live URL + live_browser_url, session_id = await self._reset_browser(websocket) + path = self.get_path_to_root(node) + logger.info(f"######### Generating children for path with {len(path)} nodes") + # Execute path + for n in path[1:]: # Skip root node + if websocket: + await websocket.send_json({ + "type": "replaying_action", + "node_id": id(n), + "action": n.action, + "timestamp": datetime.utcnow().isoformat() + }) + try: + success = await playwright_step_execution( + n, + self.goal, + self.playwright_manager, + is_replay=False, + log_folder=self.config.log_folder + ) + logger.info(f"#########Success: {success}") + + if not success: + logger.warning(f"Action execution failed: {n.action}") + n.is_terminal = True + if websocket: + await websocket.send_json({ + "type": "replay_failed", + "node_id": id(n), + "timestamp": datetime.utcnow().isoformat() + }) + return [{ + "natural_language_description": "Recover from failed action", + "action": "refresh()", + "prob": 0.1, + "element": None + }] + except Exception as e: + logger.error(f"Error executing action {n.action}: {str(e)}") + # Provide fallback actions instead of bubbling up the exception + return [{ + "natural_language_description": "Recover from action error", + "action": "refresh()", + "prob": 0.1, + "element": None + }] + + + if not n.feedback: + n.feedback = await generate_feedback( + self.goal, + n.natural_language_description, + self.playwright_manager, + ) + if websocket: + await websocket.send_json({ + "type": "feedback_generated", + "node_id": id(n), + "feedback": n.feedback, + "timestamp": datetime.utcnow().isoformat() + }) + + time.sleep(3) + page = await self.playwright_manager.get_page() + page_info = await extract_page_info(page, self.config.fullpage, self.config.log_folder) + + messages = [{"role": "user", "content": f"Action is: {n.action}"} for n in path[1:]] + + if websocket: + await websocket.send_json({ + "type": "generating_actions", + "node_id": id(node), + "timestamp": datetime.utcnow().isoformat() + }) + + next_actions = await extract_top_actions( + [{"natural_language_description": n.natural_language_description, "action": n.action, "feedback": n.feedback} for n in path[1:]], + self.goal, + self.images, + page_info, + self.action_set, + openai_client, + features=self.config.features, + elements_filter=self.config.elements_filter, + branching_factor=self.config.branching_factor, + log_folder=self.config.log_folder, + fullpage=self.config.fullpage, + action_generation_model=self.config.action_generation_model, + action_grounding_model=self.config.action_grounding_model + ) + + children = [] + for action in next_actions: + if action["action"] == "FINISH": + logger.info(f"Found FINISH action with probability: {action['prob']}") + if action["prob"] > 0.99: + node.is_terminal = True + if websocket: + await websocket.send_json({ + "type": "node_terminal", + "node_id": id(node), + "reason": "finish_action", + "timestamp": datetime.utcnow().isoformat() + }) + continue + # return [] + continue + + page = await self.playwright_manager.get_page() + code, function_calls = self.action_set.to_python_code(action["action"]) + + if len(function_calls) == 1: + try: + for function_name, function_args in function_calls: + extracted_number = parse_function_args(function_args) + element = await locate_element(page, extracted_number) + action["element"] = element + except Exception as e: + logger.warning(f"Element location failed for action: {action['action']}, error: {str(e)}") + action["element"] = None + children.append(action) + if websocket: + await websocket.send_json({ + "type": "element_location_failed", + "action": action["action"], + "error": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + children.append(action) + + if not children: + # node.is_terminal = True + # if websocket: + # await websocket.send_json({ + # "type": "node_terminal", + # "node_id": id(node), + # "reason": "no_valid_actions", + # "timestamp": datetime.utcnow().isoformat() + # }) + # logger.warning("No children generated") + logger.warning("No viable children, creating fallback exploration actions") + + # # If empty list would terminate search, create a "fallback" child + children.extend([ + { + "natural_language_description": "Navigate back to try a different approach", + "action": "navigate_backward()", + "prob": 0.15, + "element": None + }, + { + "natural_language_description": "Try refreshing the page", + "action": "refresh()", + "prob": 0.1, + "element": None + }, + { + "natural_language_description": "Try clicking on a random element", + "action": "click('random')", + "prob": 0.05, + "element": None + } + ]) + print(f"****** Generated children: {children}") + return children + + def get_path_to_root(self, node: LATSNode) -> List[LATSNode]: + path = [] + current = node + while current: + path.append(current) + current = current.parent + return list(reversed(path)) \ No newline at end of file diff --git a/visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/mcts_agent.py b/visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/mcts_agent.py index 9f1734c..f1cb9ea 100644 --- a/visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/mcts_agent.py +++ b/visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/mcts_agent.py @@ -27,6 +27,7 @@ from ...webagent_utils_async.utils.utils import urls_to_images logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) openai_client = OpenAI() class MCTSAgent: @@ -65,4 +66,943 @@ def __init__( self.reset_url = os.environ["ACCOUNT_RESET_URL"] async def run(self, websocket=None) -> List[Dict[str, Any]]: - pass \ No newline at end of file + """ + Run the MCTS algorithm based on configuration. + + Args: + websocket: Optional WebSocket connection to send updates to + + Returns: + List[Dict[str, Any]]: List of actions in the best path found + """ + logger.info("Starting Reflective MCTS algorithm") + if websocket: + return await self.rmcts_with_websocket(websocket) + else: + return await self.rmcts() + + async def rmcts(self) -> List[Dict[str, Any]]: + """ + Performs Monte Carlo Tree Search starting from the root node. + Uses GPT-4 for node selection and reflection-based backpropagation. + + Returns: + List[Dict[str, Any]]: List of actions in the best path found + """ + best_score = float('-inf') + best_path = None + visited = set() # Track visited nodes to avoid cycles + max_iterations = self.config.iterations # Use configured number of iterations + + try: + # Initial browser setup + live_browser_url, session_id = await self._reset_browser() + + for iteration in range(max_iterations): + logger.info(f"\n{'='*50}") + logger.info(f"RMCTS Iteration {iteration + 1}/{max_iterations}") + logger.info(f"{'='*50}\n") + + # Selection: Use GPT-4 to select a promising path + current_node = self.root_node + path = [current_node] + selection_depth = 0 + + while current_node.children and not current_node.is_terminal: + logger.info(f"\nSelection Step {selection_depth + 1}:") + logger.info(f"Current node action: {current_node.action}") + logger.info(f"Number of children: {len(current_node.children)}") + + # Get trajectory for GPT-4 to evaluate + trajectory = [] + for node in path[1:]: # Skip root node + trajectory.append({ + "natural_language_description": node.natural_language_description, + "action": node.action, + "feedback": node.feedback + }) + + # Create prompt for GPT-4 to select next node + prompt = f"""Given the current trajectory and goal, select the most promising child node to explore next. + Consider the overall progress, efficiency, and likelihood of success. + + Goal: {self.goal} + + Current Trajectory: + {json.dumps(trajectory, indent=2)} + + Available Children: + {json.dumps([{ + 'action': child.action, + 'description': child.natural_language_description, + 'visits': child.visits, + 'value': child.value + } for child in current_node.children], indent=2)} + + Return a JSON response with: + {{ + "selected_child_index": int, # Index of the selected child + "explanation": str # Brief explanation of the selection + }}""" + + try: + response = openai_client.chat.completions.create( + model=self.config.evaluation_model, + messages=[ + {"role": "system", "content": "You are an expert at selecting promising paths in a search tree."}, + {"role": "user", "content": prompt} + ], + response_format={"type": "json_object"} + ) + + selection = json.loads(response.choices[0].message.content) + selected_index = selection["selected_child_index"] + + if 0 <= selected_index < len(current_node.children): + current_node = current_node.children[selected_index] + path.append(current_node) + logger.info(f"Selected child {selected_index + 1}: {current_node.action}") + logger.info(f"Selection explanation: {selection['explanation']}") + else: + logger.warning(f"Invalid child index {selected_index}, breaking selection") + break + + except Exception as e: + logger.error(f"Error in node selection: {str(e)}") + break + + selection_depth += 1 + + # Expansion: Expand the selected node if possible + if not current_node.is_terminal and current_node.depth < self.config.max_depth: + logger.info(f"\nExpansion Step:") + logger.info(f"Expanding node: {current_node.action}") + + try: + await self.expand(current_node) + logger.info(f"Successfully expanded node with {len(current_node.children)} children") + except Exception as e: + logger.error(f"Error expanding node: {str(e)}") + current_node.is_terminal = True + # Expansion Step: Expand the selected node if possible + if not current_node.is_terminal and current_node.depth < self.config.max_depth: + logger.info(f"\nExpansion Step:") + logger.info(f"Expanding node: {current_node.action}") + + expansion_success = await self.expand(current_node, None) + if not expansion_success: + # No children were generated; backtrack if possible. + if len(path) > 1: + logger.info("Backtracking due to expansion failure (no children generated).") + path.pop() # Remove the current dead-end node. + current_node = path[-1] # Set current_node to its parent. + else: + logger.warning("Expansion failed at root; no further backtracking possible.") + break + else: + logger.info(f"Successfully expanded node with {len(current_node.children)} children") + + # Simulation: Evaluate the current path + logger.info(f"\nSimulation Step:") + logger.info(f"Evaluating path of length {len(path) - 1}") + + try: + trajectory = [] + for node in path[1:]: # Skip root node + trajectory.append({ + "natural_language_description": node.natural_language_description, + "action": node.action, + "feedback": node.feedback + }) + + # Score the trajectory + prompt = create_llm_prompt(trajectory, self.goal) + result = score_trajectory_with_openai(prompt, openai_client, model=self.config.evaluation_model) + score = result["overall_score"] + + logger.info(f"Simulation Results:") + logger.info(f"Overall Score: {score:.3f}") + logger.info(f"Efficiency Score: {result['efficiency_score']:.3f}") + logger.info(f"Accuracy Score: {result['accuracy_score']:.3f}") + logger.info(f"Robustness Score: {result['robustness_score']:.3f}") + + # Update best path if this score is better + if score > best_score: + best_score = score + best_path = path + logger.info(f"\nNew best path found!") + logger.info(f"Previous best score: {best_score:.3f}") + logger.info(f"New best score: {score:.3f}") + + # Reflection-based backpropagation + if score < 0.75: # If the path is not satisfactory + logger.info(f"\nReflection Step (Score {score:.3f} < 0.75):") + + # Generate reflection prompt + reflection_prompt = f"""Analyze the current trajectory and suggest improvements. + + Goal: {self.goal} + + Current Trajectory: + {json.dumps(trajectory, indent=2)} + + Score: {score} + + Return a JSON response with: + {{ + "backtrack_to_step": int, # Which step to backtrack to (0-based index) + "reason": str, # Why backtrack to this step + "suggested_improvements": [str] # List of suggested improvements + }}""" + + try: + reflection = openai_client.chat.completions.create( + model=self.config.evaluation_model, + messages=[ + {"role": "system", "content": "You are an expert at analyzing and improving search trajectories."}, + {"role": "user", "content": reflection_prompt} + ], + response_format={"type": "json_object"} + ) + + reflection_result = json.loads(reflection.choices[0].message.content) + backtrack_step = reflection_result["backtrack_to_step"] + + # Backtrack to the suggested step + if 0 <= backtrack_step < len(path): + current_node = path[backtrack_step] + # Remove nodes after the backtrack point + while len(path) > backtrack_step + 1: + path.pop() + logger.info(f"Backtracking to step {backtrack_step}") + logger.info(f"Reason: {reflection_result['reason']}") + logger.info("Suggested improvements:") + for improvement in reflection_result["suggested_improvements"]: + logger.info(f"- {improvement}") + + except Exception as e: + logger.error(f"Error in reflection: {str(e)}") + + # If we've found a satisfactory solution, return it + if score >= 0.75: + logger.info(f"\nFound satisfactory solution with score {score:.3f}") + return [{"action": node.action} for node in path[1:]] + + except Exception as e: + logger.error(f"Error in simulation: {str(e)}") + continue + + # Update node statistics + logger.info(f"\nBackpropagation Step:") + for node in path: + old_value = node.value + node.visits += 1 + node.value = (node.value * (node.visits - 1) + score) / node.visits + logger.info(f"Node {node.action}:") + logger.info(f" Visits: {node.visits}") + logger.info(f" Value: {old_value:.3f} -> {node.value:.3f}") + + # If we've exhausted all iterations and haven't found a perfect solution, + # return the best path we found + if best_path and len(best_path) > 1: + logger.info(f"\nSearch complete. Returning best path found with score {best_score:.3f}") + return [{"action": node.action} for node in best_path[1:]] + + # If no valid path was found or path was just the root, return a default action + logger.warning("\nNo valid path found, returning fallback action") + return [{"action": "refresh()", "description": "Fallback action - no valid path found"}] + + except Exception as e: + error_msg = f"Error in RMCTS search: {str(e)}" + logger.error(error_msg) + + if best_path: + logger.info(f"\nReturning best path found before error with score {best_score:.3f}") + return [{"action": node.action} for node in best_path[1:]] + return [] + + async def rmcts_with_websocket(self, websocket) -> List[Dict[str, Any]]: + """ + Performs Monte Carlo Tree Search starting from the root node with WebSocket updates. + Uses GPT-4 for node selection and reflection-based backpropagation. + + Args: + websocket: WebSocket connection to send updates to + + Returns: + List[Dict[str, Any]]: List of actions in the best path found + """ + best_score = float('-inf') + best_path = None + visited = set() # Track visited nodes to avoid cycles + max_iterations = self.config.iterations # Use configured number of iterations + + try: + # Initial browser setup + live_browser_url, session_id = await self._reset_browser(websocket) + + for iteration in range(max_iterations): + logger.info(f"\n{'='*50}") + logger.info(f"RMCTS Iteration {iteration + 1}/{max_iterations}") + logger.info(f"{'='*50}\n") + + # Send iteration update if websocket is provided + await websocket.send_json({ + "type": "rmcts_iteration", + "iteration": iteration + 1, + "max_iterations": max_iterations, + "timestamp": datetime.utcnow().isoformat() + }) + + # Selection: Use GPT-4 to select a promising path + current_node = self.root_node + path = [current_node] + selection_depth = 0 + + while current_node.children and not current_node.is_terminal: + logger.info(f"\nSelection Step {selection_depth + 1}:") + logger.info(f"Current node action: {current_node.action}") + logger.info(f"Number of children: {len(current_node.children)}") + + # Get trajectory for GPT-4 to evaluate + trajectory = [] + for node in path[1:]: # Skip root node + trajectory.append({ + "natural_language_description": node.natural_language_description, + "action": node.action, + "feedback": node.feedback + }) + + # Create prompt for GPT-4 to select next node + prompt = f"""Given the current trajectory and goal, select the most promising child node to explore next. + Consider the overall progress, efficiency, and likelihood of success. + + Goal: {self.goal} + + Current Trajectory: + {json.dumps(trajectory, indent=2)} + + Available Children: + {json.dumps([{ + 'action': child.action, + 'description': child.natural_language_description, + 'visits': child.visits, + 'value': child.value + } for child in current_node.children], indent=2)} + + Return a JSON response with: + {{ + "selected_child_index": int, # Index of the selected child + "explanation": str # Brief explanation of the selection + }}""" + + try: + response = openai_client.chat.completions.create( + model=self.config.evaluation_model, + messages=[ + {"role": "system", "content": "You are an expert at selecting promising paths in a search tree."}, + {"role": "user", "content": prompt} + ], + response_format={"type": "json_object"} + ) + + selection = json.loads(response.choices[0].message.content) + selected_index = selection["selected_child_index"] + + if 0 <= selected_index < len(current_node.children): + current_node = current_node.children[selected_index] + path.append(current_node) + logger.info(f"Selected child {selected_index + 1}: {current_node.action}") + logger.info(f"Selection explanation: {selection['explanation']}") + + # Send selection update if websocket is provided + await websocket.send_json({ + "type": "node_selected", + "node_id": id(current_node), + "explanation": selection["explanation"], + "timestamp": datetime.utcnow().isoformat() + }) + else: + logger.warning(f"Invalid child index {selected_index}, breaking selection") + break + + except Exception as e: + logger.error(f"Error in node selection: {str(e)}") + await websocket.send_json({ + "type": "selection_error", + "error": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + break + + selection_depth += 1 + + # Expansion: Expand the selected node if possible + if not current_node.is_terminal and current_node.depth < self.config.max_depth: + logger.info(f"\nExpansion Step:") + logger.info(f"Expanding node: {current_node.action}") + + await websocket.send_json({ + "type": "node_expanding", + "node_id": id(current_node), + "timestamp": datetime.utcnow().isoformat() + }) + + try: + await self.expand(current_node, websocket) + logger.info(f"Successfully expanded node with {len(current_node.children)} children") + except Exception as e: + logger.error(f"Error expanding node: {str(e)}") + current_node.is_terminal = True + await websocket.send_json({ + "type": "expansion_error", + "node_id": id(current_node), + "error": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + + # Simulation: Evaluate the current path + logger.info(f"\nSimulation Step:") + logger.info(f"Evaluating path of length {len(path) - 1}") + + await websocket.send_json({ + "type": "simulation_start", + "path_length": len(path) - 1, + "timestamp": datetime.utcnow().isoformat() + }) + + try: + trajectory = [] + for node in path[1:]: # Skip root node + trajectory.append({ + "natural_language_description": node.natural_language_description, + "action": node.action, + "feedback": node.feedback + }) + + # Score the trajectory + prompt = create_llm_prompt(trajectory, self.goal) + result = score_trajectory_with_openai(prompt, openai_client, model=self.config.evaluation_model) + score = result["overall_score"] + + logger.info(f"Simulation Results:") + logger.info(f"Overall Score: {score:.3f}") + logger.info(f"Efficiency Score: {result['efficiency_score']:.3f}") + logger.info(f"Accuracy Score: {result['accuracy_score']:.3f}") + logger.info(f"Robustness Score: {result['robustness_score']:.3f}") + + # Send simulation results if websocket is provided + await websocket.send_json({ + "type": "simulation_results", + "score": score, + "efficiency_score": result["efficiency_score"], + "accuracy_score": result["accuracy_score"], + "robustness_score": result["robustness_score"], + "timestamp": datetime.utcnow().isoformat() + }) + + # Update best path if this score is better + if score > best_score: + best_score = score + best_path = path + logger.info(f"\nNew best path found!") + logger.info(f"Previous best score: {best_score:.3f}") + logger.info(f"New best score: {score:.3f}") + + # Send best path update if websocket is provided + await websocket.send_json({ + "type": "best_path_update", + "score": best_score, + "path": [{"id": id(node), "action": node.action} for node in best_path[1:]], + "timestamp": datetime.utcnow().isoformat() + }) + + # Reflection-based backpropagation + if score < 0.75: # If the path is not satisfactory + logger.info(f"\nReflection Step (Score {score:.3f} < 0.75):") + + await websocket.send_json({ + "type": "reflection_start", + "score": score, + "timestamp": datetime.utcnow().isoformat() + }) + + # Generate reflection prompt + reflection_prompt = f"""Analyze the current trajectory and suggest improvements. + + Goal: {self.goal} + + Current Trajectory: + {json.dumps(trajectory, indent=2)} + + Score: {score} + + Return a JSON response with: + {{ + "backtrack_to_step": int, # Which step to backtrack to (0-based index) + "reason": str, # Why backtrack to this step + "suggested_improvements": [str] # List of suggested improvements + }}""" + + try: + reflection = openai_client.chat.completions.create( + model=self.config.evaluation_model, + messages=[ + {"role": "system", "content": "You are an expert at analyzing and improving search trajectories."}, + {"role": "user", "content": reflection_prompt} + ], + response_format={"type": "json_object"} + ) + + reflection_result = json.loads(reflection.choices[0].message.content) + backtrack_step = reflection_result["backtrack_to_step"] + + # Backtrack to the suggested step + if 0 <= backtrack_step < len(path): + current_node = path[backtrack_step] + # Remove nodes after the backtrack point + while len(path) > backtrack_step + 1: + path.pop() + logger.info(f"Backtracking to step {backtrack_step}") + logger.info(f"Reason: {reflection_result['reason']}") + logger.info("Suggested improvements:") + for improvement in reflection_result["suggested_improvements"]: + logger.info(f"- {improvement}") + + # Send backtracking update if websocket is provided + await websocket.send_json({ + "type": "backtracking", + "step": backtrack_step, + "reason": reflection_result["reason"], + "suggested_improvements": reflection_result["suggested_improvements"], + "timestamp": datetime.utcnow().isoformat() + }) + + except Exception as e: + logger.error(f"Error in reflection: {str(e)}") + await websocket.send_json({ + "type": "reflection_error", + "error": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + + # If we've found a satisfactory solution, return it + if score >= 0.75: + logger.info(f"\nFound satisfactory solution with score {score:.3f}") + + # Send completion update if websocket is provided + await websocket.send_json({ + "type": "search_complete", + "status": "success", + "score": score, + "path": [{"id": id(node), "action": node.action} for node in path[1:]], + "timestamp": datetime.utcnow().isoformat() + }) + + return [{"action": node.action} for node in path[1:]] + + except Exception as e: + logger.error(f"Error in simulation: {str(e)}") + await websocket.send_json({ + "type": "simulation_error", + "error": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + continue + + # Update node statistics + logger.info(f"\nBackpropagation Step:") + for node in path: + old_value = node.value + node.visits += 1 + node.value = (node.value * (node.visits - 1) + score) / node.visits + logger.info(f"Node {node.action}:") + logger.info(f" Visits: {node.visits}") + logger.info(f" Value: {old_value:.3f} -> {node.value:.3f}") + + # Send backpropagation update if websocket is provided + await websocket.send_json({ + "type": "backpropagation_complete", + "updated_nodes": [{"id": id(node), "visits": node.visits, "value": node.value} for node in path], + "timestamp": datetime.utcnow().isoformat() + }) + + # If we've exhausted all iterations and haven't found a perfect solution, + # return the best path we found + if best_path and len(best_path) > 1: + logger.info(f"\nSearch complete. Returning best path found with score {best_score:.3f}") + + # Send completion update if websocket is provided + await websocket.send_json({ + "type": "search_complete", + "status": "partial_success", + "score": best_score, + "path": [{"id": id(node), "action": node.action} for node in best_path[1:]], + "timestamp": datetime.utcnow().isoformat() + }) + + return [{"action": node.action} for node in best_path[1:]] + + # If no path was found at all + logger.warning("\nNo valid path found") + + # Send failure update if websocket is provided + await websocket.send_json({ + "type": "search_complete", + "status": "failure", + "message": "No valid path found", + "timestamp": datetime.utcnow().isoformat() + }) + + # If no valid path was found or path was just the root, return a default action + logger.warning("\nNo valid path found, returning fallback action") + return [{"action": "refresh()", "description": "Fallback action - no valid path found"}] + + except Exception as e: + error_msg = f"Error in RMCTS search: {str(e)}" + logger.error(error_msg) + + # Send error update if websocket is provided + await websocket.send_json({ + "type": "search_error", + "error": error_msg, + "timestamp": datetime.utcnow().isoformat() + }) + + if best_path: + logger.info(f"\nReturning best path found before error with score {best_score:.3f}") + return [{"action": node.action} for node in best_path[1:]] + return [] + + async def _reset_browser(self, websocket=None) -> Optional[tuple]: + """Reset the browser to initial state and return the live browser URL if available.""" + await self.playwright_manager.close() + + ## reset account using api-based account reset + if self.config.account_reset: + if websocket: + await websocket.send_json({ + "type": "account_reset", + "status": "started", + "timestamp": datetime.utcnow().isoformat() + }) + + try: + # Use aiohttp instead of curl + async with aiohttp.ClientSession() as session: + headers = {'Connection': 'close'} # Similar to curl -N + async with session.get(self.reset_url, headers=headers) as response: + if response.status == 200: + data = await response.json() + print(f"Account reset successful: {data}") + if websocket: + await websocket.send_json({ + "type": "account_reset", + "status": "success", + "data": data, + "timestamp": datetime.utcnow().isoformat() + }) + else: + error_msg = f"Account reset failed with status {response.status}" + print(error_msg) + if websocket: + await websocket.send_json({ + "type": "account_reset", + "status": "failed", + "reason": error_msg, + "timestamp": datetime.utcnow().isoformat() + }) + + except Exception as e: + print(f"Error during account reset: {e}") + if websocket: + await websocket.send_json({ + "type": "account_reset", + "status": "failed", + "reason": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + + try: + # Create new playwright manager + self.playwright_manager = await setup_playwright( + storage_state=self.config.storage_state, + headless=self.config.headless, + mode=self.config.browser_mode + ) + page = await self.playwright_manager.get_page() + live_browser_url = None + if self.config.browser_mode == "browserbase": + live_browser_url = await self.playwright_manager.get_live_browser_url() + session_id = await self.playwright_manager.get_session_id() + else: + session_id = None + live_browser_url = None + await page.goto(self.starting_url, wait_until="networkidle") + + # Send success message if websocket is provided + if websocket: + if self.config.storage_state: + await websocket.send_json({ + "type": "browser_setup", + "status": "success", + "message": f"Browser successfully initialized with storage state file: {self.config.storage_state}", + "live_browser_url": live_browser_url, + "session_id": session_id, + "timestamp": datetime.utcnow().isoformat() + }) + else: + await websocket.send_json({ + "type": "browser_setup", + "status": "success", + "message": "Browser successfully initialized", + "live_browser_url": live_browser_url, + "session_id": session_id, + "timestamp": datetime.utcnow().isoformat() + }) + + return live_browser_url, session_id + except Exception as e: + print(f"Error setting up browser: {e}") + if websocket: + await websocket.send_json({ + "type": "browser_setup", + "status": "failed", + "reason": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + return None, None + + async def expand(self, node: LATSNode, websocket=None) -> bool: + """ + Expand a node by generating its children. If no children are generated, + mark the node as terminal and return False to trigger backtracking. + + Args: + node: Node to expand. + websocket: Optional WebSocket connection to send updates. + + Returns: + bool: True if expansion succeeded (children generated), False otherwise. + """ + try: + children_state = await self.generate_children(node, websocket) + except Exception as e: + logger.error(f"Exception during generation of children for node {node.action}: {e}") + children_state = [] + + if not children_state: + logger.warning("No children generated. Marking node as terminal and triggering backtracking.") + node.is_terminal = True + return False # Indicate that expansion did not generate children. + + for child_state in children_state: + try: + child = LATSNode( + natural_language_description=child_state.get("natural_language_description", ""), + action=child_state.get("action", ""), + prob=child_state.get("prob", 0.0), + element=child_state.get("element", None), + goal=node.goal, + parent=node + ) + node.children.append(child) + + if websocket: + await websocket.send_json({ + "type": "node_created", + "node_id": id(child), + "parent_id": id(node), + "action": child.action, + "description": child.natural_language_description, + "timestamp": datetime.utcnow().isoformat() + }) + except Exception as e: + logger.error(f"Error creating child node from state {child_state}: {e}") + return True # Expansion succeeded (children were generated). + + async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]: + """ + Generate child nodes for a given node. + + Args: + node: Parent node to generate children for + websocket: Optional WebSocket connection to send updates to + + Returns: + list[dict]: List of child state dictionaries + """ + # Reset browser and get live URL + live_browser_url, session_id = await self._reset_browser(websocket) + path = self.get_path_to_root(node) + logger.info(f"######### Generating children for path with {len(path)} nodes") + # Execute path + for n in path[1:]: # Skip root node + if websocket: + await websocket.send_json({ + "type": "replaying_action", + "node_id": id(n), + "action": n.action, + "timestamp": datetime.utcnow().isoformat() + }) + try: + success = await playwright_step_execution( + n, + self.goal, + self.playwright_manager, + is_replay=False, + log_folder=self.config.log_folder + ) + logger.info(f"#########Success: {success}") + + if not success: + logger.warning(f"Action execution failed: {n.action}") + n.is_terminal = True + if websocket: + await websocket.send_json({ + "type": "replay_failed", + "node_id": id(n), + "timestamp": datetime.utcnow().isoformat() + }) + return [{ + "natural_language_description": "Recover from failed action", + "action": "refresh()", + "prob": 0.1, + "element": None + }] + except Exception as e: + logger.error(f"Error executing action {n.action}: {str(e)}") + # Provide fallback actions instead of bubbling up the exception + return [{ + "natural_language_description": "Recover from action error", + "action": "refresh()", + "prob": 0.1, + "element": None + }] + + + if not n.feedback: + n.feedback = await generate_feedback( + self.goal, + n.natural_language_description, + self.playwright_manager, + ) + if websocket: + await websocket.send_json({ + "type": "feedback_generated", + "node_id": id(n), + "feedback": n.feedback, + "timestamp": datetime.utcnow().isoformat() + }) + + time.sleep(3) + page = await self.playwright_manager.get_page() + page_info = await extract_page_info(page, self.config.fullpage, self.config.log_folder) + + messages = [{"role": "user", "content": f"Action is: {n.action}"} for n in path[1:]] + + if websocket: + await websocket.send_json({ + "type": "generating_actions", + "node_id": id(node), + "timestamp": datetime.utcnow().isoformat() + }) + + next_actions = await extract_top_actions( + [{"natural_language_description": n.natural_language_description, "action": n.action, "feedback": n.feedback} for n in path[1:]], + self.goal, + self.images, + page_info, + self.action_set, + openai_client, + features=self.config.features, + elements_filter=self.config.elements_filter, + branching_factor=self.config.branching_factor, + log_folder=self.config.log_folder, + fullpage=self.config.fullpage, + action_generation_model=self.config.action_generation_model, + action_grounding_model=self.config.action_grounding_model + ) + + children = [] + for action in next_actions: + if action["action"] == "FINISH": + logger.info(f"Found FINISH action with probability: {action['prob']}") + if action["prob"] > 0.99: + node.is_terminal = True + if websocket: + await websocket.send_json({ + "type": "node_terminal", + "node_id": id(node), + "reason": "finish_action", + "timestamp": datetime.utcnow().isoformat() + }) + continue + # return [] + continue + + page = await self.playwright_manager.get_page() + code, function_calls = self.action_set.to_python_code(action["action"]) + + if len(function_calls) == 1: + try: + for function_name, function_args in function_calls: + extracted_number = parse_function_args(function_args) + element = await locate_element(page, extracted_number) + action["element"] = element + except Exception as e: + logger.warning(f"Element location failed for action: {action['action']}, error: {str(e)}") + action["element"] = None + children.append(action) + if websocket: + await websocket.send_json({ + "type": "element_location_failed", + "action": action["action"], + "error": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + children.append(action) + + if not children: + # node.is_terminal = True + # if websocket: + # await websocket.send_json({ + # "type": "node_terminal", + # "node_id": id(node), + # "reason": "no_valid_actions", + # "timestamp": datetime.utcnow().isoformat() + # }) + # logger.warning("No children generated") + logger.warning("No viable children, creating fallback exploration actions") + + # # If empty list would terminate search, create a "fallback" child + children.extend([ + { + "natural_language_description": "Navigate back to try a different approach", + "action": "navigate_backward()", + "prob": 0.15, + "element": None + }, + { + "natural_language_description": "Try refreshing the page", + "action": "refresh()", + "prob": 0.1, + "element": None + }, + { + "natural_language_description": "Try clicking on a random element", + "action": "click('random')", + "prob": 0.05, + "element": None + } + ]) + print(f"****** Generated children: {children}") + return children + + def get_path_to_root(self, node: LATSNode) -> List[LATSNode]: + path = [] + current = node + while current: + path.append(current) + current = current.parent + return list(reversed(path)) \ No newline at end of file diff --git a/visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/simple_search_agent.py b/visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/simple_search_agent.py index fc41795..b38302f 100644 --- a/visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/simple_search_agent.py +++ b/visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/simple_search_agent.py @@ -1146,4 +1146,3 @@ def _get_tree_data(self): tree_data.append(node_data) return tree_data - diff --git a/visual-tree-search-backend/app/api/routes/new_tree_search_websocket.py b/visual-tree-search-backend/app/api/routes/new_tree_search_websocket.py index 5277676..cd7d0f9 100644 --- a/visual-tree-search-backend/app/api/routes/new_tree_search_websocket.py +++ b/visual-tree-search-backend/app/api/routes/new_tree_search_websocket.py @@ -128,6 +128,8 @@ async def handle_search_request(websocket: WebSocket, message: Dict[str, Any]): await agent.dfs_with_websocket(websocket) elif search_algorithm.lower() == "lats": await agent.run(websocket) + elif search_algorithm.lower() == "mcts": + await agent.run(websocket) else: await websocket.send_json({ "type": "error", diff --git a/visual-tree-search-backend/app/api/routes/tree_search_websocket.py b/visual-tree-search-backend/app/api/routes/tree_search_websocket.py index 144950f..f391e33 100644 --- a/visual-tree-search-backend/app/api/routes/tree_search_websocket.py +++ b/visual-tree-search-backend/app/api/routes/tree_search_websocket.py @@ -128,6 +128,8 @@ async def handle_search_request(websocket: WebSocket, message: Dict[str, Any]): await agent.dfs_with_websocket(websocket) elif search_algorithm.lower() == "lats": await agent.run(websocket) + elif search_algorithm.lower() == "mcts": + await agent.run(websocket) else: await websocket.send_json({ "type": "error", diff --git a/visual-tree-search-backend/app/api/shopping.json b/visual-tree-search-backend/app/api/shopping.json index 4de8ead..6a2c7ab 100644 --- a/visual-tree-search-backend/app/api/shopping.json +++ b/visual-tree-search-backend/app/api/shopping.json @@ -24,7 +24,7 @@ "value": "", "domain": "128.105.145.205", "path": "/", - "expires": 1775272527, + "expires": 1775370120, "httpOnly": false, "secure": false, "sameSite": "Strict" @@ -81,20 +81,20 @@ }, { "name": "private_content_version", - "value": "c514ad01bc9816ee30ab1f02510aa34a", + "value": "ff4bba58081243f67b1adee2cc6974bd", "domain": "128.105.145.205", "path": "/", - "expires": 1778296522.505323, + "expires": 1778394118.522057, "httpOnly": false, "secure": false, "sameSite": "Lax" }, { "name": "PHPSESSID", - "value": "007c737fca0eb4173ab5362c2d9c8b09", + "value": "30247306b1ad824f37f3e0384d86d991", "domain": "128.105.145.205", "path": "/", - "expires": 1775272526.468528, + "expires": 1775370122.385659, "httpOnly": true, "secure": false, "sameSite": "Lax" @@ -104,17 +104,17 @@ "value": "9bf9a599123e6402b85cde67144717a08b817412", "domain": "128.105.145.205", "path": "/", - "expires": 1775272526.468754, + "expires": 1775370122.385877, "httpOnly": true, "secure": false, "sameSite": "Lax" }, { "name": "form_key", - "value": "Hsr3n5ycGkOPfr1K", + "value": "IEPmx1hKh4NWjeUa", "domain": "128.105.145.205", "path": "/", - "expires": 1775272526.468692, + "expires": 1775370122.385817, "httpOnly": false, "secure": false, "sameSite": "Lax" @@ -124,17 +124,17 @@ "value": "true", "domain": "128.105.145.205", "path": "/", - "expires": 1775272527, + "expires": 1775370120, "httpOnly": false, "secure": false, "sameSite": "Lax" }, { "name": "section_data_ids", - "value": "{%22messages%22:1743736525%2C%22customer%22:1743736525%2C%22compare-products%22:1743736525%2C%22last-ordered-items%22:1743736525%2C%22cart%22:1743736525%2C%22directory-data%22:1743736525%2C%22captcha%22:1743736525%2C%22instant-purchase%22:1743736525%2C%22loggedAsCustomer%22:1743736525%2C%22persistent%22:1743736525%2C%22review%22:1743736525%2C%22wishlist%22:1743736525%2C%22recently_viewed_product%22:1743736525%2C%22recently_compared_product%22:1743736525%2C%22product_data_storage%22:1743736525%2C%22paypal-billing-agreement%22:1743736525}", + "value": "{%22messages%22:1743834120%2C%22customer%22:1743834120%2C%22compare-products%22:1743834120%2C%22last-ordered-items%22:1743834120%2C%22cart%22:1743834120%2C%22directory-data%22:1743834120%2C%22captcha%22:1743834120%2C%22instant-purchase%22:1743834120%2C%22loggedAsCustomer%22:1743834120%2C%22persistent%22:1743834120%2C%22review%22:1743834120%2C%22wishlist%22:1743834120%2C%22recently_viewed_product%22:1743834120%2C%22recently_compared_product%22:1743834120%2C%22product_data_storage%22:1743834120%2C%22paypal-billing-agreement%22:1743834120}", "domain": "128.105.145.205", "path": "/", - "expires": 1775272524, + "expires": 1775370120, "httpOnly": false, "secure": false, "sameSite": "Lax" diff --git a/visual-tree-search-backend/test/test-tree-search-ws-lats.py b/visual-tree-search-backend/test/test-tree-search-ws-lats.py index 567c790..c9f3cfb 100644 --- a/visual-tree-search-backend/test/test-tree-search-ws-lats.py +++ b/visual-tree-search-backend/test/test-tree-search-ws-lats.py @@ -138,7 +138,7 @@ def parse_arguments(): parser.add_argument("--goal", type=str, default=DEFAULT_GOAL, help=f"Goal to achieve (default: {DEFAULT_GOAL})") - parser.add_argument("--algorithm", type=str, choices=["bfs", "dfs", "lats"], default="lats", + parser.add_argument("--algorithm", type=str, choices=["bfs", "dfs", "lats", "mcts"], default="lats", help="Search algorithm to use (default: lats)") parser.add_argument("--max-depth", type=int, default=3, diff --git a/visual-tree-search-backend/test/test-tree-search-ws-mcts.py b/visual-tree-search-backend/test/test-tree-search-ws-mcts.py new file mode 100644 index 0000000..9046be8 --- /dev/null +++ b/visual-tree-search-backend/test/test-tree-search-ws-mcts.py @@ -0,0 +1,169 @@ +import asyncio +import json +import websockets +import argparse +import logging +from datetime import datetime + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +# Default values +DEFAULT_WS_URL = "ws://localhost:3000/new-tree-search-ws" +DEFAULT_STARTING_URL = "http://128.105.145.205:7770/" +DEFAULT_GOAL = "search running shoes, click on the first result" + +async def connect_and_test_search( + ws_url: str, + starting_url: str, + goal: str, + search_algorithm: str = "bfs", + max_depth: int = 3 +): + """ + Connect to the WebSocket endpoint and test the tree search functionality. + + Args: + ws_url: WebSocket URL to connect to + starting_url: URL to start the search from + goal: Goal to achieve + search_algorithm: Search algorithm to use (bfs or dfs) + max_depth: Maximum depth for the search tree + """ + logger.info(f"Connecting to WebSocket at {ws_url}") + + async with websockets.connect(ws_url) as websocket: + logger.info("Connected to WebSocket") + + # Wait for connection established message + response = await websocket.recv() + data = json.loads(response) + if data.get("type") == "connection_established": + logger.info(f"Connection established with ID: {data.get('connection_id')}") + + # Send search request + request = { + "type": "start_search", + "agent_type": "MCTSAgent", + "starting_url": starting_url, + "goal": goal, + "search_algorithm": search_algorithm, + "max_depth": max_depth + } + + logger.info(f"Sending search request: {request}") + await websocket.send(json.dumps(request)) + + # Process responses + while True: + try: + response = await websocket.recv() + data = json.loads(response) + + # Log the message type and some key information + msg_type = data.get("type", "unknown") + + if msg_type == "status_update": + logger.info(f"Status update: {data.get('status')} - {data.get('message')}") + + elif msg_type == "iteration_start": + logger.info(f"Iteration start: {data.get('iteration')}") + + elif msg_type == "step_start": + logger.info(f"Step start: {data.get('step')} - {data.get('step_name')}") + + elif msg_type == "node_update": + node_id = data.get("node_id") + status = data.get("status") + logger.info(f"Node update: {node_id} - {status}") + + # If node was scored, log the score + if status == "scored": + logger.info(f"Node score: {data.get('score')}") + + elif msg_type == "trajectory_update": + logger.info(f"Trajectory update received with {data.get('trajectory')}") + + elif msg_type == "tree_update": + logger.info(f"Tree update received with {data.get('tree')}") + + elif msg_type == "best_path_update": + logger.info(f"Best path update: score={data.get('score')}, path length={len(data.get('path', []))}") + + elif msg_type == "search_complete": + status = data.get("status") + score = data.get("score", "N/A") + path_length = len(data.get("path", [])) + + logger.info(f"Search complete: {status}, score={score}, path length={path_length}") + logger.info("Path actions:") + + for i, node in enumerate(data.get("path", [])): + logger.info(f" {i+1}. {node.get('action')}") + + # Exit the loop when search is complete + break + + elif msg_type == "error": + logger.error(f"Error: {data.get('message')}") + break + + else: + logger.info(f"Received message of type {msg_type}") + logger.info(f"Message: {data}") + + except websockets.exceptions.ConnectionClosed: + logger.warning("WebSocket connection closed") + break + except Exception as e: + logger.error(f"Error processing message: {e}") + break + + logger.info("Test completed") + +def parse_arguments(): + """Parse command line arguments""" + parser = argparse.ArgumentParser(description="Test the tree search WebSocket functionality") + + parser.add_argument("--ws-url", type=str, default=DEFAULT_WS_URL, + help=f"WebSocket URL (default: {DEFAULT_WS_URL})") + + parser.add_argument("--starting-url", type=str, default=DEFAULT_STARTING_URL, + help=f"Starting URL for the search (default: {DEFAULT_STARTING_URL})") + + parser.add_argument("--goal", type=str, default=DEFAULT_GOAL, + help=f"Goal to achieve (default: {DEFAULT_GOAL})") + + parser.add_argument("--algorithm", type=str, choices=["bfs", "dfs", "lats", "mcts"], default="mcts", + help="Search algorithm to use (default: lats)") + + parser.add_argument("--max-depth", type=int, default=3, + help="Maximum depth for the search tree (default: 3)") + + return parser.parse_args() + +async def main(): + """Main entry point""" + args = parse_arguments() + + logger.info("Starting tree search WebSocket test") + logger.info(f"WebSocket URL: {args.ws_url}") + logger.info(f"Starting URL: {args.starting_url}") + logger.info(f"Goal: {args.goal}") + logger.info(f"Algorithm: {args.algorithm}") + logger.info(f"Max depth: {args.max_depth}") + + await connect_and_test_search( + ws_url=args.ws_url, + starting_url=args.starting_url, + goal=args.goal, + search_algorithm=args.algorithm, + max_depth=args.max_depth + ) + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file