Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion evaluation/.env-example
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,3 @@ MEMU_API_KEY="mu_xxx"
SUPERMEMORY_API_KEY="sm_xxx"
MEMOBASE_API_KEY="xxx"
MEMOBASE_PROJECT_URL="http://***.***.***.***:8019"

128 changes: 83 additions & 45 deletions src/memos/api/routers/server_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
ExtractorFactory,
RetrieverFactory,
)
from memos.memories.textual.simple_preference import SimplePreferenceTextMemory
from memos.memories.textual.simple_tree import SimpleTreeTextMemory
from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager
from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import (
InternetRetrieverFactory,
Expand Down Expand Up @@ -195,18 +197,43 @@ def init_server():
internet_retriever = InternetRetrieverFactory.from_config(
internet_retriever_config, embedder=embedder
)

# Initialize memory manager
memory_manager = MemoryManager(
graph_db,
embedder,
llm,
memory_size=_get_default_memory_size(default_cube_config),
is_reorganize=getattr(default_cube_config.text_mem.config, "reorganize", False),
)

# Initialize text memory
text_mem = SimpleTreeTextMemory(
llm=llm,
embedder=embedder,
mem_reader=mem_reader,
graph_db=graph_db,
reranker=reranker,
memory_manager=memory_manager,
config=default_cube_config.text_mem.config,
internet_retriever=internet_retriever,
)

pref_extractor = ExtractorFactory.from_config(
config_factory=pref_extractor_config,
llm_provider=llm,
embedder=embedder,
vector_db=vector_db,
)

pref_adder = AdderFactory.from_config(
config_factory=pref_adder_config,
llm_provider=llm,
embedder=embedder,
vector_db=vector_db,
text_mem=text_mem,
)

pref_retriever = RetrieverFactory.from_config(
config_factory=pref_retriever_config,
llm_provider=llm,
Expand All @@ -215,33 +242,29 @@ def init_server():
vector_db=vector_db,
)

# Initialize memory manager
memory_manager = MemoryManager(
graph_db,
embedder,
llm,
memory_size=_get_default_memory_size(default_cube_config),
is_reorganize=getattr(default_cube_config.text_mem.config, "reorganize", False),
# Initialize preference memory
pref_mem = SimplePreferenceTextMemory(
extractor_llm=llm,
vector_db=vector_db,
embedder=embedder,
reranker=reranker,
extractor=pref_extractor,
adder=pref_adder,
retriever=pref_retriever,
)

mos_server = MOSServer(
mem_reader=mem_reader,
llm=llm,
online_bot=False,
)

# Create MemCube with pre-initialized memory instances
naive_mem_cube = NaiveMemCube(
llm=llm,
embedder=embedder,
mem_reader=mem_reader,
graph_db=graph_db,
reranker=reranker,
internet_retriever=internet_retriever,
memory_manager=memory_manager,
default_cube_config=default_cube_config,
vector_db=vector_db,
pref_extractor=pref_extractor,
pref_adder=pref_adder,
pref_retriever=pref_retriever,
text_mem=text_mem,
pref_mem=pref_mem,
act_mem=None,
para_mem=None,
)

# Initialize Scheduler
Expand Down Expand Up @@ -279,6 +302,8 @@ def init_server():
pref_extractor,
pref_adder,
pref_retriever,
text_mem,
pref_mem,
)


Expand All @@ -300,6 +325,8 @@ def init_server():
pref_extractor,
pref_adder,
pref_retriever,
text_mem,
pref_mem,
) = init_server()


Expand Down Expand Up @@ -361,36 +388,46 @@ def search_memories(search_req: APISearchRequest):
search_mode = search_req.mode

def _search_text():
if search_mode == SearchMode.FAST:
formatted_memories = fast_search_memories(
search_req=search_req, user_context=user_context
)
elif search_mode == SearchMode.FINE:
formatted_memories = fine_search_memories(
search_req=search_req, user_context=user_context
)
elif search_mode == SearchMode.MIXTURE:
formatted_memories = mix_search_memories(
search_req=search_req, user_context=user_context
)
else:
logger.error(f"Unsupported search mode: {search_mode}")
raise HTTPException(status_code=400, detail=f"Unsupported search mode: {search_mode}")
return formatted_memories
try:
if search_mode == SearchMode.FAST:
formatted_memories = fast_search_memories(
search_req=search_req, user_context=user_context
)
elif search_mode == SearchMode.FINE:
formatted_memories = fine_search_memories(
search_req=search_req, user_context=user_context
)
elif search_mode == SearchMode.MIXTURE:
formatted_memories = mix_search_memories(
search_req=search_req, user_context=user_context
)
else:
logger.error(f"Unsupported search mode: {search_mode}")
raise HTTPException(
status_code=400, detail=f"Unsupported search mode: {search_mode}"
)
return formatted_memories
except Exception as e:
logger.error("Error in search_text: %s; traceback: %s", e, traceback.format_exc())
return []

def _search_pref():
if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true":
return []
results = naive_mem_cube.pref_mem.search(
query=search_req.query,
top_k=search_req.pref_top_k,
info={
"user_id": search_req.user_id,
"session_id": search_req.session_id,
"chat_history": search_req.chat_history,
},
)
return [_format_memory_item(data) for data in results]
try:
results = naive_mem_cube.pref_mem.search(
query=search_req.query,
top_k=search_req.pref_top_k,
info={
"user_id": search_req.user_id,
"session_id": search_req.session_id,
"chat_history": search_req.chat_history,
},
)
return [_format_memory_item(data) for data in results]
except Exception as e:
logger.error("Error in _search_pref: %s; traceback: %s", e, traceback.format_exc())
return []

with ContextThreadPoolExecutor(max_workers=2) as executor:
text_future = executor.submit(_search_text)
Expand Down Expand Up @@ -601,6 +638,7 @@ def _process_pref_mem() -> list[dict[str, str]]:
info={
"user_id": add_req.user_id,
"session_id": target_session_id,
"mem_cube_id": add_req.mem_cube_id,
},
)
pref_ids_local: list[str] = naive_mem_cube.pref_mem.add(pref_memories_local)
Expand Down
62 changes: 13 additions & 49 deletions src/memos/mem_cube/navie.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,13 @@

from typing import Literal

from memos.configs.mem_cube import GeneralMemCubeConfig
from memos.configs.utils import get_json_file_model_schema
from memos.embedders.base import BaseEmbedder
from memos.exceptions import ConfigurationError, MemCubeError
from memos.graph_dbs.base import BaseGraphDB
from memos.llms.base import BaseLLM
from memos.log import get_logger
from memos.mem_cube.base import BaseMemCube
from memos.mem_reader.base import BaseMemReader
from memos.memories.activation.base import BaseActMemory
from memos.memories.parametric.base import BaseParaMemory
from memos.memories.textual.base import BaseTextMemory
from memos.memories.textual.prefer_text_memory.adder import BaseAdder
from memos.memories.textual.prefer_text_memory.extractor import BaseExtractor
from memos.memories.textual.prefer_text_memory.retrievers import BaseRetriever
from memos.memories.textual.simple_preference import SimplePreferenceTextMemory
from memos.memories.textual.simple_tree import SimpleTreeTextMemory
from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager
from memos.reranker.base import BaseReranker
from memos.vec_dbs.base import BaseVecDB


logger = get_logger(__name__)
Expand All @@ -32,51 +19,28 @@ class NaiveMemCube(BaseMemCube):

def __init__(
self,
llm: BaseLLM,
embedder: BaseEmbedder,
mem_reader: BaseMemReader,
graph_db: BaseGraphDB,
reranker: BaseReranker,
memory_manager: MemoryManager,
default_cube_config: GeneralMemCubeConfig,
vector_db: BaseVecDB,
internet_retriever: None = None,
pref_extractor: BaseExtractor | None = None,
pref_adder: BaseAdder | None = None,
pref_retriever: BaseRetriever | None = None,
text_mem: BaseTextMemory | None = None,
pref_mem: BaseTextMemory | None = None,
act_mem: BaseActMemory | None = None,
para_mem: BaseParaMemory | None = None,
):
"""Initialize the MemCube with a configuration."""
self._text_mem: BaseTextMemory | None = SimpleTreeTextMemory(
llm,
embedder,
mem_reader,
graph_db,
reranker,
memory_manager,
default_cube_config.text_mem.config,
internet_retriever,
)
self._act_mem: BaseActMemory | None = None
self._para_mem: BaseParaMemory | None = None
self._pref_mem: BaseTextMemory | None = SimplePreferenceTextMemory(
extractor_llm=llm,
vector_db=vector_db,
embedder=embedder,
reranker=reranker,
extractor=pref_extractor,
adder=pref_adder,
retriever=pref_retriever,
)
"""Initialize the MemCube with memory instances."""
self._text_mem: BaseTextMemory = text_mem
self._act_mem: BaseActMemory | None = act_mem
self._para_mem: BaseParaMemory | None = para_mem
self._pref_mem: BaseTextMemory | None = pref_mem

def load(
self, dir: str, memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None
self,
dir: str,
memory_types: list[Literal["text_mem", "act_mem", "para_mem", "pref_mem"]] | None = None,
) -> None:
"""Load memories.
Args:
dir (str): The directory containing the memory files.
memory_types (list[str], optional): List of memory types to load.
If None, loads all available memory types.
Options: ["text_mem", "act_mem", "para_mem"]
Options: ["text_mem", "act_mem", "para_mem", "pref_mem"]
"""
loaded_schema = get_json_file_model_schema(os.path.join(dir, self.config.config_filename))
if loaded_schema != self.config.model_schema:
Expand Down
1 change: 1 addition & 0 deletions src/memos/memories/textual/item.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ class PreferenceTextualMemoryMetadata(TextualMemoryMetadata):
embedding: list[float] | None = Field(default=None, description="Vector of the dialog.")
preference: str | None = Field(default=None, description="Preference.")
created_at: str | None = Field(default=None, description="Timestamp of the dialog.")
mem_cube_id: str | None = Field(default=None, description="ID of the MemCube.")


class TextualMemoryItem(BaseModel):
Expand Down
Loading
Loading