From 210fadbe3a88aff9a87f869f47c194a5663c4eb4 Mon Sep 17 00:00:00 2001 From: Tata0703 Date: Fri, 4 Apr 2025 01:50:56 -0700 Subject: [PATCH 1/3] finish 1) steo start, 2) node_expanding, node created, 3) best_path_update and search_complete, 4) iteration --- .../agents_async/SearchAgents/lats_agent.py | 33 +++++++++++++++++-- .../SearchAgents/simple_search_agent.py | 18 ++++++---- .../test/test-tree-search-ws-lats.py | 6 ++++ .../test/test-tree-search-ws-simple.py | 7 ++-- 4 files changed, 50 insertions(+), 14 deletions(-) diff --git a/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/lats_agent.py b/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/lats_agent.py index 8fe68db..146ed62 100644 --- a/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/lats_agent.py +++ b/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/lats_agent.py @@ -37,6 +37,8 @@ openai_client = OpenAI() +## TODO: add best_path_update + class LATSAgent: """ Language-based Action Tree Search Agent implementation. @@ -117,6 +119,7 @@ async def run(self, websocket=None) -> list[LATSNode]: print_trajectory(best_node) if websocket: + # TODO: use score instead of reward to determine success await websocket.send_json({ "type": "search_complete", "status": "success" if best_node.reward == 1 else "partial_success", @@ -158,7 +161,8 @@ async def lats_search(self, websocket=None) -> LATSNode: if websocket: await websocket.send_json({ "type": "step_start", - "step": "selection", + "step": 1, + "step_name": "selection", "iteration": i + 1, "timestamp": datetime.utcnow().isoformat() }) @@ -177,7 +181,8 @@ async def lats_search(self, websocket=None) -> LATSNode: if websocket: await websocket.send_json({ "type": "step_start", - "step": "expansion", + "step": 2, + "step_name": "expansion", "iteration": i + 1, "timestamp": datetime.utcnow().isoformat() }) @@ -206,6 +211,14 @@ async def lats_search(self, websocket=None) -> LATSNode: # Step 3: Evaluation print(f"") print(f"{GREEN}Step 3: evaluation{RESET}") + if websocket: + await websocket.send_json({ + "type": "step_start", + "step": 3, + "step_name": "evaluation", + "iteration": i + 1, + "timestamp": datetime.utcnow().isoformat() + }) await self.evaluate_node(node) print(f"{GREEN}Tree:{RESET}") @@ -214,7 +227,14 @@ async def lats_search(self, websocket=None) -> LATSNode: # Step 4: Simulation print(f"{GREEN}Step 4: simulation{RESET}") - # # Find the child with the highest value + if websocket: + await websocket.send_json({ + "type": "step_start", + "step": 4, + "step_name": "simulation", + "iteration": i + 1, + "timestamp": datetime.utcnow().isoformat() + }) ## always = 1 reward, terminal_node = await self.simulate(max(node.children, key=lambda child: child.value), max_depth=self.config.max_depth, num_simulations=1) terminal_nodes.append(terminal_node) @@ -224,6 +244,13 @@ async def lats_search(self, websocket=None) -> LATSNode: # Step 5: Backpropagation print(f"{GREEN}Step 5: backpropagation{RESET}") + if websocket: + await websocket.send_json({ + "type": "step_start", + "step": 5, + "step_name": "backpropagation", + "timestamp": datetime.utcnow().isoformat() + }) self.backpropagate(terminal_node, reward) print(f"{GREEN}Tree:{RESET}") better_print(self.root_node) diff --git a/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/simple_search_agent.py b/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/simple_search_agent.py index fc41795..851f231 100644 --- a/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/simple_search_agent.py +++ b/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/simple_search_agent.py @@ -627,6 +627,7 @@ async def bfs_with_websocket(self, websocket=None) -> List[Dict[str, Any]]: queue_set = {self.root_node} # Track nodes in queue best_score = float('-inf') best_path = None + best_node = None visited = set() # Track visited nodes to avoid cycles current_level = 0 # Track current level for BFS @@ -797,13 +798,14 @@ async def bfs_with_websocket(self, websocket=None) -> List[Dict[str, Any]]: if score > best_score: best_score = score best_path = path - + best_node = current_node + # Send best path update if websocket is provided if websocket: await websocket.send_json({ "type": "best_path_update", "score": best_score, - "path": [{"id": id(node), "action": node.action} for node in best_path[1:]], + "path": best_node.get_trajectory(), "timestamp": datetime.utcnow().isoformat() }) @@ -819,7 +821,7 @@ async def bfs_with_websocket(self, websocket=None) -> List[Dict[str, Any]]: "type": "search_complete", "status": "success", "score": score, - "path": [{"id": id(node), "action": node.action} for node in path[1:]], + "path":best_node.get_trajectory(), "timestamp": datetime.utcnow().isoformat() }) @@ -843,7 +845,7 @@ async def bfs_with_websocket(self, websocket=None) -> List[Dict[str, Any]]: "type": "search_complete", "status": "partial_success", "score": best_score, - "path": [{"id": id(node), "action": node.action} for node in best_path[1:]], + "path": best_node.get_trajectory(), "timestamp": datetime.utcnow().isoformat() }) @@ -892,6 +894,7 @@ async def dfs_with_websocket(self, websocket=None) -> List[Dict[str, Any]]: stack_set = {self.root_node} # Track nodes in stack best_score = float('-inf') best_path = None + best_node = None visited = set() # Track visited nodes to avoid cycles current_path = [] # Track current path for DFS @@ -1025,13 +1028,14 @@ async def dfs_with_websocket(self, websocket=None) -> List[Dict[str, Any]]: if score > best_score: best_score = score best_path = path + best_node = current_node # Send best path update if websocket is provided if websocket: await websocket.send_json({ "type": "best_path_update", "score": best_score, - "path": [{"id": id(node), "action": node.action} for node in best_path[1:]], + "path": best_node.get_trajectory(), "timestamp": datetime.utcnow().isoformat() }) @@ -1047,7 +1051,7 @@ async def dfs_with_websocket(self, websocket=None) -> List[Dict[str, Any]]: "type": "search_complete", "status": "success", "score": score, - "path": [{"id": id(node), "action": node.action} for node in path[1:]], + "path": best_node.get_trajectory(), "timestamp": datetime.utcnow().isoformat() }) @@ -1095,7 +1099,7 @@ async def dfs_with_websocket(self, websocket=None) -> List[Dict[str, Any]]: "type": "search_complete", "status": "partial_success", "score": best_score, - "path": [{"id": id(node), "action": node.action} for node in best_path[1:]], + "path": best_node.get_trajectory(), "timestamp": datetime.utcnow().isoformat() }) diff --git a/visual-tree-search-backend/test/test-tree-search-ws-lats.py b/visual-tree-search-backend/test/test-tree-search-ws-lats.py index ea0291d..49c5c79 100644 --- a/visual-tree-search-backend/test/test-tree-search-ws-lats.py +++ b/visual-tree-search-backend/test/test-tree-search-ws-lats.py @@ -70,6 +70,12 @@ async def connect_and_test_search( if msg_type == "status_update": logger.info(f"Status update: {data.get('status')} - {data.get('message')}") + elif msg_type == "iteration_start": + logger.info(f"Iteration start: {data.get('iteration')}") + + elif msg_type == "step_start": + logger.info(f"Step start: {data.get('step')} - {data.get('step_name')}") + elif msg_type == "node_update": node_id = data.get("node_id") status = data.get("status") diff --git a/visual-tree-search-backend/test/test-tree-search-ws-simple.py b/visual-tree-search-backend/test/test-tree-search-ws-simple.py index a4fbf28..1db7708 100644 --- a/visual-tree-search-backend/test/test-tree-search-ws-simple.py +++ b/visual-tree-search-backend/test/test-tree-search-ws-simple.py @@ -83,15 +83,14 @@ async def connect_and_test_search( logger.info(f"Tree update received with {len(data.get('nodes', []))} nodes") elif msg_type == "best_path_update": - logger.info(f"Best path update: score={data.get('score')}, path length={len(data.get('path', []))}") + logger.info(f"Best path update: score={data.get('score')}, path={data.get('path')}") elif msg_type == "search_complete": status = data.get("status") score = data.get("score", "N/A") - path_length = len(data.get("path", [])) + path = data.get("path") - logger.info(f"Search complete: {status}, score={score}, path length={path_length}") - logger.info("Path actions:") + logger.info(f"Search complete: {status}, score={score}, path={path}") for i, node in enumerate(data.get("path", [])): logger.info(f" {i+1}. {node.get('action')}") From cf8bca07f51099999f64075d175c5653f950dcd0 Mon Sep 17 00:00:00 2001 From: Tata0703 Date: Fri, 4 Apr 2025 02:21:46 -0700 Subject: [PATCH 2/3] add 5) Browser Set Up, Account Reset, 6) Tree Update & Trajectory Update, 7) node_selected for LATS --- .../agents_async/SearchAgents/lats_agent.py | 120 ++++++++++++++++-- .../SearchAgents/simple_search_agent.py | 44 ++++++- .../app/api/lwats/agents_async/lats_agent.py | 36 ++++++ .../test/test-tree-search-ws-lats.py | 5 +- .../test/test-tree-search-ws-simple.py | 2 +- 5 files changed, 196 insertions(+), 11 deletions(-) create mode 100644 visual-tree-search-backend/app/api/lwats/agents_async/lats_agent.py diff --git a/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/lats_agent.py b/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/lats_agent.py index 146ed62..7e33974 100644 --- a/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/lats_agent.py +++ b/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/lats_agent.py @@ -119,6 +119,12 @@ async def run(self, websocket=None) -> list[LATSNode]: print_trajectory(best_node) if websocket: + # trajectory_data = self._get_trajectory_data(best_node) + # await websocket.send_json({ + # "type": "trajectory_update", + # "trajectory": trajectory_data, + # "timestamp": datetime.utcnow().isoformat() + # }) # TODO: use score instead of reward to determine success await websocket.send_json({ "type": "search_complete", @@ -168,6 +174,12 @@ async def lats_search(self, websocket=None) -> LATSNode: }) node = self.select_node(self.root_node) + if websocket: + await websocket.send_json({ + "type": "node_selected", + "node_id": id(node), + "timestamp": datetime.utcnow().isoformat() + }) if node is None: print("All paths lead to terminal nodes with reward 0. Ending search.") @@ -207,6 +219,12 @@ async def lats_search(self, websocket=None) -> LATSNode: print(f"{GREEN}Tree:{RESET}") better_print(self.root_node) print(f"") + tree_data = self._get_tree_data() + await websocket.send_json({ + "type": "tree_update", + "tree": tree_data, + "timestamp": datetime.utcnow().isoformat() + }) # Step 3: Evaluation print(f"") @@ -224,6 +242,15 @@ async def lats_search(self, websocket=None) -> LATSNode: print(f"{GREEN}Tree:{RESET}") better_print(self.root_node) print(f"") + ## send tree update, since evaluation is added to the tree + if websocket: + tree_data = self._get_tree_data() + await websocket.send_json({ + "type": "tree_update", + "tree": tree_data, + "timestamp": datetime.utcnow().isoformat() + }) + # Step 4: Simulation print(f"{GREEN}Step 4: simulation{RESET}") @@ -236,7 +263,7 @@ async def lats_search(self, websocket=None) -> LATSNode: "timestamp": datetime.utcnow().isoformat() }) ## always = 1 - reward, terminal_node = await self.simulate(max(node.children, key=lambda child: child.value), max_depth=self.config.max_depth, num_simulations=1) + reward, terminal_node = await self.simulate(max(node.children, key=lambda child: child.value), max_depth=self.config.max_depth, num_simulations=1, websocket=websocket) terminal_nodes.append(terminal_node) if reward == 1: @@ -362,7 +389,8 @@ async def evaluate_node(self, node: LATSNode) -> None: child.value = score child.reward = score - async def simulate(self, node: LATSNode, max_depth: int = 2, num_simulations=1) -> tuple[float, LATSNode]: + ## TODO: make number of simulations configurable + async def simulate(self, node: LATSNode, max_depth: int = 2, num_simulations=1, websocket=None) -> tuple[float, LATSNode]: """ Perform a rollout simulation from a node. @@ -378,13 +406,39 @@ async def simulate(self, node: LATSNode, max_depth: int = 2, num_simulations=1) print_trajectory(node) print("print the entire tree") print_entire_tree(self.root_node) - return await self.rollout(node, max_depth=max_depth) + if websocket: + tree_data = self._get_tree_data() + await websocket.send_json({ + "type": "tree_update", + "tree": tree_data, + "timestamp": datetime.utcnow().isoformat() + }) + trajectory_data = self._get_trajectory_data(node) + await websocket.send_json({ + "type": "trajectory_update", + "trajectory": trajectory_data, + "timestamp": datetime.utcnow().isoformat() + }) + return await self.rollout(node, max_depth=max_depth, websocket=websocket) - async def send_completion_request(self, plan, depth, node, trajectory=[]): + async def send_completion_request(self, plan, depth, node, trajectory=[], websocket=None): print("print the trajectory") print_trajectory(node) print("print the entire tree") print_entire_tree(self.root_node) + if websocket: + # tree_data = self._get_tree_data() + # await websocket.send_json({ + # "type": "tree_update", + # "tree": tree_data, + # "timestamp": datetime.utcnow().isoformat() + # }) + trajectory_data = self._get_trajectory_data(node) + await websocket.send_json({ + "type": "trajectory_update", + "trajectory": trajectory_data, + "timestamp": datetime.utcnow().isoformat() + }) if depth >= self.config.max_depth: return trajectory, node @@ -447,20 +501,20 @@ async def send_completion_request(self, plan, depth, node, trajectory=[]): if goal_finished: return trajectory, new_node - return await self.send_completion_request(plan, depth + 1, new_node, trajectory) + return await self.send_completion_request(plan, depth + 1, new_node, trajectory, websocket) except Exception as e: print(f"Attempt {attempt + 1} failed with error: {e}") if attempt + 1 == retry_count: print("Max retries reached. Skipping this step and retrying the whole request.") # Retry the entire request from the same state - return await self.send_completion_request(plan, depth, node, trajectory) + return await self.send_completion_request(plan, depth, node, trajectory, websocket) # If all retries and retries of retries fail, return the current trajectory and node return trajectory, node - async def rollout(self, node: LATSNode, max_depth: int = 2)-> tuple[float, LATSNode]: + async def rollout(self, node: LATSNode, max_depth: int = 2, websocket=None)-> tuple[float, LATSNode]: # Reset browser state await self._reset_browser() path = self.get_path_to_root(node) @@ -494,11 +548,24 @@ async def rollout(self, node: LATSNode, max_depth: int = 2)-> tuple[float, LATSN ## call the prompt agent print("current depth: ", len(path) - 1) print("max depth: ", self.config.max_depth) - trajectory, node = await self.send_completion_request(self.goal, len(path) - 1, node=n, trajectory=trajectory) + trajectory, node = await self.send_completion_request(self.goal, len(path) - 1, node=n, trajectory=trajectory, websocket=websocket) print("print the trajectory") print_trajectory(node) print("print the entire tree") print_entire_tree(self.root_node) + if websocket: + # tree_data = self._get_tree_data() + # await websocket.send_json({ + # "type": "tree_update", + # "tree": tree_data, + # "timestamp": datetime.utcnow().isoformat() + # }) + trajectory_data = self._get_trajectory_data(node) + await websocket.send_json({ + "type": "trajectory_update", + "trajectory": trajectory_data, + "timestamp": datetime.utcnow().isoformat() + }) page = await self.playwright_manager.get_page() page_info = await extract_page_info(page, self.config.fullpage, self.config.log_folder) @@ -801,3 +868,40 @@ def _get_tree_data(self): tree_data.append(node_data) return tree_data + + def _get_trajectory_data(self, terminal_node: LATSNode): + """Get trajectory data in a format suitable for visualization + + Args: + terminal_node: The leaf node to start the trajectory from + + Returns: + list: List of node data dictionaries representing the trajectory + """ + trajectory_data = [] + path = [] + + # Collect path from terminal to root + current = terminal_node + while current is not None: + path.append(current) + current = current.parent + + # Process nodes in order from root to terminal + for level, node in enumerate(reversed(path)): + node_data = { + "id": id(node), + "level": level, + "action": node.action if node.action else "ROOT", + "description": node.natural_language_description, + "visits": node.visits, + "value": float(f"{node.value:.3f}") if hasattr(node, 'value') else None, + "reward": float(f"{node.reward:.3f}") if hasattr(node, 'reward') else None, + "is_terminal": node.is_terminal, + "feedback": node.feedback if hasattr(node, 'feedback') else None, + "is_root": not hasattr(node, 'parent') or node.parent is None, + "is_terminal_node": node == terminal_node + } + trajectory_data.append(node_data) + + return trajectory_data diff --git a/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/simple_search_agent.py b/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/simple_search_agent.py index 851f231..a92d55e 100644 --- a/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/simple_search_agent.py +++ b/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/simple_search_agent.py @@ -1145,9 +1145,51 @@ def _get_tree_data(self): "action": node.action if node.action else "ROOT", "description": node.natural_language_description, "depth": node.depth, - "is_terminal": node.is_terminal + "is_terminal": node.is_terminal, + "value": node.value, + "visits": node.visits, + "reward": node.reward } tree_data.append(node_data) return tree_data + + def _get_trajectory_data(self, terminal_node: LATSNode): + """Get trajectory data in a format suitable for visualization + + Args: + terminal_node: The leaf node to start the trajectory from + + Returns: + list: List of node data dictionaries representing the trajectory + """ + trajectory_data = [] + path = [] + + # Collect path from terminal to root + current = terminal_node + while current is not None: + path.append(current) + current = current.parent + + # Process nodes in order from root to terminal + for level, node in enumerate(reversed(path)): + node_data = { + "id": id(node), + "level": level, + "action": node.action if node.action else "ROOT", + "description": node.natural_language_description, + "visits": node.visits, + "value": float(f"{node.value:.3f}") if hasattr(node, 'value') else None, + "reward": float(f"{node.reward:.3f}") if hasattr(node, 'reward') else None, + "is_terminal": node.is_terminal, + "feedback": node.feedback if hasattr(node, 'feedback') else None, + "is_root": not hasattr(node, 'parent') or node.parent is None, + "is_terminal_node": node == terminal_node + } + trajectory_data.append(node_data) + + return trajectory_data + + diff --git a/visual-tree-search-backend/app/api/lwats/agents_async/lats_agent.py b/visual-tree-search-backend/app/api/lwats/agents_async/lats_agent.py new file mode 100644 index 0000000..c937dcc --- /dev/null +++ b/visual-tree-search-backend/app/api/lwats/agents_async/lats_agent.py @@ -0,0 +1,36 @@ + def _get_trajectory_data(self, terminal_node: LATSNode): + """Get trajectory data in a format suitable for visualization + + Args: + terminal_node: The leaf node to start the trajectory from + + Returns: + list: List of node data dictionaries representing the trajectory + """ + trajectory_data = [] + path = [] + + # Collect path from terminal to root + current = terminal_node + while current is not None: + path.append(current) + current = current.parent + + # Process nodes in order from root to terminal + for level, node in enumerate(reversed(path)): + node_data = { + "id": id(node), + "level": level, + "action": node.action if node.action else "ROOT", + "description": node.natural_language_description, + "visits": node.visits, + "value": float(f"{node.value:.3f}") if hasattr(node, 'value') else None, + "reward": float(f"{node.reward:.3f}") if hasattr(node, 'reward') else None, + "is_terminal": node.is_terminal, + "feedback": node.feedback if hasattr(node, 'feedback') else None, + "is_root": not hasattr(node, 'parent') or node.parent is None, + "is_terminal_node": node == terminal_node + } + trajectory_data.append(node_data) + + return trajectory_data \ No newline at end of file diff --git a/visual-tree-search-backend/test/test-tree-search-ws-lats.py b/visual-tree-search-backend/test/test-tree-search-ws-lats.py index 49c5c79..567c790 100644 --- a/visual-tree-search-backend/test/test-tree-search-ws-lats.py +++ b/visual-tree-search-backend/test/test-tree-search-ws-lats.py @@ -85,8 +85,11 @@ async def connect_and_test_search( if status == "scored": logger.info(f"Node score: {data.get('score')}") + elif msg_type == "trajectory_update": + logger.info(f"Trajectory update received with {data.get('trajectory')}") + elif msg_type == "tree_update": - logger.info(f"Tree update received with {len(data.get('nodes', []))} nodes") + logger.info(f"Tree update received with {data.get('tree')}") elif msg_type == "best_path_update": logger.info(f"Best path update: score={data.get('score')}, path length={len(data.get('path', []))}") diff --git a/visual-tree-search-backend/test/test-tree-search-ws-simple.py b/visual-tree-search-backend/test/test-tree-search-ws-simple.py index 1db7708..5808149 100644 --- a/visual-tree-search-backend/test/test-tree-search-ws-simple.py +++ b/visual-tree-search-backend/test/test-tree-search-ws-simple.py @@ -80,7 +80,7 @@ async def connect_and_test_search( logger.info(f"Node score: {data.get('score')}") elif msg_type == "tree_update": - logger.info(f"Tree update received with {len(data.get('nodes', []))} nodes") + logger.info(f"Tree update received with {data.get('tree')}") elif msg_type == "best_path_update": logger.info(f"Best path update: score={data.get('score')}, path={data.get('path')}") From bcb3eccd599511a13881bcebcb38a8666a4cb6b1 Mon Sep 17 00:00:00 2001 From: Tata0703 Date: Fri, 4 Apr 2025 02:28:25 -0700 Subject: [PATCH 3/3] wrap up the backend work --- .../app/api/lwats/agents_async/SearchAgents/lats_agent.py | 1 + .../api/lwats/agents_async/SearchAgents/simple_search_agent.py | 1 + 2 files changed, 2 insertions(+) diff --git a/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/lats_agent.py b/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/lats_agent.py index 7e33974..5e70f28 100644 --- a/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/lats_agent.py +++ b/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/lats_agent.py @@ -863,6 +863,7 @@ def _get_tree_data(self): "is_terminal": node.is_terminal, "value": node.value, "visits": node.visits, + "feedback": node.feedback, "reward": node.reward } tree_data.append(node_data) diff --git a/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/simple_search_agent.py b/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/simple_search_agent.py index a92d55e..111e99c 100644 --- a/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/simple_search_agent.py +++ b/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/simple_search_agent.py @@ -1148,6 +1148,7 @@ def _get_tree_data(self): "is_terminal": node.is_terminal, "value": node.value, "visits": node.visits, + "feedback": node.feedback, "reward": node.reward } tree_data.append(node_data)