From 6176c3396643366d8ebbaa9233929df603b7d04b Mon Sep 17 00:00:00 2001 From: SyHeee Date: Sun, 30 Mar 2025 23:51:03 -0400 Subject: [PATCH 1/6] check in local rmcts --- .../SimpleSearchAgents/simple_search_agent.py | 225 ++++++++++++++++++ 1 file changed, 225 insertions(+) diff --git a/visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/simple_search_agent.py b/visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/simple_search_agent.py index fc41795..7fede0b 100644 --- a/visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/simple_search_agent.py +++ b/visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/simple_search_agent.py @@ -92,6 +92,9 @@ async def run(self, websocket=None) -> List[Dict[str, Any]]: return await self.dfs_with_websocket(websocket) else: return await self.dfs() + elif algorithm == "rmcts": + logger.info("Starting Reflective MCTS algorithm") + return await self.rmcts() else: error_msg = f"Unsupported algorithm: {algorithm}" logger.error(error_msg) @@ -1147,3 +1150,225 @@ def _get_tree_data(self): return tree_data + async def rmcts(self) -> List[Dict[str, Any]]: + """ + Performs Monte Carlo Tree Search starting from the root node. + Uses GPT-4 for node selection and reflection-based backpropagation. + + Returns: + List[Dict[str, Any]]: List of actions in the best path found + """ + best_score = float('-inf') + best_path = None + visited = set() # Track visited nodes to avoid cycles + max_iterations = self.config.iterations # Use configured number of iterations + + try: + # Initial browser setup + live_browser_url, session_id = await self._reset_browser() + if not live_browser_url: + logger.error("Failed to initialize browser") + return [] + + for iteration in range(max_iterations): + logger.info(f"\n{'='*50}") + logger.info(f"RMCTS Iteration {iteration + 1}/{max_iterations}") + logger.info(f"{'='*50}\n") + + # Selection: Use GPT-4 to select a promising path + current_node = self.root_node + path = [current_node] + selection_depth = 0 + + while current_node.children and not current_node.is_terminal: + logger.info(f"\nSelection Step {selection_depth + 1}:") + logger.info(f"Current node action: {current_node.action}") + logger.info(f"Number of children: {len(current_node.children)}") + + # Get trajectory for GPT-4 to evaluate + trajectory = [] + for node in path[1:]: # Skip root node + trajectory.append({ + "natural_language_description": node.natural_language_description, + "action": node.action, + "feedback": node.feedback + }) + + # Create prompt for GPT-4 to select next node + prompt = f"""Given the current trajectory and goal, select the most promising child node to explore next. + Consider the overall progress, efficiency, and likelihood of success. + + Goal: {self.goal} + + Current Trajectory: + {json.dumps(trajectory, indent=2)} + + Available Children: + {json.dumps([{ + 'action': child.action, + 'description': child.natural_language_description, + 'visits': child.visits, + 'value': child.value + } for child in current_node.children], indent=2)} + + Return a JSON response with: + {{ + "selected_child_index": int, # Index of the selected child + "explanation": str # Brief explanation of the selection + }}""" + + try: + response = openai_client.chat.completions.create( + model=self.config.evaluation_model, + messages=[ + {"role": "system", "content": "You are an expert at selecting promising paths in a search tree."}, + {"role": "user", "content": prompt} + ], + response_format={"type": "json_object"} + ) + + selection = json.loads(response.choices[0].message.content) + selected_index = selection["selected_child_index"] + + if 0 <= selected_index < len(current_node.children): + current_node = current_node.children[selected_index] + path.append(current_node) + logger.info(f"Selected child {selected_index + 1}: {current_node.action}") + logger.info(f"Selection explanation: {selection['explanation']}") + else: + logger.warning(f"Invalid child index {selected_index}, breaking selection") + break + + except Exception as e: + logger.error(f"Error in node selection: {str(e)}") + break + + selection_depth += 1 + + # Expansion: Expand the selected node if possible + if not current_node.is_terminal and current_node.depth < self.config.max_depth: + logger.info(f"\nExpansion Step:") + logger.info(f"Expanding node: {current_node.action}") + try: + await self.expand(current_node) + logger.info(f"Successfully expanded node with {len(current_node.children)} children") + except Exception as e: + logger.error(f"Error expanding node: {str(e)}") + current_node.is_terminal = True + + # Simulation: Evaluate the current path + logger.info(f"\nSimulation Step:") + logger.info(f"Evaluating path of length {len(path) - 1}") + try: + trajectory = [] + for node in path[1:]: # Skip root node + trajectory.append({ + "natural_language_description": node.natural_language_description, + "action": node.action, + "feedback": node.feedback + }) + + # Score the trajectory + prompt = create_llm_prompt(trajectory, self.goal) + result = score_trajectory_with_openai(prompt, openai_client, model=self.config.evaluation_model) + score = result["overall_score"] + + logger.info(f"Simulation Results:") + logger.info(f"Overall Score: {score:.3f}") + logger.info(f"Efficiency Score: {result['efficiency_score']:.3f}") + logger.info(f"Accuracy Score: {result['accuracy_score']:.3f}") + logger.info(f"Robustness Score: {result['robustness_score']:.3f}") + + # Update best path if this score is better + if score > best_score: + best_score = score + best_path = path + logger.info(f"\nNew best path found!") + logger.info(f"Previous best score: {best_score:.3f}") + logger.info(f"New best score: {score:.3f}") + + # Reflection-based backpropagation + if score < 0.75: # If the path is not satisfactory + logger.info(f"\nReflection Step (Score {score:.3f} < 0.75):") + # Generate reflection prompt + reflection_prompt = f"""Analyze the current trajectory and suggest improvements. + + Goal: {self.goal} + + Current Trajectory: + {json.dumps(trajectory, indent=2)} + + Score: {score} + + Return a JSON response with: + {{ + "backtrack_to_step": int, # Which step to backtrack to (0-based index) + "reason": str, # Why backtrack to this step + "suggested_improvements": [str] # List of suggested improvements + }}""" + + try: + reflection = openai_client.chat.completions.create( + model=self.config.evaluation_model, + messages=[ + {"role": "system", "content": "You are an expert at analyzing and improving search trajectories."}, + {"role": "user", "content": reflection_prompt} + ], + response_format={"type": "json_object"} + ) + + reflection_result = json.loads(reflection.choices[0].message.content) + backtrack_step = reflection_result["backtrack_to_step"] + + # Backtrack to the suggested step + if 0 <= backtrack_step < len(path): + current_node = path[backtrack_step] + # Remove nodes after the backtrack point + while len(path) > backtrack_step + 1: + path.pop() + logger.info(f"Backtracking to step {backtrack_step}") + logger.info(f"Reason: {reflection_result['reason']}") + logger.info("Suggested improvements:") + for improvement in reflection_result["suggested_improvements"]: + logger.info(f"- {improvement}") + + except Exception as e: + logger.error(f"Error in reflection: {str(e)}") + + # If we've found a satisfactory solution, return it + if score >= 0.75: + logger.info(f"\nFound satisfactory solution with score {score:.3f}") + return [{"action": node.action} for node in path[1:]] + + except Exception as e: + logger.error(f"Error in simulation: {str(e)}") + continue + + # Update node statistics + logger.info(f"\nBackpropagation Step:") + for node in path: + old_value = node.value + node.visits += 1 + node.value = (node.value * (node.visits - 1) + score) / node.visits + logger.info(f"Node {node.action}:") + logger.info(f" Visits: {node.visits}") + logger.info(f" Value: {old_value:.3f} -> {node.value:.3f}") + + # If we've exhausted all iterations and haven't found a perfect solution, + # return the best path we found + if best_path: + logger.info(f"\nSearch complete. Returning best path found with score {best_score:.3f}") + return [{"action": node.action} for node in best_path[1:]] + + # If no path was found at all + logger.warning("\nNo valid path found") + return [] + + except Exception as e: + error_msg = f"Error in RMCTS search: {str(e)}" + logger.error(error_msg) + if best_path: + logger.info(f"\nReturning best path found before error with score {best_score:.3f}") + return [{"action": node.action} for node in best_path[1:]] + return [] + From 1252222c8d24651f323237ebc28b020a34d966a4 Mon Sep 17 00:00:00 2001 From: SyHeee Date: Tue, 1 Apr 2025 23:32:49 -0400 Subject: [PATCH 2/6] merge to new module, need fix termination --- .../SimpleSearchAgents/mcts_agent.py | 880 +++++++++++++++++- 1 file changed, 879 insertions(+), 1 deletion(-) diff --git a/visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/mcts_agent.py b/visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/mcts_agent.py index 9f1734c..e1e3a1c 100644 --- a/visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/mcts_agent.py +++ b/visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/mcts_agent.py @@ -65,4 +65,882 @@ def __init__( self.reset_url = os.environ["ACCOUNT_RESET_URL"] async def run(self, websocket=None) -> List[Dict[str, Any]]: - pass \ No newline at end of file + """ + Run the MCTS algorithm based on configuration. + + Args: + websocket: Optional WebSocket connection to send updates to + + Returns: + List[Dict[str, Any]]: List of actions in the best path found + """ + logger.info("Starting Reflective MCTS algorithm") + if websocket: + return await self.rmcts_with_websocket(websocket) + else: + return await self.rmcts() + + async def rmcts(self) -> List[Dict[str, Any]]: + """ + Performs Monte Carlo Tree Search starting from the root node. + Uses GPT-4 for node selection and reflection-based backpropagation. + + Returns: + List[Dict[str, Any]]: List of actions in the best path found + """ + best_score = float('-inf') + best_path = None + visited = set() # Track visited nodes to avoid cycles + max_iterations = self.config.iterations # Use configured number of iterations + + try: + # Initial browser setup + live_browser_url, session_id = await self._reset_browser() + + for iteration in range(max_iterations): + logger.info(f"\n{'='*50}") + logger.info(f"RMCTS Iteration {iteration + 1}/{max_iterations}") + logger.info(f"{'='*50}\n") + + # Selection: Use GPT-4 to select a promising path + current_node = self.root_node + path = [current_node] + selection_depth = 0 + + while current_node.children and not current_node.is_terminal: + logger.info(f"\nSelection Step {selection_depth + 1}:") + logger.info(f"Current node action: {current_node.action}") + logger.info(f"Number of children: {len(current_node.children)}") + + # Get trajectory for GPT-4 to evaluate + trajectory = [] + for node in path[1:]: # Skip root node + trajectory.append({ + "natural_language_description": node.natural_language_description, + "action": node.action, + "feedback": node.feedback + }) + + # Create prompt for GPT-4 to select next node + prompt = f"""Given the current trajectory and goal, select the most promising child node to explore next. + Consider the overall progress, efficiency, and likelihood of success. + + Goal: {self.goal} + + Current Trajectory: + {json.dumps(trajectory, indent=2)} + + Available Children: + {json.dumps([{ + 'action': child.action, + 'description': child.natural_language_description, + 'visits': child.visits, + 'value': child.value + } for child in current_node.children], indent=2)} + + Return a JSON response with: + {{ + "selected_child_index": int, # Index of the selected child + "explanation": str # Brief explanation of the selection + }}""" + + try: + response = openai_client.chat.completions.create( + model=self.config.evaluation_model, + messages=[ + {"role": "system", "content": "You are an expert at selecting promising paths in a search tree."}, + {"role": "user", "content": prompt} + ], + response_format={"type": "json_object"} + ) + + selection = json.loads(response.choices[0].message.content) + selected_index = selection["selected_child_index"] + + if 0 <= selected_index < len(current_node.children): + current_node = current_node.children[selected_index] + path.append(current_node) + logger.info(f"Selected child {selected_index + 1}: {current_node.action}") + logger.info(f"Selection explanation: {selection['explanation']}") + else: + logger.warning(f"Invalid child index {selected_index}, breaking selection") + break + + except Exception as e: + logger.error(f"Error in node selection: {str(e)}") + break + + selection_depth += 1 + + # Expansion: Expand the selected node if possible + if not current_node.is_terminal and current_node.depth < self.config.max_depth: + logger.info(f"\nExpansion Step:") + logger.info(f"Expanding node: {current_node.action}") + + try: + await self.expand(current_node) + logger.info(f"Successfully expanded node with {len(current_node.children)} children") + except Exception as e: + logger.error(f"Error expanding node: {str(e)}") + current_node.is_terminal = True + + # Simulation: Evaluate the current path + logger.info(f"\nSimulation Step:") + logger.info(f"Evaluating path of length {len(path) - 1}") + + try: + trajectory = [] + for node in path[1:]: # Skip root node + trajectory.append({ + "natural_language_description": node.natural_language_description, + "action": node.action, + "feedback": node.feedback + }) + + # Score the trajectory + prompt = create_llm_prompt(trajectory, self.goal) + result = score_trajectory_with_openai(prompt, openai_client, model=self.config.evaluation_model) + score = result["overall_score"] + + logger.info(f"Simulation Results:") + logger.info(f"Overall Score: {score:.3f}") + logger.info(f"Efficiency Score: {result['efficiency_score']:.3f}") + logger.info(f"Accuracy Score: {result['accuracy_score']:.3f}") + logger.info(f"Robustness Score: {result['robustness_score']:.3f}") + + # Update best path if this score is better + if score > best_score: + best_score = score + best_path = path + logger.info(f"\nNew best path found!") + logger.info(f"Previous best score: {best_score:.3f}") + logger.info(f"New best score: {score:.3f}") + + # Reflection-based backpropagation + if score < 0.75: # If the path is not satisfactory + logger.info(f"\nReflection Step (Score {score:.3f} < 0.5):") + + # Generate reflection prompt + reflection_prompt = f"""Analyze the current trajectory and suggest improvements. + + Goal: {self.goal} + + Current Trajectory: + {json.dumps(trajectory, indent=2)} + + Score: {score} + + Return a JSON response with: + {{ + "backtrack_to_step": int, # Which step to backtrack to (0-based index) + "reason": str, # Why backtrack to this step + "suggested_improvements": [str] # List of suggested improvements + }}""" + + try: + reflection = openai_client.chat.completions.create( + model=self.config.evaluation_model, + messages=[ + {"role": "system", "content": "You are an expert at analyzing and improving search trajectories."}, + {"role": "user", "content": reflection_prompt} + ], + response_format={"type": "json_object"} + ) + + reflection_result = json.loads(reflection.choices[0].message.content) + backtrack_step = reflection_result["backtrack_to_step"] + + # Backtrack to the suggested step + if 0 <= backtrack_step < len(path): + current_node = path[backtrack_step] + # Remove nodes after the backtrack point + while len(path) > backtrack_step + 1: + path.pop() + logger.info(f"Backtracking to step {backtrack_step}") + logger.info(f"Reason: {reflection_result['reason']}") + logger.info("Suggested improvements:") + for improvement in reflection_result["suggested_improvements"]: + logger.info(f"- {improvement}") + + except Exception as e: + logger.error(f"Error in reflection: {str(e)}") + + # # If we've found a satisfactory solution, return it + # if score >= 0.75: + # logger.info(f"\nFound satisfactory solution with score {score:.3f}") + # return [{"action": node.action} for node in path[1:]] + + except Exception as e: + logger.error(f"Error in simulation: {str(e)}") + continue + + # Update node statistics + logger.info(f"\nBackpropagation Step:") + for node in path: + old_value = node.value + node.visits += 1 + node.value = (node.value * (node.visits - 1) + score) / node.visits + logger.info(f"Node {node.action}:") + logger.info(f" Visits: {node.visits}") + logger.info(f" Value: {old_value:.3f} -> {node.value:.3f}") + + # If we've exhausted all iterations and haven't found a perfect solution, + # return the best path we found + if best_path: + logger.info(f"\nSearch complete. Returning best path found with score {best_score:.3f}") + return [{"action": node.action} for node in best_path[1:]] + + # If no path was found at all + logger.warning("\nNo valid path found") + return [] + + except Exception as e: + error_msg = f"Error in RMCTS search: {str(e)}" + logger.error(error_msg) + + if best_path: + logger.info(f"\nReturning best path found before error with score {best_score:.3f}") + return [{"action": node.action} for node in best_path[1:]] + return [] + + async def rmcts_with_websocket(self, websocket) -> List[Dict[str, Any]]: + """ + Performs Monte Carlo Tree Search starting from the root node with WebSocket updates. + Uses GPT-4 for node selection and reflection-based backpropagation. + + Args: + websocket: WebSocket connection to send updates to + + Returns: + List[Dict[str, Any]]: List of actions in the best path found + """ + best_score = float('-inf') + best_path = None + visited = set() # Track visited nodes to avoid cycles + max_iterations = self.config.iterations # Use configured number of iterations + + try: + # Initial browser setup + live_browser_url, session_id = await self._reset_browser(websocket) + + for iteration in range(max_iterations): + logger.info(f"\n{'='*50}") + logger.info(f"RMCTS Iteration {iteration + 1}/{max_iterations}") + logger.info(f"{'='*50}\n") + + # Send iteration update if websocket is provided + await websocket.send_json({ + "type": "rmcts_iteration", + "iteration": iteration + 1, + "max_iterations": max_iterations, + "timestamp": datetime.utcnow().isoformat() + }) + + # Selection: Use GPT-4 to select a promising path + current_node = self.root_node + path = [current_node] + selection_depth = 0 + + while current_node.children and not current_node.is_terminal: + logger.info(f"\nSelection Step {selection_depth + 1}:") + logger.info(f"Current node action: {current_node.action}") + logger.info(f"Number of children: {len(current_node.children)}") + + # Get trajectory for GPT-4 to evaluate + trajectory = [] + for node in path[1:]: # Skip root node + trajectory.append({ + "natural_language_description": node.natural_language_description, + "action": node.action, + "feedback": node.feedback + }) + + # Create prompt for GPT-4 to select next node + prompt = f"""Given the current trajectory and goal, select the most promising child node to explore next. + Consider the overall progress, efficiency, and likelihood of success. + + Goal: {self.goal} + + Current Trajectory: + {json.dumps(trajectory, indent=2)} + + Available Children: + {json.dumps([{ + 'action': child.action, + 'description': child.natural_language_description, + 'visits': child.visits, + 'value': child.value + } for child in current_node.children], indent=2)} + + Return a JSON response with: + {{ + "selected_child_index": int, # Index of the selected child + "explanation": str # Brief explanation of the selection + }}""" + + try: + response = openai_client.chat.completions.create( + model=self.config.evaluation_model, + messages=[ + {"role": "system", "content": "You are an expert at selecting promising paths in a search tree."}, + {"role": "user", "content": prompt} + ], + response_format={"type": "json_object"} + ) + + selection = json.loads(response.choices[0].message.content) + selected_index = selection["selected_child_index"] + + if 0 <= selected_index < len(current_node.children): + current_node = current_node.children[selected_index] + path.append(current_node) + logger.info(f"Selected child {selected_index + 1}: {current_node.action}") + logger.info(f"Selection explanation: {selection['explanation']}") + + # Send selection update if websocket is provided + await websocket.send_json({ + "type": "node_selected", + "node_id": id(current_node), + "explanation": selection["explanation"], + "timestamp": datetime.utcnow().isoformat() + }) + else: + logger.warning(f"Invalid child index {selected_index}, breaking selection") + break + + except Exception as e: + logger.error(f"Error in node selection: {str(e)}") + await websocket.send_json({ + "type": "selection_error", + "error": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + break + + selection_depth += 1 + + # Expansion: Expand the selected node if possible + if not current_node.is_terminal and current_node.depth < self.config.max_depth: + logger.info(f"\nExpansion Step:") + logger.info(f"Expanding node: {current_node.action}") + + await websocket.send_json({ + "type": "node_expanding", + "node_id": id(current_node), + "timestamp": datetime.utcnow().isoformat() + }) + + try: + await self.expand(current_node, websocket) + logger.info(f"Successfully expanded node with {len(current_node.children)} children") + except Exception as e: + logger.error(f"Error expanding node: {str(e)}") + current_node.is_terminal = True + await websocket.send_json({ + "type": "expansion_error", + "node_id": id(current_node), + "error": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + + # Simulation: Evaluate the current path + logger.info(f"\nSimulation Step:") + logger.info(f"Evaluating path of length {len(path) - 1}") + + await websocket.send_json({ + "type": "simulation_start", + "path_length": len(path) - 1, + "timestamp": datetime.utcnow().isoformat() + }) + + try: + trajectory = [] + for node in path[1:]: # Skip root node + trajectory.append({ + "natural_language_description": node.natural_language_description, + "action": node.action, + "feedback": node.feedback + }) + + # Score the trajectory + prompt = create_llm_prompt(trajectory, self.goal) + result = score_trajectory_with_openai(prompt, openai_client, model=self.config.evaluation_model) + score = result["overall_score"] + + logger.info(f"Simulation Results:") + logger.info(f"Overall Score: {score:.3f}") + logger.info(f"Efficiency Score: {result['efficiency_score']:.3f}") + logger.info(f"Accuracy Score: {result['accuracy_score']:.3f}") + logger.info(f"Robustness Score: {result['robustness_score']:.3f}") + + # Send simulation results if websocket is provided + await websocket.send_json({ + "type": "simulation_results", + "score": score, + "efficiency_score": result["efficiency_score"], + "accuracy_score": result["accuracy_score"], + "robustness_score": result["robustness_score"], + "timestamp": datetime.utcnow().isoformat() + }) + + # Update best path if this score is better + if score > best_score: + best_score = score + best_path = path + logger.info(f"\nNew best path found!") + logger.info(f"Previous best score: {best_score:.3f}") + logger.info(f"New best score: {score:.3f}") + + # Send best path update if websocket is provided + await websocket.send_json({ + "type": "best_path_update", + "score": best_score, + "path": [{"id": id(node), "action": node.action} for node in best_path[1:]], + "timestamp": datetime.utcnow().isoformat() + }) + + # Reflection-based backpropagation + if score < 0.75: # If the path is not satisfactory + logger.info(f"\nReflection Step (Score {score:.3f} < 0.75):") + + await websocket.send_json({ + "type": "reflection_start", + "score": score, + "timestamp": datetime.utcnow().isoformat() + }) + + # Generate reflection prompt + reflection_prompt = f"""Analyze the current trajectory and suggest improvements. + + Goal: {self.goal} + + Current Trajectory: + {json.dumps(trajectory, indent=2)} + + Score: {score} + + Return a JSON response with: + {{ + "backtrack_to_step": int, # Which step to backtrack to (0-based index) + "reason": str, # Why backtrack to this step + "suggested_improvements": [str] # List of suggested improvements + }}""" + + try: + reflection = openai_client.chat.completions.create( + model=self.config.evaluation_model, + messages=[ + {"role": "system", "content": "You are an expert at analyzing and improving search trajectories."}, + {"role": "user", "content": reflection_prompt} + ], + response_format={"type": "json_object"} + ) + + reflection_result = json.loads(reflection.choices[0].message.content) + backtrack_step = reflection_result["backtrack_to_step"] + + # Backtrack to the suggested step + if 0 <= backtrack_step < len(path): + current_node = path[backtrack_step] + # Remove nodes after the backtrack point + while len(path) > backtrack_step + 1: + path.pop() + logger.info(f"Backtracking to step {backtrack_step}") + logger.info(f"Reason: {reflection_result['reason']}") + logger.info("Suggested improvements:") + for improvement in reflection_result["suggested_improvements"]: + logger.info(f"- {improvement}") + + # Send backtracking update if websocket is provided + await websocket.send_json({ + "type": "backtracking", + "step": backtrack_step, + "reason": reflection_result["reason"], + "suggested_improvements": reflection_result["suggested_improvements"], + "timestamp": datetime.utcnow().isoformat() + }) + + except Exception as e: + logger.error(f"Error in reflection: {str(e)}") + await websocket.send_json({ + "type": "reflection_error", + "error": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + + # # If we've found a satisfactory solution, return it + # if score >= 0.75: + # logger.info(f"\nFound satisfactory solution with score {score:.3f}") + + # # Send completion update if websocket is provided + # await websocket.send_json({ + # "type": "search_complete", + # "status": "success", + # "score": score, + # "path": [{"id": id(node), "action": node.action} for node in path[1:]], + # "timestamp": datetime.utcnow().isoformat() + # }) + + # return [{"action": node.action} for node in path[1:]] + + except Exception as e: + logger.error(f"Error in simulation: {str(e)}") + await websocket.send_json({ + "type": "simulation_error", + "error": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + continue + + # Update node statistics + logger.info(f"\nBackpropagation Step:") + for node in path: + old_value = node.value + node.visits += 1 + node.value = (node.value * (node.visits - 1) + score) / node.visits + logger.info(f"Node {node.action}:") + logger.info(f" Visits: {node.visits}") + logger.info(f" Value: {old_value:.3f} -> {node.value:.3f}") + + # Send backpropagation update if websocket is provided + await websocket.send_json({ + "type": "backpropagation_complete", + "updated_nodes": [{"id": id(node), "visits": node.visits, "value": node.value} for node in path], + "timestamp": datetime.utcnow().isoformat() + }) + + # If we've exhausted all iterations and haven't found a perfect solution, + # return the best path we found + if best_path: + logger.info(f"\nSearch complete. Returning best path found with score {best_score:.3f}") + + # Send completion update if websocket is provided + await websocket.send_json({ + "type": "search_complete", + "status": "partial_success", + "score": best_score, + "path": [{"id": id(node), "action": node.action} for node in best_path[1:]], + "timestamp": datetime.utcnow().isoformat() + }) + + return [{"action": node.action} for node in best_path[1:]] + + # If no path was found at all + logger.warning("\nNo valid path found") + + # Send failure update if websocket is provided + await websocket.send_json({ + "type": "search_complete", + "status": "failure", + "message": "No valid path found", + "timestamp": datetime.utcnow().isoformat() + }) + + return [] + + except Exception as e: + error_msg = f"Error in RMCTS search: {str(e)}" + logger.error(error_msg) + + # Send error update if websocket is provided + await websocket.send_json({ + "type": "search_error", + "error": error_msg, + "timestamp": datetime.utcnow().isoformat() + }) + + if best_path: + logger.info(f"\nReturning best path found before error with score {best_score:.3f}") + return [{"action": node.action} for node in best_path[1:]] + return [] + + async def _reset_browser(self, websocket=None) -> Optional[tuple]: + """Reset the browser to initial state and return the live browser URL if available.""" + await self.playwright_manager.close() + + ## reset account using api-based account reset + if self.config.account_reset: + if websocket: + await websocket.send_json({ + "type": "account_reset", + "status": "started", + "timestamp": datetime.utcnow().isoformat() + }) + + try: + # Use aiohttp instead of curl + async with aiohttp.ClientSession() as session: + headers = {'Connection': 'close'} # Similar to curl -N + async with session.get(self.reset_url, headers=headers) as response: + if response.status == 200: + data = await response.json() + print(f"Account reset successful: {data}") + if websocket: + await websocket.send_json({ + "type": "account_reset", + "status": "success", + "data": data, + "timestamp": datetime.utcnow().isoformat() + }) + else: + error_msg = f"Account reset failed with status {response.status}" + print(error_msg) + if websocket: + await websocket.send_json({ + "type": "account_reset", + "status": "failed", + "reason": error_msg, + "timestamp": datetime.utcnow().isoformat() + }) + + except Exception as e: + print(f"Error during account reset: {e}") + if websocket: + await websocket.send_json({ + "type": "account_reset", + "status": "failed", + "reason": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + + try: + # Create new playwright manager + self.playwright_manager = await setup_playwright( + storage_state=self.config.storage_state, + headless=self.config.headless, + mode=self.config.browser_mode + ) + page = await self.playwright_manager.get_page() + live_browser_url = None + if self.config.browser_mode == "browserbase": + live_browser_url = await self.playwright_manager.get_live_browser_url() + session_id = await self.playwright_manager.get_session_id() + else: + session_id = None + live_browser_url = None + await page.goto(self.starting_url, wait_until="networkidle") + + # Send success message if websocket is provided + if websocket: + if self.config.storage_state: + await websocket.send_json({ + "type": "browser_setup", + "status": "success", + "message": f"Browser successfully initialized with storage state file: {self.config.storage_state}", + "live_browser_url": live_browser_url, + "session_id": session_id, + "timestamp": datetime.utcnow().isoformat() + }) + else: + await websocket.send_json({ + "type": "browser_setup", + "status": "success", + "message": "Browser successfully initialized", + "live_browser_url": live_browser_url, + "session_id": session_id, + "timestamp": datetime.utcnow().isoformat() + }) + + return live_browser_url, session_id + except Exception as e: + print(f"Error setting up browser: {e}") + if websocket: + await websocket.send_json({ + "type": "browser_setup", + "status": "failed", + "reason": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + return None, None + + async def expand(self, node: LATSNode, websocket=None) -> None: + """ + Expand a node by generating its children. + + Args: + node: Node to expand + websocket: Optional WebSocket connection to send updates to + """ + children_state = await self.generate_children(node, websocket) + logger.info(f"Generated {len(children_state)} children for node: {node.action}") + if not children_state: + logger.warning(f"No valid children found for node: {node.action}") + # Mark the node as terminal but don't halt the entire search + node.is_terminal = True + return + + for child_state in children_state: + child = LATSNode( + natural_language_description=child_state["natural_language_description"], + action=child_state["action"], + prob=child_state["prob"], + element=child_state["element"], + goal=node.goal, + parent=node + ) + node.children.append(child) + + # Send child creation update if websocket is provided + if websocket: + await websocket.send_json({ + "type": "node_created", + "node_id": id(child), + "parent_id": id(node), + "action": child.action, + "description": child.natural_language_description, + "timestamp": datetime.utcnow().isoformat() + }) + + async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]: + """ + Generate child nodes for a given node. + + Args: + node: Parent node to generate children for + websocket: Optional WebSocket connection to send updates to + + Returns: + list[dict]: List of child state dictionaries + """ + # Reset browser and get live URL + live_browser_url, session_id = await self._reset_browser(websocket) + path = self.get_path_to_root(node) + + # Execute path + for n in path[1:]: # Skip root node + if websocket: + await websocket.send_json({ + "type": "replaying_action", + "node_id": id(n), + "action": n.action, + "timestamp": datetime.utcnow().isoformat() + }) + + success = await playwright_step_execution( + n, + self.goal, + self.playwright_manager, + is_replay=False, + log_folder=self.config.log_folder + ) + if not success: + n.is_terminal = True + if websocket: + await websocket.send_json({ + "type": "replay_failed", + "node_id": id(n), + "timestamp": datetime.utcnow().isoformat() + }) + return [] + + if not n.feedback: + n.feedback = await generate_feedback( + self.goal, + n.natural_language_description, + self.playwright_manager, + ) + if websocket: + await websocket.send_json({ + "type": "feedback_generated", + "node_id": id(n), + "feedback": n.feedback, + "timestamp": datetime.utcnow().isoformat() + }) + + time.sleep(3) + page = await self.playwright_manager.get_page() + page_info = await extract_page_info(page, self.config.fullpage, self.config.log_folder) + + messages = [{"role": "user", "content": f"Action is: {n.action}"} for n in path[1:]] + + if websocket: + await websocket.send_json({ + "type": "generating_actions", + "node_id": id(node), + "timestamp": datetime.utcnow().isoformat() + }) + + next_actions = await extract_top_actions( + [{"natural_language_description": n.natural_language_description, "action": n.action, "feedback": n.feedback} for n in path[1:]], + self.goal, + self.images, + page_info, + self.action_set, + openai_client, + features=self.config.features, + elements_filter=self.config.elements_filter, + branching_factor=self.config.branching_factor, + log_folder=self.config.log_folder, + fullpage=self.config.fullpage, + action_generation_model=self.config.action_generation_model, + action_grounding_model=self.config.action_grounding_model + ) + + children = [] + for action in next_actions: + if action["action"] == "FINISH": + logger.info(f"Found FINISH action with probability: {action['prob']}") + if action["prob"] > 0.8: + node.is_terminal = True + if websocket: + await websocket.send_json({ + "type": "node_terminal", + "node_id": id(node), + "reason": "finish_action", + "timestamp": datetime.utcnow().isoformat() + }) + continue + # return [] + continue + + page = await self.playwright_manager.get_page() + code, function_calls = self.action_set.to_python_code(action["action"]) + + if len(function_calls) == 1: + try: + for function_name, function_args in function_calls: + extracted_number = parse_function_args(function_args) + element = await locate_element(page, extracted_number) + action["element"] = element + except Exception as e: + logger.warning(f"Element location failed for action: {action['action']}, error: {str(e)}") + action["element"] = None + children.append(action) + if websocket: + await websocket.send_json({ + "type": "element_location_failed", + "action": action["action"], + "error": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + children.append(action) + + if not children: + node.is_terminal = True + if websocket: + await websocket.send_json({ + "type": "node_terminal", + "node_id": id(node), + "reason": "no_valid_actions", + "timestamp": datetime.utcnow().isoformat() + }) + logger.warning("No children generated") + # logger.warning("No children generated, creating a dummy 'retry' child to keep search alive") + + # # If empty list would terminate search, create a "fallback" child + # children.append({ + # "natural_language_description": "Retry with different approach", + # "action": "refresh()", # Or some other generic action + # "prob": 0.1, + # "element": None + # }) + print(f"****** Generated children: {children}") + return children + + def get_path_to_root(self, node: LATSNode) -> List[LATSNode]: + path = [] + current = node + while current: + path.append(current) + current = current.parent + return list(reversed(path)) \ No newline at end of file From 8c39e14b2431f2306ebf7cc4e507646dd3b99e91 Mon Sep 17 00:00:00 2001 From: SyHeee Date: Tue, 1 Apr 2025 23:36:52 -0400 Subject: [PATCH 3/6] revert simple_search_agent --- .../SimpleSearchAgents/simple_search_agent.py | 226 ------------------ 1 file changed, 226 deletions(-) diff --git a/visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/simple_search_agent.py b/visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/simple_search_agent.py index 7fede0b..b38302f 100644 --- a/visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/simple_search_agent.py +++ b/visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/simple_search_agent.py @@ -92,9 +92,6 @@ async def run(self, websocket=None) -> List[Dict[str, Any]]: return await self.dfs_with_websocket(websocket) else: return await self.dfs() - elif algorithm == "rmcts": - logger.info("Starting Reflective MCTS algorithm") - return await self.rmcts() else: error_msg = f"Unsupported algorithm: {algorithm}" logger.error(error_msg) @@ -1149,226 +1146,3 @@ def _get_tree_data(self): tree_data.append(node_data) return tree_data - - async def rmcts(self) -> List[Dict[str, Any]]: - """ - Performs Monte Carlo Tree Search starting from the root node. - Uses GPT-4 for node selection and reflection-based backpropagation. - - Returns: - List[Dict[str, Any]]: List of actions in the best path found - """ - best_score = float('-inf') - best_path = None - visited = set() # Track visited nodes to avoid cycles - max_iterations = self.config.iterations # Use configured number of iterations - - try: - # Initial browser setup - live_browser_url, session_id = await self._reset_browser() - if not live_browser_url: - logger.error("Failed to initialize browser") - return [] - - for iteration in range(max_iterations): - logger.info(f"\n{'='*50}") - logger.info(f"RMCTS Iteration {iteration + 1}/{max_iterations}") - logger.info(f"{'='*50}\n") - - # Selection: Use GPT-4 to select a promising path - current_node = self.root_node - path = [current_node] - selection_depth = 0 - - while current_node.children and not current_node.is_terminal: - logger.info(f"\nSelection Step {selection_depth + 1}:") - logger.info(f"Current node action: {current_node.action}") - logger.info(f"Number of children: {len(current_node.children)}") - - # Get trajectory for GPT-4 to evaluate - trajectory = [] - for node in path[1:]: # Skip root node - trajectory.append({ - "natural_language_description": node.natural_language_description, - "action": node.action, - "feedback": node.feedback - }) - - # Create prompt for GPT-4 to select next node - prompt = f"""Given the current trajectory and goal, select the most promising child node to explore next. - Consider the overall progress, efficiency, and likelihood of success. - - Goal: {self.goal} - - Current Trajectory: - {json.dumps(trajectory, indent=2)} - - Available Children: - {json.dumps([{ - 'action': child.action, - 'description': child.natural_language_description, - 'visits': child.visits, - 'value': child.value - } for child in current_node.children], indent=2)} - - Return a JSON response with: - {{ - "selected_child_index": int, # Index of the selected child - "explanation": str # Brief explanation of the selection - }}""" - - try: - response = openai_client.chat.completions.create( - model=self.config.evaluation_model, - messages=[ - {"role": "system", "content": "You are an expert at selecting promising paths in a search tree."}, - {"role": "user", "content": prompt} - ], - response_format={"type": "json_object"} - ) - - selection = json.loads(response.choices[0].message.content) - selected_index = selection["selected_child_index"] - - if 0 <= selected_index < len(current_node.children): - current_node = current_node.children[selected_index] - path.append(current_node) - logger.info(f"Selected child {selected_index + 1}: {current_node.action}") - logger.info(f"Selection explanation: {selection['explanation']}") - else: - logger.warning(f"Invalid child index {selected_index}, breaking selection") - break - - except Exception as e: - logger.error(f"Error in node selection: {str(e)}") - break - - selection_depth += 1 - - # Expansion: Expand the selected node if possible - if not current_node.is_terminal and current_node.depth < self.config.max_depth: - logger.info(f"\nExpansion Step:") - logger.info(f"Expanding node: {current_node.action}") - try: - await self.expand(current_node) - logger.info(f"Successfully expanded node with {len(current_node.children)} children") - except Exception as e: - logger.error(f"Error expanding node: {str(e)}") - current_node.is_terminal = True - - # Simulation: Evaluate the current path - logger.info(f"\nSimulation Step:") - logger.info(f"Evaluating path of length {len(path) - 1}") - try: - trajectory = [] - for node in path[1:]: # Skip root node - trajectory.append({ - "natural_language_description": node.natural_language_description, - "action": node.action, - "feedback": node.feedback - }) - - # Score the trajectory - prompt = create_llm_prompt(trajectory, self.goal) - result = score_trajectory_with_openai(prompt, openai_client, model=self.config.evaluation_model) - score = result["overall_score"] - - logger.info(f"Simulation Results:") - logger.info(f"Overall Score: {score:.3f}") - logger.info(f"Efficiency Score: {result['efficiency_score']:.3f}") - logger.info(f"Accuracy Score: {result['accuracy_score']:.3f}") - logger.info(f"Robustness Score: {result['robustness_score']:.3f}") - - # Update best path if this score is better - if score > best_score: - best_score = score - best_path = path - logger.info(f"\nNew best path found!") - logger.info(f"Previous best score: {best_score:.3f}") - logger.info(f"New best score: {score:.3f}") - - # Reflection-based backpropagation - if score < 0.75: # If the path is not satisfactory - logger.info(f"\nReflection Step (Score {score:.3f} < 0.75):") - # Generate reflection prompt - reflection_prompt = f"""Analyze the current trajectory and suggest improvements. - - Goal: {self.goal} - - Current Trajectory: - {json.dumps(trajectory, indent=2)} - - Score: {score} - - Return a JSON response with: - {{ - "backtrack_to_step": int, # Which step to backtrack to (0-based index) - "reason": str, # Why backtrack to this step - "suggested_improvements": [str] # List of suggested improvements - }}""" - - try: - reflection = openai_client.chat.completions.create( - model=self.config.evaluation_model, - messages=[ - {"role": "system", "content": "You are an expert at analyzing and improving search trajectories."}, - {"role": "user", "content": reflection_prompt} - ], - response_format={"type": "json_object"} - ) - - reflection_result = json.loads(reflection.choices[0].message.content) - backtrack_step = reflection_result["backtrack_to_step"] - - # Backtrack to the suggested step - if 0 <= backtrack_step < len(path): - current_node = path[backtrack_step] - # Remove nodes after the backtrack point - while len(path) > backtrack_step + 1: - path.pop() - logger.info(f"Backtracking to step {backtrack_step}") - logger.info(f"Reason: {reflection_result['reason']}") - logger.info("Suggested improvements:") - for improvement in reflection_result["suggested_improvements"]: - logger.info(f"- {improvement}") - - except Exception as e: - logger.error(f"Error in reflection: {str(e)}") - - # If we've found a satisfactory solution, return it - if score >= 0.75: - logger.info(f"\nFound satisfactory solution with score {score:.3f}") - return [{"action": node.action} for node in path[1:]] - - except Exception as e: - logger.error(f"Error in simulation: {str(e)}") - continue - - # Update node statistics - logger.info(f"\nBackpropagation Step:") - for node in path: - old_value = node.value - node.visits += 1 - node.value = (node.value * (node.visits - 1) + score) / node.visits - logger.info(f"Node {node.action}:") - logger.info(f" Visits: {node.visits}") - logger.info(f" Value: {old_value:.3f} -> {node.value:.3f}") - - # If we've exhausted all iterations and haven't found a perfect solution, - # return the best path we found - if best_path: - logger.info(f"\nSearch complete. Returning best path found with score {best_score:.3f}") - return [{"action": node.action} for node in best_path[1:]] - - # If no path was found at all - logger.warning("\nNo valid path found") - return [] - - except Exception as e: - error_msg = f"Error in RMCTS search: {str(e)}" - logger.error(error_msg) - if best_path: - logger.info(f"\nReturning best path found before error with score {best_score:.3f}") - return [{"action": node.action} for node in best_path[1:]] - return [] - From ceb65981586ede1ebb76d21dfbc6d5f909b7c4ba Mon Sep 17 00:00:00 2001 From: SyHeee Date: Fri, 4 Apr 2025 20:35:00 -0400 Subject: [PATCH 4/6] checkin buggy algo --- .../SimpleSearchAgents/mcts_agent.py | 138 ++++++++++++------ 1 file changed, 90 insertions(+), 48 deletions(-) diff --git a/visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/mcts_agent.py b/visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/mcts_agent.py index e1e3a1c..d3b43aa 100644 --- a/visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/mcts_agent.py +++ b/visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/mcts_agent.py @@ -217,8 +217,8 @@ async def rmcts(self) -> List[Dict[str, Any]]: logger.info(f"New best score: {score:.3f}") # Reflection-based backpropagation - if score < 0.75: # If the path is not satisfactory - logger.info(f"\nReflection Step (Score {score:.3f} < 0.5):") + if score < 0.25: # If the path is not satisfactory + logger.info(f"\nReflection Step (Score {score:.3f} < 0.25):") # Generate reflection prompt reflection_prompt = f"""Analyze the current trajectory and suggest improvements. @@ -500,8 +500,8 @@ async def rmcts_with_websocket(self, websocket) -> List[Dict[str, Any]]: }) # Reflection-based backpropagation - if score < 0.75: # If the path is not satisfactory - logger.info(f"\nReflection Step (Score {score:.3f} < 0.75):") + if score < 0.25: # If the path is not satisfactory + logger.info(f"\nReflection Step (Score {score:.3f} < 0.25):") await websocket.send_json({ "type": "reflection_start", @@ -761,35 +761,63 @@ async def expand(self, node: LATSNode, websocket=None) -> None: node: Node to expand websocket: Optional WebSocket connection to send updates to """ - children_state = await self.generate_children(node, websocket) - logger.info(f"Generated {len(children_state)} children for node: {node.action}") + try: + children_state = await self.generate_children(node, websocket) + logger.info(f"Generated {len(children_state)} children for node: {node.action}") + except Exception as e: + logger.error(f"Exception during generation of children for node {node.action}: {e}") + children_state = [] + + # if not children_state: + # logger.warning(f"No valid children found for node: {node.action}") + # # Mark the node as terminal but don't halt the entire search + # node.is_terminal = True + # return if not children_state: - logger.warning(f"No valid children found for node: {node.action}") - # Mark the node as terminal but don't halt the entire search - node.is_terminal = True - return - + logger.warning("No valid children returned, creating fallback children") + children_state = [ + { + "natural_language_description": "Navigate back to try a different approach", + "action": "navigate_backward()", + "prob": 0.15, + "element": None + }, + { + "natural_language_description": "Refresh the page to reinitialize search", + "action": "refresh()", + "prob": 0.1, + "element": None + }, + { + "natural_language_description": "Click a random element for exploration", + "action": "click('random')", + "prob": 0.05, + "element": None + } + ] for child_state in children_state: - child = LATSNode( - natural_language_description=child_state["natural_language_description"], - action=child_state["action"], - prob=child_state["prob"], - element=child_state["element"], - goal=node.goal, - parent=node - ) - node.children.append(child) - - # Send child creation update if websocket is provided - if websocket: - await websocket.send_json({ - "type": "node_created", - "node_id": id(child), - "parent_id": id(node), - "action": child.action, - "description": child.natural_language_description, - "timestamp": datetime.utcnow().isoformat() - }) + try: + child = LATSNode( + natural_language_description=child_state.get("natural_language_description", ""), + action=child_state.get("action", ""), + prob=child_state.get("prob", 0.0), + element=child_state.get("element", None), + goal=node.goal, + parent=node + ) + node.children.append(child) + + if websocket: + await websocket.send_json({ + "type": "node_created", + "node_id": id(child), + "parent_id": id(node), + "action": child.action, + "description": child.natural_language_description, + "timestamp": datetime.utcnow().isoformat() + }) + except Exception as e: + logger.error(f"Error creating child node from state {child_state}: {e}") async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]: """ @@ -880,7 +908,7 @@ async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]: for action in next_actions: if action["action"] == "FINISH": logger.info(f"Found FINISH action with probability: {action['prob']}") - if action["prob"] > 0.8: + if action["prob"] > 0.99: node.is_terminal = True if websocket: await websocket.send_json({ @@ -916,24 +944,38 @@ async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]: children.append(action) if not children: - node.is_terminal = True - if websocket: - await websocket.send_json({ - "type": "node_terminal", - "node_id": id(node), - "reason": "no_valid_actions", - "timestamp": datetime.utcnow().isoformat() - }) - logger.warning("No children generated") - # logger.warning("No children generated, creating a dummy 'retry' child to keep search alive") + # node.is_terminal = True + # if websocket: + # await websocket.send_json({ + # "type": "node_terminal", + # "node_id": id(node), + # "reason": "no_valid_actions", + # "timestamp": datetime.utcnow().isoformat() + # }) + # logger.warning("No children generated") + logger.warning("No viable children, creating fallback exploration actions") # # If empty list would terminate search, create a "fallback" child - # children.append({ - # "natural_language_description": "Retry with different approach", - # "action": "refresh()", # Or some other generic action - # "prob": 0.1, - # "element": None - # }) + children.extend([ + { + "natural_language_description": "Navigate back to try a different approach", + "action": "navigate_backward()", + "prob": 0.15, + "element": None + }, + { + "natural_language_description": "Try refreshing the page", + "action": "refresh()", + "prob": 0.1, + "element": None + }, + { + "natural_language_description": "Try clicking on a random element", + "action": "click('random')", + "prob": 0.05, + "element": None + } + ]) print(f"****** Generated children: {children}") return children From 39a809426a2842e2301ef5a20a0dc382ec702816 Mon Sep 17 00:00:00 2001 From: SyHeee Date: Sat, 5 Apr 2025 01:00:12 -0400 Subject: [PATCH 5/6] adding rmcts --- .../SimpleSearchAgents/mcts_agent.py | 176 ++++++----- .../app/api/shopping.json | 282 +++++++++--------- 2 files changed, 239 insertions(+), 219 deletions(-) diff --git a/visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/mcts_agent.py b/visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/mcts_agent.py index d3b43aa..f1cb9ea 100644 --- a/visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/mcts_agent.py +++ b/visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/mcts_agent.py @@ -27,6 +27,7 @@ from ...webagent_utils_async.utils.utils import urls_to_images logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) openai_client = OpenAI() class MCTSAgent: @@ -183,6 +184,23 @@ async def rmcts(self) -> List[Dict[str, Any]]: except Exception as e: logger.error(f"Error expanding node: {str(e)}") current_node.is_terminal = True + # Expansion Step: Expand the selected node if possible + if not current_node.is_terminal and current_node.depth < self.config.max_depth: + logger.info(f"\nExpansion Step:") + logger.info(f"Expanding node: {current_node.action}") + + expansion_success = await self.expand(current_node, None) + if not expansion_success: + # No children were generated; backtrack if possible. + if len(path) > 1: + logger.info("Backtracking due to expansion failure (no children generated).") + path.pop() # Remove the current dead-end node. + current_node = path[-1] # Set current_node to its parent. + else: + logger.warning("Expansion failed at root; no further backtracking possible.") + break + else: + logger.info(f"Successfully expanded node with {len(current_node.children)} children") # Simulation: Evaluate the current path logger.info(f"\nSimulation Step:") @@ -217,8 +235,8 @@ async def rmcts(self) -> List[Dict[str, Any]]: logger.info(f"New best score: {score:.3f}") # Reflection-based backpropagation - if score < 0.25: # If the path is not satisfactory - logger.info(f"\nReflection Step (Score {score:.3f} < 0.25):") + if score < 0.75: # If the path is not satisfactory + logger.info(f"\nReflection Step (Score {score:.3f} < 0.75):") # Generate reflection prompt reflection_prompt = f"""Analyze the current trajectory and suggest improvements. @@ -265,10 +283,10 @@ async def rmcts(self) -> List[Dict[str, Any]]: except Exception as e: logger.error(f"Error in reflection: {str(e)}") - # # If we've found a satisfactory solution, return it - # if score >= 0.75: - # logger.info(f"\nFound satisfactory solution with score {score:.3f}") - # return [{"action": node.action} for node in path[1:]] + # If we've found a satisfactory solution, return it + if score >= 0.75: + logger.info(f"\nFound satisfactory solution with score {score:.3f}") + return [{"action": node.action} for node in path[1:]] except Exception as e: logger.error(f"Error in simulation: {str(e)}") @@ -286,13 +304,13 @@ async def rmcts(self) -> List[Dict[str, Any]]: # If we've exhausted all iterations and haven't found a perfect solution, # return the best path we found - if best_path: + if best_path and len(best_path) > 1: logger.info(f"\nSearch complete. Returning best path found with score {best_score:.3f}") return [{"action": node.action} for node in best_path[1:]] - - # If no path was found at all - logger.warning("\nNo valid path found") - return [] + + # If no valid path was found or path was just the root, return a default action + logger.warning("\nNo valid path found, returning fallback action") + return [{"action": "refresh()", "description": "Fallback action - no valid path found"}] except Exception as e: error_msg = f"Error in RMCTS search: {str(e)}" @@ -500,8 +518,8 @@ async def rmcts_with_websocket(self, websocket) -> List[Dict[str, Any]]: }) # Reflection-based backpropagation - if score < 0.25: # If the path is not satisfactory - logger.info(f"\nReflection Step (Score {score:.3f} < 0.25):") + if score < 0.75: # If the path is not satisfactory + logger.info(f"\nReflection Step (Score {score:.3f} < 0.75):") await websocket.send_json({ "type": "reflection_start", @@ -568,20 +586,20 @@ async def rmcts_with_websocket(self, websocket) -> List[Dict[str, Any]]: "timestamp": datetime.utcnow().isoformat() }) - # # If we've found a satisfactory solution, return it - # if score >= 0.75: - # logger.info(f"\nFound satisfactory solution with score {score:.3f}") + # If we've found a satisfactory solution, return it + if score >= 0.75: + logger.info(f"\nFound satisfactory solution with score {score:.3f}") - # # Send completion update if websocket is provided - # await websocket.send_json({ - # "type": "search_complete", - # "status": "success", - # "score": score, - # "path": [{"id": id(node), "action": node.action} for node in path[1:]], - # "timestamp": datetime.utcnow().isoformat() - # }) + # Send completion update if websocket is provided + await websocket.send_json({ + "type": "search_complete", + "status": "success", + "score": score, + "path": [{"id": id(node), "action": node.action} for node in path[1:]], + "timestamp": datetime.utcnow().isoformat() + }) - # return [{"action": node.action} for node in path[1:]] + return [{"action": node.action} for node in path[1:]] except Exception as e: logger.error(f"Error in simulation: {str(e)}") @@ -611,7 +629,7 @@ async def rmcts_with_websocket(self, websocket) -> List[Dict[str, Any]]: # If we've exhausted all iterations and haven't found a perfect solution, # return the best path we found - if best_path: + if best_path and len(best_path) > 1: logger.info(f"\nSearch complete. Returning best path found with score {best_score:.3f}") # Send completion update if websocket is provided @@ -635,8 +653,10 @@ async def rmcts_with_websocket(self, websocket) -> List[Dict[str, Any]]: "message": "No valid path found", "timestamp": datetime.utcnow().isoformat() }) - - return [] + + # If no valid path was found or path was just the root, return a default action + logger.warning("\nNo valid path found, returning fallback action") + return [{"action": "refresh()", "description": "Fallback action - no valid path found"}] except Exception as e: error_msg = f"Error in RMCTS search: {str(e)}" @@ -753,48 +773,29 @@ async def _reset_browser(self, websocket=None) -> Optional[tuple]: }) return None, None - async def expand(self, node: LATSNode, websocket=None) -> None: + async def expand(self, node: LATSNode, websocket=None) -> bool: """ - Expand a node by generating its children. + Expand a node by generating its children. If no children are generated, + mark the node as terminal and return False to trigger backtracking. Args: - node: Node to expand - websocket: Optional WebSocket connection to send updates to + node: Node to expand. + websocket: Optional WebSocket connection to send updates. + + Returns: + bool: True if expansion succeeded (children generated), False otherwise. """ try: children_state = await self.generate_children(node, websocket) - logger.info(f"Generated {len(children_state)} children for node: {node.action}") except Exception as e: logger.error(f"Exception during generation of children for node {node.action}: {e}") children_state = [] - - # if not children_state: - # logger.warning(f"No valid children found for node: {node.action}") - # # Mark the node as terminal but don't halt the entire search - # node.is_terminal = True - # return + if not children_state: - logger.warning("No valid children returned, creating fallback children") - children_state = [ - { - "natural_language_description": "Navigate back to try a different approach", - "action": "navigate_backward()", - "prob": 0.15, - "element": None - }, - { - "natural_language_description": "Refresh the page to reinitialize search", - "action": "refresh()", - "prob": 0.1, - "element": None - }, - { - "natural_language_description": "Click a random element for exploration", - "action": "click('random')", - "prob": 0.05, - "element": None - } - ] + logger.warning("No children generated. Marking node as terminal and triggering backtracking.") + node.is_terminal = True + return False # Indicate that expansion did not generate children. + for child_state in children_state: try: child = LATSNode( @@ -818,6 +819,7 @@ async def expand(self, node: LATSNode, websocket=None) -> None: }) except Exception as e: logger.error(f"Error creating child node from state {child_state}: {e}") + return True # Expansion succeeded (children were generated). async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]: """ @@ -833,7 +835,7 @@ async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]: # Reset browser and get live URL live_browser_url, session_id = await self._reset_browser(websocket) path = self.get_path_to_root(node) - + logger.info(f"######### Generating children for path with {len(path)} nodes") # Execute path for n in path[1:]: # Skip root node if websocket: @@ -843,23 +845,41 @@ async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]: "action": n.action, "timestamp": datetime.utcnow().isoformat() }) - - success = await playwright_step_execution( - n, - self.goal, - self.playwright_manager, - is_replay=False, - log_folder=self.config.log_folder - ) - if not success: - n.is_terminal = True - if websocket: - await websocket.send_json({ - "type": "replay_failed", - "node_id": id(n), - "timestamp": datetime.utcnow().isoformat() - }) - return [] + try: + success = await playwright_step_execution( + n, + self.goal, + self.playwright_manager, + is_replay=False, + log_folder=self.config.log_folder + ) + logger.info(f"#########Success: {success}") + + if not success: + logger.warning(f"Action execution failed: {n.action}") + n.is_terminal = True + if websocket: + await websocket.send_json({ + "type": "replay_failed", + "node_id": id(n), + "timestamp": datetime.utcnow().isoformat() + }) + return [{ + "natural_language_description": "Recover from failed action", + "action": "refresh()", + "prob": 0.1, + "element": None + }] + except Exception as e: + logger.error(f"Error executing action {n.action}: {str(e)}") + # Provide fallback actions instead of bubbling up the exception + return [{ + "natural_language_description": "Recover from action error", + "action": "refresh()", + "prob": 0.1, + "element": None + }] + if not n.feedback: n.feedback = await generate_feedback( diff --git a/visual-tree-search-backend/app/api/shopping.json b/visual-tree-search-backend/app/api/shopping.json index 4de8ead..d3b1819 100644 --- a/visual-tree-search-backend/app/api/shopping.json +++ b/visual-tree-search-backend/app/api/shopping.json @@ -1,142 +1,142 @@ [ - { - "name": "mage-cache-storage", - "value": "{}", - "domain": "128.105.145.205", - "path": "/", - "expires": 1775110732, - "httpOnly": false, - "secure": false, - "sameSite": "Lax" - }, - { - "name": "mage-cache-storage-section-invalidation", - "value": "{}", - "domain": "128.105.145.205", - "path": "/", - "expires": 1775110732, - "httpOnly": false, - "secure": false, - "sameSite": "Lax" - }, - { - "name": "mage-messages", - "value": "", - "domain": "128.105.145.205", - "path": "/", - "expires": 1775272527, - "httpOnly": false, - "secure": false, - "sameSite": "Strict" - }, - { - "name": "recently_viewed_product", - "value": "{}", - "domain": "128.105.145.205", - "path": "/", - "expires": 1775110732, - "httpOnly": false, - "secure": false, - "sameSite": "Lax" - }, - { - "name": "recently_viewed_product_previous", - "value": "{}", - "domain": "128.105.145.205", - "path": "/", - "expires": 1775110732, - "httpOnly": false, - "secure": false, - "sameSite": "Lax" - }, - { - "name": "recently_compared_product", - "value": "{}", - "domain": "128.105.145.205", - "path": "/", - "expires": 1775110732, - "httpOnly": false, - "secure": false, - "sameSite": "Lax" - }, - { - "name": "recently_compared_product_previous", - "value": "{}", - "domain": "128.105.145.205", - "path": "/", - "expires": 1775110732, - "httpOnly": false, - "secure": false, - "sameSite": "Lax" - }, - { - "name": "product_data_storage", - "value": "{}", - "domain": "128.105.145.205", - "path": "/", - "expires": 1775110732, - "httpOnly": false, - "secure": false, - "sameSite": "Lax" - }, - { - "name": "private_content_version", - "value": "c514ad01bc9816ee30ab1f02510aa34a", - "domain": "128.105.145.205", - "path": "/", - "expires": 1778296522.505323, - "httpOnly": false, - "secure": false, - "sameSite": "Lax" - }, - { - "name": "PHPSESSID", - "value": "007c737fca0eb4173ab5362c2d9c8b09", - "domain": "128.105.145.205", - "path": "/", - "expires": 1775272526.468528, - "httpOnly": true, - "secure": false, - "sameSite": "Lax" - }, - { - "name": "X-Magento-Vary", - "value": "9bf9a599123e6402b85cde67144717a08b817412", - "domain": "128.105.145.205", - "path": "/", - "expires": 1775272526.468754, - "httpOnly": true, - "secure": false, - "sameSite": "Lax" - }, - { - "name": "form_key", - "value": "Hsr3n5ycGkOPfr1K", - "domain": "128.105.145.205", - "path": "/", - "expires": 1775272526.468692, - "httpOnly": false, - "secure": false, - "sameSite": "Lax" - }, - { - "name": "mage-cache-sessid", - "value": "true", - "domain": "128.105.145.205", - "path": "/", - "expires": 1775272527, - "httpOnly": false, - "secure": false, - "sameSite": "Lax" - }, - { - "name": "section_data_ids", - "value": "{%22messages%22:1743736525%2C%22customer%22:1743736525%2C%22compare-products%22:1743736525%2C%22last-ordered-items%22:1743736525%2C%22cart%22:1743736525%2C%22directory-data%22:1743736525%2C%22captcha%22:1743736525%2C%22instant-purchase%22:1743736525%2C%22loggedAsCustomer%22:1743736525%2C%22persistent%22:1743736525%2C%22review%22:1743736525%2C%22wishlist%22:1743736525%2C%22recently_viewed_product%22:1743736525%2C%22recently_compared_product%22:1743736525%2C%22product_data_storage%22:1743736525%2C%22paypal-billing-agreement%22:1743736525}", - "domain": "128.105.145.205", - "path": "/", - "expires": 1775272524, - "httpOnly": false, - "secure": false, - "sameSite": "Lax" - } -] \ No newline at end of file + { + "name": "mage-cache-storage", + "value": "{}", + "domain": "128.105.145.205", + "path": "/", + "expires": 1775110732, + "httpOnly": false, + "secure": false, + "sameSite": "Lax" + }, + { + "name": "mage-cache-storage-section-invalidation", + "value": "{}", + "domain": "128.105.145.205", + "path": "/", + "expires": 1775110732, + "httpOnly": false, + "secure": false, + "sameSite": "Lax" + }, + { + "name": "mage-messages", + "value": "", + "domain": "128.105.145.205", + "path": "/", + "expires": 1775272527, + "httpOnly": false, + "secure": false, + "sameSite": "Strict" + }, + { + "name": "recently_viewed_product", + "value": "{}", + "domain": "128.105.145.205", + "path": "/", + "expires": 1775110732, + "httpOnly": false, + "secure": false, + "sameSite": "Lax" + }, + { + "name": "recently_viewed_product_previous", + "value": "{}", + "domain": "128.105.145.205", + "path": "/", + "expires": 1775110732, + "httpOnly": false, + "secure": false, + "sameSite": "Lax" + }, + { + "name": "recently_compared_product", + "value": "{}", + "domain": "128.105.145.205", + "path": "/", + "expires": 1775110732, + "httpOnly": false, + "secure": false, + "sameSite": "Lax" + }, + { + "name": "recently_compared_product_previous", + "value": "{}", + "domain": "128.105.145.205", + "path": "/", + "expires": 1775110732, + "httpOnly": false, + "secure": false, + "sameSite": "Lax" + }, + { + "name": "product_data_storage", + "value": "{}", + "domain": "128.105.145.205", + "path": "/", + "expires": 1775110732, + "httpOnly": false, + "secure": false, + "sameSite": "Lax" + }, + { + "name": "private_content_version", + "value": "c514ad01bc9816ee30ab1f02510aa34a", + "domain": "128.105.145.205", + "path": "/", + "expires": 1778296522.505323, + "httpOnly": false, + "secure": false, + "sameSite": "Lax" + }, + { + "name": "PHPSESSID", + "value": "007c737fca0eb4173ab5362c2d9c8b09", + "domain": "128.105.145.205", + "path": "/", + "expires": 1775272526.468528, + "httpOnly": true, + "secure": false, + "sameSite": "Lax" + }, + { + "name": "X-Magento-Vary", + "value": "9bf9a599123e6402b85cde67144717a08b817412", + "domain": "128.105.145.205", + "path": "/", + "expires": 1775272526.468754, + "httpOnly": true, + "secure": false, + "sameSite": "Lax" + }, + { + "name": "form_key", + "value": "Hsr3n5ycGkOPfr1K", + "domain": "128.105.145.205", + "path": "/", + "expires": 1775272526.468692, + "httpOnly": false, + "secure": false, + "sameSite": "Lax" + }, + { + "name": "mage-cache-sessid", + "value": "true", + "domain": "128.105.145.205", + "path": "/", + "expires": 1775272527, + "httpOnly": false, + "secure": false, + "sameSite": "Lax" + }, + { + "name": "section_data_ids", + "value": "{%22messages%22:1743736525%2C%22customer%22:1743736525%2C%22compare-products%22:1743736525%2C%22last-ordered-items%22:1743736525%2C%22cart%22:1743736525%2C%22directory-data%22:1743736525%2C%22captcha%22:1743736525%2C%22instant-purchase%22:1743736525%2C%22loggedAsCustomer%22:1743736525%2C%22persistent%22:1743736525%2C%22review%22:1743736525%2C%22wishlist%22:1743736525%2C%22recently_viewed_product%22:1743736525%2C%22recently_compared_product%22:1743736525%2C%22product_data_storage%22:1743736525%2C%22paypal-billing-agreement%22:1743736525}", + "domain": "128.105.145.205", + "path": "/", + "expires": 1775272524, + "httpOnly": false, + "secure": false, + "sameSite": "Lax" + } + ] \ No newline at end of file From 8dc5d138fecb8d72173c73986bd05c9bee728d60 Mon Sep 17 00:00:00 2001 From: Tata0703 Date: Sat, 5 Apr 2025 00:08:03 -0700 Subject: [PATCH 6/6] add mcts test --- visual-tree-search-backend/README.md | 8 + .../agents_async/SearchAgents/mcts_agent.py | 942 +++++++++++++++++- .../api/routes/new_tree_search_websocket.py | 2 + .../app/api/routes/tree_search_websocket.py | 2 + .../app/api/shopping.json | 282 +++--- .../test/test-tree-search-ws-lats.py | 2 +- .../test/test-tree-search-ws-mcts.py | 169 ++++ 7 files changed, 1264 insertions(+), 143 deletions(-) create mode 100644 visual-tree-search-backend/test/test-tree-search-ws-mcts.py diff --git a/visual-tree-search-backend/README.md b/visual-tree-search-backend/README.md index 9ddb92d..3cb4402 100644 --- a/visual-tree-search-backend/README.md +++ b/visual-tree-search-backend/README.md @@ -90,4 +90,12 @@ python run_demo_treesearch_async.py \ ``` uvicorn app.main:app --host 0.0.0.0 --port 3000 python test/test-tree-search-ws-lats.py +``` + +## 7. Add MCTS agent +* test run_demo_treesearch_async.py +* test web socket +``` +uvicorn app.main:app --host 0.0.0.0 --port 3000 +python test/test-tree-search-ws-mcts.py ``` \ No newline at end of file diff --git a/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/mcts_agent.py b/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/mcts_agent.py index 9f1734c..f1cb9ea 100644 --- a/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/mcts_agent.py +++ b/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/mcts_agent.py @@ -27,6 +27,7 @@ from ...webagent_utils_async.utils.utils import urls_to_images logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) openai_client = OpenAI() class MCTSAgent: @@ -65,4 +66,943 @@ def __init__( self.reset_url = os.environ["ACCOUNT_RESET_URL"] async def run(self, websocket=None) -> List[Dict[str, Any]]: - pass \ No newline at end of file + """ + Run the MCTS algorithm based on configuration. + + Args: + websocket: Optional WebSocket connection to send updates to + + Returns: + List[Dict[str, Any]]: List of actions in the best path found + """ + logger.info("Starting Reflective MCTS algorithm") + if websocket: + return await self.rmcts_with_websocket(websocket) + else: + return await self.rmcts() + + async def rmcts(self) -> List[Dict[str, Any]]: + """ + Performs Monte Carlo Tree Search starting from the root node. + Uses GPT-4 for node selection and reflection-based backpropagation. + + Returns: + List[Dict[str, Any]]: List of actions in the best path found + """ + best_score = float('-inf') + best_path = None + visited = set() # Track visited nodes to avoid cycles + max_iterations = self.config.iterations # Use configured number of iterations + + try: + # Initial browser setup + live_browser_url, session_id = await self._reset_browser() + + for iteration in range(max_iterations): + logger.info(f"\n{'='*50}") + logger.info(f"RMCTS Iteration {iteration + 1}/{max_iterations}") + logger.info(f"{'='*50}\n") + + # Selection: Use GPT-4 to select a promising path + current_node = self.root_node + path = [current_node] + selection_depth = 0 + + while current_node.children and not current_node.is_terminal: + logger.info(f"\nSelection Step {selection_depth + 1}:") + logger.info(f"Current node action: {current_node.action}") + logger.info(f"Number of children: {len(current_node.children)}") + + # Get trajectory for GPT-4 to evaluate + trajectory = [] + for node in path[1:]: # Skip root node + trajectory.append({ + "natural_language_description": node.natural_language_description, + "action": node.action, + "feedback": node.feedback + }) + + # Create prompt for GPT-4 to select next node + prompt = f"""Given the current trajectory and goal, select the most promising child node to explore next. + Consider the overall progress, efficiency, and likelihood of success. + + Goal: {self.goal} + + Current Trajectory: + {json.dumps(trajectory, indent=2)} + + Available Children: + {json.dumps([{ + 'action': child.action, + 'description': child.natural_language_description, + 'visits': child.visits, + 'value': child.value + } for child in current_node.children], indent=2)} + + Return a JSON response with: + {{ + "selected_child_index": int, # Index of the selected child + "explanation": str # Brief explanation of the selection + }}""" + + try: + response = openai_client.chat.completions.create( + model=self.config.evaluation_model, + messages=[ + {"role": "system", "content": "You are an expert at selecting promising paths in a search tree."}, + {"role": "user", "content": prompt} + ], + response_format={"type": "json_object"} + ) + + selection = json.loads(response.choices[0].message.content) + selected_index = selection["selected_child_index"] + + if 0 <= selected_index < len(current_node.children): + current_node = current_node.children[selected_index] + path.append(current_node) + logger.info(f"Selected child {selected_index + 1}: {current_node.action}") + logger.info(f"Selection explanation: {selection['explanation']}") + else: + logger.warning(f"Invalid child index {selected_index}, breaking selection") + break + + except Exception as e: + logger.error(f"Error in node selection: {str(e)}") + break + + selection_depth += 1 + + # Expansion: Expand the selected node if possible + if not current_node.is_terminal and current_node.depth < self.config.max_depth: + logger.info(f"\nExpansion Step:") + logger.info(f"Expanding node: {current_node.action}") + + try: + await self.expand(current_node) + logger.info(f"Successfully expanded node with {len(current_node.children)} children") + except Exception as e: + logger.error(f"Error expanding node: {str(e)}") + current_node.is_terminal = True + # Expansion Step: Expand the selected node if possible + if not current_node.is_terminal and current_node.depth < self.config.max_depth: + logger.info(f"\nExpansion Step:") + logger.info(f"Expanding node: {current_node.action}") + + expansion_success = await self.expand(current_node, None) + if not expansion_success: + # No children were generated; backtrack if possible. + if len(path) > 1: + logger.info("Backtracking due to expansion failure (no children generated).") + path.pop() # Remove the current dead-end node. + current_node = path[-1] # Set current_node to its parent. + else: + logger.warning("Expansion failed at root; no further backtracking possible.") + break + else: + logger.info(f"Successfully expanded node with {len(current_node.children)} children") + + # Simulation: Evaluate the current path + logger.info(f"\nSimulation Step:") + logger.info(f"Evaluating path of length {len(path) - 1}") + + try: + trajectory = [] + for node in path[1:]: # Skip root node + trajectory.append({ + "natural_language_description": node.natural_language_description, + "action": node.action, + "feedback": node.feedback + }) + + # Score the trajectory + prompt = create_llm_prompt(trajectory, self.goal) + result = score_trajectory_with_openai(prompt, openai_client, model=self.config.evaluation_model) + score = result["overall_score"] + + logger.info(f"Simulation Results:") + logger.info(f"Overall Score: {score:.3f}") + logger.info(f"Efficiency Score: {result['efficiency_score']:.3f}") + logger.info(f"Accuracy Score: {result['accuracy_score']:.3f}") + logger.info(f"Robustness Score: {result['robustness_score']:.3f}") + + # Update best path if this score is better + if score > best_score: + best_score = score + best_path = path + logger.info(f"\nNew best path found!") + logger.info(f"Previous best score: {best_score:.3f}") + logger.info(f"New best score: {score:.3f}") + + # Reflection-based backpropagation + if score < 0.75: # If the path is not satisfactory + logger.info(f"\nReflection Step (Score {score:.3f} < 0.75):") + + # Generate reflection prompt + reflection_prompt = f"""Analyze the current trajectory and suggest improvements. + + Goal: {self.goal} + + Current Trajectory: + {json.dumps(trajectory, indent=2)} + + Score: {score} + + Return a JSON response with: + {{ + "backtrack_to_step": int, # Which step to backtrack to (0-based index) + "reason": str, # Why backtrack to this step + "suggested_improvements": [str] # List of suggested improvements + }}""" + + try: + reflection = openai_client.chat.completions.create( + model=self.config.evaluation_model, + messages=[ + {"role": "system", "content": "You are an expert at analyzing and improving search trajectories."}, + {"role": "user", "content": reflection_prompt} + ], + response_format={"type": "json_object"} + ) + + reflection_result = json.loads(reflection.choices[0].message.content) + backtrack_step = reflection_result["backtrack_to_step"] + + # Backtrack to the suggested step + if 0 <= backtrack_step < len(path): + current_node = path[backtrack_step] + # Remove nodes after the backtrack point + while len(path) > backtrack_step + 1: + path.pop() + logger.info(f"Backtracking to step {backtrack_step}") + logger.info(f"Reason: {reflection_result['reason']}") + logger.info("Suggested improvements:") + for improvement in reflection_result["suggested_improvements"]: + logger.info(f"- {improvement}") + + except Exception as e: + logger.error(f"Error in reflection: {str(e)}") + + # If we've found a satisfactory solution, return it + if score >= 0.75: + logger.info(f"\nFound satisfactory solution with score {score:.3f}") + return [{"action": node.action} for node in path[1:]] + + except Exception as e: + logger.error(f"Error in simulation: {str(e)}") + continue + + # Update node statistics + logger.info(f"\nBackpropagation Step:") + for node in path: + old_value = node.value + node.visits += 1 + node.value = (node.value * (node.visits - 1) + score) / node.visits + logger.info(f"Node {node.action}:") + logger.info(f" Visits: {node.visits}") + logger.info(f" Value: {old_value:.3f} -> {node.value:.3f}") + + # If we've exhausted all iterations and haven't found a perfect solution, + # return the best path we found + if best_path and len(best_path) > 1: + logger.info(f"\nSearch complete. Returning best path found with score {best_score:.3f}") + return [{"action": node.action} for node in best_path[1:]] + + # If no valid path was found or path was just the root, return a default action + logger.warning("\nNo valid path found, returning fallback action") + return [{"action": "refresh()", "description": "Fallback action - no valid path found"}] + + except Exception as e: + error_msg = f"Error in RMCTS search: {str(e)}" + logger.error(error_msg) + + if best_path: + logger.info(f"\nReturning best path found before error with score {best_score:.3f}") + return [{"action": node.action} for node in best_path[1:]] + return [] + + async def rmcts_with_websocket(self, websocket) -> List[Dict[str, Any]]: + """ + Performs Monte Carlo Tree Search starting from the root node with WebSocket updates. + Uses GPT-4 for node selection and reflection-based backpropagation. + + Args: + websocket: WebSocket connection to send updates to + + Returns: + List[Dict[str, Any]]: List of actions in the best path found + """ + best_score = float('-inf') + best_path = None + visited = set() # Track visited nodes to avoid cycles + max_iterations = self.config.iterations # Use configured number of iterations + + try: + # Initial browser setup + live_browser_url, session_id = await self._reset_browser(websocket) + + for iteration in range(max_iterations): + logger.info(f"\n{'='*50}") + logger.info(f"RMCTS Iteration {iteration + 1}/{max_iterations}") + logger.info(f"{'='*50}\n") + + # Send iteration update if websocket is provided + await websocket.send_json({ + "type": "rmcts_iteration", + "iteration": iteration + 1, + "max_iterations": max_iterations, + "timestamp": datetime.utcnow().isoformat() + }) + + # Selection: Use GPT-4 to select a promising path + current_node = self.root_node + path = [current_node] + selection_depth = 0 + + while current_node.children and not current_node.is_terminal: + logger.info(f"\nSelection Step {selection_depth + 1}:") + logger.info(f"Current node action: {current_node.action}") + logger.info(f"Number of children: {len(current_node.children)}") + + # Get trajectory for GPT-4 to evaluate + trajectory = [] + for node in path[1:]: # Skip root node + trajectory.append({ + "natural_language_description": node.natural_language_description, + "action": node.action, + "feedback": node.feedback + }) + + # Create prompt for GPT-4 to select next node + prompt = f"""Given the current trajectory and goal, select the most promising child node to explore next. + Consider the overall progress, efficiency, and likelihood of success. + + Goal: {self.goal} + + Current Trajectory: + {json.dumps(trajectory, indent=2)} + + Available Children: + {json.dumps([{ + 'action': child.action, + 'description': child.natural_language_description, + 'visits': child.visits, + 'value': child.value + } for child in current_node.children], indent=2)} + + Return a JSON response with: + {{ + "selected_child_index": int, # Index of the selected child + "explanation": str # Brief explanation of the selection + }}""" + + try: + response = openai_client.chat.completions.create( + model=self.config.evaluation_model, + messages=[ + {"role": "system", "content": "You are an expert at selecting promising paths in a search tree."}, + {"role": "user", "content": prompt} + ], + response_format={"type": "json_object"} + ) + + selection = json.loads(response.choices[0].message.content) + selected_index = selection["selected_child_index"] + + if 0 <= selected_index < len(current_node.children): + current_node = current_node.children[selected_index] + path.append(current_node) + logger.info(f"Selected child {selected_index + 1}: {current_node.action}") + logger.info(f"Selection explanation: {selection['explanation']}") + + # Send selection update if websocket is provided + await websocket.send_json({ + "type": "node_selected", + "node_id": id(current_node), + "explanation": selection["explanation"], + "timestamp": datetime.utcnow().isoformat() + }) + else: + logger.warning(f"Invalid child index {selected_index}, breaking selection") + break + + except Exception as e: + logger.error(f"Error in node selection: {str(e)}") + await websocket.send_json({ + "type": "selection_error", + "error": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + break + + selection_depth += 1 + + # Expansion: Expand the selected node if possible + if not current_node.is_terminal and current_node.depth < self.config.max_depth: + logger.info(f"\nExpansion Step:") + logger.info(f"Expanding node: {current_node.action}") + + await websocket.send_json({ + "type": "node_expanding", + "node_id": id(current_node), + "timestamp": datetime.utcnow().isoformat() + }) + + try: + await self.expand(current_node, websocket) + logger.info(f"Successfully expanded node with {len(current_node.children)} children") + except Exception as e: + logger.error(f"Error expanding node: {str(e)}") + current_node.is_terminal = True + await websocket.send_json({ + "type": "expansion_error", + "node_id": id(current_node), + "error": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + + # Simulation: Evaluate the current path + logger.info(f"\nSimulation Step:") + logger.info(f"Evaluating path of length {len(path) - 1}") + + await websocket.send_json({ + "type": "simulation_start", + "path_length": len(path) - 1, + "timestamp": datetime.utcnow().isoformat() + }) + + try: + trajectory = [] + for node in path[1:]: # Skip root node + trajectory.append({ + "natural_language_description": node.natural_language_description, + "action": node.action, + "feedback": node.feedback + }) + + # Score the trajectory + prompt = create_llm_prompt(trajectory, self.goal) + result = score_trajectory_with_openai(prompt, openai_client, model=self.config.evaluation_model) + score = result["overall_score"] + + logger.info(f"Simulation Results:") + logger.info(f"Overall Score: {score:.3f}") + logger.info(f"Efficiency Score: {result['efficiency_score']:.3f}") + logger.info(f"Accuracy Score: {result['accuracy_score']:.3f}") + logger.info(f"Robustness Score: {result['robustness_score']:.3f}") + + # Send simulation results if websocket is provided + await websocket.send_json({ + "type": "simulation_results", + "score": score, + "efficiency_score": result["efficiency_score"], + "accuracy_score": result["accuracy_score"], + "robustness_score": result["robustness_score"], + "timestamp": datetime.utcnow().isoformat() + }) + + # Update best path if this score is better + if score > best_score: + best_score = score + best_path = path + logger.info(f"\nNew best path found!") + logger.info(f"Previous best score: {best_score:.3f}") + logger.info(f"New best score: {score:.3f}") + + # Send best path update if websocket is provided + await websocket.send_json({ + "type": "best_path_update", + "score": best_score, + "path": [{"id": id(node), "action": node.action} for node in best_path[1:]], + "timestamp": datetime.utcnow().isoformat() + }) + + # Reflection-based backpropagation + if score < 0.75: # If the path is not satisfactory + logger.info(f"\nReflection Step (Score {score:.3f} < 0.75):") + + await websocket.send_json({ + "type": "reflection_start", + "score": score, + "timestamp": datetime.utcnow().isoformat() + }) + + # Generate reflection prompt + reflection_prompt = f"""Analyze the current trajectory and suggest improvements. + + Goal: {self.goal} + + Current Trajectory: + {json.dumps(trajectory, indent=2)} + + Score: {score} + + Return a JSON response with: + {{ + "backtrack_to_step": int, # Which step to backtrack to (0-based index) + "reason": str, # Why backtrack to this step + "suggested_improvements": [str] # List of suggested improvements + }}""" + + try: + reflection = openai_client.chat.completions.create( + model=self.config.evaluation_model, + messages=[ + {"role": "system", "content": "You are an expert at analyzing and improving search trajectories."}, + {"role": "user", "content": reflection_prompt} + ], + response_format={"type": "json_object"} + ) + + reflection_result = json.loads(reflection.choices[0].message.content) + backtrack_step = reflection_result["backtrack_to_step"] + + # Backtrack to the suggested step + if 0 <= backtrack_step < len(path): + current_node = path[backtrack_step] + # Remove nodes after the backtrack point + while len(path) > backtrack_step + 1: + path.pop() + logger.info(f"Backtracking to step {backtrack_step}") + logger.info(f"Reason: {reflection_result['reason']}") + logger.info("Suggested improvements:") + for improvement in reflection_result["suggested_improvements"]: + logger.info(f"- {improvement}") + + # Send backtracking update if websocket is provided + await websocket.send_json({ + "type": "backtracking", + "step": backtrack_step, + "reason": reflection_result["reason"], + "suggested_improvements": reflection_result["suggested_improvements"], + "timestamp": datetime.utcnow().isoformat() + }) + + except Exception as e: + logger.error(f"Error in reflection: {str(e)}") + await websocket.send_json({ + "type": "reflection_error", + "error": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + + # If we've found a satisfactory solution, return it + if score >= 0.75: + logger.info(f"\nFound satisfactory solution with score {score:.3f}") + + # Send completion update if websocket is provided + await websocket.send_json({ + "type": "search_complete", + "status": "success", + "score": score, + "path": [{"id": id(node), "action": node.action} for node in path[1:]], + "timestamp": datetime.utcnow().isoformat() + }) + + return [{"action": node.action} for node in path[1:]] + + except Exception as e: + logger.error(f"Error in simulation: {str(e)}") + await websocket.send_json({ + "type": "simulation_error", + "error": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + continue + + # Update node statistics + logger.info(f"\nBackpropagation Step:") + for node in path: + old_value = node.value + node.visits += 1 + node.value = (node.value * (node.visits - 1) + score) / node.visits + logger.info(f"Node {node.action}:") + logger.info(f" Visits: {node.visits}") + logger.info(f" Value: {old_value:.3f} -> {node.value:.3f}") + + # Send backpropagation update if websocket is provided + await websocket.send_json({ + "type": "backpropagation_complete", + "updated_nodes": [{"id": id(node), "visits": node.visits, "value": node.value} for node in path], + "timestamp": datetime.utcnow().isoformat() + }) + + # If we've exhausted all iterations and haven't found a perfect solution, + # return the best path we found + if best_path and len(best_path) > 1: + logger.info(f"\nSearch complete. Returning best path found with score {best_score:.3f}") + + # Send completion update if websocket is provided + await websocket.send_json({ + "type": "search_complete", + "status": "partial_success", + "score": best_score, + "path": [{"id": id(node), "action": node.action} for node in best_path[1:]], + "timestamp": datetime.utcnow().isoformat() + }) + + return [{"action": node.action} for node in best_path[1:]] + + # If no path was found at all + logger.warning("\nNo valid path found") + + # Send failure update if websocket is provided + await websocket.send_json({ + "type": "search_complete", + "status": "failure", + "message": "No valid path found", + "timestamp": datetime.utcnow().isoformat() + }) + + # If no valid path was found or path was just the root, return a default action + logger.warning("\nNo valid path found, returning fallback action") + return [{"action": "refresh()", "description": "Fallback action - no valid path found"}] + + except Exception as e: + error_msg = f"Error in RMCTS search: {str(e)}" + logger.error(error_msg) + + # Send error update if websocket is provided + await websocket.send_json({ + "type": "search_error", + "error": error_msg, + "timestamp": datetime.utcnow().isoformat() + }) + + if best_path: + logger.info(f"\nReturning best path found before error with score {best_score:.3f}") + return [{"action": node.action} for node in best_path[1:]] + return [] + + async def _reset_browser(self, websocket=None) -> Optional[tuple]: + """Reset the browser to initial state and return the live browser URL if available.""" + await self.playwright_manager.close() + + ## reset account using api-based account reset + if self.config.account_reset: + if websocket: + await websocket.send_json({ + "type": "account_reset", + "status": "started", + "timestamp": datetime.utcnow().isoformat() + }) + + try: + # Use aiohttp instead of curl + async with aiohttp.ClientSession() as session: + headers = {'Connection': 'close'} # Similar to curl -N + async with session.get(self.reset_url, headers=headers) as response: + if response.status == 200: + data = await response.json() + print(f"Account reset successful: {data}") + if websocket: + await websocket.send_json({ + "type": "account_reset", + "status": "success", + "data": data, + "timestamp": datetime.utcnow().isoformat() + }) + else: + error_msg = f"Account reset failed with status {response.status}" + print(error_msg) + if websocket: + await websocket.send_json({ + "type": "account_reset", + "status": "failed", + "reason": error_msg, + "timestamp": datetime.utcnow().isoformat() + }) + + except Exception as e: + print(f"Error during account reset: {e}") + if websocket: + await websocket.send_json({ + "type": "account_reset", + "status": "failed", + "reason": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + + try: + # Create new playwright manager + self.playwright_manager = await setup_playwright( + storage_state=self.config.storage_state, + headless=self.config.headless, + mode=self.config.browser_mode + ) + page = await self.playwright_manager.get_page() + live_browser_url = None + if self.config.browser_mode == "browserbase": + live_browser_url = await self.playwright_manager.get_live_browser_url() + session_id = await self.playwright_manager.get_session_id() + else: + session_id = None + live_browser_url = None + await page.goto(self.starting_url, wait_until="networkidle") + + # Send success message if websocket is provided + if websocket: + if self.config.storage_state: + await websocket.send_json({ + "type": "browser_setup", + "status": "success", + "message": f"Browser successfully initialized with storage state file: {self.config.storage_state}", + "live_browser_url": live_browser_url, + "session_id": session_id, + "timestamp": datetime.utcnow().isoformat() + }) + else: + await websocket.send_json({ + "type": "browser_setup", + "status": "success", + "message": "Browser successfully initialized", + "live_browser_url": live_browser_url, + "session_id": session_id, + "timestamp": datetime.utcnow().isoformat() + }) + + return live_browser_url, session_id + except Exception as e: + print(f"Error setting up browser: {e}") + if websocket: + await websocket.send_json({ + "type": "browser_setup", + "status": "failed", + "reason": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + return None, None + + async def expand(self, node: LATSNode, websocket=None) -> bool: + """ + Expand a node by generating its children. If no children are generated, + mark the node as terminal and return False to trigger backtracking. + + Args: + node: Node to expand. + websocket: Optional WebSocket connection to send updates. + + Returns: + bool: True if expansion succeeded (children generated), False otherwise. + """ + try: + children_state = await self.generate_children(node, websocket) + except Exception as e: + logger.error(f"Exception during generation of children for node {node.action}: {e}") + children_state = [] + + if not children_state: + logger.warning("No children generated. Marking node as terminal and triggering backtracking.") + node.is_terminal = True + return False # Indicate that expansion did not generate children. + + for child_state in children_state: + try: + child = LATSNode( + natural_language_description=child_state.get("natural_language_description", ""), + action=child_state.get("action", ""), + prob=child_state.get("prob", 0.0), + element=child_state.get("element", None), + goal=node.goal, + parent=node + ) + node.children.append(child) + + if websocket: + await websocket.send_json({ + "type": "node_created", + "node_id": id(child), + "parent_id": id(node), + "action": child.action, + "description": child.natural_language_description, + "timestamp": datetime.utcnow().isoformat() + }) + except Exception as e: + logger.error(f"Error creating child node from state {child_state}: {e}") + return True # Expansion succeeded (children were generated). + + async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]: + """ + Generate child nodes for a given node. + + Args: + node: Parent node to generate children for + websocket: Optional WebSocket connection to send updates to + + Returns: + list[dict]: List of child state dictionaries + """ + # Reset browser and get live URL + live_browser_url, session_id = await self._reset_browser(websocket) + path = self.get_path_to_root(node) + logger.info(f"######### Generating children for path with {len(path)} nodes") + # Execute path + for n in path[1:]: # Skip root node + if websocket: + await websocket.send_json({ + "type": "replaying_action", + "node_id": id(n), + "action": n.action, + "timestamp": datetime.utcnow().isoformat() + }) + try: + success = await playwright_step_execution( + n, + self.goal, + self.playwright_manager, + is_replay=False, + log_folder=self.config.log_folder + ) + logger.info(f"#########Success: {success}") + + if not success: + logger.warning(f"Action execution failed: {n.action}") + n.is_terminal = True + if websocket: + await websocket.send_json({ + "type": "replay_failed", + "node_id": id(n), + "timestamp": datetime.utcnow().isoformat() + }) + return [{ + "natural_language_description": "Recover from failed action", + "action": "refresh()", + "prob": 0.1, + "element": None + }] + except Exception as e: + logger.error(f"Error executing action {n.action}: {str(e)}") + # Provide fallback actions instead of bubbling up the exception + return [{ + "natural_language_description": "Recover from action error", + "action": "refresh()", + "prob": 0.1, + "element": None + }] + + + if not n.feedback: + n.feedback = await generate_feedback( + self.goal, + n.natural_language_description, + self.playwright_manager, + ) + if websocket: + await websocket.send_json({ + "type": "feedback_generated", + "node_id": id(n), + "feedback": n.feedback, + "timestamp": datetime.utcnow().isoformat() + }) + + time.sleep(3) + page = await self.playwright_manager.get_page() + page_info = await extract_page_info(page, self.config.fullpage, self.config.log_folder) + + messages = [{"role": "user", "content": f"Action is: {n.action}"} for n in path[1:]] + + if websocket: + await websocket.send_json({ + "type": "generating_actions", + "node_id": id(node), + "timestamp": datetime.utcnow().isoformat() + }) + + next_actions = await extract_top_actions( + [{"natural_language_description": n.natural_language_description, "action": n.action, "feedback": n.feedback} for n in path[1:]], + self.goal, + self.images, + page_info, + self.action_set, + openai_client, + features=self.config.features, + elements_filter=self.config.elements_filter, + branching_factor=self.config.branching_factor, + log_folder=self.config.log_folder, + fullpage=self.config.fullpage, + action_generation_model=self.config.action_generation_model, + action_grounding_model=self.config.action_grounding_model + ) + + children = [] + for action in next_actions: + if action["action"] == "FINISH": + logger.info(f"Found FINISH action with probability: {action['prob']}") + if action["prob"] > 0.99: + node.is_terminal = True + if websocket: + await websocket.send_json({ + "type": "node_terminal", + "node_id": id(node), + "reason": "finish_action", + "timestamp": datetime.utcnow().isoformat() + }) + continue + # return [] + continue + + page = await self.playwright_manager.get_page() + code, function_calls = self.action_set.to_python_code(action["action"]) + + if len(function_calls) == 1: + try: + for function_name, function_args in function_calls: + extracted_number = parse_function_args(function_args) + element = await locate_element(page, extracted_number) + action["element"] = element + except Exception as e: + logger.warning(f"Element location failed for action: {action['action']}, error: {str(e)}") + action["element"] = None + children.append(action) + if websocket: + await websocket.send_json({ + "type": "element_location_failed", + "action": action["action"], + "error": str(e), + "timestamp": datetime.utcnow().isoformat() + }) + children.append(action) + + if not children: + # node.is_terminal = True + # if websocket: + # await websocket.send_json({ + # "type": "node_terminal", + # "node_id": id(node), + # "reason": "no_valid_actions", + # "timestamp": datetime.utcnow().isoformat() + # }) + # logger.warning("No children generated") + logger.warning("No viable children, creating fallback exploration actions") + + # # If empty list would terminate search, create a "fallback" child + children.extend([ + { + "natural_language_description": "Navigate back to try a different approach", + "action": "navigate_backward()", + "prob": 0.15, + "element": None + }, + { + "natural_language_description": "Try refreshing the page", + "action": "refresh()", + "prob": 0.1, + "element": None + }, + { + "natural_language_description": "Try clicking on a random element", + "action": "click('random')", + "prob": 0.05, + "element": None + } + ]) + print(f"****** Generated children: {children}") + return children + + def get_path_to_root(self, node: LATSNode) -> List[LATSNode]: + path = [] + current = node + while current: + path.append(current) + current = current.parent + return list(reversed(path)) \ No newline at end of file diff --git a/visual-tree-search-backend/app/api/routes/new_tree_search_websocket.py b/visual-tree-search-backend/app/api/routes/new_tree_search_websocket.py index 5277676..cd7d0f9 100644 --- a/visual-tree-search-backend/app/api/routes/new_tree_search_websocket.py +++ b/visual-tree-search-backend/app/api/routes/new_tree_search_websocket.py @@ -128,6 +128,8 @@ async def handle_search_request(websocket: WebSocket, message: Dict[str, Any]): await agent.dfs_with_websocket(websocket) elif search_algorithm.lower() == "lats": await agent.run(websocket) + elif search_algorithm.lower() == "mcts": + await agent.run(websocket) else: await websocket.send_json({ "type": "error", diff --git a/visual-tree-search-backend/app/api/routes/tree_search_websocket.py b/visual-tree-search-backend/app/api/routes/tree_search_websocket.py index 144950f..f391e33 100644 --- a/visual-tree-search-backend/app/api/routes/tree_search_websocket.py +++ b/visual-tree-search-backend/app/api/routes/tree_search_websocket.py @@ -128,6 +128,8 @@ async def handle_search_request(websocket: WebSocket, message: Dict[str, Any]): await agent.dfs_with_websocket(websocket) elif search_algorithm.lower() == "lats": await agent.run(websocket) + elif search_algorithm.lower() == "mcts": + await agent.run(websocket) else: await websocket.send_json({ "type": "error", diff --git a/visual-tree-search-backend/app/api/shopping.json b/visual-tree-search-backend/app/api/shopping.json index d3b1819..6a2c7ab 100644 --- a/visual-tree-search-backend/app/api/shopping.json +++ b/visual-tree-search-backend/app/api/shopping.json @@ -1,142 +1,142 @@ [ - { - "name": "mage-cache-storage", - "value": "{}", - "domain": "128.105.145.205", - "path": "/", - "expires": 1775110732, - "httpOnly": false, - "secure": false, - "sameSite": "Lax" - }, - { - "name": "mage-cache-storage-section-invalidation", - "value": "{}", - "domain": "128.105.145.205", - "path": "/", - "expires": 1775110732, - "httpOnly": false, - "secure": false, - "sameSite": "Lax" - }, - { - "name": "mage-messages", - "value": "", - "domain": "128.105.145.205", - "path": "/", - "expires": 1775272527, - "httpOnly": false, - "secure": false, - "sameSite": "Strict" - }, - { - "name": "recently_viewed_product", - "value": "{}", - "domain": "128.105.145.205", - "path": "/", - "expires": 1775110732, - "httpOnly": false, - "secure": false, - "sameSite": "Lax" - }, - { - "name": "recently_viewed_product_previous", - "value": "{}", - "domain": "128.105.145.205", - "path": "/", - "expires": 1775110732, - "httpOnly": false, - "secure": false, - "sameSite": "Lax" - }, - { - "name": "recently_compared_product", - "value": "{}", - "domain": "128.105.145.205", - "path": "/", - "expires": 1775110732, - "httpOnly": false, - "secure": false, - "sameSite": "Lax" - }, - { - "name": "recently_compared_product_previous", - "value": "{}", - "domain": "128.105.145.205", - "path": "/", - "expires": 1775110732, - "httpOnly": false, - "secure": false, - "sameSite": "Lax" - }, - { - "name": "product_data_storage", - "value": "{}", - "domain": "128.105.145.205", - "path": "/", - "expires": 1775110732, - "httpOnly": false, - "secure": false, - "sameSite": "Lax" - }, - { - "name": "private_content_version", - "value": "c514ad01bc9816ee30ab1f02510aa34a", - "domain": "128.105.145.205", - "path": "/", - "expires": 1778296522.505323, - "httpOnly": false, - "secure": false, - "sameSite": "Lax" - }, - { - "name": "PHPSESSID", - "value": "007c737fca0eb4173ab5362c2d9c8b09", - "domain": "128.105.145.205", - "path": "/", - "expires": 1775272526.468528, - "httpOnly": true, - "secure": false, - "sameSite": "Lax" - }, - { - "name": "X-Magento-Vary", - "value": "9bf9a599123e6402b85cde67144717a08b817412", - "domain": "128.105.145.205", - "path": "/", - "expires": 1775272526.468754, - "httpOnly": true, - "secure": false, - "sameSite": "Lax" - }, - { - "name": "form_key", - "value": "Hsr3n5ycGkOPfr1K", - "domain": "128.105.145.205", - "path": "/", - "expires": 1775272526.468692, - "httpOnly": false, - "secure": false, - "sameSite": "Lax" - }, - { - "name": "mage-cache-sessid", - "value": "true", - "domain": "128.105.145.205", - "path": "/", - "expires": 1775272527, - "httpOnly": false, - "secure": false, - "sameSite": "Lax" - }, - { - "name": "section_data_ids", - "value": "{%22messages%22:1743736525%2C%22customer%22:1743736525%2C%22compare-products%22:1743736525%2C%22last-ordered-items%22:1743736525%2C%22cart%22:1743736525%2C%22directory-data%22:1743736525%2C%22captcha%22:1743736525%2C%22instant-purchase%22:1743736525%2C%22loggedAsCustomer%22:1743736525%2C%22persistent%22:1743736525%2C%22review%22:1743736525%2C%22wishlist%22:1743736525%2C%22recently_viewed_product%22:1743736525%2C%22recently_compared_product%22:1743736525%2C%22product_data_storage%22:1743736525%2C%22paypal-billing-agreement%22:1743736525}", - "domain": "128.105.145.205", - "path": "/", - "expires": 1775272524, - "httpOnly": false, - "secure": false, - "sameSite": "Lax" - } - ] \ No newline at end of file + { + "name": "mage-cache-storage", + "value": "{}", + "domain": "128.105.145.205", + "path": "/", + "expires": 1775110732, + "httpOnly": false, + "secure": false, + "sameSite": "Lax" + }, + { + "name": "mage-cache-storage-section-invalidation", + "value": "{}", + "domain": "128.105.145.205", + "path": "/", + "expires": 1775110732, + "httpOnly": false, + "secure": false, + "sameSite": "Lax" + }, + { + "name": "mage-messages", + "value": "", + "domain": "128.105.145.205", + "path": "/", + "expires": 1775370120, + "httpOnly": false, + "secure": false, + "sameSite": "Strict" + }, + { + "name": "recently_viewed_product", + "value": "{}", + "domain": "128.105.145.205", + "path": "/", + "expires": 1775110732, + "httpOnly": false, + "secure": false, + "sameSite": "Lax" + }, + { + "name": "recently_viewed_product_previous", + "value": "{}", + "domain": "128.105.145.205", + "path": "/", + "expires": 1775110732, + "httpOnly": false, + "secure": false, + "sameSite": "Lax" + }, + { + "name": "recently_compared_product", + "value": "{}", + "domain": "128.105.145.205", + "path": "/", + "expires": 1775110732, + "httpOnly": false, + "secure": false, + "sameSite": "Lax" + }, + { + "name": "recently_compared_product_previous", + "value": "{}", + "domain": "128.105.145.205", + "path": "/", + "expires": 1775110732, + "httpOnly": false, + "secure": false, + "sameSite": "Lax" + }, + { + "name": "product_data_storage", + "value": "{}", + "domain": "128.105.145.205", + "path": "/", + "expires": 1775110732, + "httpOnly": false, + "secure": false, + "sameSite": "Lax" + }, + { + "name": "private_content_version", + "value": "ff4bba58081243f67b1adee2cc6974bd", + "domain": "128.105.145.205", + "path": "/", + "expires": 1778394118.522057, + "httpOnly": false, + "secure": false, + "sameSite": "Lax" + }, + { + "name": "PHPSESSID", + "value": "30247306b1ad824f37f3e0384d86d991", + "domain": "128.105.145.205", + "path": "/", + "expires": 1775370122.385659, + "httpOnly": true, + "secure": false, + "sameSite": "Lax" + }, + { + "name": "X-Magento-Vary", + "value": "9bf9a599123e6402b85cde67144717a08b817412", + "domain": "128.105.145.205", + "path": "/", + "expires": 1775370122.385877, + "httpOnly": true, + "secure": false, + "sameSite": "Lax" + }, + { + "name": "form_key", + "value": "IEPmx1hKh4NWjeUa", + "domain": "128.105.145.205", + "path": "/", + "expires": 1775370122.385817, + "httpOnly": false, + "secure": false, + "sameSite": "Lax" + }, + { + "name": "mage-cache-sessid", + "value": "true", + "domain": "128.105.145.205", + "path": "/", + "expires": 1775370120, + "httpOnly": false, + "secure": false, + "sameSite": "Lax" + }, + { + "name": "section_data_ids", + "value": "{%22messages%22:1743834120%2C%22customer%22:1743834120%2C%22compare-products%22:1743834120%2C%22last-ordered-items%22:1743834120%2C%22cart%22:1743834120%2C%22directory-data%22:1743834120%2C%22captcha%22:1743834120%2C%22instant-purchase%22:1743834120%2C%22loggedAsCustomer%22:1743834120%2C%22persistent%22:1743834120%2C%22review%22:1743834120%2C%22wishlist%22:1743834120%2C%22recently_viewed_product%22:1743834120%2C%22recently_compared_product%22:1743834120%2C%22product_data_storage%22:1743834120%2C%22paypal-billing-agreement%22:1743834120}", + "domain": "128.105.145.205", + "path": "/", + "expires": 1775370120, + "httpOnly": false, + "secure": false, + "sameSite": "Lax" + } +] \ No newline at end of file diff --git a/visual-tree-search-backend/test/test-tree-search-ws-lats.py b/visual-tree-search-backend/test/test-tree-search-ws-lats.py index 567c790..c9f3cfb 100644 --- a/visual-tree-search-backend/test/test-tree-search-ws-lats.py +++ b/visual-tree-search-backend/test/test-tree-search-ws-lats.py @@ -138,7 +138,7 @@ def parse_arguments(): parser.add_argument("--goal", type=str, default=DEFAULT_GOAL, help=f"Goal to achieve (default: {DEFAULT_GOAL})") - parser.add_argument("--algorithm", type=str, choices=["bfs", "dfs", "lats"], default="lats", + parser.add_argument("--algorithm", type=str, choices=["bfs", "dfs", "lats", "mcts"], default="lats", help="Search algorithm to use (default: lats)") parser.add_argument("--max-depth", type=int, default=3, diff --git a/visual-tree-search-backend/test/test-tree-search-ws-mcts.py b/visual-tree-search-backend/test/test-tree-search-ws-mcts.py new file mode 100644 index 0000000..9046be8 --- /dev/null +++ b/visual-tree-search-backend/test/test-tree-search-ws-mcts.py @@ -0,0 +1,169 @@ +import asyncio +import json +import websockets +import argparse +import logging +from datetime import datetime + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +# Default values +DEFAULT_WS_URL = "ws://localhost:3000/new-tree-search-ws" +DEFAULT_STARTING_URL = "http://128.105.145.205:7770/" +DEFAULT_GOAL = "search running shoes, click on the first result" + +async def connect_and_test_search( + ws_url: str, + starting_url: str, + goal: str, + search_algorithm: str = "bfs", + max_depth: int = 3 +): + """ + Connect to the WebSocket endpoint and test the tree search functionality. + + Args: + ws_url: WebSocket URL to connect to + starting_url: URL to start the search from + goal: Goal to achieve + search_algorithm: Search algorithm to use (bfs or dfs) + max_depth: Maximum depth for the search tree + """ + logger.info(f"Connecting to WebSocket at {ws_url}") + + async with websockets.connect(ws_url) as websocket: + logger.info("Connected to WebSocket") + + # Wait for connection established message + response = await websocket.recv() + data = json.loads(response) + if data.get("type") == "connection_established": + logger.info(f"Connection established with ID: {data.get('connection_id')}") + + # Send search request + request = { + "type": "start_search", + "agent_type": "MCTSAgent", + "starting_url": starting_url, + "goal": goal, + "search_algorithm": search_algorithm, + "max_depth": max_depth + } + + logger.info(f"Sending search request: {request}") + await websocket.send(json.dumps(request)) + + # Process responses + while True: + try: + response = await websocket.recv() + data = json.loads(response) + + # Log the message type and some key information + msg_type = data.get("type", "unknown") + + if msg_type == "status_update": + logger.info(f"Status update: {data.get('status')} - {data.get('message')}") + + elif msg_type == "iteration_start": + logger.info(f"Iteration start: {data.get('iteration')}") + + elif msg_type == "step_start": + logger.info(f"Step start: {data.get('step')} - {data.get('step_name')}") + + elif msg_type == "node_update": + node_id = data.get("node_id") + status = data.get("status") + logger.info(f"Node update: {node_id} - {status}") + + # If node was scored, log the score + if status == "scored": + logger.info(f"Node score: {data.get('score')}") + + elif msg_type == "trajectory_update": + logger.info(f"Trajectory update received with {data.get('trajectory')}") + + elif msg_type == "tree_update": + logger.info(f"Tree update received with {data.get('tree')}") + + elif msg_type == "best_path_update": + logger.info(f"Best path update: score={data.get('score')}, path length={len(data.get('path', []))}") + + elif msg_type == "search_complete": + status = data.get("status") + score = data.get("score", "N/A") + path_length = len(data.get("path", [])) + + logger.info(f"Search complete: {status}, score={score}, path length={path_length}") + logger.info("Path actions:") + + for i, node in enumerate(data.get("path", [])): + logger.info(f" {i+1}. {node.get('action')}") + + # Exit the loop when search is complete + break + + elif msg_type == "error": + logger.error(f"Error: {data.get('message')}") + break + + else: + logger.info(f"Received message of type {msg_type}") + logger.info(f"Message: {data}") + + except websockets.exceptions.ConnectionClosed: + logger.warning("WebSocket connection closed") + break + except Exception as e: + logger.error(f"Error processing message: {e}") + break + + logger.info("Test completed") + +def parse_arguments(): + """Parse command line arguments""" + parser = argparse.ArgumentParser(description="Test the tree search WebSocket functionality") + + parser.add_argument("--ws-url", type=str, default=DEFAULT_WS_URL, + help=f"WebSocket URL (default: {DEFAULT_WS_URL})") + + parser.add_argument("--starting-url", type=str, default=DEFAULT_STARTING_URL, + help=f"Starting URL for the search (default: {DEFAULT_STARTING_URL})") + + parser.add_argument("--goal", type=str, default=DEFAULT_GOAL, + help=f"Goal to achieve (default: {DEFAULT_GOAL})") + + parser.add_argument("--algorithm", type=str, choices=["bfs", "dfs", "lats", "mcts"], default="mcts", + help="Search algorithm to use (default: lats)") + + parser.add_argument("--max-depth", type=int, default=3, + help="Maximum depth for the search tree (default: 3)") + + return parser.parse_args() + +async def main(): + """Main entry point""" + args = parse_arguments() + + logger.info("Starting tree search WebSocket test") + logger.info(f"WebSocket URL: {args.ws_url}") + logger.info(f"Starting URL: {args.starting_url}") + logger.info(f"Goal: {args.goal}") + logger.info(f"Algorithm: {args.algorithm}") + logger.info(f"Max depth: {args.max_depth}") + + await connect_and_test_search( + ws_url=args.ws_url, + starting_url=args.starting_url, + goal=args.goal, + search_algorithm=args.algorithm, + max_depth=args.max_depth + ) + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file