Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@

openai_client = OpenAI()

## TODO: add best_path_update

class LATSAgent:
"""
Language-based Action Tree Search Agent implementation.
Expand Down Expand Up @@ -117,6 +119,13 @@ async def run(self, websocket=None) -> list[LATSNode]:
print_trajectory(best_node)

if websocket:
# trajectory_data = self._get_trajectory_data(best_node)
# await websocket.send_json({
# "type": "trajectory_update",
# "trajectory": trajectory_data,
# "timestamp": datetime.utcnow().isoformat()
# })
# TODO: use score instead of reward to determine success
await websocket.send_json({
"type": "search_complete",
"status": "success" if best_node.reward == 1 else "partial_success",
Expand Down Expand Up @@ -158,12 +167,19 @@ async def lats_search(self, websocket=None) -> LATSNode:
if websocket:
await websocket.send_json({
"type": "step_start",
"step": "selection",
"step": 1,
"step_name": "selection",
"iteration": i + 1,
"timestamp": datetime.utcnow().isoformat()
})

node = self.select_node(self.root_node)
if websocket:
await websocket.send_json({
"type": "node_selected",
"node_id": id(node),
"timestamp": datetime.utcnow().isoformat()
})

if node is None:
print("All paths lead to terminal nodes with reward 0. Ending search.")
Expand All @@ -177,7 +193,8 @@ async def lats_search(self, websocket=None) -> LATSNode:
if websocket:
await websocket.send_json({
"type": "step_start",
"step": "expansion",
"step": 2,
"step_name": "expansion",
"iteration": i + 1,
"timestamp": datetime.utcnow().isoformat()
})
Expand All @@ -202,28 +219,65 @@ async def lats_search(self, websocket=None) -> LATSNode:
print(f"{GREEN}Tree:{RESET}")
better_print(self.root_node)
print(f"")
tree_data = self._get_tree_data()
await websocket.send_json({
"type": "tree_update",
"tree": tree_data,
"timestamp": datetime.utcnow().isoformat()
})

# Step 3: Evaluation
print(f"")
print(f"{GREEN}Step 3: evaluation{RESET}")
if websocket:
await websocket.send_json({
"type": "step_start",
"step": 3,
"step_name": "evaluation",
"iteration": i + 1,
"timestamp": datetime.utcnow().isoformat()
})
await self.evaluate_node(node)

print(f"{GREEN}Tree:{RESET}")
better_print(self.root_node)
print(f"")
## send tree update, since evaluation is added to the tree
if websocket:
tree_data = self._get_tree_data()
await websocket.send_json({
"type": "tree_update",
"tree": tree_data,
"timestamp": datetime.utcnow().isoformat()
})


# Step 4: Simulation
print(f"{GREEN}Step 4: simulation{RESET}")
# # Find the child with the highest value
if websocket:
await websocket.send_json({
"type": "step_start",
"step": 4,
"step_name": "simulation",
"iteration": i + 1,
"timestamp": datetime.utcnow().isoformat()
})
## always = 1
reward, terminal_node = await self.simulate(max(node.children, key=lambda child: child.value), max_depth=self.config.max_depth, num_simulations=1)
reward, terminal_node = await self.simulate(max(node.children, key=lambda child: child.value), max_depth=self.config.max_depth, num_simulations=1, websocket=websocket)
terminal_nodes.append(terminal_node)

if reward == 1:
return terminal_node

# Step 5: Backpropagation
print(f"{GREEN}Step 5: backpropagation{RESET}")
if websocket:
await websocket.send_json({
"type": "step_start",
"step": 5,
"step_name": "backpropagation",
"timestamp": datetime.utcnow().isoformat()
})
self.backpropagate(terminal_node, reward)
print(f"{GREEN}Tree:{RESET}")
better_print(self.root_node)
Expand Down Expand Up @@ -335,7 +389,8 @@ async def evaluate_node(self, node: LATSNode) -> None:
child.value = score
child.reward = score

async def simulate(self, node: LATSNode, max_depth: int = 2, num_simulations=1) -> tuple[float, LATSNode]:
## TODO: make number of simulations configurable
async def simulate(self, node: LATSNode, max_depth: int = 2, num_simulations=1, websocket=None) -> tuple[float, LATSNode]:
"""
Perform a rollout simulation from a node.

Expand All @@ -351,13 +406,39 @@ async def simulate(self, node: LATSNode, max_depth: int = 2, num_simulations=1)
print_trajectory(node)
print("print the entire tree")
print_entire_tree(self.root_node)
return await self.rollout(node, max_depth=max_depth)
if websocket:
tree_data = self._get_tree_data()
await websocket.send_json({
"type": "tree_update",
"tree": tree_data,
"timestamp": datetime.utcnow().isoformat()
})
trajectory_data = self._get_trajectory_data(node)
await websocket.send_json({
"type": "trajectory_update",
"trajectory": trajectory_data,
"timestamp": datetime.utcnow().isoformat()
})
return await self.rollout(node, max_depth=max_depth, websocket=websocket)

async def send_completion_request(self, plan, depth, node, trajectory=[]):
async def send_completion_request(self, plan, depth, node, trajectory=[], websocket=None):
print("print the trajectory")
print_trajectory(node)
print("print the entire tree")
print_entire_tree(self.root_node)
if websocket:
# tree_data = self._get_tree_data()
# await websocket.send_json({
# "type": "tree_update",
# "tree": tree_data,
# "timestamp": datetime.utcnow().isoformat()
# })
trajectory_data = self._get_trajectory_data(node)
await websocket.send_json({
"type": "trajectory_update",
"trajectory": trajectory_data,
"timestamp": datetime.utcnow().isoformat()
})

if depth >= self.config.max_depth:
return trajectory, node
Expand Down Expand Up @@ -420,20 +501,20 @@ async def send_completion_request(self, plan, depth, node, trajectory=[]):
if goal_finished:
return trajectory, new_node

return await self.send_completion_request(plan, depth + 1, new_node, trajectory)
return await self.send_completion_request(plan, depth + 1, new_node, trajectory, websocket)

except Exception as e:
print(f"Attempt {attempt + 1} failed with error: {e}")
if attempt + 1 == retry_count:
print("Max retries reached. Skipping this step and retrying the whole request.")
# Retry the entire request from the same state
return await self.send_completion_request(plan, depth, node, trajectory)
return await self.send_completion_request(plan, depth, node, trajectory, websocket)

# If all retries and retries of retries fail, return the current trajectory and node
return trajectory, node


async def rollout(self, node: LATSNode, max_depth: int = 2)-> tuple[float, LATSNode]:
async def rollout(self, node: LATSNode, max_depth: int = 2, websocket=None)-> tuple[float, LATSNode]:
# Reset browser state
await self._reset_browser()
path = self.get_path_to_root(node)
Expand Down Expand Up @@ -467,11 +548,24 @@ async def rollout(self, node: LATSNode, max_depth: int = 2)-> tuple[float, LATSN
## call the prompt agent
print("current depth: ", len(path) - 1)
print("max depth: ", self.config.max_depth)
trajectory, node = await self.send_completion_request(self.goal, len(path) - 1, node=n, trajectory=trajectory)
trajectory, node = await self.send_completion_request(self.goal, len(path) - 1, node=n, trajectory=trajectory, websocket=websocket)
print("print the trajectory")
print_trajectory(node)
print("print the entire tree")
print_entire_tree(self.root_node)
if websocket:
# tree_data = self._get_tree_data()
# await websocket.send_json({
# "type": "tree_update",
# "tree": tree_data,
# "timestamp": datetime.utcnow().isoformat()
# })
trajectory_data = self._get_trajectory_data(node)
await websocket.send_json({
"type": "trajectory_update",
"trajectory": trajectory_data,
"timestamp": datetime.utcnow().isoformat()
})

page = await self.playwright_manager.get_page()
page_info = await extract_page_info(page, self.config.fullpage, self.config.log_folder)
Expand Down Expand Up @@ -769,8 +863,46 @@ def _get_tree_data(self):
"is_terminal": node.is_terminal,
"value": node.value,
"visits": node.visits,
"feedback": node.feedback,
"reward": node.reward
}
tree_data.append(node_data)

return tree_data

def _get_trajectory_data(self, terminal_node: LATSNode):
"""Get trajectory data in a format suitable for visualization

Args:
terminal_node: The leaf node to start the trajectory from

Returns:
list: List of node data dictionaries representing the trajectory
"""
trajectory_data = []
path = []

# Collect path from terminal to root
current = terminal_node
while current is not None:
path.append(current)
current = current.parent

# Process nodes in order from root to terminal
for level, node in enumerate(reversed(path)):
node_data = {
"id": id(node),
"level": level,
"action": node.action if node.action else "ROOT",
"description": node.natural_language_description,
"visits": node.visits,
"value": float(f"{node.value:.3f}") if hasattr(node, 'value') else None,
"reward": float(f"{node.reward:.3f}") if hasattr(node, 'reward') else None,
"is_terminal": node.is_terminal,
"feedback": node.feedback if hasattr(node, 'feedback') else None,
"is_root": not hasattr(node, 'parent') or node.parent is None,
"is_terminal_node": node == terminal_node
}
trajectory_data.append(node_data)

return trajectory_data
Loading