Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/memos/embedders/universal_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ def __init__(self, config: UniversalAPIEmbedderConfig):
else:
raise ValueError(f"Embeddings unsupported provider: {self.provider}")

@timed(log=True, log_prefix="model_timed_embedding")
@timed(
log=True,
log_prefix="model_timed_embedding",
log_extra_args={"model_name_or_path": "text-embedding-3-large"},
)
def embed(self, texts: list[str]) -> list[list[float]]:
if self.provider == "openai" or self.provider == "azure":
try:
Expand Down
2 changes: 1 addition & 1 deletion src/memos/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, config: OpenAILLMConfig):
)
logger.info("OpenAI LLM instance initialized")

@timed(log=True, log_prefix="OpenAI LLM")
@timed(log=True, log_prefix="OpenAI LLM", log_args=["model_name_or_path"])
def generate(self, messages: MessageList, **kwargs) -> str:
"""Generate a response from OpenAI LLM, optionally overriding generation params."""
response = self.client.chat.completions.create(
Expand Down
4 changes: 3 additions & 1 deletion src/memos/reranker/http_bge.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ def __init__(
self.warn_unknown_filter_keys = bool(warn_unknown_filter_keys)
self._warned_missing_keys: set[str] = set()

@timed(log=True, log_prefix="model_timed_rerank")
@timed(
log=True, log_prefix="model_timed_rerank", log_extra_args={"model_name_or_path": "reranker"}
)
def rerank(
self,
query: str,
Expand Down
44 changes: 36 additions & 8 deletions src/memos/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,48 @@
logger = get_logger(__name__)


def timed(func=None, *, log=True, log_prefix=""):
"""Decorator to measure and optionally log time of retrieval steps.

Can be used as @timed or @timed(log=True)
def timed(func=None, *, log=True, log_prefix="", log_args=None, log_extra_args=None):
"""
Parameters:
- log: enable timing logs (default True)
- log_prefix: prefix; falls back to function name
- log_args: names to include in logs (str or list/tuple of str).
Value priority: kwargs → args[0].config.<name> (if available).
Non-string items are ignored.

Examples:
- @timed(log=True, log_prefix="OpenAI LLM", log_args=["model_name_or_path", "temperature"])
- @timed(log=True, log_prefix="OpenAI LLM", log_args=["temperature"])
- @timed() # defaults
"""

def decorator(fn):
def wrapper(*args, **kwargs):
start = time.perf_counter()
result = fn(*args, **kwargs)
elapsed = time.perf_counter() - start
elapsed_ms = elapsed * 1000.0
if log:
logger.info(f"[TIMER] {log_prefix or fn.__name__} took {elapsed_ms:.0f} ms")
elapsed_ms = (time.perf_counter() - start) * 1000.0
ctx_str = ""
ctx_parts = []

if log is not True:
return result

if log_args:
for key in log_args:
val = kwargs.get(key)
ctx_parts.append(f"{key}={val}")
ctx_str = f" [{', '.join(ctx_parts)}]"

if log_extra_args:
ctx_parts.extend([f"{key}={val}" for key, val in log_extra_args.items()])

if ctx_parts:
ctx_str = f" [{', '.join(ctx_parts)}]"

logger.info(
f"[TIMER] {log_prefix or fn.__name__} took {elapsed_ms:.0f} ms, args: {ctx_str}"
)

return result

return wrapper
Expand Down
Loading