Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/memos/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,7 +872,7 @@ def get_scheduler_config() -> dict[str, Any]:
),
"context_window_size": int(os.getenv("MOS_SCHEDULER_CONTEXT_WINDOW_SIZE", "5")),
"thread_pool_max_workers": int(
os.getenv("MOS_SCHEDULER_THREAD_POOL_MAX_WORKERS", "200")
os.getenv("MOS_SCHEDULER_THREAD_POOL_MAX_WORKERS", "50")
),
"consume_interval_seconds": float(
os.getenv("MOS_SCHEDULER_CONSUME_INTERVAL_SECONDS", "0.01")
Expand Down
3 changes: 3 additions & 0 deletions src/memos/mem_scheduler/task_schedule_modules/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,9 @@ def stats(self) -> dict[str, int]:
running = 0
try:
with self._task_lock:
done = {f for f in self._futures if f.done()}
if done:
self._futures -= done
inflight = len(self._futures)
except Exception:
inflight = 0
Expand Down
113 changes: 89 additions & 24 deletions src/memos/mem_scheduler/task_schedule_modules/redis_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ def __init__(
self.task_broker_flush_bar = 10
self._refill_lock = threading.Lock()
self._refill_thread: ContextThread | None = None
self._refill_in_progress = False
self._refill_thread_start: float = 0.0
self._refill_thread_timeout: float = float(
os.getenv("MEMSCHEDULER_REDIS_REFILL_TIMEOUT_SEC", "30") or 30
)

# Track empty streams first-seen time to avoid zombie keys
self._empty_stream_seen_times: dict[str, float] = {}
Expand All @@ -110,8 +115,11 @@ def __init__(

self.seen_streams = set()

# Task Orchestrator
self.message_pack_cache = deque()
# Task Orchestrator — cap in-memory cache to avoid unbounded growth
self._cache_max_packs = int(os.getenv("MEMSCHEDULER_REDIS_CACHE_MAX_PACKS", "50") or 50)
self.message_pack_cache: deque[list[ScheduleMessageItem]] = deque(
maxlen=self._cache_max_packs
)

self.orchestrator = SchedulerOrchestrator() if orchestrator is None else orchestrator

Expand Down Expand Up @@ -349,38 +357,78 @@ def task_broker(
def _async_refill_cache(self, batch_size: int) -> None:
"""Background thread to refill message cache without blocking get_messages."""
try:
logger.debug(f"Starting async cache refill with batch_size={batch_size}")
with self._refill_lock:
remaining = self._cache_max_packs - len(self.message_pack_cache)
if remaining <= 0:
logger.debug("Async refill skipped: cache already at capacity")
return
self._refill_in_progress = True

logger.debug(
f"Starting async cache refill with batch_size={batch_size}, remaining_capacity={remaining}"
)
new_packs = self.task_broker(consume_batch_size=batch_size)
logger.debug(f"task_broker returned {len(new_packs)} packs")

with self._refill_lock:
added = 0
for pack in new_packs:
if pack: # Only add non-empty packs
if pack:
self.message_pack_cache.append(pack)
logger.debug(f"Added pack with {len(pack)} messages to cache")
logger.debug(f"Cache refill complete, cache size now: {len(self.message_pack_cache)}")
added += 1
if added >= remaining:
break
logger.debug(
f"Cache refill complete, added={added}, cache size now: {len(self.message_pack_cache)}"
)
except Exception as e:
logger.warning(f"Async cache refill failed: {e}", exc_info=True)
finally:
with self._refill_lock:
self._refill_in_progress = False

def _is_refill_thread_available(self) -> bool:
"""Check whether a new refill thread can be started."""
if self._refill_thread is None or not self._refill_thread.is_alive():
return True
if (time.time() - self._refill_thread_start) > self._refill_thread_timeout:
logger.warning(
f"Refill thread has been running for >{self._refill_thread_timeout}s, treating as stale"
)
return True
return False

def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]:
if self.message_pack_cache:
# Trigger async refill if below threshold (non-blocking)
if len(self.message_pack_cache) < self.task_broker_flush_bar and (
self._refill_thread is None or not self._refill_thread.is_alive()
if (
len(self.message_pack_cache) < self.task_broker_flush_bar
and self._is_refill_thread_available()
):
logger.debug(
f"Triggering async cache refill: cache size {len(self.message_pack_cache)} < {self.task_broker_flush_bar}"
)
self._refill_thread = ContextThread(
target=self._async_refill_cache, args=(batch_size,), name="redis-cache-refill"
)
self._refill_thread_start = time.time()
self._refill_thread.start()
else:
logger.debug(f"The size of message_pack_cache is {len(self.message_pack_cache)}")
else:
new_packs = self.task_broker(consume_batch_size=batch_size)
for pack in new_packs:
if pack: # Only add non-empty packs
self.message_pack_cache.append(pack)
should_fetch = False
with self._refill_lock:
if not self.message_pack_cache and not self._refill_in_progress:
self._refill_in_progress = True
should_fetch = True
if should_fetch:
try:
new_packs = self.task_broker(consume_batch_size=batch_size)
with self._refill_lock:
for pack in new_packs:
if pack:
self.message_pack_cache.append(pack)
finally:
with self._refill_lock:
self._refill_in_progress = False
if len(self.message_pack_cache) == 0:
return []
else:
Expand Down Expand Up @@ -443,12 +491,17 @@ def put(
with self._stream_keys_lock:
if stream_key not in self.seen_streams:
self.seen_streams.add(stream_key)
self._ensure_consumer_group(stream_key=stream_key)
need_create_group = True
else:
need_create_group = False

if stream_key not in self._stream_keys_cache:
self._stream_keys_cache.append(stream_key)
self._stream_keys_last_refresh = time.time()

if need_create_group:
self._ensure_consumer_group(stream_key=stream_key)

message.stream_key = stream_key

# Convert message to dictionary for Redis storage
Expand Down Expand Up @@ -1054,14 +1107,9 @@ def get_stream_keys(self, stream_key_prefix: str | None = None) -> list[str]:
with self._stream_keys_lock:
cache_snapshot = list(self._stream_keys_cache)

# Validate that cached keys conform to the expected prefix
escaped_prefix = re.escape(effective_prefix)
regex_pattern = f"^{escaped_prefix}:"
for key in cache_snapshot:
if not re.match(regex_pattern, key):
logger.error(
f"[REDIS_QUEUE] Cached stream key '{key}' does not match prefix '{effective_prefix}:'"
)
if effective_prefix != self.stream_key_prefix:
pattern = re.compile(f"^{re.escape(effective_prefix)}:")
cache_snapshot = [k for k in cache_snapshot if pattern.match(k)]

return cache_snapshot

Expand Down Expand Up @@ -1211,7 +1259,7 @@ def __del__(self):

@property
def unfinished_tasks(self) -> int:
return self.qsize()
return self.size()

def _scan_candidate_stream_keys(
self,
Expand Down Expand Up @@ -1396,6 +1444,23 @@ def _update_stream_cache_with_log(
self._stream_keys_cache = active_stream_keys
self._stream_keys_last_refresh = time.time()
cache_count = len(self._stream_keys_cache)

active_set = set(active_stream_keys)
stale = self.seen_streams - active_set
if stale:
self.seen_streams -= stale
logger.debug(f"Pruned {len(stale)} stale entries from seen_streams")

candidate_set = set(candidate_keys)
with self._empty_stream_seen_lock:
orphaned = [k for k in self._empty_stream_seen_times if k not in candidate_set]
for k in orphaned:
del self._empty_stream_seen_times[k]
if orphaned:
logger.debug(
f"Pruned {len(orphaned)} orphaned entries from _empty_stream_seen_times"
)

logger.debug(
f"Refreshed stream keys cache: {cache_count} active keys, "
f"{deleted_count} deleted, {len(candidate_keys)} candidates examined."
Expand Down
Loading