Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/memos/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def get_memreader_config() -> dict[str, Any]:
"config": {
"model_name_or_path": os.getenv("MEMRADER_MODEL", "gpt-4o-mini"),
"temperature": 0.6,
"max_tokens": int(os.getenv("MEMRADER_MAX_TOKENS", "5000")),
"max_tokens": int(os.getenv("MEMRADER_MAX_TOKENS", "8000")),
"top_p": 0.95,
"top_k": 20,
"api_key": os.getenv("MEMRADER_API_KEY", "EMPTY"),
Expand Down Expand Up @@ -614,7 +614,7 @@ def get_scheduler_config() -> dict[str, Any]:
),
"context_window_size": int(os.getenv("MOS_SCHEDULER_CONTEXT_WINDOW_SIZE", "5")),
"thread_pool_max_workers": int(
os.getenv("MOS_SCHEDULER_THREAD_POOL_MAX_WORKERS", "10")
os.getenv("MOS_SCHEDULER_THREAD_POOL_MAX_WORKERS", "10000")
),
"consume_interval_seconds": float(
os.getenv("MOS_SCHEDULER_CONSUME_INTERVAL_SECONDS", "0.01")
Expand Down
152 changes: 84 additions & 68 deletions src/memos/api/routers/server_router.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import json
import os
import random as _random
import socket
import time
import traceback

from collections.abc import Iterable
from datetime import datetime
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -69,6 +72,16 @@
logger = get_logger(__name__)

router = APIRouter(prefix="/product", tags=["Server API"])
INSTANCE_ID = f"{socket.gethostname()}:{os.getpid()}:{_random.randint(1000, 9999)}"


def _to_iter(running: Any) -> Iterable:
"""Normalize running tasks to an iterable of task objects."""
if running is None:
return []
if isinstance(running, dict):
return running.values()
return running # assume it's already an iterable (e.g., list)


def _build_graph_db_config(user_id: str = "default") -> dict[str, Any]:
Expand Down Expand Up @@ -607,46 +620,65 @@ def _process_pref_mem() -> list[dict[str, str]]:
)


@router.get("/scheduler/status", summary="Get scheduler running task count")
def scheduler_status():
"""
Return current running tasks count from scheduler dispatcher.
Shape is consistent with /scheduler/wait.
"""
@router.get("/scheduler/status", summary="Get scheduler running status")
def scheduler_status(user_name: str | None = None):
try:
running = mem_scheduler.dispatcher.get_running_tasks()
running_count = len(running)
now_ts = time.time()

return {
"message": "ok",
"data": {
"running_tasks": running_count,
"timestamp": now_ts,
},
}

if user_name:
running = mem_scheduler.dispatcher.get_running_tasks(
lambda task: getattr(task, "mem_cube_id", None) == user_name
)
tasks_iter = list(_to_iter(running))
running_count = len(tasks_iter)
return {
"message": "ok",
"data": {
"scope": "user",
"user_name": user_name,
"running_tasks": running_count,
"timestamp": time.time(),
"instance_id": INSTANCE_ID,
},
}
else:
running_all = mem_scheduler.dispatcher.get_running_tasks(lambda _t: True)
tasks_iter = list(_to_iter(running_all))
running_count = len(tasks_iter)

task_count_per_user: dict[str, int] = {}
for task in tasks_iter:
cube = getattr(task, "mem_cube_id", "unknown")
task_count_per_user[cube] = task_count_per_user.get(cube, 0) + 1

return {
"message": "ok",
"data": {
"scope": "global",
"running_tasks": running_count,
"task_count_per_user": task_count_per_user,
"timestamp": time.time(),
"instance_id": INSTANCE_ID,
},
}
except Exception as err:
logger.error("Failed to get scheduler status: %s", traceback.format_exc())

raise HTTPException(status_code=500, detail="Failed to get scheduler status") from err


@router.post("/scheduler/wait", summary="Wait until scheduler is idle")
def scheduler_wait(timeout_seconds: float = 120.0, poll_interval: float = 0.2):
@router.post("/scheduler/wait", summary="Wait until scheduler is idle for a specific user")
def scheduler_wait(
user_name: str,
timeout_seconds: float = 120.0,
poll_interval: float = 0.2,
):
"""
Block until scheduler has no running tasks, or timeout.
We return a consistent structured payload so callers can
tell whether this was a clean flush or a timeout.

Args:
timeout_seconds: max seconds to wait
poll_interval: seconds between polls
Block until scheduler has no running tasks for the given user_name, or timeout.
"""
start = time.time()
try:
while True:
running = mem_scheduler.dispatcher.get_running_tasks()
running = mem_scheduler.dispatcher.get_running_tasks(
lambda task: task.mem_cube_id == user_name
)
running_count = len(running)
elapsed = time.time() - start

Expand All @@ -658,6 +690,7 @@ def scheduler_wait(timeout_seconds: float = 120.0, poll_interval: float = 0.2):
"running_tasks": 0,
"waited_seconds": round(elapsed, 3),
"timed_out": False,
"user_name": user_name,
},
}

Expand All @@ -669,24 +702,23 @@ def scheduler_wait(timeout_seconds: float = 120.0, poll_interval: float = 0.2):
"running_tasks": running_count,
"waited_seconds": round(elapsed, 3),
"timed_out": True,
"user_name": user_name,
},
}

time.sleep(poll_interval)

except Exception as err:
logger.error(
"Failed while waiting for scheduler: %s",
traceback.format_exc(),
)
raise HTTPException(
status_code=500,
detail="Failed while waiting for scheduler",
) from err
logger.error("Failed while waiting for scheduler: %s", traceback.format_exc())
raise HTTPException(status_code=500, detail="Failed while waiting for scheduler") from err


@router.get("/scheduler/wait/stream", summary="Stream scheduler progress (SSE)")
def scheduler_wait_stream(timeout_seconds: float = 120.0, poll_interval: float = 0.2):
@router.get("/scheduler/wait/stream", summary="Stream scheduler progress for a user")
def scheduler_wait_stream(
user_name: str,
timeout_seconds: float = 120.0,
poll_interval: float = 0.2,
):
"""
Stream scheduler progress via Server-Sent Events (SSE).

Expand All @@ -704,38 +736,25 @@ def event_generator():
start = time.time()
try:
while True:
running = mem_scheduler.dispatcher.get_running_tasks()
running = mem_scheduler.dispatcher.get_running_tasks(
lambda task: task.mem_cube_id == user_name
)
running_count = len(running)
elapsed = time.time() - start

# heartbeat frame
heartbeat_payload = {
payload = {
"user_name": user_name,
"running_tasks": running_count,
"elapsed_seconds": round(elapsed, 3),
"status": "running" if running_count > 0 else "idle",
"instance_id": INSTANCE_ID,
}
yield "data: " + json.dumps(heartbeat_payload, ensure_ascii=False) + "\n\n"
yield "data: " + json.dumps(payload, ensure_ascii=False) + "\n\n"

# scheduler is idle -> final frame + break
if running_count == 0:
final_payload = {
"running_tasks": 0,
"elapsed_seconds": round(elapsed, 3),
"status": "idle",
"timed_out": False,
}
yield "data: " + json.dumps(final_payload, ensure_ascii=False) + "\n\n"
break

# timeout -> final frame + break
if elapsed > timeout_seconds:
final_payload = {
"running_tasks": running_count,
"elapsed_seconds": round(elapsed, 3),
"status": "timeout",
"timed_out": True,
}
yield "data: " + json.dumps(final_payload, ensure_ascii=False) + "\n\n"
if running_count == 0 or elapsed > timeout_seconds:
payload["status"] = "idle" if running_count == 0 else "timeout"
payload["timed_out"] = running_count > 0
yield "data: " + json.dumps(payload, ensure_ascii=False) + "\n\n"
break

time.sleep(poll_interval)
Expand All @@ -745,12 +764,9 @@ def event_generator():
"status": "error",
"detail": "stream_failed",
"exception": str(e),
"user_name": user_name,
}
logger.error(
"Failed streaming scheduler wait: %s: %s",
e,
traceback.format_exc(),
)
logger.error(f"Scheduler stream error for {user_name}: {traceback.format_exc()}")
yield "data: " + json.dumps(err_payload, ensure_ascii=False) + "\n\n"

return StreamingResponse(event_generator(), media_type="text/event-stream")
Expand Down
Loading