Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,10 @@ def post_retrieve(
top_k: int,
user_name: str | None = None,
info=None,
plugin=False,
):
deduped = self._deduplicate_results(retrieved_results)
final_results = self._sort_and_trim(deduped, top_k)
final_results = self._sort_and_trim(deduped, top_k, plugin)
self._update_usage_history(final_results, info, user_name)
return final_results

Expand Down Expand Up @@ -170,6 +171,7 @@ def search(
top_k=top_k,
user_name=user_name,
info=None,
plugin=kwargs.get("plugin", False),
)

logger.info(f"[SEARCH] Done. Total {len(final_results)} results.")
Expand Down Expand Up @@ -281,6 +283,9 @@ def _retrieve_simple(
similarity_matrix = cosine_similarity_matrix(documents_embeddings)
selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix)
selected_items = [items[i] for i in selected_indices]
logger.info(
f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}"
)
return self.reranker.rerank(
query=query,
query_embedding=query_embeddings[0],
Expand Down Expand Up @@ -540,12 +545,14 @@ def _deduplicate_results(self, results):
return list(deduped.values())

@timed
def _sort_and_trim(self, results, top_k):
def _sort_and_trim(self, results, top_k, plugin=False):
"""Sort results by score and trim to top_k"""

sorted_results = sorted(results, key=lambda pair: pair[1], reverse=True)[:top_k]
final_items = []
for item, score in sorted_results:
if plugin and round(score, 2) == 0.00:
continue
meta_data = item.metadata.model_dump()
meta_data["relativity"] = score
final_items.append(
Expand Down
Loading