Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions visual-tree-search-backend/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,12 @@ python run_demo_treesearch_async.py \
```
uvicorn app.main:app --host 0.0.0.0 --port 3000
python test/test-tree-search-ws-lats.py
```

## 7. Add MCTS agent
* test run_demo_treesearch_async.py
* test web socket
```
uvicorn app.main:app --host 0.0.0.0 --port 3000
python test/test-tree-search-ws-mcts.py
```

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -1146,4 +1146,3 @@ def _get_tree_data(self):
tree_data.append(node_data)

return tree_data

Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ async def handle_search_request(websocket: WebSocket, message: Dict[str, Any]):
await agent.dfs_with_websocket(websocket)
elif search_algorithm.lower() == "lats":
await agent.run(websocket)
elif search_algorithm.lower() == "mcts":
await agent.run(websocket)
else:
await websocket.send_json({
"type": "error",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ async def handle_search_request(websocket: WebSocket, message: Dict[str, Any]):
await agent.dfs_with_websocket(websocket)
elif search_algorithm.lower() == "lats":
await agent.run(websocket)
elif search_algorithm.lower() == "mcts":
await agent.run(websocket)
else:
await websocket.send_json({
"type": "error",
Expand Down
22 changes: 11 additions & 11 deletions visual-tree-search-backend/app/api/shopping.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"value": "",
"domain": "128.105.145.205",
"path": "/",
"expires": 1775272527,
"expires": 1775370120,
"httpOnly": false,
"secure": false,
"sameSite": "Strict"
Expand Down Expand Up @@ -81,20 +81,20 @@
},
{
"name": "private_content_version",
"value": "c514ad01bc9816ee30ab1f02510aa34a",
"value": "ff4bba58081243f67b1adee2cc6974bd",
"domain": "128.105.145.205",
"path": "/",
"expires": 1778296522.505323,
"expires": 1778394118.522057,
"httpOnly": false,
"secure": false,
"sameSite": "Lax"
},
{
"name": "PHPSESSID",
"value": "007c737fca0eb4173ab5362c2d9c8b09",
"value": "30247306b1ad824f37f3e0384d86d991",
"domain": "128.105.145.205",
"path": "/",
"expires": 1775272526.468528,
"expires": 1775370122.385659,
"httpOnly": true,
"secure": false,
"sameSite": "Lax"
Expand All @@ -104,17 +104,17 @@
"value": "9bf9a599123e6402b85cde67144717a08b817412",
"domain": "128.105.145.205",
"path": "/",
"expires": 1775272526.468754,
"expires": 1775370122.385877,
"httpOnly": true,
"secure": false,
"sameSite": "Lax"
},
{
"name": "form_key",
"value": "Hsr3n5ycGkOPfr1K",
"value": "IEPmx1hKh4NWjeUa",
"domain": "128.105.145.205",
"path": "/",
"expires": 1775272526.468692,
"expires": 1775370122.385817,
"httpOnly": false,
"secure": false,
"sameSite": "Lax"
Expand All @@ -124,17 +124,17 @@
"value": "true",
"domain": "128.105.145.205",
"path": "/",
"expires": 1775272527,
"expires": 1775370120,
"httpOnly": false,
"secure": false,
"sameSite": "Lax"
},
{
"name": "section_data_ids",
"value": "{%22messages%22:1743736525%2C%22customer%22:1743736525%2C%22compare-products%22:1743736525%2C%22last-ordered-items%22:1743736525%2C%22cart%22:1743736525%2C%22directory-data%22:1743736525%2C%22captcha%22:1743736525%2C%22instant-purchase%22:1743736525%2C%22loggedAsCustomer%22:1743736525%2C%22persistent%22:1743736525%2C%22review%22:1743736525%2C%22wishlist%22:1743736525%2C%22recently_viewed_product%22:1743736525%2C%22recently_compared_product%22:1743736525%2C%22product_data_storage%22:1743736525%2C%22paypal-billing-agreement%22:1743736525}",
"value": "{%22messages%22:1743834120%2C%22customer%22:1743834120%2C%22compare-products%22:1743834120%2C%22last-ordered-items%22:1743834120%2C%22cart%22:1743834120%2C%22directory-data%22:1743834120%2C%22captcha%22:1743834120%2C%22instant-purchase%22:1743834120%2C%22loggedAsCustomer%22:1743834120%2C%22persistent%22:1743834120%2C%22review%22:1743834120%2C%22wishlist%22:1743834120%2C%22recently_viewed_product%22:1743834120%2C%22recently_compared_product%22:1743834120%2C%22product_data_storage%22:1743834120%2C%22paypal-billing-agreement%22:1743834120}",
"domain": "128.105.145.205",
"path": "/",
"expires": 1775272524,
"expires": 1775370120,
"httpOnly": false,
"secure": false,
"sameSite": "Lax"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def parse_arguments():
parser.add_argument("--goal", type=str, default=DEFAULT_GOAL,
help=f"Goal to achieve (default: {DEFAULT_GOAL})")

parser.add_argument("--algorithm", type=str, choices=["bfs", "dfs", "lats"], default="lats",
parser.add_argument("--algorithm", type=str, choices=["bfs", "dfs", "lats", "mcts"], default="lats",
help="Search algorithm to use (default: lats)")

parser.add_argument("--max-depth", type=int, default=3,
Expand Down
169 changes: 169 additions & 0 deletions visual-tree-search-backend/test/test-tree-search-ws-mcts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import asyncio
import json
import websockets
import argparse
import logging
from datetime import datetime

# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Default values
DEFAULT_WS_URL = "ws://localhost:3000/new-tree-search-ws"
DEFAULT_STARTING_URL = "http://128.105.145.205:7770/"
DEFAULT_GOAL = "search running shoes, click on the first result"

async def connect_and_test_search(
ws_url: str,
starting_url: str,
goal: str,
search_algorithm: str = "bfs",
max_depth: int = 3
):
"""
Connect to the WebSocket endpoint and test the tree search functionality.

Args:
ws_url: WebSocket URL to connect to
starting_url: URL to start the search from
goal: Goal to achieve
search_algorithm: Search algorithm to use (bfs or dfs)
max_depth: Maximum depth for the search tree
"""
logger.info(f"Connecting to WebSocket at {ws_url}")

async with websockets.connect(ws_url) as websocket:
logger.info("Connected to WebSocket")

# Wait for connection established message
response = await websocket.recv()
data = json.loads(response)
if data.get("type") == "connection_established":
logger.info(f"Connection established with ID: {data.get('connection_id')}")

# Send search request
request = {
"type": "start_search",
"agent_type": "MCTSAgent",
"starting_url": starting_url,
"goal": goal,
"search_algorithm": search_algorithm,
"max_depth": max_depth
}

logger.info(f"Sending search request: {request}")
await websocket.send(json.dumps(request))

# Process responses
while True:
try:
response = await websocket.recv()
data = json.loads(response)

# Log the message type and some key information
msg_type = data.get("type", "unknown")

if msg_type == "status_update":
logger.info(f"Status update: {data.get('status')} - {data.get('message')}")

elif msg_type == "iteration_start":
logger.info(f"Iteration start: {data.get('iteration')}")

elif msg_type == "step_start":
logger.info(f"Step start: {data.get('step')} - {data.get('step_name')}")

elif msg_type == "node_update":
node_id = data.get("node_id")
status = data.get("status")
logger.info(f"Node update: {node_id} - {status}")

# If node was scored, log the score
if status == "scored":
logger.info(f"Node score: {data.get('score')}")

elif msg_type == "trajectory_update":
logger.info(f"Trajectory update received with {data.get('trajectory')}")

elif msg_type == "tree_update":
logger.info(f"Tree update received with {data.get('tree')}")

elif msg_type == "best_path_update":
logger.info(f"Best path update: score={data.get('score')}, path length={len(data.get('path', []))}")

elif msg_type == "search_complete":
status = data.get("status")
score = data.get("score", "N/A")
path_length = len(data.get("path", []))

logger.info(f"Search complete: {status}, score={score}, path length={path_length}")
logger.info("Path actions:")

for i, node in enumerate(data.get("path", [])):
logger.info(f" {i+1}. {node.get('action')}")

# Exit the loop when search is complete
break

elif msg_type == "error":
logger.error(f"Error: {data.get('message')}")
break

else:
logger.info(f"Received message of type {msg_type}")
logger.info(f"Message: {data}")

except websockets.exceptions.ConnectionClosed:
logger.warning("WebSocket connection closed")
break
except Exception as e:
logger.error(f"Error processing message: {e}")
break

logger.info("Test completed")

def parse_arguments():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(description="Test the tree search WebSocket functionality")

parser.add_argument("--ws-url", type=str, default=DEFAULT_WS_URL,
help=f"WebSocket URL (default: {DEFAULT_WS_URL})")

parser.add_argument("--starting-url", type=str, default=DEFAULT_STARTING_URL,
help=f"Starting URL for the search (default: {DEFAULT_STARTING_URL})")

parser.add_argument("--goal", type=str, default=DEFAULT_GOAL,
help=f"Goal to achieve (default: {DEFAULT_GOAL})")

parser.add_argument("--algorithm", type=str, choices=["bfs", "dfs", "lats", "mcts"], default="mcts",
help="Search algorithm to use (default: lats)")

parser.add_argument("--max-depth", type=int, default=3,
help="Maximum depth for the search tree (default: 3)")

return parser.parse_args()

async def main():
"""Main entry point"""
args = parse_arguments()

logger.info("Starting tree search WebSocket test")
logger.info(f"WebSocket URL: {args.ws_url}")
logger.info(f"Starting URL: {args.starting_url}")
logger.info(f"Goal: {args.goal}")
logger.info(f"Algorithm: {args.algorithm}")
logger.info(f"Max depth: {args.max_depth}")

await connect_and_test_search(
ws_url=args.ws_url,
starting_url=args.starting_url,
goal=args.goal,
search_algorithm=args.algorithm,
max_depth=args.max_depth
)

if __name__ == "__main__":
asyncio.run(main())