Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 97 additions & 123 deletions src/memos/mem_reader/multi_modal_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import re
import traceback
import uuid

from typing import TYPE_CHECKING, Any

Expand All @@ -15,6 +16,8 @@
from memos.mem_reader.simple_struct import PROMPT_DICT, SimpleStructMemReader
from memos.mem_reader.utils import parse_json_result
from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata
from memos.plugins.hook_defs import H
from memos.plugins.hooks import trigger_single_hook
from memos.templates.mem_reader_prompts import MEMORY_MERGE_PROMPT_EN, MEMORY_MERGE_PROMPT_ZH
from memos.templates.tool_mem_prompts import TOOL_TRAJECTORY_PROMPT_EN, TOOL_TRAJECTORY_PROMPT_ZH
from memos.types import MessagesType
Expand Down Expand Up @@ -58,6 +61,8 @@ def __init__(self, config: MultiModalStructMemReaderConfig):
simple_config = SimpleStructMemReaderConfig(**config_dict)
super().__init__(simple_config)

self.memory_version_switch = getattr(config, "memory_version_switch", "off")

# Image parser LLM (requires vision model)
# Falls back to general_llm if not configured (general_llm itself falls back to main llm)
self.image_parser_llm = (
Expand Down Expand Up @@ -124,33 +129,48 @@ def _split_large_memory_item(
try:
chunks = self.chunker.chunk(item_text)
split_items = []
source_info = dict(item.metadata.info or {})
source_internal_info = dict(item.metadata.internal_info or {})
ingest_batch_id = str(source_internal_info.get("ingest_batch_id") or uuid.uuid4())
chunk_total = len(chunks)

def _create_chunk_item(chunk):
def _create_chunk_item(chunk_idx: int, chunk):
# Different chunkers are not fully consistent:
# some return Chunk-like objects with `.text`, while others return raw strings.
chunk_text = chunk.text if hasattr(chunk, "text") else chunk
if not chunk_text or not chunk_text.strip():
return None
chunk_info = {
"user_id": item.metadata.user_id,
"session_id": item.metadata.session_id,
**source_info,
}
chunk_internal_info = {
**source_internal_info,
"ingest_batch_id": ingest_batch_id,
"chunk_index": chunk_idx,
"chunk_total": chunk_total,
}
# Create a new memory item for each chunk, preserving original metadata
split_item = self._make_memory_item(
value=chunk_text,
info={
"user_id": item.metadata.user_id,
"session_id": item.metadata.session_id,
**(item.metadata.info or {}),
},
info=chunk_info,
memory_type=item.metadata.memory_type,
tags=item.metadata.tags or [],
key=item.metadata.key,
sources=item.metadata.sources or [],
background=item.metadata.background or "",
need_embed=False,
)
split_item.metadata.internal_info = chunk_internal_info
return split_item

# Use thread pool to parallel process chunks, but keep the original order
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
futures = [executor.submit(_create_chunk_item, chunk) for chunk in chunks]
futures = [
executor.submit(_create_chunk_item, chunk_idx, chunk)
for chunk_idx, chunk in enumerate(chunks)
]
for future in futures:
split_item = future.result()
if split_item is not None:
Expand Down Expand Up @@ -306,6 +326,7 @@ def _build_window_from_items(
all_sources = []
roles = set()
aggregated_file_ids: list[str] = []
ingest_batch_ids: set[str] = set()

for item in items:
if item.memory:
Expand Down Expand Up @@ -334,6 +355,11 @@ def _build_window_from_items(
for fid in item_file_ids:
if fid and fid not in aggregated_file_ids:
aggregated_file_ids.append(fid)
item_internal_info = getattr(metadata, "internal_info", None)
if isinstance(item_internal_info, dict):
batch_id = item_internal_info.get("ingest_batch_id")
if batch_id:
ingest_batch_ids.add(str(batch_id))

# Determine memory_type based on roles (same logic as simple_struct)
# UserMemory if only user role, else LongTermMemory
Expand Down Expand Up @@ -368,7 +394,6 @@ def _build_window_from_items(
info_ = info.copy()
user_id = info_.pop("user_id", "")
session_id = info_.pop("session_id", "")

# Create memory item without embedding (set to None, will be filled in batch)
aggregated_item = TextualMemoryItem(
memory=merged_text,
Expand All @@ -389,6 +414,10 @@ def _build_window_from_items(
**extra_kwargs,
),
)
if len(ingest_batch_ids) == 1:
aggregated_item.metadata.internal_info = {
"ingest_batch_id": next(iter(ingest_batch_ids))
}

return aggregated_item

Expand Down Expand Up @@ -458,6 +487,9 @@ def _get_llm_response(

if self.config.remove_prompt_example and examples:
prompt = prompt.replace(examples, "")

logger.info(f"[MultiModalParser] Process String Fine Prompt: {prompt}")

messages = [{"role": "user", "content": prompt}]
try:
response_text = self.llm.generate(messages)
Expand Down Expand Up @@ -506,6 +538,7 @@ def _get_maybe_merged_memory(
sources: list,
**kwargs,
) -> dict:
# TODO: delete this function
"""
Check if extracted memory should be merged with similar existing memories.
If merge is needed, return merged memory dict with merged_from field.
Expand All @@ -520,102 +553,7 @@ def _get_maybe_merged_memory(
Returns:
Memory dict (possibly merged) with merged_from field if merged
"""
# If no graph_db or user_name, return original
if not self.graph_db or "user_name" not in kwargs:
return extracted_memory_dict
user_name = kwargs.get("user_name")

# Detect language
lang = "en"
if sources:
for source in sources:
if hasattr(source, "lang") and source.lang:
lang = source.lang
break
elif isinstance(source, dict) and source.get("lang"):
lang = source.get("lang")
break
if lang is None:
lang = detect_lang(mem_text)

# Search for similar memories
merge_threshold = kwargs.get("merge_similarity_threshold", 0.3)

try:
search_results = self.graph_db.search_by_embedding(
vector=self.embedder.embed(mem_text)[0],
top_k=20,
status="activated",
threshold=merge_threshold,
user_name=user_name,
)

if not search_results:
return extracted_memory_dict

# Get full memory details
similar_memory_ids = [r["id"] for r in search_results if r.get("id")]
similar_memories_list = [
self.graph_db.get_node(mem_id, include_embedding=False, user_name=user_name)
for mem_id in similar_memory_ids
]

# Filter out None and mode:fast memories
filtered_similar = []
for mem in similar_memories_list:
if not mem:
continue
mem_metadata = mem.get("metadata", {})
tags = mem_metadata.get("tags", [])
if isinstance(tags, list) and "mode:fast" in tags:
continue
filtered_similar.append(
{
"id": mem.get("id"),
"memory": mem.get("memory", ""),
}
)
logger.info(
f"Valid similar memories for {mem_text} is "
f"{len(filtered_similar)}: {filtered_similar}"
)

if not filtered_similar:
return extracted_memory_dict

# Create a temporary TextualMemoryItem for merge check
temp_memory_item = TextualMemoryItem(
memory=mem_text,
metadata=TreeNodeTextualMemoryMetadata(
user_id="",
session_id="",
memory_type=extracted_memory_dict.get("memory_type", "LongTermMemory"),
status="activated",
tags=extracted_memory_dict.get("tags", []),
key=extracted_memory_dict.get("key", ""),
),
)

# Try to merge with LLM
merge_result = self._merge_memories_with_llm(
temp_memory_item, filtered_similar, lang=lang
)

if merge_result:
# Return merged memory dict
merged_dict = extracted_memory_dict.copy()
merged_content = merge_result.get("value", mem_text)
merged_dict["value"] = merged_content
merged_from_ids = merge_result.get("merged_from", [])
merged_dict["merged_from"] = merged_from_ids
return merged_dict
else:
return extracted_memory_dict

except Exception as e:
logger.error(f"[MultiModalFine] Error in get_maybe_merged_memory: {e}")
# On error, return original
return extracted_memory_dict
return extracted_memory_dict

def _merge_memories_with_llm(
self,
Expand Down Expand Up @@ -717,6 +655,35 @@ def _process_one_item(
# Determine prompt type based on sources
prompt_type = self._determine_prompt_type(sources)

# ========== Stage 0: Memory version async extraction/update pipeline ==========
if getattr(self, "memory_version_switch", "off") == "on":
try:
user_name = kwargs.get("user_name")
should_use_version_pipeline = trigger_single_hook(
H.MEMORY_VERSION_PREPARE_UPDATES,
item=fast_item,
user_name=user_name,
judge_llm=self.general_llm,
)
if should_use_version_pipeline:
lang = detect_lang(kwargs.get("chat_history") or mem_str)
custom_tags_prompt_template = PROMPT_DICT["custom_tags"][lang]
new_items = trigger_single_hook(
H.MEMORY_VERSION_APPLY_UPDATES,
item=fast_item,
user_name=user_name,
version_llm=self.qwen_llm,
merge_llm=self.general_llm,
custom_tags=custom_tags,
custom_tags_prompt_template=custom_tags_prompt_template,
timeout_sec=30,
)
return new_items
except RuntimeError as ex:
logger.warning(f"[MultiModalFine] Memory version hook unavailable: {ex}")
except Exception as ex:
logger.warning(f"[MultiModalFine] Fine memory version pipeline failed: {ex}")

# ========== Stage 1: Normal extraction (without reference) ==========
try:
resp = self._get_llm_response(mem_str, custom_tags, sources, prompt_type)
Expand All @@ -727,14 +694,15 @@ def _process_one_item(
if resp.get("memory list", []):
for m in resp.get("memory list", []):
try:
# Check and merge with similar memories if needed
m_maybe_merged = self._get_maybe_merged_memory(
extracted_memory_dict=m,
mem_text=m.get("value", ""),
sources=sources,
original_query=mem_str,
**kwargs,
)
m_maybe_merged = m
if getattr(self, "memory_version_switch", "off") != "on":
m_maybe_merged = self._get_maybe_merged_memory(
extracted_memory_dict=m,
mem_text=m.get("value", ""),
sources=sources,
original_query=mem_str,
**kwargs,
)
# Normalize memory_type (same as simple_struct)
memory_type = (
m_maybe_merged.get("memory_type", "LongTermMemory")
Expand All @@ -752,8 +720,10 @@ def _process_one_item(
background=resp.get("summary", ""),
**extra_kwargs,
)
# Add merged_from to info if present
if "merged_from" in m_maybe_merged:
if (
getattr(self, "memory_version_switch", "off") != "on"
and "merged_from" in m_maybe_merged
):
node.metadata.info = node.metadata.info or {}
node.metadata.info["merged_from"] = m_maybe_merged["merged_from"]
fine_items.append(node)
Expand All @@ -762,13 +732,15 @@ def _process_one_item(
elif resp.get("value") and resp.get("key"):
try:
# Check and merge with similar memories if needed
resp_maybe_merged = self._get_maybe_merged_memory(
extracted_memory_dict=resp,
mem_text=resp.get("value", "").strip(),
sources=sources,
original_query=mem_str,
**kwargs,
)
resp_maybe_merged = resp
if getattr(self, "memory_version_switch", "off") != "on":
resp_maybe_merged = self._get_maybe_merged_memory(
extracted_memory_dict=resp,
mem_text=resp.get("value", "").strip(),
sources=sources,
original_query=mem_str,
**kwargs,
)
node = self._make_memory_item(
value=resp_maybe_merged.get("value", "").strip(),
info=info_per_item,
Expand All @@ -779,8 +751,10 @@ def _process_one_item(
background=resp.get("summary", ""),
**extra_kwargs,
)
# Add merged_from to info if present
if "merged_from" in resp_maybe_merged:
if (
getattr(self, "memory_version_switch", "off") != "on"
and "merged_from" in resp_maybe_merged
):
node.metadata.info = node.metadata.info or {}
node.metadata.info["merged_from"] = resp_maybe_merged["merged_from"]
fine_items.append(node)
Expand Down
18 changes: 11 additions & 7 deletions src/memos/memories/textual/item.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ class ArchivedTextualMemory(BaseModel):
memory: str | None = Field(
default_factory=lambda: "", description="The content of the archived version of the memory."
)
update_type: Literal["conflict", "duplicate", "extract", "unrelated"] = Field(
update_type: Literal["conflict", "duplicate", "extract", "unrelated", "feedback"] = Field(
default="unrelated",
description="The type of the memory (e.g., `conflict`, `duplicate`, `extract`, `unrelated`).",
description="The type of the memory (e.g., `conflict`, `duplicate`, `extract`, `unrelated`, `feedback`).",
)
archived_memory_id: str | None = Field(
default=None,
Expand Down Expand Up @@ -106,15 +106,15 @@ class TextualMemoryMetadata(BaseModel):
default=None,
description="Whether or not the memory was created in fast mode, carrying raw memory contents that haven't been edited by llms yet.",
)
evolve_to: list[str] | None = Field(
evolve_to: list[str] = Field(
default_factory=list,
description="Only valid if a node was once a (raw)fast node. Recording which new memory nodes it 'evolves' to after llm extraction.",
description="Recording which new memory nodes it 'evolves' to after llm extraction.",
)
version: int | None = Field(
default=None,
version: int = Field(
default=1,
description="The version of the memory. Will be incremented when the memory is updated.",
)
history: list[ArchivedTextualMemory] | None = Field(
history: list[ArchivedTextualMemory] = Field(
default_factory=list,
description="Storing the archived versions of the memory. Only preserving core information of each version.",
)
Expand Down Expand Up @@ -146,6 +146,10 @@ class TextualMemoryMetadata(BaseModel):
default=None,
description="Arbitrary key-value pairs for additional metadata.",
)
internal_info: dict | None = Field(
default=None,
description="Internal algorithm metadata reserved for system use.",
)

model_config = ConfigDict(extra="allow")

Expand Down
Loading
Loading