Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/memos/api/routers/server_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,11 @@ def scheduler_status(user_name: str | None = None):
cube = getattr(task, "mem_cube_id", "unknown")
task_count_per_user[cube] = task_count_per_user.get(cube, 0) + 1

try:
metrics_snapshot = mem_scheduler.dispatcher.metrics.snapshot()
except Exception:
metrics_snapshot = {}

return {
"message": "ok",
"data": {
Expand All @@ -661,6 +666,7 @@ def scheduler_status(user_name: str | None = None):
"task_count_per_user": task_count_per_user,
"timestamp": time.time(),
"instance_id": INSTANCE_ID,
"metrics": metrics_snapshot,
},
}
except Exception as err:
Expand Down
111 changes: 7 additions & 104 deletions src/memos/mem_scheduler/base_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
from memos.memories.activation.kv import KVCacheMemory
from memos.memories.activation.vllmkv import VLLMKVCacheItem, VLLMKVCacheMemory
from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory
from memos.memos_tools.notification_utils import send_online_bot_notification
from memos.templates.mem_scheduler_prompts import MEMORY_ASSEMBLY_TEMPLATE


Expand Down Expand Up @@ -127,21 +126,6 @@ def __init__(self, config: BaseSchedulerConfig):
"consume_interval_seconds", DEFAULT_CONSUME_INTERVAL_SECONDS
)

# queue monitor (optional)
self._queue_monitor_thread: threading.Thread | None = None
self._queue_monitor_running: bool = False
self.queue_monitor_interval_seconds: float = self.config.get(
"queue_monitor_interval_seconds", 60.0
)
self.queue_monitor_warn_utilization: float = self.config.get(
"queue_monitor_warn_utilization", 0.7
)
self.queue_monitor_crit_utilization: float = self.config.get(
"queue_monitor_crit_utilization", 0.9
)
self.enable_queue_monitor: bool = self.config.get("enable_queue_monitor", False)
self._online_bot_callable = None # type: ignore[var-annotated]

# other attributes
self._context_lock = threading.Lock()
self.current_user_id: UserID | str | None = None
Expand Down Expand Up @@ -541,6 +525,10 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt
logger.error(error_msg)
raise TypeError(error_msg)

if getattr(message, "timestamp", None) is None:
with contextlib.suppress(Exception):
message.timestamp = datetime.utcnow()

if self.disable_handlers and message.label in self.disable_handlers:
logger.info(f"Skipping disabled handler: {message.label} - {message.content}")
continue
Expand All @@ -555,6 +543,9 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt
logger.info(
f"Submitted message to local queue: {message.label} - {message.content}"
)
with contextlib.suppress(Exception):
if messages:
self.dispatcher.on_messages_enqueued(messages)

def _submit_web_logs(
self, messages: ScheduleLogForWebItem | list[ScheduleLogForWebItem]
Expand Down Expand Up @@ -706,13 +697,6 @@ def start(self) -> None:
self._consumer_thread.start()
logger.info("Message consumer thread started")

# optionally start queue monitor if enabled and bot callable present
if self.enable_queue_monitor and self._online_bot_callable is not None:
try:
self.start_queue_monitor(self._online_bot_callable)
except Exception as e:
logger.warning(f"Failed to start queue monitor: {e}")

def stop(self) -> None:
"""Stop all scheduler components gracefully.

Expand Down Expand Up @@ -762,9 +746,6 @@ def stop(self) -> None:
self._cleanup_queues()
logger.info("Memory Scheduler stopped completely")

# Stop queue monitor
self.stop_queue_monitor()

@property
def handlers(self) -> dict[str, Callable]:
"""
Expand Down Expand Up @@ -997,16 +978,6 @@ def _fmt_eta(seconds: float | None) -> str:

return True

# ---------------- Queue monitor & notifications ----------------
def set_notification_bots(self, online_bot=None):
"""
Set external notification callables.

Args:
online_bot: a callable matching dinding_report_bot.online_bot signature
"""
self._online_bot_callable = online_bot

def _gather_queue_stats(self) -> dict:
"""Collect queue/dispatcher stats for reporting."""
stats: dict[str, int | float | str] = {}
Expand Down Expand Up @@ -1044,71 +1015,3 @@ def _gather_queue_stats(self) -> dict:
except Exception:
stats.update({"running": 0, "inflight": 0, "handlers": 0})
return stats

def _queue_monitor_loop(self, online_bot) -> None:
logger.info(f"Queue monitor started (interval={self.queue_monitor_interval_seconds}s)")
self._queue_monitor_running = True
while self._queue_monitor_running:
time.sleep(self.queue_monitor_interval_seconds)
try:
stats = self._gather_queue_stats()
# decide severity based on utilization if local queue
title_color = "#00956D"
subtitle = "Scheduler"
if not stats.get("use_redis_queue"):
util = float(stats.get("utilization", 0.0))
if util >= self.queue_monitor_crit_utilization:
title_color = "#C62828" # red
subtitle = "Scheduler (CRITICAL)"
elif util >= self.queue_monitor_warn_utilization:
title_color = "#E65100" # orange
subtitle = "Scheduler (WARNING)"

other_data1 = {
"use_redis_queue": stats.get("use_redis_queue"),
"handlers": stats.get("handlers"),
"running": stats.get("running"),
"inflight": stats.get("inflight"),
}
if not stats.get("use_redis_queue"):
other_data2 = {
"qsize": stats.get("qsize"),
"unfinished_tasks": stats.get("unfinished_tasks"),
"maxsize": stats.get("maxsize"),
"utilization": f"{float(stats.get('utilization', 0.0)):.2%}",
}
else:
other_data2 = {
"redis_mode": True,
}

send_online_bot_notification(
online_bot=online_bot,
header_name="Scheduler Queue",
sub_title_name=subtitle,
title_color=title_color,
other_data1=other_data1,
other_data2=other_data2,
emoji={"Runtime": "🧠", "Queue": "📬"},
)
except Exception as e:
logger.warning(f"Queue monitor iteration failed: {e}")
logger.info("Queue monitor stopped")

def start_queue_monitor(self, online_bot) -> None:
if self._queue_monitor_thread and self._queue_monitor_thread.is_alive():
return
self._online_bot_callable = online_bot
self._queue_monitor_thread = threading.Thread(
target=self._queue_monitor_loop,
args=(online_bot,),
daemon=True,
name="QueueMonitorThread",
)
self._queue_monitor_thread.start()

def stop_queue_monitor(self) -> None:
self._queue_monitor_running = False
if self._queue_monitor_thread and self._queue_monitor_thread.is_alive():
with contextlib.suppress(Exception):
self._queue_monitor_thread.join(timeout=2.0)
47 changes: 47 additions & 0 deletions src/memos/mem_scheduler/general_modules/dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import concurrent
import threading
import time

from collections import defaultdict
from collections.abc import Callable
from datetime import timezone
from typing import Any

from memos.context.context import ContextThreadPoolExecutor
Expand All @@ -11,6 +13,7 @@
from memos.mem_scheduler.general_modules.task_threads import ThreadManager
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
from memos.mem_scheduler.schemas.task_schemas import RunningTaskItem
from memos.mem_scheduler.utils.metrics import MetricsRegistry


logger = get_logger(__name__)
Expand Down Expand Up @@ -70,6 +73,19 @@ def __init__(self, max_workers=30, enable_parallel_dispatch=True, config=None):
self._completed_tasks = []
self.completed_tasks_max_show_size = 10

self.metrics = MetricsRegistry(
topk_per_label=(self.config or {}).get("metrics_topk_per_label", 50)
)

def on_messages_enqueued(self, msgs: list[ScheduleMessageItem]) -> None:
if not msgs:
return
now = time.time()
for m in msgs:
self.metrics.on_enqueue(
label=m.label, mem_cube_id=m.mem_cube_id, inst_rate=1.0, now=now
)

def _create_task_wrapper(self, handler: Callable, task_item: RunningTaskItem):
"""
Create a wrapper around the handler to track task execution and capture results.
Expand All @@ -84,9 +100,37 @@ def _create_task_wrapper(self, handler: Callable, task_item: RunningTaskItem):

def wrapped_handler(messages: list[ScheduleMessageItem]):
try:
# --- mark start: record queuing time(now - enqueue_ts)---
now = time.time()
for m in messages:
enq_ts = getattr(m, "timestamp", None)

# Path 1: epoch seconds (preferred)
if isinstance(enq_ts, int | float):
enq_epoch = float(enq_ts)

# Path 2: datetime -> normalize to UTC epoch
elif hasattr(enq_ts, "timestamp"):
dt = enq_ts
if dt.tzinfo is None:
# treat naive as UTC to neutralize +8h skew
dt = dt.replace(tzinfo=timezone.utc)
enq_epoch = dt.timestamp()
else:
# fallback: treat as "just now"
enq_epoch = now

wait_sec = max(0.0, now - enq_epoch)
self.metrics.on_start(
label=m.label, mem_cube_id=m.mem_cube_id, wait_sec=wait_sec, now=now
)

# Execute the original handler
result = handler(messages)

# --- mark done ---
for m in messages:
self.metrics.on_done(label=m.label, mem_cube_id=m.mem_cube_id, now=time.time())
# Mark task as completed and remove from tracking
with self._task_lock:
if task_item.item_id in self._running_tasks:
Expand All @@ -100,6 +144,9 @@ def wrapped_handler(messages: list[ScheduleMessageItem]):

except Exception as e:
# Mark task as failed and remove from tracking
for m in messages:
self.metrics.on_done(label=m.label, mem_cube_id=m.mem_cube_id, now=time.time())
# Mark task as failed and remove from tracking
with self._task_lock:
if task_item.item_id in self._running_tasks:
task_item.mark_failed(str(e))
Expand Down
Loading