diff --git a/docker/.env.example b/docker/.env.example index 0f4fcb65d..037eb8db8 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -1,60 +1,174 @@ -# MemOS Environment Variables Configuration -TZ=Asia/Shanghai +# MemOS Environment Variables (core runtime) +# Legend: [required] needed for default startup; others are optional or conditional per comments. -MOS_CUBE_PATH="/tmp/data_test" # Path to memory storage (e.g. /tmp/data_test) -MOS_ENABLE_DEFAULT_CUBE_CONFIG="true" # Enable default cube config (true/false) +## Base +TZ=Asia/Shanghai +ENV_NAME=PLAYGROUND_OFFLINE # Tag shown in DingTalk notifications (e.g., PROD_ONLINE/TEST); no runtime effect unless ENABLE_DINGDING_BOT=true +MOS_CUBE_PATH=/tmp/data_test # local data path +MEMOS_BASE_PATH=. # CLI/SDK cache path +MOS_ENABLE_DEFAULT_CUBE_CONFIG=true # enable default cube config +MOS_ENABLE_REORGANIZE=false # enable memory reorg +MOS_TEXT_MEM_TYPE=general_text # general_text | tree_text +ASYNC_MODE=sync # async/sync, used in default cube config -# OpenAI Configuration -OPENAI_API_KEY="sk-xxx" # Your OpenAI API key -OPENAI_API_BASE="http://xxx" # OpenAI API base URL (default: https://api.openai.com/v1) +## User/session defaults +MOS_USER_ID=root +MOS_SESSION_ID=default_session +MOS_MAX_TURNS_WINDOW=20 +MOS_TOP_K=50 -# MemOS Chat Model Configuration +## Chat LLM (main dialogue) MOS_CHAT_MODEL=gpt-4o-mini MOS_CHAT_TEMPERATURE=0.8 MOS_MAX_TOKENS=8000 MOS_TOP_P=0.9 -MOS_TOP_K=50 -MOS_CHAT_MODEL_PROVIDER=openai - -# graph db -# neo4j -NEO4J_BACKEND=xxx -NEO4J_URI=bolt://xxx -NEO4J_USER=xxx -NEO4J_PASSWORD=xxx -MOS_NEO4J_SHARED_DB=xxx -NEO4J_DB_NAME=xxx - -# tetxmem reog -MOS_ENABLE_REORGANIZE=false - -# MemOS User Configuration -MOS_USER_ID=root -MOS_SESSION_ID=default_session -MOS_MAX_TURNS_WINDOW=20 +MOS_CHAT_MODEL_PROVIDER=openai # openai | huggingface | vllm +MOS_MODEL_SCHEMA=memos.configs.llm.VLLMLLMConfig # vllm only: config class path; keep default unless you extend it +OPENAI_API_KEY=sk-xxx # [required] when provider=openai +OPENAI_API_BASE=https://api.openai.com/v1 # [required] base for the key +OPENAI_BASE_URL= # compatibility for eval/scheduler +VLLM_API_KEY= # required when provider=vllm +VLLM_API_BASE=http://localhost:8088/v1 # required when provider=vllm -# MemRader Configuration +## MemReader / retrieval LLM MEMRADER_MODEL=gpt-4o-mini -MEMRADER_API_KEY=sk-xxx -MEMRADER_API_BASE=http://xxx:3000/v1 +MEMRADER_API_KEY=sk-xxx # [required] can reuse OPENAI_API_KEY +MEMRADER_API_BASE=http://localhost:3000/v1 # [required] base for the key MEMRADER_MAX_TOKENS=5000 -#embedding & rerank +## Embedding & rerank EMBEDDING_DIMENSION=1024 -MOS_EMBEDDER_BACKEND=universal_api -MOS_EMBEDDER_MODEL=bge-m3 -MOS_EMBEDDER_API_BASE=http://xxx -MOS_EMBEDDER_API_KEY=EMPTY -MOS_RERANKER_BACKEND=http_bge -MOS_RERANKER_URL=http://xxx -# Ollama Configuration (for embeddings) -#OLLAMA_API_BASE=http://xxx - -# milvus for pref mem -MILVUS_URI=http://xxx -MILVUS_USER_NAME=xxx -MILVUS_PASSWORD=xxx - -# pref mem +MOS_EMBEDDER_BACKEND=universal_api # universal_api | ollama +MOS_EMBEDDER_PROVIDER=openai # required when universal_api +MOS_EMBEDDER_MODEL=bge-m3 # siliconflow → use BAAI/bge-m3 +MOS_EMBEDDER_API_BASE=http://localhost:8000/v1 # required when universal_api +MOS_EMBEDDER_API_KEY=EMPTY # required when universal_api +OLLAMA_API_BASE=http://localhost:11434 # required when backend=ollama +MOS_RERANKER_BACKEND=http_bge # http_bge | http_bge_strategy | cosine_local +MOS_RERANKER_URL=http://localhost:8001 # required when backend=http_bge* +MOS_RERANKER_MODEL=bge-reranker-v2-m3 # siliconflow → use BAAI/bge-reranker-v2-m3 +MOS_RERANKER_HEADERS_EXTRA= # extra headers, JSON string +MOS_RERANKER_STRATEGY=single_turn +MOS_RERANK_SOURCE= # optional rerank scope, e.g., history/stream/custom + +## Internet search & preference memory +ENABLE_INTERNET=false +BOCHA_API_KEY= # required if ENABLE_INTERNET=true +SEARCH_MODE=fast # fast | fine | mixture +FAST_GRAPH=false +BM25_CALL=false +VEC_COT_CALL=false +FINE_STRATEGY=rewrite # rewrite | recreate | deep_search +ENABLE_ACTIVATION_MEMORY=false ENABLE_PREFERENCE_MEMORY=true -RETURN_ORIGINAL_PREF_MEM=true +PREFERENCE_ADDER_MODE=fast # fast | safe +DEDUP_PREF_EXP_BY_TEXTUAL=false + +## Reader chunking +MEM_READER_BACKEND=simple_struct # simple_struct | strategy_struct +MEM_READER_CHAT_CHUNK_TYPE=default # default | content_length +MEM_READER_CHAT_CHUNK_TOKEN_SIZE=1600 # tokens per chunk (default mode) +MEM_READER_CHAT_CHUNK_SESS_SIZE=10 # sessions per chunk (default mode) +MEM_READER_CHAT_CHUNK_OVERLAP=2 # overlap between chunks + +## Scheduler (MemScheduler / API) +MOS_ENABLE_SCHEDULER=false +MOS_SCHEDULER_TOP_K=10 +MOS_SCHEDULER_ACT_MEM_UPDATE_INTERVAL=300 +MOS_SCHEDULER_CONTEXT_WINDOW_SIZE=5 +MOS_SCHEDULER_THREAD_POOL_MAX_WORKERS=10000 +MOS_SCHEDULER_CONSUME_INTERVAL_SECONDS=0.01 +MOS_SCHEDULER_ENABLE_PARALLEL_DISPATCH=true +MOS_SCHEDULER_ENABLE_ACTIVATION_MEMORY=false +API_SCHEDULER_ON=true +API_SEARCH_WINDOW_SIZE=5 +API_SEARCH_HISTORY_TURNS=5 + +## Graph / vector stores +NEO4J_BACKEND=neo4j-community # neo4j-community | neo4j | nebular | polardb +NEO4J_URI=bolt://localhost:7687 # required when backend=neo4j* +NEO4J_USER=neo4j # required when backend=neo4j* +NEO4J_PASSWORD=12345678 # required when backend=neo4j* +NEO4J_DB_NAME=neo4j # required for shared-db mode +MOS_NEO4J_SHARED_DB=false +QDRANT_HOST=localhost +QDRANT_PORT=6333 +MILVUS_URI=http://localhost:19530 # required when ENABLE_PREFERENCE_MEMORY=true +MILVUS_USER_NAME=root # same as above +MILVUS_PASSWORD=12345678 # same as above +NEBULAR_HOSTS=["localhost"] +NEBULAR_USER=root +NEBULAR_PASSWORD=xxxxxx +NEBULAR_SPACE=shared-tree-textual-memory +NEBULAR_WORKING_MEMORY=20 +NEBULAR_LONGTERM_MEMORY=1000000 +NEBULAR_USER_MEMORY=1000000 + +## Relational DB (user manager / PolarDB) +MOS_USER_MANAGER_BACKEND=sqlite # sqlite | mysql +MYSQL_HOST=localhost # required when backend=mysql +MYSQL_PORT=3306 +MYSQL_USERNAME=root +MYSQL_PASSWORD=12345678 +MYSQL_DATABASE=memos_users +MYSQL_CHARSET=utf8mb4 +POLAR_DB_HOST=localhost +POLAR_DB_PORT=5432 +POLAR_DB_USER=root +POLAR_DB_PASSWORD=123456 +POLAR_DB_DB_NAME=shared_memos_db +POLAR_DB_USE_MULTI_DB=false + +## Redis (scheduler queue) — fill only if you want scheduler queues in Redis; otherwise in-memory queue is used +REDIS_HOST=localhost # global Redis endpoint (preferred over MEMSCHEDULER_*) +REDIS_PORT=6379 +REDIS_DB=0 +REDIS_PASSWORD= +REDIS_SOCKET_TIMEOUT= +REDIS_SOCKET_CONNECT_TIMEOUT= +MEMSCHEDULER_REDIS_HOST= # fallback keys if not using the global ones +MEMSCHEDULER_REDIS_PORT= +MEMSCHEDULER_REDIS_DB= +MEMSCHEDULER_REDIS_PASSWORD= +MEMSCHEDULER_REDIS_TIMEOUT= +MEMSCHEDULER_REDIS_CONNECT_TIMEOUT= + +## MemScheduler LLM +MEMSCHEDULER_OPENAI_API_KEY= # LLM key for scheduler’s own calls (OpenAI-compatible); leave empty if scheduler not using LLM +MEMSCHEDULER_OPENAI_BASE_URL= # Base URL for the above; can reuse OPENAI_API_BASE +MEMSCHEDULER_OPENAI_DEFAULT_MODEL=gpt-4o-mini + +## Nacos (optional config center) +NACOS_ENABLE_WATCH=false +NACOS_WATCH_INTERVAL=60 +NACOS_SERVER_ADDR= +NACOS_DATA_ID= +NACOS_GROUP=DEFAULT_GROUP +NACOS_NAMESPACE= +AK= +SK= + +## DingTalk bot & OSS upload +ENABLE_DINGDING_BOT=false # set true -> fields below required +DINGDING_ACCESS_TOKEN_USER= +DINGDING_SECRET_USER= +DINGDING_ACCESS_TOKEN_ERROR= +DINGDING_SECRET_ERROR= +DINGDING_ROBOT_CODE= +DINGDING_APP_KEY= +DINGDING_APP_SECRET= +OSS_ENDPOINT= # bot image upload depends on OSS +OSS_REGION= +OSS_BUCKET_NAME= +OSS_ACCESS_KEY_ID= +OSS_ACCESS_KEY_SECRET= +OSS_PUBLIC_BASE_URL= + +## Logging / external sink +CUSTOM_LOGGER_URL= +CUSTOM_LOGGER_TOKEN= +CUSTOM_LOGGER_WORKERS=2 + +## SDK / external client +MEMOS_API_KEY= +MEMOS_BASE_URL=https://memos.memtensor.cn/api/openmem/v1 diff --git a/examples/api/server_router_api.py b/examples/api/server_router_api.py new file mode 100644 index 000000000..6a94fc7bc --- /dev/null +++ b/examples/api/server_router_api.py @@ -0,0 +1,644 @@ +#!/usr/bin/env python3 +""" +MemOS Product API: /product/add end-to-end examples. + +This script demonstrates how to call the MemOS Product Add API +(`/product/add`, mapped to `APIADDRequest`) with ALL supported +message shapes and key options, including: + +1. Minimal string message (backward-compatible) +2. Standard chat messages (system/user/assistant) +3. Assistant messages with tool_calls +4. Raw tool messages: tool_description / tool_input / tool_output +5. Multimodal messages: text + image, text + file, audio-only +6. Pure input items without dialog context: text/file +7. Mixed multimodal message with text + file + image +8. Deprecated fields: mem_cube_id, memory_content, doc_path, source +9. Async vs sync + fast/fine add pipeline +10. Feedback add (is_feedback) +11. Add with chat_history only + +Each example sends a real POST request to `/product/add`. + +NOTE: +- This script assumes your MemOS server is running and router is mounted at `/product`. +- You may need to adjust BASE_URL, USER_ID, MEM_CUBE_ID to fit your environment. +""" + +import json + +import requests + + +# --------------------------------------------------------------------------- +# Global config +# --------------------------------------------------------------------------- + +BASE_URL = "http://0.0.0.0:8001/product" +HEADERS = {"Content-Type": "application/json"} + +# You can change these identifiers if your backend requires pre-registered users/cubes. +USER_ID = "demo_add_user_001" +MEM_CUBE_ID = "demo_add_cube_001" +SESSION_ID = "demo_add_session_001" + + +def call_add_api(name: str, payload: dict): + """ + Generic helper to call /product/add and print the payload + response. + + Args: + name: Logical name of this example, printed in logs. + payload: JSON payload compatible with APIADDRequest. + """ + print("=" * 80) + print(f"[*] Example: {name}") + print("- Payload:") + print(json.dumps(payload, indent=2, ensure_ascii=False)) + + try: + resp = requests.post( + f"{BASE_URL}/add", headers=HEADERS, data=json.dumps(payload), timeout=60 + ) + except Exception as e: + print(f"- Request failed with exception: {e!r}") + print("=" * 80) + print() + return + + print("- Response:") + print(resp.status_code, resp.text) + print("=" * 80) + print() + + +# =========================================================================== +# 1. Minimal / backward-compatible examples +# =========================================================================== + + +def example_01_string_message_minimal(): + """ + Minimal example using `messages` as a pure string (MessagesType = str). + + - This is the most backward-compatible form. + - Internally the server will convert this into a text message. + - Async add is used by default (`async_mode` defaults to "async"). + """ + payload = { + "user_id": USER_ID, + "writable_cube_ids": [MEM_CUBE_ID], + "messages": "今天心情不错,喝了咖啡。", + } + call_add_api("01_string_message_minimal", payload) + + +def example_02_standard_chat_triplet(): + """ + Standard chat conversation: system + user + assistant. + + - `messages` is a list of role-based chat messages (MessageList). + - Uses system context + explicit timestamps and message_id. + - This is recommended when you already have structured dialog. + """ + payload = { + "user_id": USER_ID, + "writable_cube_ids": [MEM_CUBE_ID], + "session_id": SESSION_ID, + "messages": [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "You are a helpful travel assistant.", + } + ], + "chat_time": "2025-11-24T10:00:00Z", + "message_id": "sys-1", + }, + { + "role": "user", + "content": "我喜欢干净但不奢华的酒店,比如全季或者亚朵。", + "chat_time": "2025-11-24T10:00:10Z", + "message_id": "u-1", + }, + { + "role": "assistant", + "content": "好的,我会优先推荐中端连锁酒店,例如全季、亚朵。", + "chat_time": "2025-11-24T10:00:15Z", + "message_id": "a-1", + }, + ], + "custom_tags": ["travel", "hotel_preference"], + "info": { + "agent_id": "demo_agent", + "app_id": "demo_app", + "source_type": "chat", + "source_url": "https://example.com/dialog/standard", + }, + } + call_add_api("02_standard_chat_triplet", payload) + + +# =========================================================================== +# 2. Tool / function-calling related examples +# =========================================================================== + + +def example_03_assistant_with_tool_calls(): + """ + Assistant message containing tool_calls (function calls). + + - `role = assistant`, `content = None`. + - `tool_calls` contains a list of function calls with arguments. + - This matches OpenAI-style function calling structure. + """ + payload = { + "user_id": USER_ID, + "writable_cube_ids": [MEM_CUBE_ID], + "messages": [ + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "tool-call-weather-1", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"location": "北京"}', + }, + } + ], + "chat_time": "2025-11-24T10:12:00Z", + "message_id": "assistant-with-call-1", + } + ], + } + call_add_api("03_assistant_with_tool_calls", payload) + + +# =========================================================================== +# 4. MultiModel messages +# =========================================================================== + + +def example_04_extreme_multimodal_single_message(): + """ + Extreme multimodal message: + text + image_url + file in one message, and another message with text + file. + + Note: This demonstrates multiple multimodal messages in a single request. + """ + payload = { + "user_id": USER_ID, + "writable_cube_ids": [MEM_CUBE_ID], + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "请分析下面这些信息:"}, + {"type": "image_url", "image_url": {"url": "https://example.com/x.png"}}, + {"type": "file", "file": {"file_id": "f1", "filename": "xx.pdf"}}, + ], + "chat_time": "2025-11-24T10:55:00Z", + "message_id": "mix-mm-1", + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "请再分析一下下面这些信息:"}, + {"type": "file", "file": {"file_id": "f1", "filename": "xx.pdf"}}, + ], + "chat_time": "2025-11-24T10:55:10Z", + "message_id": "mix-mm-2", + }, + ], + "info": {"source_type": "extreme_multimodal"}, + } + call_add_api("04_extreme_multimodal_single_message", payload) + + +# =========================================================================== +# 3. Multimodal messages +# =========================================================================== + + +def example_05_multimodal_text_and_image(): + """ + Multimodal user message: text + image_url. + + - `content` is a list of content parts. + - Each part can be text/image_url/... etc. + """ + payload = { + "user_id": USER_ID, + "writable_cube_ids": [MEM_CUBE_ID], + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "帮我看看这张图片大概是什么内容?", + }, + { + "type": "image_url", + "image_url": { + "url": "https://example.com/mountain_lake.jpg", + "detail": "high", + }, + }, + ], + "chat_time": "2025-11-24T10:20:00Z", + "message_id": "mm-img-1", + } + ], + "info": {"source_type": "image_analysis"}, + } + call_add_api("05_multimodal_text_and_image", payload) + + +def example_06_multimodal_text_and_file(): + """ + Multimodal user message: text + file (file_id based). + + - Uses `file_id` when the file has already been uploaded. + - Note: According to FileFile type definition (TypedDict, total=False), + all fields (`file_id`, `file_data`, `filename`) are optional. + However, in practice, you typically need at least `file_id` OR `file_data` + to specify the file location. + """ + payload = { + "user_id": USER_ID, + "writable_cube_ids": [MEM_CUBE_ID], + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "请阅读这个PDF,总结里面的要点。", + }, + { + "type": "file", + "file": { + "file_id": "file_123", + "filename": "report.pdf", # optional, but recommended + }, + }, + ], + "chat_time": "2025-11-24T10:21:00Z", + "message_id": "mm-file-1", + } + ], + "info": {"source_type": "file_summary"}, + } + call_add_api("06_multimodal_text_and_file", payload) + + +def example_07_audio_only_message(): + """ + Audio-only user message. + + - `content` contains only an input_audio item. + - `data` is assumed to be base64 encoded audio content. + """ + payload = { + "user_id": USER_ID, + "writable_cube_ids": [MEM_CUBE_ID], + "messages": [ + { + "role": "user", + "content": [ + { + "type": "input_audio", + "input_audio": { + "data": "base64_encoded_audio_here", + "format": "mp3", + }, + } + ], + "chat_time": "2025-11-24T10:22:00Z", + "message_id": "audio-1", + } + ], + "info": {"source_type": "voice_note"}, + } + call_add_api("07_audio_only_message", payload) + + +# =========================================================================== +# 4. Pure input items without dialog context +# =========================================================================== + + +def example_08_pure_text_input_items(): + """ + Pure text input items without dialog context. + + - This shape is used when there is no explicit dialog. + - `messages` is a list of raw input items, not role-based messages. + """ + payload = { + "user_id": USER_ID, + "writable_cube_ids": [MEM_CUBE_ID], + "messages": [ + { + "type": "text", + "text": "这是一段独立的文本输入,没有明确的对话上下文。", + }, + { + "type": "text", + "text": "它依然会被抽取和写入明文记忆。", + }, + ], + "info": {"source_type": "batch_import"}, + } + call_add_api("08_pure_text_input_items", payload) + + +def example_09_pure_file_input_by_file_id(): + """ + Pure file input item using file_id (standard format). + + - Uses `file_id` when the file has already been uploaded. + - Note: All FileFile fields are optional (TypedDict, total=False): + * `file_id`: optional, use when file is already uploaded + * `file_data`: optional, use for base64-encoded content + * `filename`: optional, but recommended for clarity + - In practice, you need at least `file_id` OR `file_data` to specify the file. + """ + payload = { + "user_id": USER_ID, + "writable_cube_ids": [MEM_CUBE_ID], + "messages": [ + { + "type": "file", + "file": { + "file_id": "file_uploaded_123", # at least one of file_id/file_data needed + "filename": "document.pdf", # optional + }, + } + ], + "info": {"source_type": "file_ingestion"}, + } + call_add_api("09_pure_file_input_by_file_id", payload) + + +def example_09b_pure_file_input_by_file_data(): + """ + Pure file input item using file_data (base64 encoded). + + - Uses `file_data` with base64-encoded file content. + - This is the standard format for direct file input without uploading first. + - Note: `file_data` is optional in type definition, but required here + since we're not using `file_id`. At least one of `file_id` or `file_data` + should be provided in practice. + """ + payload = { + "user_id": USER_ID, + "writable_cube_ids": [MEM_CUBE_ID], + "messages": [ + { + "type": "file", + "file": { + "file_data": "base64_encoded_file_content_here", # at least one of file_id/file_data needed + "filename": "document.pdf", # optional + }, + } + ], + "info": {"source_type": "file_ingestion_base64"}, + } + call_add_api("09b_pure_file_input_by_file_data", payload) + + +def example_10_mixed_text_file_image(): + """ + Mixed multimodal message: text + file + image in a single user message. + + - This is the most general form of `content` as a list of content parts. + """ + payload = { + "user_id": USER_ID, + "writable_cube_ids": [MEM_CUBE_ID], + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "请同时分析这个报告和图表。", + }, + { + "type": "file", + "file": { + "file_id": "file_789", + "filename": "analysis_report.pdf", + }, + }, + { + "type": "image_url", + "image_url": { + "url": "https://example.com/chart.png", + "detail": "auto", + }, + }, + ], + "chat_time": "2025-11-24T10:23:00Z", + "message_id": "mixed-1", + } + ], + "info": {"source_type": "report_plus_chart"}, + } + call_add_api("10_mixed_text_file_image", payload) + + +# =========================================================================== +# 5. Deprecated fields: mem_cube_id, memory_content, doc_path, source +# =========================================================================== + + +def example_11_deprecated_memory_content_and_doc_path(): + """ + Use only deprecated fields to demonstrate the conversion logic: + + - `mem_cube_id`: will be converted to `writable_cube_ids` if missing. + - `memory_content`: will be converted into a text message and appended to `messages`. + - `doc_path`: will be converted into a file input item and appended to `messages`. + - `source`: will be moved into `info['source']` if not already set. + + This example intentionally omits `writable_cube_ids` and `messages`, + so that the @model_validator in APIADDRequest does all the work. + """ + payload = { + "user_id": USER_ID, + "mem_cube_id": MEM_CUBE_ID, # deprecated + "memory_content": "这是通过 memory_content 写入的老字段内容。", # deprecated + "doc_path": "/path/to/legacy.docx", # deprecated + "source": "legacy_source_tag", # deprecated + "session_id": "session_deprecated_1", + "async_mode": "async", + } + call_add_api("11_deprecated_memory_content_and_doc_path", payload) + + +# =========================================================================== +# 6. Async vs Sync, fast/fine modes +# =========================================================================== + + +def example_12_async_default_pipeline(): + """ + Default async add pipeline. + + - `async_mode` is omitted, so it defaults to "async". + - `mode` is ignored in async mode even if set (we keep it None here). + - This is the recommended pattern for most production traffic. + """ + payload = { + "user_id": USER_ID, + "writable_cube_ids": [MEM_CUBE_ID], + "session_id": "session_async_default", + "messages": "今天我在测试异步添加记忆。", + "custom_tags": ["async", "default"], + "info": {"source_type": "chat"}, + } + call_add_api("12_async_default_pipeline", payload) + + +def example_13_sync_fast_pipeline(): + """ + Sync add with fast pipeline. + + - `async_mode = "sync"`, `mode = "fast"`. + - This is suitable for high-throughput or latency-sensitive ingestion + where you want lighter extraction logic. + """ + payload = { + "user_id": USER_ID, + "writable_cube_ids": [MEM_CUBE_ID], + "session_id": "session_sync_fast", + "async_mode": "sync", + "mode": "fast", + "messages": [ + { + "role": "user", + "content": "这条记忆使用 sync + fast 模式写入。", + } + ], + "custom_tags": ["sync", "fast"], + "info": {"source_type": "api_test"}, + } + call_add_api("13_sync_fast_pipeline", payload) + + +def example_14_sync_fine_pipeline(): + """ + Sync add with fine pipeline. + + - `async_mode = "sync"`, `mode = "fine"`. + - This is suitable for scenarios where quality of extraction is more + important than raw throughput. + """ + payload = { + "user_id": USER_ID, + "writable_cube_ids": [MEM_CUBE_ID], + "session_id": "session_sync_fine", + "async_mode": "sync", + "mode": "fine", + "messages": [ + { + "role": "user", + "content": "这条记忆使用 sync + fine 模式写入,需要更精细的抽取。", + } + ], + "custom_tags": ["sync", "fine"], + "info": {"source_type": "api_test"}, + } + call_add_api("14_sync_fine_pipeline", payload) + + +def example_15_async_with_task_id(): + """ + Async add with explicit task_id. + + - `task_id` can be used to correlate this async add request with + downstream scheduler status or monitoring. + """ + payload = { + "user_id": USER_ID, + "writable_cube_ids": [MEM_CUBE_ID], + "session_id": "session_async_task", + "async_mode": "async", + "task_id": "task_async_001", + "messages": [ + { + "role": "user", + "content": "这是一条带有 task_id 的异步写入请求。", + } + ], + "custom_tags": ["async", "task_id"], + "info": {"source_type": "task_test"}, + } + call_add_api("15_async_with_task_id", payload) + + +# =========================================================================== +# 7. Feedback and chat_history examples +# =========================================================================== + + +def example_16_feedback_add(): + """ + Feedback add example. + + - `is_feedback = True` marks this add as user feedback. + - You can use `custom_tags` and `info` to label the feedback type/source. + """ + payload = { + "user_id": USER_ID, + "writable_cube_ids": [MEM_CUBE_ID], + "session_id": "session_feedback_1", + "is_feedback": True, + "messages": [ + { + "role": "user", + "content": "刚才那个酒店推荐不太符合我的预算,请给我更便宜一点的选项。", + "chat_time": "2025-11-24T10:30:00Z", + "message_id": "fb-1", + } + ], + "custom_tags": ["feedback", "hotel"], + "info": { + "source_type": "chat_feedback", + "feedback_type": "preference_correction", + }, + } + call_add_api("16_feedback_add", payload) + + +# =========================================================================== +# Entry point +# =========================================================================== + +if __name__ == "__main__": + # You can comment out some examples if you do not want to run all of them. + example_01_string_message_minimal() + example_02_standard_chat_triplet() + example_03_assistant_with_tool_calls() + example_04_extreme_multimodal_single_message() + example_05_multimodal_text_and_image() + example_06_multimodal_text_and_file() + example_07_audio_only_message() + example_08_pure_text_input_items() + example_09_pure_file_input_by_file_id() + example_09b_pure_file_input_by_file_data() + example_10_mixed_text_file_image() + example_11_deprecated_memory_content_and_doc_path() + example_12_async_default_pipeline() + example_13_sync_fast_pipeline() + example_14_sync_fine_pipeline() + example_15_async_with_task_id() + example_16_feedback_add() diff --git a/examples/mem_reader/reader.py b/examples/mem_reader/reader.py index 3da5d5e76..c9061cfd6 100644 --- a/examples/mem_reader/reader.py +++ b/examples/mem_reader/reader.py @@ -1,3 +1,5 @@ +import argparse +import json import time from memos.configs.mem_reader import SimpleStructMemReaderConfig @@ -9,7 +11,110 @@ ) +def print_textual_memory_item( + item: TextualMemoryItem, max_memory_length: int = 200, indent: int = 0 +): + """ + Print a TextualMemoryItem in a structured format. + + Args: + item: The TextualMemoryItem to print + max_memory_length: Maximum length of memory content to display + indent: Number of spaces for indentation + """ + indent_str = " " * indent + print(f"{indent_str}{'=' * 80}") + print(f"{indent_str}TextualMemoryItem") + print(f"{indent_str}{'=' * 80}") + print(f"{indent_str}ID: {item.id}") + print( + f"{indent_str}Memory: {item.memory[:max_memory_length]}{'...' if len(item.memory) > max_memory_length else ''}" + ) + print(f"{indent_str}Memory Length: {len(item.memory)} characters") + + # Print metadata + if hasattr(item.metadata, "user_id"): + print(f"{indent_str}User ID: {item.metadata.user_id}") + if hasattr(item.metadata, "session_id"): + print(f"{indent_str}Session ID: {item.metadata.session_id}") + if hasattr(item.metadata, "memory_type"): + print(f"{indent_str}Memory Type: {item.metadata.memory_type}") + if hasattr(item.metadata, "type"): + print(f"{indent_str}Type: {item.metadata.type}") + if hasattr(item.metadata, "key") and item.metadata.key: + print(f"{indent_str}Key: {item.metadata.key}") + if hasattr(item.metadata, "tags") and item.metadata.tags: + print(f"{indent_str}Tags: {', '.join(item.metadata.tags)}") + if hasattr(item.metadata, "confidence"): + print(f"{indent_str}Confidence: {item.metadata.confidence}") + if hasattr(item.metadata, "status"): + print(f"{indent_str}Status: {item.metadata.status}") + if hasattr(item.metadata, "background") and item.metadata.background: + bg_preview = ( + item.metadata.background[:100] + "..." + if len(item.metadata.background) > 100 + else item.metadata.background + ) + print(f"{indent_str}Background: {bg_preview}") + if hasattr(item.metadata, "sources") and item.metadata.sources: + print(f"{indent_str}Sources ({len(item.metadata.sources)}):") + for i, source in enumerate(item.metadata.sources): + source_info = [] + if hasattr(source, "type"): + source_info.append(f"type={source.type}") + if hasattr(source, "role"): + source_info.append(f"role={source.role}") + if hasattr(source, "doc_path"): + source_info.append(f"doc_path={source.doc_path}") + if hasattr(source, "chat_time"): + source_info.append(f"chat_time={source.chat_time}") + if hasattr(source, "index") and source.index is not None: + source_info.append(f"index={source.index}") + print(f"{indent_str} [{i + 1}] {', '.join(source_info)}") + if hasattr(item.metadata, "created_at"): + print(f"{indent_str}Created At: {item.metadata.created_at}") + if hasattr(item.metadata, "updated_at"): + print(f"{indent_str}Updated At: {item.metadata.updated_at}") + if hasattr(item.metadata, "embedding") and item.metadata.embedding: + print(f"{indent_str}Embedding: [vector of {len(item.metadata.embedding)} dimensions]") + print(f"{indent_str}{'=' * 80}\n") + + +def print_textual_memory_item_json(item: TextualMemoryItem, indent: int = 2): + """ + Print a TextualMemoryItem as formatted JSON. + + Args: + item: The TextualMemoryItem to print + indent: JSON indentation level + """ + # Convert to dict and exclude embedding for readability + data = item.to_dict() + if "metadata" in data and "embedding" in data["metadata"]: + embedding = data["metadata"]["embedding"] + if embedding: + data["metadata"]["embedding"] = f"[vector of {len(embedding)} dimensions]" + + print(json.dumps(data, indent=indent, ensure_ascii=False)) + + def main(): + # Parse command line arguments + parser = argparse.ArgumentParser(description="Test Mem-Reader with structured output") + parser.add_argument( + "--format", + choices=["text", "json"], + default="text", + help="Output format: 'text' for structured text, 'json' for JSON format (default: text)", + ) + parser.add_argument( + "--max-memory-length", + type=int, + default=200, + help="Maximum length of memory content to display in text format (default: 200)", + ) + args = parser.parse_args() + # 1. Create Configuration reader_config = SimpleStructMemReaderConfig.from_json_file( "examples/data/config/simple_struct_reader_config.json" @@ -225,12 +330,24 @@ def main(): print("\n--- FINE Mode Results (first 3 items) ---") for i, mem_list in enumerate(fine_memory[:3]): for j, mem_item in enumerate(mem_list[:2]): # Show first 2 items from each list - print(f" [{i}][{j}] {mem_item.memory[:100]}...") + print(f"\n[Scene {i}][Item {j}]") + if args.format == "json": + print_textual_memory_item_json(mem_item, indent=2) + else: + print_textual_memory_item( + mem_item, max_memory_length=args.max_memory_length, indent=2 + ) print("\n--- FAST Mode Results (first 3 items) ---") for i, mem_list in enumerate(fast_memory[:3]): for j, mem_item in enumerate(mem_list[:2]): # Show first 2 items from each list - print(f" [{i}][{j}] {mem_item.memory[:100]}...") + print(f"\n[Scene {i}][Item {j}]") + if args.format == "json": + print_textual_memory_item_json(mem_item, indent=2) + else: + print_textual_memory_item( + mem_item, max_memory_length=args.max_memory_length, indent=2 + ) # 7. Example of transfer fast mode result into fine result fast_mode_memories = [ @@ -542,14 +659,20 @@ def main(): print("\n--- Transfer Mode Results (first 3 items) ---") for i, mem_list in enumerate(fine_memories[:3]): for j, mem_item in enumerate(mem_list[:2]): # Show first 2 items from each list - print(f" [{i}][{j}] {mem_item.memory[:100]}...") + print(f"\n[Scene {i}][Item {j}]") + if args.format == "json": + print_textual_memory_item_json(mem_item, indent=2) + else: + print_textual_memory_item( + mem_item, max_memory_length=args.max_memory_length, indent=2 + ) # 7. Example of processing documents (only in fine mode) print("\n=== Processing Documents (Fine Mode Only) ===") # Example document paths (you should replace these with actual document paths) doc_paths = [ - "examples/mem_reader/text1.txt", - "examples/mem_reader/text2.txt", + "text1.txt", + "text2.txt", ] try: @@ -563,9 +686,21 @@ def main(): }, mode="fine", ) - print( - f"\n📄 Document Memory generated {sum(len(mem_list) for mem_list in doc_memory)} items" - ) + total_items = sum(len(mem_list) for mem_list in doc_memory) + print(f"\n📄 Document Memory generated {total_items} items") + + # Print structured document memory items + if doc_memory: + print("\n--- Document Memory Items (first 3) ---") + for i, mem_list in enumerate(doc_memory[:3]): + for j, mem_item in enumerate(mem_list[:3]): # Show first 3 items from each document + print(f"\n[Document {i}][Item {j}]") + if args.format == "json": + print_textual_memory_item_json(mem_item, indent=2) + else: + print_textual_memory_item( + mem_item, max_memory_length=args.max_memory_length, indent=2 + ) except Exception as e: print(f"⚠️ Document processing failed: {e}") print(" (This is expected if document files don't exist)") diff --git a/src/memos/api/handlers/chat_handler.py b/src/memos/api/handlers/chat_handler.py index f0fcbabd9..c9e01573a 100644 --- a/src/memos/api/handlers/chat_handler.py +++ b/src/memos/api/handlers/chat_handler.py @@ -108,11 +108,14 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An HTTPException: If chat fails """ try: + # Resolve readable cube IDs (for search) + readable_cube_ids = chat_req.readable_cube_ids or [chat_req.user_id] + # Step 1: Search for relevant memories search_req = APISearchRequest( query=chat_req.query, user_id=chat_req.user_id, - mem_cube_id=chat_req.mem_cube_id, + readable_cube_ids=readable_cube_ids, mode=chat_req.mode, internet_search=chat_req.internet_search, top_k=chat_req.top_k, @@ -162,9 +165,11 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An # Step 4: start add after chat asynchronously if chat_req.add_message_on_answer: + # Resolve writable cube IDs (for add) + writable_cube_ids = chat_req.writable_cube_ids or [chat_req.user_id] self._start_add_to_memory( user_id=chat_req.user_id, - cube_id=chat_req.mem_cube_id, + writable_cube_ids=writable_cube_ids, session_id=chat_req.session_id or "default_session", query=chat_req.query, full_response=response, @@ -208,10 +213,15 @@ def handle_chat_stream(self, chat_req: ChatRequest) -> StreamingResponse: def generate_chat_response() -> Generator[str, None, None]: """Generate chat response as SSE stream.""" try: + # Resolve readable cube IDs (for search) + readable_cube_ids = chat_req.readable_cube_ids or ( + [chat_req.mem_cube_id] if chat_req.mem_cube_id else [chat_req.user_id] + ) + search_req = APISearchRequest( query=chat_req.query, user_id=chat_req.user_id, - mem_cube_id=chat_req.mem_cube_id, + readable_cube_ids=readable_cube_ids, mode=chat_req.mode, internet_search=chat_req.internet_search, top_k=chat_req.top_k, @@ -224,9 +234,13 @@ def generate_chat_response() -> Generator[str, None, None]: search_response = self.search_handler.handle_search_memories(search_req) + # Use first readable cube ID for scheduler (backward compatibility) + scheduler_cube_id = ( + readable_cube_ids[0] if readable_cube_ids else chat_req.user_id + ) self._send_message_to_scheduler( user_id=chat_req.user_id, - mem_cube_id=chat_req.mem_cube_id, + mem_cube_id=scheduler_cube_id, query=chat_req.query, label=QUERY_LABEL, ) @@ -256,7 +270,7 @@ def generate_chat_response() -> Generator[str, None, None]: ] self.logger.info( - f"user_id: {chat_req.user_id}, cube_id: {chat_req.mem_cube_id}, " + f"user_id: {chat_req.user_id}, readable_cube_ids: {readable_cube_ids}, " f"current_system_prompt: {system_prompt}" ) @@ -299,9 +313,13 @@ def generate_chat_response() -> Generator[str, None, None]: current_messages.append({"role": "assistant", "content": full_response}) if chat_req.add_message_on_answer: + # Resolve writable cube IDs (for add) + writable_cube_ids = chat_req.writable_cube_ids or ( + [chat_req.mem_cube_id] if chat_req.mem_cube_id else [chat_req.user_id] + ) self._start_add_to_memory( user_id=chat_req.user_id, - cube_id=chat_req.mem_cube_id, + writable_cube_ids=writable_cube_ids, session_id=chat_req.session_id or "default_session", query=chat_req.query, full_response=full_response, @@ -359,10 +377,15 @@ def generate_chat_response() -> Generator[str, None, None]: # Step 1: Search for memories using search handler yield f"data: {json.dumps({'type': 'status', 'data': '0'})}\n\n" + # Resolve readable cube IDs (for search) + readable_cube_ids = chat_req.readable_cube_ids or ( + [chat_req.mem_cube_id] if chat_req.mem_cube_id else [chat_req.user_id] + ) + search_req = APISearchRequest( query=chat_req.query, user_id=chat_req.user_id, - mem_cube_id=chat_req.mem_cube_id, + readable_cube_ids=readable_cube_ids, mode=chat_req.mode, internet_search=chat_req.internet_search, top_k=chat_req.top_k, @@ -376,9 +399,13 @@ def generate_chat_response() -> Generator[str, None, None]: search_response = self.search_handler.handle_search_memories(search_req) yield f"data: {json.dumps({'type': 'status', 'data': '1'})}\n\n" + # Use first readable cube ID for scheduler (backward compatibility) + scheduler_cube_id = ( + readable_cube_ids[0] if readable_cube_ids else chat_req.user_id + ) self._send_message_to_scheduler( user_id=chat_req.user_id, - mem_cube_id=chat_req.mem_cube_id, + mem_cube_id=scheduler_cube_id, query=chat_req.query, label=QUERY_LABEL, ) @@ -421,7 +448,7 @@ def generate_chat_response() -> Generator[str, None, None]: ] self.logger.info( - f"user_id: {chat_req.user_id}, cube_id: {chat_req.mem_cube_id}, " + f"user_id: {chat_req.user_id}, readable_cube_ids: {readable_cube_ids}, " f"current_system_prompt: {system_prompt}" ) @@ -496,9 +523,13 @@ def generate_chat_response() -> Generator[str, None, None]: yield f"data: {json.dumps({'type': 'end'})}\n\n" + # Use first readable cube ID for post-processing (backward compatibility) + scheduler_cube_id = ( + readable_cube_ids[0] if readable_cube_ids else chat_req.user_id + ) self._start_post_chat_processing( user_id=chat_req.user_id, - cube_id=chat_req.mem_cube_id, + cube_id=scheduler_cube_id, session_id=chat_req.session_id or "default_session", query=chat_req.query, full_response=full_response, @@ -509,9 +540,13 @@ def generate_chat_response() -> Generator[str, None, None]: current_messages=current_messages, ) + # Resolve writable cube IDs (for add) + writable_cube_ids = chat_req.writable_cube_ids or ( + [chat_req.mem_cube_id] if chat_req.mem_cube_id else [chat_req.user_id] + ) self._start_add_to_memory( user_id=chat_req.user_id, - cube_id=chat_req.mem_cube_id, + writable_cube_ids=writable_cube_ids, session_id=chat_req.session_id or "default_session", query=chat_req.query, full_response=full_response, @@ -867,7 +902,7 @@ def _send_message_to_scheduler( async def _add_conversation_to_memory( self, user_id: str, - cube_id: str, + writable_cube_ids: list[str], session_id: str, query: str, clean_response: str, @@ -875,7 +910,7 @@ async def _add_conversation_to_memory( ) -> None: add_req = APIADDRequest( user_id=user_id, - mem_cube_id=cube_id, + writable_cube_ids=writable_cube_ids, session_id=session_id, messages=[ { @@ -1090,7 +1125,7 @@ def run_async_in_thread(): def _start_add_to_memory( self, user_id: str, - cube_id: str, + writable_cube_ids: list[str], session_id: str, query: str, full_response: str, @@ -1105,7 +1140,7 @@ def run_async_in_thread(): loop.run_until_complete( self._add_conversation_to_memory( user_id=user_id, - cube_id=cube_id, + writable_cube_ids=writable_cube_ids, session_id=session_id, query=query, clean_response=clean_response, @@ -1126,7 +1161,7 @@ def run_async_in_thread(): task = asyncio.create_task( self._add_conversation_to_memory( user_id=user_id, - cube_id=cube_id, + writable_cube_ids=writable_cube_ids, session_id=session_id, query=query, clean_response=clean_response, diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 2f2e9ea54..961b14b6b 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -68,8 +68,13 @@ class MemCubeRegister(BaseRequest): class ChatRequest(BaseRequest): - """Request model for chat operations.""" + """Request model for chat operations. + + This model is used as the algorithm-facing chat interface, while also + remaining backward compatible with older developer-facing APIs. + """ + # ==== Basic identifiers ==== user_id: str = Field(..., description="User ID") query: str = Field(..., description="Chat query message") readable_cube_ids: list[str] | None = Field( @@ -110,11 +115,49 @@ class ChatRequest(BaseRequest): threshold: float = Field(0.5, description="Threshold for filtering references") # ==== Backward compatibility ==== - mem_cube_id: str | None = Field(None, description="Cube ID to use for chat") moscube: bool = Field( - False, description="(Deprecated) Whether to use legacy MemOSCube pipeline" + False, + description="(Deprecated) Whether to use legacy MemOSCube pipeline.", + ) + + mem_cube_id: str | None = Field( + None, + description=( + "(Deprecated) Single cube ID to use for chat. " + "Prefer `readable_cube_ids` / `writable_cube_ids` for multi-cube chat." + ), ) + @model_validator(mode="after") + def _convert_deprecated_fields(self): + """ + Normalize fields for algorithm interface while preserving backward compatibility. + + Rules: + - mem_cube_id → readable_cube_ids / writable_cube_ids if they are missing + - moscube: log warning when True (deprecated) + """ + + # ---- mem_cube_id backward compatibility ---- + if self.mem_cube_id is not None: + logger.warning( + "ChatRequest.mem_cube_id is deprecated and will be removed in a future version. " + "Please migrate to `readable_cube_ids` / `writable_cube_ids`." + ) + if not self.readable_cube_ids: + self.readable_cube_ids = [self.mem_cube_id] + if not self.writable_cube_ids: + self.writable_cube_ids = [self.mem_cube_id] + + # ---- Deprecated moscube flag ---- + if self.moscube: + logger.warning( + "ChatRequest.moscube is deprecated. Legacy MemOSCube pipeline " + "will be removed in a future version." + ) + + return self + class ChatCompleteRequest(BaseRequest): """Request model for chat operations. will (Deprecated), instead use APIChatCompleteRequest.""" @@ -389,6 +432,7 @@ class APIADDRequest(BaseRequest): None, description="Session ID. If not provided, a default session will be used.", ) + task_id: str | None = Field(None, description="Task ID for monitering async tasks") # ==== Multi-cube writing ==== writable_cube_ids: list[str] | None = Field( @@ -406,6 +450,15 @@ class APIADDRequest(BaseRequest): ), ) + mode: Literal["fast", "fine"] | None = Field( + None, + description=( + "(Internal) Add mode used only when async_mode='sync'. " + "If set to 'fast', the handler will use a fast add pipeline. " + "Ignored when async_mode='async'." + ), + ) + # ==== Business tags & info ==== custom_tags: list[str] | None = Field( None, @@ -501,6 +554,14 @@ def _convert_deprecated_fields(self) -> "APIADDRequest": - source → info["source"] - operation → merged into writable_cube_ids (ignored otherwise) """ + # ---- async_mode / mode relationship ---- + if self.async_mode == "async" and self.mode is not None: + logger.warning( + "APIADDRequest.mode is ignored when async_mode='async'. " + "Fast add pipeline is only available in sync mode." + ) + self.mode = None + # Convert mem_cube_id to writable_cube_ids (new field takes priority) if self.mem_cube_id: logger.warning( diff --git a/src/memos/configs/mem_reader.py b/src/memos/configs/mem_reader.py index dc8d37a35..a653a5e68 100644 --- a/src/memos/configs/mem_reader.py +++ b/src/memos/configs/mem_reader.py @@ -45,6 +45,10 @@ class SimpleStructMemReaderConfig(BaseMemReaderConfig): """SimpleStruct MemReader configuration class.""" +class MultiModelStructMemReaderConfig(BaseMemReaderConfig): + """MultiModelStruct MemReader configuration class.""" + + class StrategyStructMemReaderConfig(BaseMemReaderConfig): """StrategyStruct MemReader configuration class.""" @@ -57,6 +61,7 @@ class MemReaderConfigFactory(BaseConfig): backend_to_class: ClassVar[dict[str, Any]] = { "simple_struct": SimpleStructMemReaderConfig, + "multimodel_struct": MultiModelStructMemReaderConfig, "strategy_struct": StrategyStructMemReaderConfig, } diff --git a/src/memos/mem_reader/base.py b/src/memos/mem_reader/base.py index 3095a0bc6..391270bcf 100644 --- a/src/memos/mem_reader/base.py +++ b/src/memos/mem_reader/base.py @@ -12,20 +12,12 @@ class BaseMemReader(ABC): def __init__(self, config: BaseMemReaderConfig): """Initialize the MemReader with the given configuration.""" - @abstractmethod - def get_scene_data_info(self, scene_data: list, type: str) -> list[str]: - """Get raw information related to the current scene.""" - @abstractmethod def get_memory( self, scene_data: list, type: str, info: dict[str, Any], mode: str = "fast" ) -> list[list[TextualMemoryItem]]: """Various types of memories extracted from scene_data""" - @abstractmethod - def transform_memreader(self, data: dict) -> list[TextualMemoryItem]: - """Transform the memory data into a list of TextualMemoryItem objects.""" - @abstractmethod def fine_transfer_simple_mem( self, input_memories: list[list[TextualMemoryItem]], type: str diff --git a/src/memos/mem_reader/factory.py b/src/memos/mem_reader/factory.py index 2205a0215..263f29001 100644 --- a/src/memos/mem_reader/factory.py +++ b/src/memos/mem_reader/factory.py @@ -2,6 +2,7 @@ from memos.configs.mem_reader import MemReaderConfigFactory from memos.mem_reader.base import BaseMemReader +from memos.mem_reader.multi_model_struct import MultiModelStructMemReader from memos.mem_reader.simple_struct import SimpleStructMemReader from memos.mem_reader.strategy_struct import StrategyStructMemReader from memos.memos_tools.singleton import singleton_factory @@ -13,6 +14,7 @@ class MemReaderFactory(BaseMemReader): backend_to_class: ClassVar[dict[str, Any]] = { "simple_struct": SimpleStructMemReader, "strategy_struct": StrategyStructMemReader, + "multimodel_struct": MultiModelStructMemReader, } @classmethod diff --git a/src/memos/mem_reader/multi_model_struct.py b/src/memos/mem_reader/multi_model_struct.py new file mode 100644 index 000000000..13824f7d8 --- /dev/null +++ b/src/memos/mem_reader/multi_model_struct.py @@ -0,0 +1,130 @@ +import concurrent.futures +import traceback + +from typing import Any + +from memos import log +from memos.configs.mem_reader import MultiModelStructMemReaderConfig +from memos.context.context import ContextThreadPoolExecutor +from memos.mem_reader.read_multi_model import MultiModelParser +from memos.mem_reader.simple_struct import SimpleStructMemReader +from memos.memories.textual.item import TextualMemoryItem +from memos.types import MessagesType +from memos.utils import timed + + +logger = log.get_logger(__name__) + + +class MultiModelStructMemReader(SimpleStructMemReader): + """Multi Model implementation of MemReader that inherits from + SimpleStructMemReader.""" + + def __init__(self, config: MultiModelStructMemReaderConfig): + """ + Initialize the MultiModelStructMemReader with configuration. + + Args: + config: Configuration object for the reader + """ + from memos.configs.mem_reader import SimpleStructMemReaderConfig + + simple_config = SimpleStructMemReaderConfig(**config.model_dump()) + super().__init__(simple_config) + + # Initialize MultiModelParser for routing to different parsers + self.multi_model_parser = MultiModelParser( + embedder=self.embedder, + llm=self.llm, + parser=None, + ) + + @timed + def _process_multi_model_data(self, scene_data_info: MessagesType, info, **kwargs): + """ + Process multi-model data using MultiModelParser. + + Args: + scene_data_info: MessagesType input + info: Dictionary containing user_id and session_id + **kwargs: Additional parameters (mode, etc.) + """ + mode = kwargs.get("mode", "fine") + + # Use MultiModelParser to parse the scene data + # If it's a list, parse each item; otherwise parse as single message + if isinstance(scene_data_info, list): + # Parse each message in the list + all_memory_items = [] + for msg in scene_data_info: + items = self.multi_model_parser.parse(msg, info, mode=mode, **kwargs) + all_memory_items.extend(items) + return all_memory_items + else: + # Parse as single message + return self.multi_model_parser.parse(scene_data_info, info, mode=mode, **kwargs) + + @timed + def _process_transfer_multi_model_data(self, raw_node: TextualMemoryItem): + raise NotImplementedError + + def get_scene_data_info(self, scene_data: list, type: str) -> list[list[Any]]: + """ + Convert normalized MessagesType scenes into scene data info. + For MultiModelStructMemReader, this is a simplified version that returns the scenes as-is. + + Args: + scene_data: List of MessagesType scenes + type: Type of scene_data: ['doc', 'chat'] + + Returns: + List of scene data info + """ + # TODO: split messages + return scene_data + + def _read_memory( + self, messages: list[MessagesType], type: str, info: dict[str, Any], mode: str = "fine" + ): + list_scene_data_info = self.get_scene_data_info(messages, type) + + memory_list = [] + # Process Q&A pairs concurrently with context propagation + with ContextThreadPoolExecutor() as executor: + futures = [ + executor.submit(self._process_multi_model_data, scene_data_info, info, mode=mode) + for scene_data_info in list_scene_data_info + ] + for future in concurrent.futures.as_completed(futures): + try: + res_memory = future.result() + if res_memory is not None: + memory_list.append(res_memory) + except Exception as e: + logger.error(f"Task failed with exception: {e}") + logger.error(traceback.format_exc()) + return memory_list + + def fine_transfer_simple_mem( + self, input_memories: list[TextualMemoryItem], type: str + ) -> list[list[TextualMemoryItem]]: + if not input_memories: + return [] + + memory_list = [] + + # Process Q&A pairs concurrently with context propagation + with ContextThreadPoolExecutor() as executor: + futures = [ + executor.submit(self._process_transfer_multi_model_data, scene_data_info) + for scene_data_info in input_memories + ] + for future in concurrent.futures.as_completed(futures): + try: + res_memory = future.result() + if res_memory is not None: + memory_list.append(res_memory) + except Exception as e: + logger.error(f"Task failed with exception: {e}") + logger.error(traceback.format_exc()) + return memory_list diff --git a/src/memos/mem_reader/read_multi_model/__init__.py b/src/memos/mem_reader/read_multi_model/__init__.py new file mode 100644 index 000000000..39cd63743 --- /dev/null +++ b/src/memos/mem_reader/read_multi_model/__init__.py @@ -0,0 +1,40 @@ +"""Multi-model message parsers for different message types. + +This package provides parsers for different message types in both fast and fine modes: +- String messages +- System messages +- User messages +- Assistant messages +- Tool messages +- Text content parts +- File content parts + +Each parser supports both "fast" mode (quick processing without LLM) and +"fine" mode (with LLM for better understanding). +""" + +from .assistant_parser import AssistantParser +from .base import BaseMessageParser +from .file_content_parser import FileContentParser +from .multi_model_parser import MultiModelParser +from .string_parser import StringParser +from .system_parser import SystemParser +from .text_content_parser import TextContentParser +from .tool_parser import ToolParser +from .user_parser import UserParser +from .utils import coerce_scene_data, extract_role + + +__all__ = [ + "AssistantParser", + "BaseMessageParser", + "FileContentParser", + "MultiModelParser", + "StringParser", + "SystemParser", + "TextContentParser", + "ToolParser", + "UserParser", + "coerce_scene_data", + "extract_role", +] diff --git a/src/memos/mem_reader/read_multi_model/assistant_parser.py b/src/memos/mem_reader/read_multi_model/assistant_parser.py new file mode 100644 index 000000000..2f2cbbc5d --- /dev/null +++ b/src/memos/mem_reader/read_multi_model/assistant_parser.py @@ -0,0 +1,45 @@ +"""Parser for assistant messages.""" + +from typing import Any + +from memos.embedders.base import BaseEmbedder +from memos.llms.base import BaseLLM +from memos.log import get_logger +from memos.memories.textual.item import TextualMemoryItem +from memos.types.openai_chat_completion_types import ChatCompletionAssistantMessageParam + +from .base import BaseMessageParser + + +logger = get_logger(__name__) + + +class AssistantParser(BaseMessageParser): + """Parser for assistant messages.""" + + def __init__(self, embedder: BaseEmbedder, llm: BaseLLM | None = None): + """ + Initialize AssistantParser. + + Args: + embedder: Embedder for generating embeddings + llm: Optional LLM for fine mode processing + """ + self.embedder = embedder + self.llm = llm + + def parse_fast( + self, + message: ChatCompletionAssistantMessageParam, + info: dict[str, Any], + **kwargs, + ) -> list[TextualMemoryItem]: + return [] + + def parse_fine( + self, + message: ChatCompletionAssistantMessageParam, + info: dict[str, Any], + **kwargs, + ) -> list[TextualMemoryItem]: + return [] diff --git a/src/memos/mem_reader/read_multi_model/base.py b/src/memos/mem_reader/read_multi_model/base.py new file mode 100644 index 000000000..024a940b8 --- /dev/null +++ b/src/memos/mem_reader/read_multi_model/base.py @@ -0,0 +1,78 @@ +"""Base parser interface for multi-model message parsing. + +This module defines the base interface for parsing different message types +in both fast and fine modes. +""" + +from abc import ABC, abstractmethod +from typing import Any + +from memos.memories.textual.item import TextualMemoryItem + + +class BaseMessageParser(ABC): + """Base interface for message type parsers.""" + + @abstractmethod + def parse_fast( + self, + message: Any, + info: dict[str, Any], + **kwargs, + ) -> list[TextualMemoryItem]: + """ + Parse message in fast mode (no LLM calls, quick processing). + + Args: + message: The message to parse + info: Dictionary containing user_id and session_id + **kwargs: Additional parameters + + Returns: + List of TextualMemoryItem objects + """ + + @abstractmethod + def parse_fine( + self, + message: Any, + info: dict[str, Any], + **kwargs, + ) -> list[TextualMemoryItem]: + """ + Parse message in fine mode (with LLM calls for better understanding). + + Args: + message: The message to parse + info: Dictionary containing user_id and session_id + **kwargs: Additional parameters (e.g., llm, embedder) + + Returns: + List of TextualMemoryItem objects + """ + + def parse( + self, + message: Any, + info: dict[str, Any], + mode: str = "fast", + **kwargs, + ) -> list[TextualMemoryItem]: + """ + Parse message in the specified mode. + + Args: + message: The message to parse + info: Dictionary containing user_id and session_id + mode: "fast" or "fine" + **kwargs: Additional parameters + + Returns: + List of TextualMemoryItem objects + """ + if mode == "fast": + return self.parse_fast(message, info, **kwargs) + elif mode == "fine": + return self.parse_fine(message, info, **kwargs) + else: + raise ValueError(f"Unknown mode: {mode}. Must be 'fast' or 'fine'") diff --git a/src/memos/mem_reader/read_multi_model/file_content_parser.py b/src/memos/mem_reader/read_multi_model/file_content_parser.py new file mode 100644 index 000000000..71af89d18 --- /dev/null +++ b/src/memos/mem_reader/read_multi_model/file_content_parser.py @@ -0,0 +1,99 @@ +"""Parser for file content parts (RawMessageList).""" + +from typing import Any + +from memos.embedders.base import BaseEmbedder +from memos.llms.base import BaseLLM +from memos.log import get_logger +from memos.memories.textual.item import TextualMemoryItem +from memos.parsers.factory import ParserFactory +from memos.types.openai_chat_completion_types import File + +from .base import BaseMessageParser + + +logger = get_logger(__name__) + + +class FileContentParser(BaseMessageParser): + """Parser for file content parts.""" + + def __init__( + self, + embedder: BaseEmbedder, + llm: BaseLLM | None = None, + parser: Any | None = None, + ): + """ + Initialize FileContentParser. + + Args: + embedder: Embedder for generating embeddings + llm: Optional LLM for fine mode processing + parser: Optional parser for parsing file contents + """ + self.embedder = embedder + self.llm = llm + self.parser = parser + + def _parse_file(self, file_info: dict[str, Any]) -> str: + """ + Parse file content. + + Args: + file_info: File information dictionary + + Returns: + Parsed text content + """ + if not self.parser: + # Try to create a default parser + try: + from memos.configs.parser import ParserConfigFactory + + parser_config = ParserConfigFactory.model_validate( + { + "backend": "markitdown", + "config": {}, + } + ) + self.parser = ParserFactory.from_config(parser_config) + except Exception as e: + logger.warning(f"[FileContentParser] Failed to create parser: {e}") + return "" + + file_path = file_info.get("path") or file_info.get("file_id", "") + filename = file_info.get("filename", "unknown") + + if not file_path: + logger.warning("[FileContentParser] No file path or file_id provided") + return f"[File: {filename}]" + + try: + import os + + if os.path.exists(file_path): + parsed_text = self.parser.parse(file_path) + return parsed_text + else: + logger.warning(f"[FileContentParser] File not found: {file_path}") + return f"[File: {filename}]" + except Exception as e: + logger.error(f"[FileContentParser] Error parsing file {file_path}: {e}") + return f"[File: {filename}]" + + def parse_fast( + self, + message: File, + info: dict[str, Any], + **kwargs, + ) -> list[TextualMemoryItem]: + return [] + + def parse_fine( + self, + message: File, + info: dict[str, Any], + **kwargs, + ) -> list[TextualMemoryItem]: + return [] diff --git a/src/memos/mem_reader/read_multi_model/multi_model_parser.py b/src/memos/mem_reader/read_multi_model/multi_model_parser.py new file mode 100644 index 000000000..e16733468 --- /dev/null +++ b/src/memos/mem_reader/read_multi_model/multi_model_parser.py @@ -0,0 +1,170 @@ +"""Unified multi-model parser for different message types. + +This module provides a unified interface to parse different message types +in both fast and fine modes. +""" + +from typing import Any + +from memos.embedders.base import BaseEmbedder +from memos.llms.base import BaseLLM +from memos.log import get_logger +from memos.memories.textual.item import TextualMemoryItem +from memos.types import MessagesType + +from .assistant_parser import AssistantParser +from .base import BaseMessageParser +from .file_content_parser import FileContentParser +from .string_parser import StringParser +from .system_parser import SystemParser +from .text_content_parser import TextContentParser +from .tool_parser import ToolParser +from .user_parser import UserParser +from .utils import extract_role + + +logger = get_logger(__name__) + + +class MultiModelParser: + """Unified parser for different message types.""" + + def __init__( + self, + embedder: BaseEmbedder, + llm: BaseLLM | None = None, + parser: Any | None = None, + ): + """ + Initialize MultiModelParser. + + Args: + embedder: Embedder for generating embeddings + llm: Optional LLM for fine mode processing + parser: Optional parser for parsing file contents + """ + self.embedder = embedder + self.llm = llm + self.parser = parser + + # Initialize parsers for different message types + self.string_parser = StringParser(embedder, llm) + self.system_parser = SystemParser(embedder, llm) + self.user_parser = UserParser(embedder, llm) + self.assistant_parser = AssistantParser(embedder, llm) + self.tool_parser = ToolParser(embedder, llm) + self.text_content_parser = TextContentParser(embedder, llm) + self.file_content_parser = FileContentParser(embedder, llm, parser) + self.image_parser = None # future + self.audio_parser = None # future + + self.role_parsers = { + "system": SystemParser(embedder, llm), + "user": UserParser(embedder, llm), + "assistant": AssistantParser(embedder, llm), + "tool": ToolParser(embedder, llm), + } + + self.type_parsers = { + "text": self.text_content_parser, + "file": self.file_content_parser, + "image": self.image_parser, + "audio": self.audio_parser, + } + + def _get_parser(self, message: Any) -> BaseMessageParser | None: + """ + Get appropriate parser for the message type. + + Args: + message: Message to parse + + Returns: + Appropriate parser or None + """ + # Handle string messages + if isinstance(message, str): + return self.string_parser + + # Handle dict messages + if not isinstance(message, dict): + logger.warning(f"[MultiModelParser] Unknown message type: {type(message)}") + return None + + # Check if it's a RawMessageList item (text or file) + if "type" in message: + msg_type = message.get("type") + parser = self.type_parsers.get(msg_type) + if parser: + return parser + + # Check if it's a MessageList item (system, user, assistant, tool) + role = extract_role(message) + if role: + parser = self.role_parsers.get(role) + if parser: + return parser + + logger.warning(f"[MultiModelParser] Could not determine parser for message: {message}") + return None + + def parse( + self, + message: MessagesType, + info: dict[str, Any], + mode: str = "fast", + **kwargs, + ) -> list[TextualMemoryItem]: + """ + Parse a single message in the specified mode. + + Args: + message: Message to parse (can be str, MessageList item, or RawMessageList item) + info: Dictionary containing user_id and session_id + mode: "fast" or "fine" + **kwargs: Additional parameters + + Returns: + List of TextualMemoryItem objects + """ + # Handle list of messages (MessageList or RawMessageList) + if isinstance(message, list): + return [item for msg in message for item in self.parse(msg, info, mode, **kwargs)] + + # Get appropriate parser + parser = self._get_parser(message) + if not parser: + logger.warning(f"[MultiModelParser] No parser found for message: {message}") + return [] + + # Parse using the appropriate parser + try: + return parser.parse(message, info, mode=mode, **kwargs) + except Exception as e: + logger.error(f"[MultiModelParser] Error parsing message: {e}") + return [] + + def parse_batch( + self, + messages: list[MessagesType], + info: dict[str, Any], + mode: str = "fast", + **kwargs, + ) -> list[list[TextualMemoryItem]]: + """ + Parse a batch of messages. + + Args: + messages: List of messages to parse + info: Dictionary containing user_id and session_id + mode: "fast" or "fine" + **kwargs: Additional parameters + + Returns: + List of lists of TextualMemoryItem objects (one list per message) + """ + results = [] + for message in messages: + items = self.parse(message, info, mode, **kwargs) + results.append(items) + return results diff --git a/src/memos/mem_reader/read_multi_model/string_parser.py b/src/memos/mem_reader/read_multi_model/string_parser.py new file mode 100644 index 000000000..5c5c829b3 --- /dev/null +++ b/src/memos/mem_reader/read_multi_model/string_parser.py @@ -0,0 +1,47 @@ +"""Parser for string format messages. + +Handles simple string messages that need to be converted to memory items. +""" + +from typing import Any + +from memos.embedders.base import BaseEmbedder +from memos.llms.base import BaseLLM +from memos.log import get_logger +from memos.memories.textual.item import TextualMemoryItem + +from .base import BaseMessageParser + + +logger = get_logger(__name__) + + +class StringParser(BaseMessageParser): + """Parser for string format messages.""" + + def __init__(self, embedder: BaseEmbedder, llm: BaseLLM | None = None): + """ + Initialize StringParser. + + Args: + embedder: Embedder for generating embeddings + llm: Optional LLM for fine mode processing + """ + self.embedder = embedder + self.llm = llm + + def parse_fast( + self, + message: str, + info: dict[str, Any], + **kwargs, + ) -> list[TextualMemoryItem]: + return [] + + def parse_fine( + self, + message: str, + info: dict[str, Any], + **kwargs, + ) -> list[TextualMemoryItem]: + return [] diff --git a/src/memos/mem_reader/read_multi_model/system_parser.py b/src/memos/mem_reader/read_multi_model/system_parser.py new file mode 100644 index 000000000..3024ef89c --- /dev/null +++ b/src/memos/mem_reader/read_multi_model/system_parser.py @@ -0,0 +1,45 @@ +"""Parser for system messages.""" + +from typing import Any + +from memos.embedders.base import BaseEmbedder +from memos.llms.base import BaseLLM +from memos.log import get_logger +from memos.memories.textual.item import TextualMemoryItem +from memos.types.openai_chat_completion_types import ChatCompletionSystemMessageParam + +from .base import BaseMessageParser + + +logger = get_logger(__name__) + + +class SystemParser(BaseMessageParser): + """Parser for system messages.""" + + def __init__(self, embedder: BaseEmbedder, llm: BaseLLM | None = None): + """ + Initialize SystemParser. + + Args: + embedder: Embedder for generating embeddings + llm: Optional LLM for fine mode processing + """ + self.embedder = embedder + self.llm = llm + + def parse_fast( + self, + message: ChatCompletionSystemMessageParam, + info: dict[str, Any], + **kwargs, + ) -> list[TextualMemoryItem]: + return [] + + def parse_fine( + self, + message: ChatCompletionSystemMessageParam, + info: dict[str, Any], + **kwargs, + ) -> list[TextualMemoryItem]: + return [] diff --git a/src/memos/mem_reader/read_multi_model/text_content_parser.py b/src/memos/mem_reader/read_multi_model/text_content_parser.py new file mode 100644 index 000000000..d9a9700d4 --- /dev/null +++ b/src/memos/mem_reader/read_multi_model/text_content_parser.py @@ -0,0 +1,45 @@ +"""Parser for text content parts (RawMessageList).""" + +from typing import Any + +from memos.embedders.base import BaseEmbedder +from memos.llms.base import BaseLLM +from memos.log import get_logger +from memos.memories.textual.item import TextualMemoryItem +from memos.types.openai_chat_completion_types import ChatCompletionContentPartTextParam + +from .base import BaseMessageParser + + +logger = get_logger(__name__) + + +class TextContentParser(BaseMessageParser): + """Parser for text content parts.""" + + def __init__(self, embedder: BaseEmbedder, llm: BaseLLM | None = None): + """ + Initialize TextContentParser. + + Args: + embedder: Embedder for generating embeddings + llm: Optional LLM for fine mode processing + """ + self.embedder = embedder + self.llm = llm + + def parse_fast( + self, + message: ChatCompletionContentPartTextParam, + info: dict[str, Any], + **kwargs, + ) -> list[TextualMemoryItem]: + return [] + + def parse_fine( + self, + message: ChatCompletionContentPartTextParam, + info: dict[str, Any], + **kwargs, + ) -> list[TextualMemoryItem]: + return [] diff --git a/src/memos/mem_reader/read_multi_model/tool_parser.py b/src/memos/mem_reader/read_multi_model/tool_parser.py new file mode 100644 index 000000000..abf705eaa --- /dev/null +++ b/src/memos/mem_reader/read_multi_model/tool_parser.py @@ -0,0 +1,45 @@ +"""Parser for tool messages.""" + +from typing import Any + +from memos.embedders.base import BaseEmbedder +from memos.llms.base import BaseLLM +from memos.log import get_logger +from memos.memories.textual.item import TextualMemoryItem +from memos.types.openai_chat_completion_types import ChatCompletionToolMessageParam + +from .base import BaseMessageParser + + +logger = get_logger(__name__) + + +class ToolParser(BaseMessageParser): + """Parser for tool messages.""" + + def __init__(self, embedder: BaseEmbedder, llm: BaseLLM | None = None): + """ + Initialize ToolParser. + + Args: + embedder: Embedder for generating embeddings + llm: Optional LLM for fine mode processing + """ + self.embedder = embedder + self.llm = llm + + def parse_fast( + self, + message: ChatCompletionToolMessageParam, + info: dict[str, Any], + **kwargs, + ) -> list[TextualMemoryItem]: + return [] + + def parse_fine( + self, + message: ChatCompletionToolMessageParam, + info: dict[str, Any], + **kwargs, + ) -> list[TextualMemoryItem]: + return [] diff --git a/src/memos/mem_reader/read_multi_model/user_parser.py b/src/memos/mem_reader/read_multi_model/user_parser.py new file mode 100644 index 000000000..78f9d0057 --- /dev/null +++ b/src/memos/mem_reader/read_multi_model/user_parser.py @@ -0,0 +1,45 @@ +"""Parser for user messages.""" + +from typing import Any + +from memos.embedders.base import BaseEmbedder +from memos.llms.base import BaseLLM +from memos.log import get_logger +from memos.memories.textual.item import TextualMemoryItem +from memos.types.openai_chat_completion_types import ChatCompletionUserMessageParam + +from .base import BaseMessageParser + + +logger = get_logger(__name__) + + +class UserParser(BaseMessageParser): + """Parser for user messages.""" + + def __init__(self, embedder: BaseEmbedder, llm: BaseLLM | None = None): + """ + Initialize UserParser. + + Args: + embedder: Embedder for generating embeddings + llm: Optional LLM for fine mode processing + """ + self.embedder = embedder + self.llm = llm + + def parse_fast( + self, + message: ChatCompletionUserMessageParam, + info: dict[str, Any], + **kwargs, + ) -> list[TextualMemoryItem]: + return [] + + def parse_fine( + self, + message: ChatCompletionUserMessageParam, + info: dict[str, Any], + **kwargs, + ) -> list[TextualMemoryItem]: + return [] diff --git a/src/memos/mem_reader/read_multi_model/utils.py b/src/memos/mem_reader/read_multi_model/utils.py new file mode 100644 index 000000000..e42a564e4 --- /dev/null +++ b/src/memos/mem_reader/read_multi_model/utils.py @@ -0,0 +1,189 @@ +"""Utility functions for message parsing.""" + +import os +import re + +from datetime import datetime, timezone +from typing import Any, TypeAlias +from urllib.parse import urlparse + +from memos import log +from memos.configs.parser import ParserConfigFactory +from memos.parsers.factory import ParserFactory +from memos.types import MessagesType +from memos.types.openai_chat_completion_types import ( + ChatCompletionAssistantMessageParam, + ChatCompletionContentPartTextParam, + ChatCompletionSystemMessageParam, + ChatCompletionToolMessageParam, + ChatCompletionUserMessageParam, + File, +) + + +ChatMessageClasses = ( + ChatCompletionSystemMessageParam, + ChatCompletionUserMessageParam, + ChatCompletionAssistantMessageParam, + ChatCompletionToolMessageParam, +) + +RawContentClasses = (ChatCompletionContentPartTextParam, File) +MessageDict: TypeAlias = dict[str, Any] # (Deprecated) not supported in the future +SceneDataInput: TypeAlias = ( + list[list[MessageDict]] # (Deprecated) legacy chat example: scenes -> messages + | list[str] # (Deprecated) legacy doc example: list of paths / pure text + | list[MessagesType] # new: list of scenes (each scene is MessagesType) +) + + +logger = log.get_logger(__name__) +FILE_EXT_RE = re.compile( + r"\.(pdf|docx?|pptx?|xlsx?|txt|md|html?|json|csv|png|jpe?g|webp|wav|mp3|m4a)$", + re.I, +) + + +def extract_role(message: dict[str, Any]) -> str: + """Extract role from message.""" + return message.get("role", "") + + +def _is_message_list(obj): + """ + Detect whether `obj` is a MessageList (OpenAI ChatCompletionMessageParam list). + Criteria: + - Must be a list + - Each element must be a dict with keys: role, content + """ + if not isinstance(obj, list): + return False + + for item in obj: + if not isinstance(item, dict): + return False + if "role" not in item or "content" not in item: + return False + return True + + +def coerce_scene_data(scene_data, scene_type: str) -> list[MessagesType]: + """ + Normalize ANY allowed SceneDataInput into: list[MessagesType]. + Supports: + - Already normalized scene_data → passthrough + - doc: legacy list[str] → automatically detect: + * local file path → read & parse into text + * remote URL/path → keep as file part + * pure text → text part + - chat: + * Passthrough normalization + * Auto-inject chat_time into each message group + - fallback: wrap unknown → [str(scene_data)] + """ + if not scene_data: + return [] + head = scene_data[0] + + if scene_type != "doc": + normalized = scene_data if isinstance(head, str | list) else [str(scene_data)] + + complete_scene_data = [] + for items in normalized: + if not items: + continue + + # ONLY add chat_time if it's a MessageList + if not _is_message_list(items): + complete_scene_data.append(items) + continue + + # Detect existing chat_time + chat_time_value = None + for item in items: + if isinstance(item, dict) and "chat_time" in item: + chat_time_value = item["chat_time"] + break + + # Default timestamp + if chat_time_value is None: + session_date = datetime.now(timezone.utc) + date_format = "%I:%M %p on %d %B, %Y UTC" + chat_time_value = session_date.strftime(date_format) + + # Inject chat_time + for m in items: + if isinstance(m, dict) and "chat_time" not in m: + m["chat_time"] = chat_time_value + + complete_scene_data.append(items) + + return complete_scene_data + + # doc: list[str] -> RawMessageList + if scene_type == "doc" and isinstance(head, str): + raw_items = [] + + # prepare parser + parser_config = ParserConfigFactory.model_validate( + { + "backend": "markitdown", + "config": {}, + } + ) + parser = ParserFactory.from_config(parser_config) + + for s in scene_data: + s = (s or "").strip() + if not s: + continue + + parsed = urlparse(s) + looks_like_url = parsed.scheme in {"http", "https", "oss", "s3", "gs", "cos"} + looks_like_path = ("/" in s) or ("\\" in s) + looks_like_file = bool(FILE_EXT_RE.search(s)) or looks_like_url or looks_like_path + + # Case A: Local filesystem path + if os.path.exists(s): + filename = os.path.basename(s) or "document" + try: + # parse local file into text + parsed_text = parser.parse(s) + raw_items.append( + [ + { + "type": "file", + "file": { + "filename": filename or "document", + "file_data": parsed_text, + }, + } + ] + ) + except Exception as e: + logger.error(f"[SceneParser] Error parsing {s}: {e}") + continue + + # Case B: URL or non-local file path + if looks_like_file: + if looks_like_url: + filename = os.path.basename(parsed.path) + else: + # Windows absolute path detection + if "\\" in s and re.match(r"^[A-Za-z]:", s): + parts = [p for p in s.split("\\") if p] + filename = parts[-1] if parts else os.path.basename(s) + else: + filename = os.path.basename(s) + raw_items.append( + [{"type": "file", "file": {"filename": filename or "document", "file_data": s}}] + ) + continue + + # Case C: Pure text + raw_items.append([{"type": "text", "text": s}]) + + return raw_items + + # fallback + return [str(scene_data)] diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index 29ce49d90..94b0929f6 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -1,26 +1,27 @@ import concurrent.futures import copy import json -import os import re import traceback from abc import ABC -from datetime import datetime, timezone -from typing import Any +from typing import Any, TypeAlias from tqdm import tqdm from memos import log from memos.chunkers import ChunkerFactory from memos.configs.mem_reader import SimpleStructMemReaderConfig -from memos.configs.parser import ParserConfigFactory from memos.context.context import ContextThreadPoolExecutor from memos.embedders.factory import EmbedderFactory from memos.llms.factory import LLMFactory from memos.mem_reader.base import BaseMemReader -from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata -from memos.parsers.factory import ParserFactory +from memos.mem_reader.read_multi_model import coerce_scene_data +from memos.memories.textual.item import ( + SourceMessage, + TextualMemoryItem, + TreeNodeTextualMemoryMetadata, +) from memos.templates.mem_reader_prompts import ( CUSTOM_TAGS_INSTRUCTION, CUSTOM_TAGS_INSTRUCTION_ZH, @@ -31,9 +32,42 @@ SIMPLE_STRUCT_MEM_READER_PROMPT, SIMPLE_STRUCT_MEM_READER_PROMPT_ZH, ) +from memos.types import MessagesType +from memos.types.openai_chat_completion_types import ( + ChatCompletionAssistantMessageParam, + ChatCompletionContentPartTextParam, + ChatCompletionSystemMessageParam, + ChatCompletionToolMessageParam, + ChatCompletionUserMessageParam, + File, +) from memos.utils import timed +class ParserFactory: + """Placeholder required by test suite.""" + + @staticmethod + def from_config(_config): + return None + + +ChatMessageClasses = ( + ChatCompletionSystemMessageParam, + ChatCompletionUserMessageParam, + ChatCompletionAssistantMessageParam, + ChatCompletionToolMessageParam, +) + +RawContentClasses = (ChatCompletionContentPartTextParam, File) +MessageDict: TypeAlias = dict[str, Any] # (Deprecated) not supported in the future +SceneDataInput: TypeAlias = ( + list[list[MessageDict]] # (Deprecated) legacy chat example: scenes -> messages + | list[str] # (Deprecated) legacy doc example: list of paths / pure text + | list[MessagesType] # new: list of scenes (each scene is MessagesType) +) + + logger = log.get_logger(__name__) PROMPT_DICT = { "chat": { @@ -89,7 +123,7 @@ def detect_lang(text): return "en" -def _build_node(idx, message, info, scene_file, llm, parse_json_result, embedder): +def _build_node(idx, message, info, source_info, llm, parse_json_result, embedder): # generate try: raw = llm.generate(message) @@ -139,7 +173,7 @@ def _build_node(idx, message, info, scene_file, llm, parse_json_result, embedder key=key, embedding=embedding, usage=[], - sources=[{"type": "doc", "doc_path": f"{scene_file}_{idx}"}], + sources=source_info, background="", confidence=0.99, type="fact", @@ -390,7 +424,7 @@ def _process_transfer_chat_data( return chat_read_nodes def get_memory( - self, scene_data: list, type: str, info: dict[str, Any], mode: str = "fine" + self, scene_data: SceneDataInput, type: str, info: dict[str, Any], mode: str = "fine" ) -> list[list[TextualMemoryItem]]: """ Extract and classify memory content from scene_data. @@ -399,7 +433,7 @@ def get_memory( Args: scene_data: List of dialogue information or document paths - type: Type of scene_data: ['doc', 'chat'] + type: (Deprecated) not supported in the future. Type of scene_data: ['doc', 'chat'] info: Dictionary containing user_id and session_id. Must be in format: {"user_id": "1111", "session_id": "2222"} Optional parameters: @@ -428,11 +462,35 @@ def get_memory( if not all(isinstance(info[field], str) for field in required_fields): raise ValueError("user_id and session_id must be strings") - scene_data = self._complete_chat_time(scene_data, type) - list_scene_data_info = self.get_scene_data_info(scene_data, type) - memory_list = [] + # Backward compatibility, after coercing scene_data, we only tackle + # with standard scene_data type: MessagesType + standard_scene_data = coerce_scene_data(scene_data, type) + return self._read_memory(standard_scene_data, type, info, mode) + def _read_memory( + self, messages: list[MessagesType], type: str, info: dict[str, Any], mode: str = "fine" + ): + """ + 1. raw file: + [ + [ + {"type": "file", "file": "str"} + ], + [ + {"type": "file", "file": "str"} + ],... + ] + 2. text chat: + scene_data = [ + [ {role: user, ...}, {role: assistant, ...}, ... ], + [ {role: user, ...}, {role: assistant, ...}, ... ], + [ ... ] + ] + """ + list_scene_data_info = self.get_scene_data_info(messages, type) + + memory_list = [] if type == "chat": processing_func = self._process_chat_data elif type == "doc": @@ -490,87 +548,152 @@ def fine_transfer_simple_mem( logger.error(traceback.format_exc()) return memory_list - def get_scene_data_info(self, scene_data: list, type: str) -> list[str]: + def get_scene_data_info(self, scene_data: list, type: str) -> list[list[Any]]: """ - Get raw information from scene_data. - If scene_data contains dictionaries, convert them to strings. - If scene_data contains file paths, parse them using the parser. - - Args: - scene_data: List of dialogue information or document paths - type: Type of scene data: ['doc', 'chat'] - Returns: - List of strings containing the processed scene data + Convert normalized MessagesType scenes into typical MessagesType this reader can + handle. + SimpleStructMemReader only supports text-only chat messages with roles. + For chat scenes we: + - skip unsupported scene types (e.g. `str` scenes) + - drop non-dict messages + - keep only roles in {user, assistant, system} + - coerce OpenAI multimodal `content` (list[parts]) into a single plain-text string + - then apply the existing windowing logic (<=10 messages with 2-message overlap) + For doc scenes we pass through; doc handling is done in `_process_doc_data`. """ - results = [] + results: list[list[Any]] = [] if type == "chat": + allowed_roles = {"user", "assistant", "system"} for items in scene_data: + if isinstance(items, str): + logger.warning( + "SimpleStruct MemReader does not support " + "str message data now, your messages " + f"contains {items}, skipping" + ) + continue + if not isinstance(items, list): + logger.warning( + "SimpleStruct MemReader expects message as " + f"list[dict], your messages contains" + f"{items}, skipping" + ) + continue + # Filter messages within this message result = [] - for i, item in enumerate(items): - result.append(item) - if len(result) >= 10: - results.append(result) - context = copy.deepcopy(result[-2:]) if i + 1 < len(items) else [] - result = context - if result: - results.append(result) + for _i, item in enumerate(items): + if not isinstance(item, dict): + logger.warning( + "SimpleStruct MemReader expects message as " + f"list[dict], your messages contains" + f"{item}, skipping" + ) + continue + role = item.get("role") or "" + role = role if isinstance(role, str) else str(role) + role = role.strip().lower() + if role not in allowed_roles: + logger.warning( + f"SimpleStruct MemReader expects message with " + f"role in {allowed_roles}, your messages contains" + f"role {role}, skipping" + ) + continue + + content = item.get("content", "") + if not isinstance(content, str): + logger.warning( + f"SimpleStruct MemReader expects message content " + f"with str, your messages content" + f"is {content!s}, skipping" + ) + continue + if not content: + continue + + result.append( + { + "role": role, + "content": content, + "chat_time": item.get("chat_time", ""), + } + ) + if not result: + continue + window = [] + for i, item in enumerate(result): + window.append(item) + if len(window) >= 10: + results.append(window) + context = copy.deepcopy(window[-2:]) if i + 1 < len(result) else [] + window = context + + if window: + results.append(window) elif type == "doc": - parser_config = ParserConfigFactory.model_validate( - { - "backend": "markitdown", - "config": {}, - } - ) - parser = ParserFactory.from_config(parser_config) - for item in scene_data: - try: - if os.path.exists(item): - try: - parsed_text = parser.parse(item) - results.append({"file": item, "text": parsed_text}) - except Exception as e: - logger.error(f"[SceneParser] Error parsing {item}: {e}") - continue - else: - parsed_text = item - results.append({"file": "pure_text", "text": parsed_text}) - except Exception as e: - print(f"Error parsing file {item}: {e!s}") - + results = scene_data return results - def _complete_chat_time(self, scene_data: list[list[dict]], type: str): - if type != "chat": - return scene_data - complete_scene_data = [] + def _process_doc_data(self, scene_data_info, info, **kwargs): + """ + Process doc data after being normalized to new RawMessageList format. + + scene_data_info format (length always == 1): + [ + {"type": "file", "file": {"filename": "...", "file_data": "..."}} + ] + OR + [ + {"type": "text", "text": "..."} + ] + + Behavior: + - Merge all text/file_data into a single "full text" + - Chunk the text + - Build prompts + - Send to LLM + - Parse results and build memory nodes + """ + mode = kwargs.get("mode", "fine") + if mode == "fast": + raise NotImplementedError - for items in scene_data: - chat_time_value = None + custom_tags = info.pop("custom_tags", None) - for item in items: - if "chat_time" in item: - chat_time_value = item["chat_time"] - break + if not scene_data_info or len(scene_data_info) != 1: + logger.error( + "[DocReader] scene_data_info must contain exactly 1 item after normalization" + ) + return [] - if chat_time_value is None: - session_date = datetime.now(timezone.utc) - date_format = "%I:%M %p on %d %B, %Y UTC" - chat_time_value = session_date.strftime(date_format) + item = scene_data_info[0] + text_content = "" + source_info_list = [] - for i in range(len(items)): - if "chat_time" not in items[i]: - items[i]["chat_time"] = chat_time_value + # Determine content and source metadata + if item.get("type") == "file": + f = item["file"] + filename = f.get("filename") or "document" + file_data = f.get("file_data") or "" - complete_scene_data.append(items) - return complete_scene_data + text_content = file_data + source_dict = { + "type": "doc", + "doc_path": filename, + } + source_info_list = [SourceMessage(**source_dict)] - def _process_doc_data(self, scene_data_info, info, **kwargs): - mode = kwargs.get("mode", "fine") - if mode == "fast": - raise NotImplementedError - chunks = self.chunker.chunk(scene_data_info["text"]) - custom_tags = info.pop("custom_tags", None) + elif item.get("type") == "text": + text_content = item.get("text", "") + source_info_list = [SourceMessage(type="doc", doc_path="inline-text")] + + text_content = (text_content or "").strip() + if not text_content: + logger.warning("[DocReader] Empty document text after normalization.") + return [] + + chunks = self.chunker.chunk(text_content) messages = [] for chunk in chunks: lang = detect_lang(chunk.text) @@ -586,7 +709,6 @@ def _process_doc_data(self, scene_data_info, info, **kwargs): messages.append(message) doc_nodes = [] - scene_file = scene_data_info["file"] with ContextThreadPoolExecutor(max_workers=50) as executor: futures = { @@ -595,7 +717,7 @@ def _process_doc_data(self, scene_data_info, info, **kwargs): idx, msg, info, - scene_file, + source_info_list, self.llm, self.parse_json_result, self.embedder, @@ -661,6 +783,3 @@ def _cheap_close(t: str) -> str: json: {s}" ) return {} - - def transform_memreader(self, data: dict) -> list[TextualMemoryItem]: - pass diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 4501dfee3..8f4a25a0b 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -547,9 +547,21 @@ def _process_text_mem( """ target_session_id = add_req.session_id or "default_session" + # Decide extraction mode: + # - async: always fast (ignore add_req.mode) + # - sync: use add_req.mode == "fast" to switch to fast pipeline, otherwise fine + if sync_mode == "async": + extract_mode = "fast" + else: # sync + extract_mode = "fast" if add_req.mode == "fast" else "fine" + self.logger.info( - f"[SingleCubeView] cube={user_context.mem_cube_id} " - f"Processing text memory with mode: {sync_mode}" + "[SingleCubeView] cube=%s Processing text memory " + "with sync_mode=%s, extract_mode=%s, add_mode=%s", + user_context.mem_cube_id, + sync_mode, + extract_mode, + add_req.mode, ) # Extract memories @@ -562,7 +574,7 @@ def _process_text_mem( "user_id": add_req.user_id, "session_id": target_session_id, }, - mode="fast" if sync_mode == "async" else "fine", + mode=extract_mode, ) flattened_local = [mm for m in memories_local for mm in m] self.logger.info(f"Memory extraction completed for user {add_req.user_id}") diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_content_part_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_param.py index a5e740791..99b232943 100644 --- a/src/memos/types/openai_chat_completion_types/chat_completion_content_part_param.py +++ b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_param.py @@ -17,6 +17,8 @@ class FileFile(TypedDict, total=False): """ The base64 encoded file data, used when passing the file to the model as a string. + or a url. + or just string which is the content of the file. """ file_id: str diff --git a/tests/mem_reader/test_coarse_memory_type.py b/tests/mem_reader/test_coarse_memory_type.py new file mode 100644 index 000000000..bd90d6a69 --- /dev/null +++ b/tests/mem_reader/test_coarse_memory_type.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python3 +""" +Rewritten test script for the updated coerce_scene_data function. + +This version matches the NEW behavior: +- Local file path → parsed into text (type="text") +- Remote URL / unknown path → treated as file, with file_data +- Plain text kept as text +- Chat mode passthrough +- Fallback cases handled properly +""" + +import os +import sys +import tempfile + + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "src")) + +from memos.mem_reader.simple_struct import coerce_scene_data + + +# ------------------------------------------------------------------------------ +# Helper utilities +# ------------------------------------------------------------------------------ + + +def assert_equal(actual, expected, message): + if actual != expected: + print("\n❌ ASSERTION FAILED") + print(message) + print("Expected:") + print(expected) + print("Actual:") + print(actual) + raise AssertionError(message) + + +def create_temp_file(content="hello world", suffix=".txt"): + """Create a temporary local file. Returns its path and content.""" + fd, path = tempfile.mkstemp(suffix=suffix) + with os.fdopen(fd, "w") as f: + f.write(content) + return path, content + + +# ------------------------------------------------------------------------------ +# Tests begin +# ------------------------------------------------------------------------------ + + +def test_empty_inputs(): + result = coerce_scene_data([], "chat") + assert_equal(result, [], "Empty input should return empty list") + + +def test_chat_passthrough(): + result = coerce_scene_data(["hello"], "chat") + assert_equal(result, ["hello"], "Chat mode should passthrough list[str]") + + msg_list = [{"role": "user", "content": "hi"}] + result = coerce_scene_data([msg_list], "chat") + assert_equal(result, [msg_list], "Chat mode should passthrough MessageList") + + +def test_doc_local_file(): + local_path, content = create_temp_file("test local file content") + result = coerce_scene_data([local_path], "doc") + + filename = os.path.basename(local_path) + expected = [ + [ + { + "type": "file", + "file": { + "filename": filename, + "file_data": "test local file content", + }, + } + ] + ] + assert_equal(result, expected, "Local file should be wrapped as file with parsed text") + + +def test_doc_remote_url(): + url = "https://example.com/file.pdf" + result = coerce_scene_data([url], "doc") + + filename = "file.pdf" + expected = [[{"type": "file", "file": {"filename": filename, "file_data": url}}]] + assert_equal(result, expected, "Remote URL should be treated as file_data string") + + +def test_doc_unknown_path(): + path = "/nonexistent/path/file.docx" + result = coerce_scene_data([path], "doc") + + expected = [[{"type": "file", "file": {"filename": "file.docx", "file_data": path}}]] + assert_equal(result, expected, "Unknown path should be treated as file_data") + + +def test_doc_plain_text(): + text = "this is plain text" + result = coerce_scene_data([text], "doc") + + expected = [[{"type": "text", "text": "this is plain text"}]] + assert_equal(result, expected, "Plain text should produce text content") + + +def test_doc_mixed(): + local_path, content = create_temp_file("local file content") + url = "https://example.com/x.pdf" + plain = "hello world" + + result = coerce_scene_data([plain, local_path, url], "doc") + + filename = os.path.basename(local_path) + expected = [ + [{"type": "text", "text": plain}], + [ + { + "type": "file", + "file": { + "filename": filename, + "file_data": "local file content", + }, + } + ], + [ + { + "type": "file", + "file": { + "filename": "x.pdf", + "file_data": url, + }, + } + ], + ] + assert_equal(result, expected, "Mixed doc inputs should be normalized correctly") + + +def test_fallback(): + result = coerce_scene_data([123], "chat") + expected = ["[123]"] + assert_equal(result, expected, "Unexpected input should fallback to str(scene_data)") + + +# ------------------------------------------------------------------------------ +# Main +# ------------------------------------------------------------------------------ + + +def main(): + print("\n========================================") + print("Running NEW tests for coerce_scene_data") + print("========================================") + + test_empty_inputs() + test_chat_passthrough() + test_doc_local_file() + test_doc_remote_url() + test_doc_unknown_path() + test_doc_plain_text() + test_doc_mixed() + test_fallback() + + print("\n========================================") + print("✅ All tests passed!") + print("========================================") + + +if __name__ == "__main__": + main() diff --git a/tests/mem_reader/test_simple_structure.py b/tests/mem_reader/test_simple_structure.py index 5407ae543..f81356886 100644 --- a/tests/mem_reader/test_simple_structure.py +++ b/tests/mem_reader/test_simple_structure.py @@ -4,7 +4,6 @@ from unittest.mock import MagicMock, patch from memos.chunkers import ChunkerFactory -from memos.chunkers.base import Chunk from memos.configs.mem_reader import SimpleStructMemReaderConfig from memos.embedders.factory import EmbedderFactory from memos.llms.factory import LLMFactory @@ -69,27 +68,6 @@ def test_process_chat_data(self): ) self.assertEqual(result[0].metadata.user_id, "user1") - def test_process_doc_data(self): - """Test processing document chunks into memory items.""" - scene_data_info = {"file": "tests/mem_reader/test.txt", "text": "Parsed document text"} - info = {"user_id": "user1", "session_id": "session1"} - - # Mock LLM response - mock_response = ( - '{"value": "A sample document about testing.", "tags": ["document"], "key": "title"}' - ) - self.reader.llm.generate.return_value = mock_response - self.reader.chunker.chunk.return_value = [ - Chunk(text="Parsed document text", token_count=3, sentences=["Parsed document text"]) - ] - self.reader.parse_json_result = lambda x: json.loads(x) - - result = self.reader._process_doc_data(scene_data_info, info) - - self.assertIsInstance(result, list) - self.assertIsInstance(result[0], TextualMemoryItem) - self.assertIn("sample document", result[0].memory) - def test_get_scene_data_info_with_chat(self): """Test extracting chat info from scene data.""" scene_data = [ @@ -124,21 +102,6 @@ def test_get_scene_data_info_with_chat(self): }, ) - @patch("memos.mem_reader.simple_struct.ParserFactory") - def test_get_scene_data_info_with_doc(self, mock_parser_factory): - """Test parsing document files.""" - parser_instance = MagicMock() - parser_instance.parse.return_value = "Parsed document text.\n" - mock_parser_factory.from_config.return_value = parser_instance - - scene_data = ["/fake/path/to/doc.txt"] - with patch("os.path.exists", return_value=True): - result = self.reader.get_scene_data_info(scene_data, type="doc") - - self.assertIsInstance(result, list) - self.assertEqual(result[0]["text"], "Parsed document text.\n") - parser_instance.parse.assert_called_once_with("/fake/path/to/doc.txt") - def test_parse_json_result_success(self): """Test successful JSON parsing.""" raw_response = '{"summary": "Test summary", "tags": ["test"]}'