From 9fb206e86f1c2f6e4c23a47b0a1a41f4d53c3bec Mon Sep 17 00:00:00 2001 From: zhizhi <928570418@qq.com> Date: Wed, 15 Apr 2026 15:01:06 +0800 Subject: [PATCH 1/5] =?UTF-8?q?=E2=9C=A8=20Add=20display=20name=20to=20ind?= =?UTF-8?q?ex=20name=20mapping=20for=20KnowledgeBaseSearchTool?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Introduced `get_knowledge_name_map_by_index_names` function to retrieve a mapping of index names to their corresponding display names. - Updated `create_agent_config` and `create_tool_config_list` to utilize the new mapping for generating user-friendly summaries. - Enhanced `KnowledgeBaseSearchTool` to support conversion from display names to index names during queries. - Added tests to verify the functionality of the new mapping and its integration within the tool configuration process. --- backend/agents/create_agent_info.py | 17 +- backend/database/knowledge_db.py | 39 ++ .../services/tool_configuration_service.py | 10 + sdk/nexent/core/agents/nexent_agent.py | 4 +- .../core/tools/knowledge_base_search_tool.py | 39 ++ test/backend/agents/test_create_agent_info.py | 298 ++++++++++- test/backend/database/test_knowledge_db.py | 140 ++++- .../test_tool_configuration_service.py | 78 ++- test/sdk/core/agents/test_nexent_agent.py | 116 +++++ .../tools/test_knowledge_base_search_tool.py | 489 ++++++++++++++++++ 10 files changed, 1197 insertions(+), 33 deletions(-) diff --git a/backend/agents/create_agent_info.py b/backend/agents/create_agent_info.py index faa1aa2d8..6bef02bef 100644 --- a/backend/agents/create_agent_info.py +++ b/backend/agents/create_agent_info.py @@ -23,12 +23,12 @@ from database.agent_version_db import query_current_version_no from database.tool_db import search_tools_for_sub_agent from database.model_management_db import get_model_records, get_model_by_model_id +from database.knowledge_db import get_knowledge_name_map_by_index_names from database.client import minio_client from utils.model_name_utils import add_repo_to_name from utils.prompt_template_utils import get_agent_prompt_template from utils.config_utils import tenant_config_manager, get_model_name_from_config from consts.const import LOCAL_MCP_SERVER, MODEL_CONFIG_MAPPING, LANGUAGE, DATA_PROCESS_SERVICE -import re logger = logging.getLogger("create_agent_info") logger.setLevel(logging.DEBUG) @@ -262,11 +262,14 @@ async def create_agent_config( if "KnowledgeBaseSearchTool" == tool.class_name: index_names = tool.params.get("index_names") if index_names: + # Batch query to get display names (knowledge_name) for all index_names + knowledge_name_map = get_knowledge_name_map_by_index_names(index_names) for index_name in index_names: try: + display_name = knowledge_name_map.get(index_name, index_name) message = ElasticSearchService().get_summary(index_name=index_name) summary = message.get("summary", "") - knowledge_base_summary += f"**{index_name}**: {summary}\n\n" + knowledge_base_summary += f"**{display_name}**: {summary}\n\n" except Exception as e: logger.warning( f"Failed to get summary for knowledge base {index_name}: {e}") @@ -359,10 +362,20 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int tenant_id=tenant_id, model_name=rerank_model_name ) + # Build display_name to index_name mapping for LLM parameter conversion + index_names = param_dict.get("index_names", []) + display_name_to_index_map = {} + if index_names: + knowledge_name_map = get_knowledge_name_map_by_index_names(index_names) + # Reverse the mapping: display_name (knowledge_name) -> index_name + for idx_name, kb_name in knowledge_name_map.items(): + display_name_to_index_map[kb_name] = idx_name + tool_config.metadata = { "vdb_core": get_vector_db_core(), "embedding_model": get_embedding_model(tenant_id=tenant_id), "rerank_model": rerank_model, + "display_name_to_index_map": display_name_to_index_map, } elif tool_config.class_name in ["DifySearchTool", "DataMateSearchTool"]: rerank = param_dict.get("rerank", False) diff --git a/backend/database/knowledge_db.py b/backend/database/knowledge_db.py index df42e1888..0d13eb9f7 100644 --- a/backend/database/knowledge_db.py +++ b/backend/database/knowledge_db.py @@ -374,3 +374,42 @@ def get_index_name_by_knowledge_name(knowledge_name: str, tenant_id: str) -> str ) except SQLAlchemyError as e: raise e + + +def get_knowledge_name_map_by_index_names(index_names: List[str]) -> Dict[str, str]: + """ + Get a mapping from index_name to knowledge_name (display name) for the given index_names. + Used to build user-friendly knowledge base summaries in prompts. + + Args: + index_names: List of internal index names + + Returns: + Dict[str, str]: Mapping of index_name -> knowledge_name. + If a knowledge base is not found in the database, + the index_name itself is used as the fallback value. + """ + if not index_names: + return {} + + try: + with get_db_session() as session: + result = session.query( + KnowledgeRecord.index_name, + KnowledgeRecord.knowledge_name + ).filter( + KnowledgeRecord.index_name.in_(index_names), + KnowledgeRecord.delete_flag != 'Y' + ).all() + + knowledge_name_map = {} + for row in result: + knowledge_name_map[row.index_name] = row.knowledge_name + + for index_name in index_names: + if index_name not in knowledge_name_map: + knowledge_name_map[index_name] = index_name + + return knowledge_name_map + except SQLAlchemyError as e: + raise e diff --git a/backend/services/tool_configuration_service.py b/backend/services/tool_configuration_service.py index d7240db26..233faa6b7 100644 --- a/backend/services/tool_configuration_service.py +++ b/backend/services/tool_configuration_service.py @@ -36,6 +36,7 @@ search_last_tool_instance_by_tool_id, update_tool_table_from_scan_tool_list, ) +from database.knowledge_db import get_knowledge_name_map_by_index_names from mcpadapt.smolagents_adapter import _sanitize_function_name from services.file_management_service import get_llm_model from services.vectordatabase_service import get_embedding_model, get_rerank_model, get_vector_db_core @@ -714,11 +715,20 @@ def _validate_local_tool( if rerank and rerank_model_name: rerank_model = get_rerank_model(tenant_id=tenant_id, model_name=rerank_model_name) + # Build display_name to index_name mapping for LLM parameter conversion + index_names = instantiation_params.get("index_names", []) + display_name_to_index_map = {} + if index_names: + knowledge_name_map = get_knowledge_name_map_by_index_names(index_names) + for idx_name, kb_name in knowledge_name_map.items(): + display_name_to_index_map[kb_name] = idx_name + params = { **instantiation_params, 'vdb_core': vdb_core, 'embedding_model': embedding_model, 'rerank_model': rerank_model, + 'display_name_to_index_map': display_name_to_index_map, } tool_instance = tool_class(**params) elif tool_name in ["dify_search", "datamate_search"]: diff --git a/sdk/nexent/core/agents/nexent_agent.py b/sdk/nexent/core/agents/nexent_agent.py index 3878e05dd..7a9c67bdc 100644 --- a/sdk/nexent/core/agents/nexent_agent.py +++ b/sdk/nexent/core/agents/nexent_agent.py @@ -73,7 +73,7 @@ def create_local_tool(self, tool_config: ToolConfig): # These parameters have exclude=True and cannot be passed to __init__ # due to smolagents.tools.Tool wrapper restrictions filtered_params = {k: v for k, v in params.items() - if k not in ["vdb_core", "embedding_model", "observer", "rerank_model"]} + if k not in ["vdb_core", "embedding_model", "observer", "rerank_model", "display_name_to_index_map"]} # Create instance with only non-excluded parameters tools_obj = tool_class(**filtered_params) # Set excluded parameters directly as attributes after instantiation @@ -85,6 +85,8 @@ def create_local_tool(self, tool_config: ToolConfig): "embedding_model", None) if tool_config.metadata else None tools_obj.rerank_model = tool_config.metadata.get( "rerank_model", None) if tool_config.metadata else None + tools_obj.display_name_to_index_map = tool_config.metadata.get( + "display_name_to_index_map", {}) if tool_config.metadata else {} elif class_name in ["DifySearchTool", "DataMateSearchTool"]: # These parameters have exclude=True and cannot be passed to __init__ filtered_params = {k: v for k, v in params.items() diff --git a/sdk/nexent/core/tools/knowledge_base_search_tool.py b/sdk/nexent/core/tools/knowledge_base_search_tool.py index a8863caaf..dcb4c4cef 100644 --- a/sdk/nexent/core/tools/knowledge_base_search_tool.py +++ b/sdk/nexent/core/tools/knowledge_base_search_tool.py @@ -86,12 +86,18 @@ def __init__( description="The rerank model to use", default=None, exclude=True), vdb_core: VectorDatabaseCore = Field( description="Vector database client", default=None, exclude=True), + display_name_to_index_map: dict = Field( + description="Mapping from display_name (knowledge_name) to index_name", + default_factory=dict, exclude=True), ): """Initialize the KBSearchTool. Args: top_k (int, optional): Number of results to return. Defaults to 3. observer (MessageObserver, optional): Message observer instance. Defaults to None. + display_name_to_index_map (dict, optional): Mapping from display_name to index_name. + When LLM passes display_name as index_names parameter, it will be converted + to the actual index_name for ES queries. Raises: ValueError: If language is not supported @@ -106,16 +112,49 @@ def __init__( self.rerank = rerank self.rerank_model_name = rerank_model_name self.rerank_model = rerank_model + self.display_name_to_index_map = display_name_to_index_map self.record_ops = 1 # To record serial number self.running_prompt_zh = "知识库检索中..." self.running_prompt_en = "Searching the knowledge base..." + def _convert_to_index_names(self, names: List[str]) -> List[str]: + """Convert display names (knowledge_name) to index names if necessary. + + When LLM passes display_name as the index_names parameter, + this method converts it to the actual index_name for ES queries. + + Args: + names: List of names that could be either display_name or index_name + + Returns: + List of actual index_names for ES queries + """ + # Handle FieldInfo case (smolagents doesn't expand Field defaults) + display_map = self.display_name_to_index_map + if isinstance(display_map, FieldInfo): + display_map = display_map.default + if not display_map: + return names + + converted_names = [] + for name in names: + # If the name is in the map as a display_name, convert it to index_name + if name in display_map: + converted_names.append(display_map[name]) + else: + # Otherwise, assume it's already an index_name + converted_names.append(name) + return converted_names + def forward(self, query: str, index_names: Optional[List[str]] = None) -> str: # Parse index_names from string (always required) search_index_names = index_names if index_names is not None else self.index_names + # Convert display names to index names if necessary + search_index_names = self._convert_to_index_names(search_index_names) + # Use the instance search_mode search_mode = self.search_mode diff --git a/test/backend/agents/test_create_agent_info.py b/test/backend/agents/test_create_agent_info.py index 7d4706c5e..0db11733e 100644 --- a/test/backend/agents/test_create_agent_info.py +++ b/test/backend/agents/test_create_agent_info.py @@ -90,6 +90,8 @@ def _create_stub_module(name: str, **attrs): sys.modules['database.tool_db'] = MagicMock() sys.modules['database.model_management_db'] = MagicMock() sys.modules['database.agent_version_db'] = MagicMock() +sys.modules['database.knowledge_db'] = MagicMock() +sys.modules['database.knowledge_db'].get_knowledge_name_map_by_index_names = MagicMock() sys.modules['services.vectordatabase_service'] = MagicMock() sys.modules['services.tenant_config_service'] = MagicMock() sys.modules['utils.prompt_template_utils'] = MagicMock() @@ -731,13 +733,13 @@ async def test_create_tool_config_list_with_knowledge_base_tool_metadata(self): mock_get_vector_db_core.assert_called_once() mock_embedding.assert_called_once_with(tenant_id="tenant_1") - # Verify metadata contains vdb_core, embedding_model and rerank_model - expected_metadata = { - "vdb_core": mock_vdb_core, - "embedding_model": mock_embedding_model, - "rerank_model": mock_rerank.return_value, - } - assert mock_tool_instance.metadata == expected_metadata + # Verify metadata contains vdb_core, embedding_model, rerank_model and display_name_to_index_map + assert "vdb_core" in mock_tool_instance.metadata + assert "embedding_model" in mock_tool_instance.metadata + assert "rerank_model" in mock_tool_instance.metadata + assert "display_name_to_index_map" in mock_tool_instance.metadata + # display_name_to_index_map should be empty dict when index_names is empty + assert mock_tool_instance.metadata["display_name_to_index_map"] == {} # Explicitly verify that old fields are NOT present assert "index_names" not in mock_tool_instance.metadata @@ -798,12 +800,11 @@ async def test_create_tool_config_list_with_knowledge_base_tool_multiple_tools(s assert len(result) == 2 - # Verify KnowledgeBaseSearchTool has correct metadata - assert mock_tool_kb.metadata == { - "vdb_core": "vdb_core_instance", - "embedding_model": "embedding_instance", - "rerank_model": mock_rerank.return_value, - } + # Verify KnowledgeBaseSearchTool has correct metadata including display_name_to_index_map + assert "vdb_core" in mock_tool_kb.metadata + assert "embedding_model" in mock_tool_kb.metadata + assert "rerank_model" in mock_tool_kb.metadata + assert "display_name_to_index_map" in mock_tool_kb.metadata # Verify OtherTool has no special metadata (should not have metadata attribute set) # Note: MagicMock will return a new MagicMock for unset attributes, so we check call_args @@ -851,11 +852,9 @@ async def test_create_tool_config_list_with_knowledge_base_tool_mixed_sources(se assert len(result) == 1 # Even for MCP-sourced KnowledgeBaseSearchTool, metadata should be set - assert mock_tool_instance.metadata == { - "vdb_core": "vdb_core", - "embedding_model": "embedding", - "rerank_model": mock_rerank.return_value, - } + assert "vdb_core" in mock_tool_instance.metadata + assert "embedding_model" in mock_tool_instance.metadata + assert "display_name_to_index_map" in mock_tool_instance.metadata @pytest.mark.asyncio async def test_create_tool_config_list_with_datamate_tool(self): @@ -1000,14 +999,13 @@ async def test_create_tool_config_list_multiple_tools_same_type(self): assert len(result) == 2 - # Both tools should have the same simplified metadata - expected_metadata = { - "vdb_core": "vdb_core", - "embedding_model": "embedding", - "rerank_model": mock_rerank.return_value, - } - assert mock_tool_1.metadata == expected_metadata - assert mock_tool_2.metadata == expected_metadata + # Both tools should have the same metadata including display_name_to_index_map + assert "vdb_core" in mock_tool_1.metadata + assert "embedding_model" in mock_tool_1.metadata + assert "rerank_model" in mock_tool_1.metadata + assert "display_name_to_index_map" in mock_tool_1.metadata + assert mock_tool_1.metadata["display_name_to_index_map"] == {} + assert mock_tool_2.metadata["display_name_to_index_map"] == {} @pytest.mark.asyncio async def test_create_tool_config_list_with_dify_tool(self): @@ -1899,6 +1897,9 @@ async def test_create_agent_config_with_knowledge_base_summary_filtering(self): patch( "backend.agents.create_agent_info._get_skill_script_tools" ) as mock_get_skill_tools, + patch( + "backend.agents.create_agent_info.get_knowledge_name_map_by_index_names" + ) as mock_get_knowledge_name_map, ): mock_search_agent.return_value = { "name": "test_agent", @@ -1941,6 +1942,8 @@ async def test_create_agent_config_with_knowledge_base_summary_filtering(self): mock_get_model_by_id.return_value = {"display_name": "test_model"} mock_get_skills.return_value = [] mock_get_skill_tools.return_value = [] + # Mock knowledge_name_map to return index_name as fallback + mock_get_knowledge_name_map.return_value = {"idx_a": "idx_a", "idx_b": "idx_b"} mock_es_instance = Mock() mock_es_instance.get_summary.side_effect = [ @@ -2981,5 +2984,248 @@ async def test_prepare_prompt_templates_worker_en(self): assert result["test"] == "template" +class TestCreateToolConfigListWithDisplayNameMap: + """Tests for create_tool_config_list with display_name_to_index_map functionality""" + + @pytest.mark.asyncio + async def test_knowledge_base_with_display_name_to_index_map(self): + """Test that KnowledgeBaseSearchTool gets correct display_name_to_index_map from index_names""" + mock_tool_instance = MagicMock() + mock_tool_instance.class_name = "KnowledgeBaseSearchTool" + + with patch('backend.agents.create_agent_info.ToolConfig') as mock_tool_config, \ + patch('backend.agents.create_agent_info.discover_langchain_tools', return_value=[]), \ + patch('backend.agents.create_agent_info.search_tools_for_sub_agent') as mock_search_tools, \ + patch('backend.agents.create_agent_info.get_vector_db_core') as mock_get_vector_db_core, \ + patch('backend.agents.create_agent_info.get_embedding_model') as mock_embedding, \ + patch('backend.agents.create_agent_info.get_rerank_model') as mock_rerank, \ + patch('backend.agents.create_agent_info.get_knowledge_name_map_by_index_names') as mock_get_knowledge_map: + + mock_tool_config.return_value = mock_tool_instance + + mock_search_tools.return_value = [ + { + "class_name": "KnowledgeBaseSearchTool", + "name": "knowledge_search", + "description": "Knowledge search tool", + "inputs": "string", + "output_type": "string", + "params": [ + {"name": "index_names", "default": ["idx1", "idx2"]}, + {"name": "rerank", "default": False}, + ], + "source": "local", + "usage": None + } + ] + mock_get_vector_db_core.return_value = "vdb_core_instance" + mock_embedding.return_value = "embedding_instance" + mock_rerank.return_value = None + # Mock the knowledge name map: index_name -> knowledge_name (display_name) + mock_get_knowledge_map.return_value = { + "idx1": "Knowledge Base 1", + "idx2": "Knowledge Base 2" + } + + result = await create_tool_config_list("agent_1", "tenant_1", "user_1") + + assert len(result) == 1 + # Verify get_knowledge_name_map_by_index_names was called + mock_get_knowledge_map.assert_called_once_with(["idx1", "idx2"]) + # Verify display_name_to_index_map contains reversed mapping + assert result[0].metadata["display_name_to_index_map"] == { + "Knowledge Base 1": "idx1", + "Knowledge Base 2": "idx2" + } + + @pytest.mark.asyncio + async def test_knowledge_base_with_empty_index_names(self): + """Test that KnowledgeBaseSearchTool gets empty display_name_to_index_map when no index_names""" + mock_tool_instance = MagicMock() + mock_tool_instance.class_name = "KnowledgeBaseSearchTool" + + with patch('backend.agents.create_agent_info.ToolConfig') as mock_tool_config, \ + patch('backend.agents.create_agent_info.discover_langchain_tools', return_value=[]), \ + patch('backend.agents.create_agent_info.search_tools_for_sub_agent') as mock_search_tools, \ + patch('backend.agents.create_agent_info.get_vector_db_core') as mock_get_vector_db_core, \ + patch('backend.agents.create_agent_info.get_embedding_model') as mock_embedding, \ + patch('backend.agents.create_agent_info.get_rerank_model') as mock_rerank, \ + patch('backend.agents.create_agent_info.get_knowledge_name_map_by_index_names') as mock_get_knowledge_map: + + mock_tool_config.return_value = mock_tool_instance + + mock_search_tools.return_value = [ + { + "class_name": "KnowledgeBaseSearchTool", + "name": "knowledge_search", + "description": "Knowledge search tool", + "inputs": "string", + "output_type": "string", + "params": [ + {"name": "index_names", "default": []}, + {"name": "rerank", "default": False}, + ], + "source": "local", + "usage": None + } + ] + mock_get_vector_db_core.return_value = "vdb_core_instance" + mock_embedding.return_value = "embedding_instance" + mock_rerank.return_value = None + + result = await create_tool_config_list("agent_1", "tenant_1", "user_1") + + # get_knowledge_name_map_by_index_names should NOT be called with empty index_names + mock_get_knowledge_map.assert_not_called() + assert result[0].metadata["display_name_to_index_map"] == {} + + @pytest.mark.asyncio + async def test_knowledge_base_with_partial_name_mapping(self): + """Test that KnowledgeBaseSearchTool handles partial name mapping correctly""" + mock_tool_instance = MagicMock() + mock_tool_instance.class_name = "KnowledgeBaseSearchTool" + + with patch('backend.agents.create_agent_info.ToolConfig') as mock_tool_config, \ + patch('backend.agents.create_agent_info.discover_langchain_tools', return_value=[]), \ + patch('backend.agents.create_agent_info.search_tools_for_sub_agent') as mock_search_tools, \ + patch('backend.agents.create_agent_info.get_vector_db_core') as mock_get_vector_db_core, \ + patch('backend.agents.create_agent_info.get_embedding_model') as mock_embedding, \ + patch('backend.agents.create_agent_info.get_rerank_model') as mock_rerank, \ + patch('backend.agents.create_agent_info.get_knowledge_name_map_by_index_names') as mock_get_knowledge_map: + + mock_tool_config.return_value = mock_tool_instance + + mock_search_tools.return_value = [ + { + "class_name": "KnowledgeBaseSearchTool", + "name": "knowledge_search", + "description": "Knowledge search tool", + "inputs": "string", + "output_type": "string", + "params": [ + {"name": "index_names", "default": ["idx1", "idx2", "idx3"]}, + {"name": "rerank", "default": False}, + ], + "source": "local", + "usage": None + } + ] + mock_get_vector_db_core.return_value = "vdb_core_instance" + mock_embedding.return_value = "embedding_instance" + mock_rerank.return_value = None + # Only idx1 is found in database, idx2 and idx3 are not found + mock_get_knowledge_map.return_value = { + "idx1": "Knowledge Base 1" + } + + result = await create_tool_config_list("agent_1", "tenant_1", "user_1") + + # display_name_to_index_map should only contain the found mappings + # Unfound indices will use index_name as fallback (which is not in get_knowledge_name_map result) + assert "Knowledge Base 1" in result[0].metadata["display_name_to_index_map"] + + +class TestFilterMcpServersAndTools: + """Tests for filter_mcp_servers_and_tools function""" + + def test_filter_mcp_servers_with_multiple_tools(self): + """Test filtering with multiple MCP tools""" + mock_tool1 = MagicMock() + mock_tool1.source = "mcp" + mock_tool1.usage = "server1" + + mock_tool2 = MagicMock() + mock_tool2.source = "local" + mock_tool2.usage = None + + mock_tool3 = MagicMock() + mock_tool3.source = "mcp" + mock_tool3.usage = "server2" + + mock_sub_agent = MagicMock() + mock_sub_agent.tools = [] + mock_sub_agent.managed_agents = [] + + mock_agent_config = MagicMock() + mock_agent_config.tools = [mock_tool1, mock_tool2, mock_tool3] + mock_agent_config.managed_agents = [mock_sub_agent] + + mcp_info_dict = { + "server1": {"remote_mcp_server": "http://server1.example.com"}, + "server2": {"remote_mcp_server": "http://server2.example.com"}, + } + + result = filter_mcp_servers_and_tools(mock_agent_config, mcp_info_dict) + + assert len(result) == 2 + assert "http://server1.example.com" in result + assert "http://server2.example.com" in result + + def test_filter_mcp_servers_with_nested_sub_agents(self): + """Test filtering with nested sub-agents""" + mock_tool1 = MagicMock() + mock_tool1.source = "mcp" + mock_tool1.usage = "nested_server" + + mock_sub_sub_agent = MagicMock() + mock_sub_sub_agent.tools = [mock_tool1] + mock_sub_sub_agent.managed_agents = [] + + mock_sub_agent = MagicMock() + mock_sub_agent.tools = [] + mock_sub_agent.managed_agents = [mock_sub_sub_agent] + + mock_agent_config = MagicMock() + mock_agent_config.tools = [] + mock_agent_config.managed_agents = [mock_sub_agent] + + mcp_info_dict = { + "nested_server": {"remote_mcp_server": "http://nested.example.com"}, + } + + result = filter_mcp_servers_and_tools(mock_agent_config, mcp_info_dict) + + assert len(result) == 1 + assert "http://nested.example.com" in result + + def test_filter_mcp_servers_with_disabled_server(self): + """Test filtering excludes servers not in mcp_info_dict""" + mock_tool1 = MagicMock() + mock_tool1.source = "mcp" + mock_tool1.usage = "enabled_server" + + mock_tool2 = MagicMock() + mock_tool2.source = "mcp" + mock_tool2.usage = "disabled_server" + + mock_agent_config = MagicMock() + mock_agent_config.tools = [mock_tool1, mock_tool2] + mock_agent_config.managed_agents = [] + + mcp_info_dict = { + "enabled_server": {"remote_mcp_server": "http://enabled.example.com"}, + # disabled_server is not in the dict + } + + result = filter_mcp_servers_and_tools(mock_agent_config, mcp_info_dict) + + assert len(result) == 1 + assert "http://enabled.example.com" in result + + def test_filter_mcp_servers_with_empty_tools(self): + """Test filtering with no tools returns empty list""" + mock_agent_config = MagicMock() + mock_agent_config.tools = [] + mock_agent_config.managed_agents = [] + + mcp_info_dict = { + "server1": {"remote_mcp_server": "http://server1.example.com"}, + } + + result = filter_mcp_servers_and_tools(mock_agent_config, mcp_info_dict) + + assert result == [] + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/backend/database/test_knowledge_db.py b/test/backend/database/test_knowledge_db.py index 496e04b19..724a62c68 100644 --- a/test/backend/database/test_knowledge_db.py +++ b/test/backend/database/test_knowledge_db.py @@ -51,7 +51,8 @@ get_index_name_by_knowledge_name, get_knowledge_info_by_tenant_and_source, upsert_knowledge_record, - _generate_index_name + _generate_index_name, + get_knowledge_name_map_by_index_names, ) @@ -1948,3 +1949,140 @@ def mock_exit(exc_type, exc_val, exc_tb): with pytest.raises(MockSQLAlchemyError, match="Database error"): get_knowledge_info_by_tenant_and_source("tenant1", "datamate") + + +def test_get_knowledge_name_map_by_index_names_success(monkeypatch, mock_session): + """Test successfully getting knowledge name map by index names""" + session, query = mock_session + + # Create mock records with index_name and knowledge_name + class MockRow: + def __init__(self, index_name, knowledge_name): + self.index_name = index_name + self.knowledge_name = knowledge_name + + mock_rows = [ + MockRow("index1", "Knowledge Base 1"), + MockRow("index2", "Knowledge Base 2"), + ] + + mock_filter = MagicMock() + mock_filter.all.return_value = mock_rows + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + + def mock_exit(exc_type, exc_val, exc_tb): + if exc_type is not None: + session.rollback() + return None + mock_ctx.__exit__.side_effect = mock_exit + monkeypatch.setattr( + "backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + index_names = ["index1", "index2"] + result = get_knowledge_name_map_by_index_names(index_names) + + expected = { + "index1": "Knowledge Base 1", + "index2": "Knowledge Base 2", + } + assert result == expected + + +def test_get_knowledge_name_map_by_index_names_with_fallback(monkeypatch, mock_session): + """Test get_knowledge_name_map_by_index_names uses index_name as fallback when not found""" + session, query = mock_session + + # Only return one of the two index names + class MockRow: + def __init__(self, index_name, knowledge_name): + self.index_name = index_name + self.knowledge_name = knowledge_name + + mock_rows = [ + MockRow("index1", "Knowledge Base 1"), + # index2 is not found in database + ] + + mock_filter = MagicMock() + mock_filter.all.return_value = mock_rows + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + + def mock_exit(exc_type, exc_val, exc_tb): + if exc_type is not None: + session.rollback() + return None + mock_ctx.__exit__.side_effect = mock_exit + monkeypatch.setattr( + "backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + index_names = ["index1", "index2"] + result = get_knowledge_name_map_by_index_names(index_names) + + expected = { + "index1": "Knowledge Base 1", + "index2": "index2", # Falls back to index_name + } + assert result == expected + + +def test_get_knowledge_name_map_by_index_names_empty_list(monkeypatch): + """Test get_knowledge_name_map_by_index_names with empty list returns empty dict""" + result = get_knowledge_name_map_by_index_names([]) + + assert result == {} + + +def test_get_knowledge_name_map_by_index_names_no_results(monkeypatch, mock_session): + """Test get_knowledge_name_map_by_index_names when no records found""" + session, query = mock_session + + mock_filter = MagicMock() + mock_filter.all.return_value = [] + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + + def mock_exit(exc_type, exc_val, exc_tb): + if exc_type is not None: + session.rollback() + return None + mock_ctx.__exit__.side_effect = mock_exit + monkeypatch.setattr( + "backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + index_names = ["nonexistent1", "nonexistent2"] + result = get_knowledge_name_map_by_index_names(index_names) + + # Should return index_names as fallback for all + expected = { + "nonexistent1": "nonexistent1", + "nonexistent2": "nonexistent2", + } + assert result == expected + + +def test_get_knowledge_name_map_by_index_names_exception(monkeypatch, mock_session): + """Test exception during get_knowledge_name_map_by_index_names""" + session, query = mock_session + query.filter.side_effect = MockSQLAlchemyError("Database error") + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + + def mock_exit(exc_type, exc_val, exc_tb): + if exc_type is not None: + session.rollback() + return None + mock_ctx.__exit__.side_effect = mock_exit + monkeypatch.setattr( + "backend.database.knowledge_db.get_db_session", lambda: mock_ctx) + + with pytest.raises(MockSQLAlchemyError, match="Database error"): + get_knowledge_name_map_by_index_names(["index1", "index2"]) diff --git a/test/backend/services/test_tool_configuration_service.py b/test/backend/services/test_tool_configuration_service.py index f970c284d..eacd33b3e 100644 --- a/test/backend/services/test_tool_configuration_service.py +++ b/test/backend/services/test_tool_configuration_service.py @@ -269,6 +269,12 @@ def validate(self): sys.modules['nexent.storage.storage_client_factory'] = storage_factory_module sys.modules['nexent.storage.minio_config'] = storage_config_module +# Mock nexent.memory module to break import chain before loading backend modules +memory_service_module = types.ModuleType('nexent.memory.memory_service') +memory_service_module.clear_memory = MagicMock() +sys.modules['nexent.memory'] = _create_package_mock('nexent.memory') +sys.modules['nexent.memory.memory_service'] = memory_service_module + # Load actual backend modules so that patch targets resolve correctly import importlib # noqa: E402 backend_module = importlib.import_module('backend') @@ -321,6 +327,7 @@ def validate(self): patch('services.tenant_config_service.build_knowledge_name_mapping', MagicMock()).start() patch('services.image_service.get_vlm_model', MagicMock()).start() +patch('backend.database.knowledge_db.get_knowledge_name_map_by_index_names', MagicMock()).start() # Import consts after patching dependencies from consts.model import ToolInfo, ToolSourceEnum, ToolInstanceInfoRequest, ToolValidateRequest # noqa: E402 @@ -2312,12 +2319,13 @@ async def test_validate_tool_langchain_tool_not_found(self, mock_validate_tool_i class TestValidateLocalToolKnowledgeBaseSearch: """Test cases for _validate_local_tool function with knowledge_base_search tool""" + @patch('backend.services.tool_configuration_service.get_knowledge_name_map_by_index_names') @patch('backend.services.tool_configuration_service._get_tool_class_by_name') @patch('backend.services.tool_configuration_service.inspect.signature') @patch('backend.services.tool_configuration_service.get_embedding_model') @patch('backend.services.tool_configuration_service.get_vector_db_core') def test_validate_local_tool_knowledge_base_search_success(self, mock_get_vector_db_core, mock_get_embedding_model, - mock_signature, mock_get_class): + mock_signature, mock_get_class, mock_get_knowledge_map): """Test successful knowledge_base_search tool validation with proper dependencies""" # Mock tool class mock_tool_class = Mock() @@ -2345,6 +2353,9 @@ def test_validate_local_tool_knowledge_base_search_success(self, mock_get_vector mock_vdb_core = Mock() mock_get_vector_db_core.return_value = mock_vdb_core + # Mock knowledge name map to return empty dict for this test + mock_get_knowledge_map.return_value = {} + from backend.services.tool_configuration_service import _validate_local_tool result = _validate_local_tool( @@ -2365,6 +2376,7 @@ def test_validate_local_tool_knowledge_base_search_success(self, mock_get_vector "vdb_core": mock_vdb_core, "embedding_model": "mock_embedding_model", "rerank_model": None, + "display_name_to_index_map": {}, } mock_tool_class.assert_called_once_with(**expected_params) mock_tool_instance.forward.assert_called_once_with(query="test query") @@ -2372,6 +2384,61 @@ def test_validate_local_tool_knowledge_base_search_success(self, mock_get_vector # Verify service calls mock_get_embedding_model.assert_called_once_with(tenant_id="tenant1") + @patch('backend.services.tool_configuration_service.get_knowledge_name_map_by_index_names') + @patch('backend.services.tool_configuration_service._get_tool_class_by_name') + @patch('backend.services.tool_configuration_service.get_embedding_model') + @patch('backend.services.tool_configuration_service.get_vector_db_core') + def test_validate_local_tool_knowledge_base_search_with_display_name_mapping( + self, mock_get_vector_db_core, mock_get_embedding_model, mock_get_class, mock_get_knowledge_map): + """Test knowledge_base_search tool with display_name_to_index_map parameter""" + mock_tool_class = Mock() + mock_tool_instance = Mock() + mock_tool_instance.forward.return_value = "mapped knowledge result" + mock_tool_class.return_value = mock_tool_instance + mock_get_class.return_value = mock_tool_class + + mock_get_embedding_model.return_value = "mock_embedding_model" + mock_vdb_core = Mock() + mock_get_vector_db_core.return_value = mock_vdb_core + + # Mock the knowledge name map for display_name to index_name mapping + mock_get_knowledge_map.return_value = { + "test_index_1": "Display Knowledge 1", + "test_index_2": "Display Knowledge 2" + } + + from backend.services.tool_configuration_service import _validate_local_tool + + result = _validate_local_tool( + "knowledge_base_search", + {"query": "test query"}, + {"index_names": ["test_index_1", "test_index_2"]}, + "tenant1", + "user1" + ) + + assert result == "mapped knowledge result" + + # Verify tool class was called exactly once + assert mock_tool_class.call_count == 1, f"Expected 1 call, got {mock_tool_class.call_count}" + + # Get the actual call arguments + actual_call = mock_tool_class.call_args + actual_kwargs = actual_call.kwargs if actual_call.kwargs else actual_call[1] + + # Verify each expected parameter + assert actual_kwargs.get("index_names") == ["test_index_1", "test_index_2"] + assert actual_kwargs.get("vdb_core") == mock_vdb_core + assert actual_kwargs.get("embedding_model") == "mock_embedding_model" + assert actual_kwargs.get("rerank_model") is None + assert actual_kwargs.get("display_name_to_index_map") == { + "Display Knowledge 1": "test_index_1", + "Display Knowledge 2": "test_index_2" + } + + # Verify knowledge name map was called with index_names + mock_get_knowledge_map.assert_called_once_with(["test_index_1", "test_index_2"]) + @patch('backend.services.tool_configuration_service._get_tool_class_by_name') @patch('backend.services.tool_configuration_service.get_embedding_model') @patch('backend.services.tool_configuration_service.get_vector_db_core') @@ -2456,6 +2523,7 @@ def test_validate_local_tool_knowledge_base_search_missing_both_ids(self, mock_g assert result == "knowledge base search result" + @patch('backend.services.tool_configuration_service.get_knowledge_name_map_by_index_names') @patch('backend.services.tool_configuration_service._get_tool_class_by_name') @patch('backend.services.tool_configuration_service.inspect.signature') @patch('backend.services.tool_configuration_service.get_embedding_model') @@ -2463,7 +2531,8 @@ def test_validate_local_tool_knowledge_base_search_missing_both_ids(self, mock_g def test_validate_local_tool_knowledge_base_search_empty_knowledge_list(self, mock_get_vector_db_core, mock_get_embedding_model, mock_signature, - mock_get_class): + mock_get_class, + mock_get_knowledge_map): """Test knowledge_base_search tool validation with empty knowledge list""" # Mock tool class mock_tool_class = Mock() @@ -2509,11 +2578,13 @@ def test_validate_local_tool_knowledge_base_search_empty_knowledge_list(self, mo "vdb_core": mock_vdb_core, "embedding_model": "mock_embedding_model", "rerank_model": None, + "display_name_to_index_map": {}, } mock_tool_class.assert_called_once_with(**expected_params) mock_tool_instance.forward.assert_called_once_with(query="test query") + @patch('backend.services.tool_configuration_service.get_knowledge_name_map_by_index_names') @patch('backend.services.tool_configuration_service._get_tool_class_by_name') @patch('backend.services.tool_configuration_service.inspect.signature') @patch('backend.services.tool_configuration_service.get_embedding_model') @@ -2521,7 +2592,8 @@ def test_validate_local_tool_knowledge_base_search_empty_knowledge_list(self, mo def test_validate_local_tool_knowledge_base_search_execution_error(self, mock_get_vector_db_core, mock_get_embedding_model, mock_signature, - mock_get_class): + mock_get_class, + mock_get_knowledge_map): """Test knowledge_base_search tool validation when execution fails""" # Mock tool class mock_tool_class = Mock() diff --git a/test/sdk/core/agents/test_nexent_agent.py b/test/sdk/core/agents/test_nexent_agent.py index 435a336d1..f87d4aa06 100644 --- a/test/sdk/core/agents/test_nexent_agent.py +++ b/test/sdk/core/agents/test_nexent_agent.py @@ -880,6 +880,122 @@ def test_create_local_tool_knowledge_base_search_tool_with_none_defaults(nexent_ assert result == mock_kb_tool_instance +def test_create_local_tool_knowledge_base_with_display_name_map(nexent_agent_instance): + """Test KnowledgeBaseSearchTool creation sets display_name_to_index_map from metadata.""" + mock_kb_tool_class = MagicMock() + mock_kb_tool_instance = MagicMock() + mock_kb_tool_class.return_value = mock_kb_tool_instance + + display_name_map = { + "Knowledge A": "es_index_knowledge_a", + "Knowledge B": "es_index_knowledge_b", + } + + tool_config = ToolConfig( + class_name="KnowledgeBaseSearchTool", + name="knowledge_base_search", + description="desc", + inputs="{}", + output_type="string", + params={"top_k": 10}, + source="local", + metadata={ + "vdb_core": "mock_vdb_core", + "embedding_model": "mock_embedding_model", + "rerank_model": "mock_rerank_model", + "display_name_to_index_map": display_name_map, + }, + ) + + original_value = nexent_agent.__dict__.get("KnowledgeBaseSearchTool") + nexent_agent.__dict__["KnowledgeBaseSearchTool"] = mock_kb_tool_class + + try: + result = nexent_agent_instance.create_local_tool(tool_config) + finally: + if original_value is not None: + nexent_agent.__dict__["KnowledgeBaseSearchTool"] = original_value + elif "KnowledgeBaseSearchTool" in nexent_agent.__dict__: + del nexent_agent.__dict__["KnowledgeBaseSearchTool"] + + # Verify display_name_to_index_map was set correctly from metadata + assert result.display_name_to_index_map == display_name_map + assert result.vdb_core == "mock_vdb_core" + assert result.embedding_model == "mock_embedding_model" + assert result.rerank_model == "mock_rerank_model" + + +def test_create_local_tool_knowledge_base_with_empty_display_name_map(nexent_agent_instance): + """Test KnowledgeBaseSearchTool creation handles empty display_name_to_index_map.""" + mock_kb_tool_class = MagicMock() + mock_kb_tool_instance = MagicMock() + mock_kb_tool_class.return_value = mock_kb_tool_instance + + tool_config = ToolConfig( + class_name="KnowledgeBaseSearchTool", + name="knowledge_base_search", + description="desc", + inputs="{}", + output_type="string", + params={"top_k": 10}, + source="local", + metadata={ + "vdb_core": "mock_vdb_core", + "embedding_model": "mock_embedding_model", + "display_name_to_index_map": {}, + }, + ) + + original_value = nexent_agent.__dict__.get("KnowledgeBaseSearchTool") + nexent_agent.__dict__["KnowledgeBaseSearchTool"] = mock_kb_tool_class + + try: + result = nexent_agent_instance.create_local_tool(tool_config) + finally: + if original_value is not None: + nexent_agent.__dict__["KnowledgeBaseSearchTool"] = original_value + elif "KnowledgeBaseSearchTool" in nexent_agent.__dict__: + del nexent_agent.__dict__["KnowledgeBaseSearchTool"] + + # Verify empty display_name_to_index_map was set + assert result.display_name_to_index_map == {} + + +def test_create_local_tool_knowledge_base_without_metadata(nexent_agent_instance): + """Test KnowledgeBaseSearchTool creation handles missing metadata.""" + mock_kb_tool_class = MagicMock() + mock_kb_tool_instance = MagicMock() + mock_kb_tool_class.return_value = mock_kb_tool_instance + + tool_config = ToolConfig( + class_name="KnowledgeBaseSearchTool", + name="knowledge_base_search", + description="desc", + inputs="{}", + output_type="string", + params={"top_k": 10}, + source="local", + metadata=None, + ) + + original_value = nexent_agent.__dict__.get("KnowledgeBaseSearchTool") + nexent_agent.__dict__["KnowledgeBaseSearchTool"] = mock_kb_tool_class + + try: + result = nexent_agent_instance.create_local_tool(tool_config) + finally: + if original_value is not None: + nexent_agent.__dict__["KnowledgeBaseSearchTool"] = original_value + elif "KnowledgeBaseSearchTool" in nexent_agent.__dict__: + del nexent_agent.__dict__["KnowledgeBaseSearchTool"] + + # Verify defaults were set when metadata is None + assert result.display_name_to_index_map == {} + assert result.vdb_core is None + assert result.embedding_model is None + assert result.rerank_model is None + + def test_create_local_tool_analyze_text_file_tool(nexent_agent_instance): """Test AnalyzeTextFileTool creation injects observer and metadata.""" mock_analyze_tool_class = MagicMock() diff --git a/test/sdk/core/tools/test_knowledge_base_search_tool.py b/test/sdk/core/tools/test_knowledge_base_search_tool.py index ad6c7987b..141ce5ca9 100644 --- a/test/sdk/core/tools/test_knowledge_base_search_tool.py +++ b/test/sdk/core/tools/test_knowledge_base_search_tool.py @@ -40,6 +40,7 @@ def knowledge_base_search_tool(mock_observer, mock_vdb_core, mock_embedding_mode vdb_core=mock_vdb_core, search_mode="hybrid", rerank=False, + display_name_to_index_map={}, ) return tool @@ -395,6 +396,7 @@ def test_forward_with_rerank_enabled(self, mock_observer, mock_vdb_core, mock_em vdb_core=mock_vdb_core, embedding_model=mock_embedding_model, observer=mock_observer, + display_name_to_index_map={}, ) result = tool.forward("test query") @@ -433,6 +435,7 @@ def test_forward_rerank_disabled(self, mock_observer, mock_vdb_core, mock_embedd vdb_core=mock_vdb_core, embedding_model=mock_embedding_model, observer=mock_observer, + display_name_to_index_map={}, ) result = tool.forward("test query") @@ -472,6 +475,7 @@ def test_forward_rerank_error_continues(self, mock_observer, mock_vdb_core, mock vdb_core=mock_vdb_core, embedding_model=mock_embedding_model, observer=mock_observer, + display_name_to_index_map={}, ) # Should not raise, should continue with original results @@ -536,3 +540,488 @@ def test_forward_with_whitespace_in_index_names(self, knowledge_base_search_tool embedding_model=knowledge_base_search_tool.embedding_model, top_k=5 ) + + +class TestConvertToIndexNames: + """Tests for _convert_to_index_names method.""" + + def test_convert_with_empty_map(self, mock_observer, mock_vdb_core, mock_embedding_model): + """Test conversion when display_name_to_index_map is empty.""" + tool = KnowledgeBaseSearchTool( + index_names=["index1", "index2"], + search_mode="hybrid", + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + observer=mock_observer, + display_name_to_index_map={}, + ) + + result = tool._convert_to_index_names(["index1", "index2"]) + + assert result == ["index1", "index2"] + + def test_convert_with_matching_names(self, mock_observer, mock_vdb_core, mock_embedding_model): + """Test conversion when names are in the map.""" + tool = KnowledgeBaseSearchTool( + index_names=[], + search_mode="hybrid", + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + observer=mock_observer, + display_name_to_index_map={ + "Knowledge A": "es_index_knowledge_a", + "Knowledge B": "es_index_knowledge_b", + }, + ) + + result = tool._convert_to_index_names(["Knowledge A", "Knowledge B"]) + + assert result == ["es_index_knowledge_a", "es_index_knowledge_b"] + + def test_convert_with_mixed_names(self, mock_observer, mock_vdb_core, mock_embedding_model): + """Test conversion when some names are in the map and some are not.""" + tool = KnowledgeBaseSearchTool( + index_names=[], + search_mode="hybrid", + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + observer=mock_observer, + display_name_to_index_map={ + "Knowledge A": "es_index_knowledge_a", + }, + ) + + result = tool._convert_to_index_names(["Knowledge A", "raw_index_name"]) + + assert result == ["es_index_knowledge_a", "raw_index_name"] + + def test_convert_with_unmatched_names(self, mock_observer, mock_vdb_core, mock_embedding_model): + """Test conversion when no names are in the map.""" + tool = KnowledgeBaseSearchTool( + index_names=[], + search_mode="hybrid", + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + observer=mock_observer, + display_name_to_index_map={ + "Knowledge A": "es_index_knowledge_a", + }, + ) + + result = tool._convert_to_index_names(["raw_index1", "raw_index2"]) + + assert result == ["raw_index1", "raw_index2"] + + def test_convert_forward_integration(self, mock_observer, mock_vdb_core, mock_embedding_model): + """Test that forward method uses _convert_to_index_names correctly.""" + mock_results = create_mock_search_result(1) + mock_vdb_core.hybrid_search.return_value = mock_results + + tool = KnowledgeBaseSearchTool( + index_names=[], + search_mode="hybrid", + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + observer=mock_observer, + display_name_to_index_map={ + "Knowledge A": "es_index_knowledge_a", + }, + ) + + tool.forward("test query", index_names=["Knowledge A"]) + + mock_vdb_core.hybrid_search.assert_called_once_with( + index_names=["es_index_knowledge_a"], + query_text="test query", + embedding_model=mock_embedding_model, + top_k=3 + ) + + +class TestEffectiveTopK: + """Tests for effective_top_k calculation with rerank.""" + + def test_effective_top_k_increases_with_rerank(self, mock_observer, mock_vdb_core, mock_embedding_model): + """Test that effective_top_k is multiplied when rerank is enabled.""" + from sdk.nexent.core.utils.constants import RERANK_OVERSEARCH_MULTIPLIER + + mock_results = create_mock_search_result(10) + mock_vdb_core.hybrid_search.return_value = mock_results + + tool = KnowledgeBaseSearchTool( + index_names=["kb1"], + search_mode="hybrid", + top_k=5, + rerank=True, + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + observer=mock_observer, + display_name_to_index_map={}, + ) + + tool.forward("test query") + + call_kwargs = mock_vdb_core.hybrid_search.call_args[1] + assert call_kwargs["top_k"] == 5 * RERANK_OVERSEARCH_MULTIPLIER + + def test_effective_top_k_unchanged_without_rerank(self, mock_observer, mock_vdb_core, mock_embedding_model): + """Test that effective_top_k remains the same when rerank is disabled.""" + mock_results = create_mock_search_result(5) + mock_vdb_core.hybrid_search.return_value = mock_results + + tool = KnowledgeBaseSearchTool( + index_names=["kb1"], + search_mode="hybrid", + top_k=5, + rerank=False, + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + observer=mock_observer, + display_name_to_index_map={}, + ) + + tool.forward("test query") + + call_kwargs = mock_vdb_core.hybrid_search.call_args[1] + assert call_kwargs["top_k"] == 5 + + +class TestSourceTypeConversion: + """Tests for source_type conversion (local/minio -> file).""" + + def test_source_type_local_converted_to_file(self, knowledge_base_search_tool, mock_vdb_core): + """Test that source_type 'local' is converted to 'file'.""" + mock_results = [ + { + "document": { + "title": "Local Doc", + "content": "Content from local file", + "filename": "local.txt", + "path_or_url": "/path/local.txt", + "create_time": "2024-01-01T12:00:00Z", + "source_type": "local" + }, + "score": 0.9, + "index": "kb1" + } + ] + mock_vdb_core.hybrid_search.return_value = mock_results + knowledge_base_search_tool.vdb_core = mock_vdb_core + + knowledge_base_search_tool.forward("test query", index_names=["kb1"]) + + # Check the SEARCH_CONTENT message which contains full results via to_dict() + search_content_call = [ + call for call in knowledge_base_search_tool.observer.add_message.call_args_list + if call[0][1] == ProcessType.SEARCH_CONTENT + ][0] + full_results = json.loads(search_content_call[0][2]) + + assert full_results[0]["source_type"] == "file" + + def test_source_type_minio_converted_to_file(self, knowledge_base_search_tool, mock_vdb_core): + """Test that source_type 'minio' is converted to 'file'.""" + mock_results = [ + { + "document": { + "title": "Minio Doc", + "content": "Content from minio storage", + "filename": "minio.txt", + "path_or_url": "/minio/bucket/minio.txt", + "create_time": "2024-01-01T12:00:00Z", + "source_type": "minio" + }, + "score": 0.9, + "index": "kb1" + } + ] + mock_vdb_core.hybrid_search.return_value = mock_results + knowledge_base_search_tool.vdb_core = mock_vdb_core + + knowledge_base_search_tool.forward("test query", index_names=["kb1"]) + + # Check the SEARCH_CONTENT message + search_content_call = [ + call for call in knowledge_base_search_tool.observer.add_message.call_args_list + if call[0][1] == ProcessType.SEARCH_CONTENT + ][0] + full_results = json.loads(search_content_call[0][2]) + + assert full_results[0]["source_type"] == "file" + + def test_source_type_other_unchanged(self, knowledge_base_search_tool, mock_vdb_core): + """Test that source_type other than local/minio remains unchanged.""" + mock_results = [ + { + "document": { + "title": "Web Doc", + "content": "Content from web", + "filename": "web.html", + "path_or_url": "https://example.com/page.html", + "create_time": "2024-01-01T12:00:00Z", + "source_type": "web" + }, + "score": 0.9, + "index": "kb1" + } + ] + mock_vdb_core.hybrid_search.return_value = mock_results + knowledge_base_search_tool.vdb_core = mock_vdb_core + + knowledge_base_search_tool.forward("test query", index_names=["kb1"]) + + # Check the SEARCH_CONTENT message + search_content_call = [ + call for call in knowledge_base_search_tool.observer.add_message.call_args_list + if call[0][1] == ProcessType.SEARCH_CONTENT + ][0] + full_results = json.loads(search_content_call[0][2]) + + assert full_results[0]["source_type"] == "web" + + +class TestRecordOps: + """Tests for record_ops counter functionality.""" + + def test_record_ops_increments_by_result_count(self, knowledge_base_search_tool): + """Test that record_ops increases by the number of results returned.""" + mock_results = create_mock_search_result(2) + knowledge_base_search_tool.vdb_core.hybrid_search.return_value = mock_results + + initial_ops = knowledge_base_search_tool.record_ops + + knowledge_base_search_tool.forward("test query", index_names=["kb1"]) + + assert knowledge_base_search_tool.record_ops == initial_ops + 2 + + def test_record_ops_accumulates_across_calls(self, knowledge_base_search_tool): + """Test that record_ops accumulates across multiple forward calls.""" + mock_results = create_mock_search_result(1) + knowledge_base_search_tool.vdb_core.hybrid_search.return_value = mock_results + + knowledge_base_search_tool.record_ops = 0 + knowledge_base_search_tool.forward("query1", index_names=["kb1"]) + first_call_ops = knowledge_base_search_tool.record_ops + + knowledge_base_search_tool.forward("query2", index_names=["kb1"]) + second_call_ops = knowledge_base_search_tool.record_ops + + # Each call with 1 result adds 1 to record_ops + assert first_call_ops == 1 + assert second_call_ops == 2 + + def test_cite_index_in_results(self, knowledge_base_search_tool): + """Test that cite_index in results starts from record_ops + index + 1.""" + mock_results = create_mock_search_result(2) + knowledge_base_search_tool.vdb_core.hybrid_search.return_value = mock_results + + # record_ops starts at 1, so cite_index should be 1+0+1=1, 1+1+1=2 + knowledge_base_search_tool.forward("test query", index_names=["kb1"]) + + # Check the SEARCH_CONTENT message for cite_index values + search_content_call = [ + call for call in knowledge_base_search_tool.observer.add_message.call_args_list + if call[0][1] == ProcessType.SEARCH_CONTENT + ][0] + full_results = json.loads(search_content_call[0][2]) + + assert full_results[0]["cite_index"] == 1 + assert full_results[1]["cite_index"] == 2 + + +class TestSearchContentObserver: + """Tests for SEARCH_CONTENT observer message.""" + + def test_forward_sends_search_content_to_observer(self, knowledge_base_search_tool): + """Test that forward sends SEARCH_CONTENT message to observer.""" + mock_results = create_mock_search_result(1) + knowledge_base_search_tool.vdb_core.hybrid_search.return_value = mock_results + + knowledge_base_search_tool.forward("test query", index_names=["kb1"]) + + search_content_calls = [ + call for call in knowledge_base_search_tool.observer.add_message.call_args_list + if call[0][1] == ProcessType.SEARCH_CONTENT + ] + + assert len(search_content_calls) == 1 + message = search_content_calls[0][0][2] + parsed = json.loads(message) + assert isinstance(parsed, list) + assert len(parsed) == 1 + + def test_forward_no_search_content_without_observer(self, mock_vdb_core, mock_embedding_model): + """Test that forward works without observer and doesn't send SEARCH_CONTENT.""" + mock_results = create_mock_search_result(1) + mock_vdb_core.hybrid_search.return_value = mock_results + + tool = KnowledgeBaseSearchTool( + index_names=["kb1"], + search_mode="hybrid", + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + observer=None, + display_name_to_index_map={}, + ) + + result = tool.forward("test query") + + assert result is not None + + +class TestToolMetadata: + """Tests for tool metadata attributes.""" + + def test_tool_name(self, knowledge_base_search_tool): + """Test tool name is correctly set.""" + assert knowledge_base_search_tool.name == "knowledge_base_search" + + def test_tool_category(self, knowledge_base_search_tool): + """Test tool category is SEARCH.""" + from sdk.nexent.core.utils.tools_common_message import ToolCategory + assert knowledge_base_search_tool.category == ToolCategory.SEARCH.value + + def test_tool_sign(self, knowledge_base_search_tool): + """Test tool_sign is KNOWLEDGE_BASE.""" + from sdk.nexent.core.utils.tools_common_message import ToolSign + assert knowledge_base_search_tool.tool_sign == ToolSign.KNOWLEDGE_BASE.value + + def test_output_type(self, knowledge_base_search_tool): + """Test output_type is string.""" + assert knowledge_base_search_tool.output_type == "string" + + def test_inputs_contain_required_fields(self): + """Test that inputs dict contains required fields.""" + assert "query" in KnowledgeBaseSearchTool.inputs + assert "index_names" in KnowledgeBaseSearchTool.inputs + assert KnowledgeBaseSearchTool.inputs["query"]["type"] == "string" + assert KnowledgeBaseSearchTool.inputs["index_names"]["type"] == "array" + + def test_running_prompts(self, knowledge_base_search_tool): + """Test running prompts for both languages.""" + assert knowledge_base_search_tool.running_prompt_zh == "知识库检索中..." + assert knowledge_base_search_tool.running_prompt_en == "Searching the knowledge base..." + + +class TestEdgeCases: + """Tests for edge cases and boundary conditions.""" + + def test_forward_with_score_details(self, knowledge_base_search_tool, mock_vdb_core): + """Test forward includes score_details in results via SEARCH_CONTENT.""" + mock_results = [ + { + "document": { + "title": "Doc", + "content": "Content", + "filename": "doc.txt", + "path_or_url": "/path/doc.txt", + "create_time": "2024-01-01T12:00:00Z", + "source_type": "file", + "score_details": {"bm25": 0.5, "knn": 0.4} + }, + "score": 0.9, + "index": "kb1" + } + ] + mock_vdb_core.hybrid_search.return_value = mock_results + knowledge_base_search_tool.vdb_core = mock_vdb_core + + knowledge_base_search_tool.forward("test query", index_names=["kb1"]) + + # Check the SEARCH_CONTENT message which contains full results via to_dict() + search_content_call = [ + call for call in knowledge_base_search_tool.observer.add_message.call_args_list + if call[0][1] == ProcessType.SEARCH_CONTENT + ][0] + full_results = json.loads(search_content_call[0][2]) + + assert "score_details" in full_results[0] + assert full_results[0]["score_details"]["bm25"] == 0.5 + + def test_forward_with_empty_content(self, knowledge_base_search_tool, mock_vdb_core): + """Test forward handles empty content gracefully.""" + mock_results = [ + { + "document": { + "title": "Doc with no content", + "content": "", + "filename": "empty.txt", + "path_or_url": "/path/empty.txt", + "create_time": "2024-01-01T12:00:00Z", + "source_type": "file" + }, + "score": 0.5, + "index": "kb1" + } + ] + mock_vdb_core.hybrid_search.return_value = mock_results + knowledge_base_search_tool.vdb_core = mock_vdb_core + + result = knowledge_base_search_tool.forward("test query", index_names=["kb1"]) + search_results = json.loads(result) + + assert search_results[0]["text"] == "" + + def test_forward_multiple_indices(self, knowledge_base_search_tool, mock_vdb_core): + """Test forward searches across multiple indices.""" + mock_results = [ + { + "document": { + "title": "Doc from index1", + "content": "Content", + "filename": "doc1.txt", + "path_or_url": "/path/doc1.txt", + "create_time": "2024-01-01T12:00:00Z", + "source_type": "file", + }, + "score": 0.9, + "index": "index1" + }, + { + "document": { + "title": "Doc from index2", + "content": "Content", + "filename": "doc2.txt", + "path_or_url": "/path/doc2.txt", + "create_time": "2024-01-01T12:00:00Z", + "source_type": "file", + }, + "score": 0.8, + "index": "index2" + } + ] + mock_vdb_core.hybrid_search.return_value = mock_results + knowledge_base_search_tool.vdb_core = mock_vdb_core + + result = knowledge_base_search_tool.forward("test query", index_names=["index1", "index2"]) + search_results = json.loads(result) + + assert len(search_results) == 2 + + def test_rerank_trims_to_top_k(self, mock_observer, mock_vdb_core, mock_embedding_model): + """Test that rerank results are trimmed to original top_k.""" + mock_results = create_mock_search_result(10) + mock_vdb_core.hybrid_search.return_value = mock_results + + mock_rerank_model = MagicMock() + mock_rerank_model.rerank.return_value = [ + {"index": i, "relevance_score": 0.9 - i * 0.05} + for i in range(10) + ] + + tool = KnowledgeBaseSearchTool( + index_names=["kb1"], + search_mode="hybrid", + top_k=3, + rerank=True, + rerank_model=mock_rerank_model, + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + observer=mock_observer, + display_name_to_index_map={}, + ) + + result = tool.forward("test query") + search_results = json.loads(result) + + assert len(search_results) == 3 From 7104317ae02c474f285af4b0fc5b19123d67f65a Mon Sep 17 00:00:00 2001 From: zhizhi <928570418@qq.com> Date: Fri, 17 Apr 2026 17:21:43 +0800 Subject: [PATCH 2/5] =?UTF-8?q?=E2=9C=A8=20Enhance=20prompt=20generation?= =?UTF-8?q?=20with=20knowledge=20base=20display=20names?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Added `knowledge_base_display_names` to the `GeneratePromptRequest` model to allow frontend-configured names for knowledge bases. - Updated backend functions to utilize these display names, improving few-shot example generation without requiring database lookups. - Modified frontend components to capture and pass knowledge base display names during prompt generation. - Enhanced tests to cover the new functionality and ensure proper integration of knowledge base display names in the prompt generation process. --- backend/apps/prompt_app.py | 3 +- backend/consts/model.py | 2 + backend/prompts/utils/prompt_generate_en.yaml | 9 +- backend/prompts/utils/prompt_generate_zh.yaml | 7 + backend/services/prompt_service.py | 109 ++++++- .../agentConfig/tool/ToolConfigModal.tsx | 12 +- .../agentInfo/AgentGenerateDetail.tsx | 14 + frontend/types/agentConfig.ts | 13 + test/backend/services/test_prompt_service.py | 298 +++++++++++++++++- 9 files changed, 445 insertions(+), 22 deletions(-) diff --git a/backend/apps/prompt_app.py b/backend/apps/prompt_app.py index 7c0b799dc..a9bd8d3a6 100644 --- a/backend/apps/prompt_app.py +++ b/backend/apps/prompt_app.py @@ -29,7 +29,8 @@ async def generate_and_save_system_prompt_api( tenant_id=tenant_id, language=language, tool_ids=prompt_request.tool_ids, - sub_agent_ids=prompt_request.sub_agent_ids + sub_agent_ids=prompt_request.sub_agent_ids, + knowledge_base_display_names=prompt_request.knowledge_base_display_names ), media_type="text/event-stream") except Exception as e: logger.exception(f"Error occurred while generating system prompt: {e}") diff --git a/backend/consts/model.py b/backend/consts/model.py index 2728d95ca..e9d3aae5f 100644 --- a/backend/consts/model.py +++ b/backend/consts/model.py @@ -250,6 +250,8 @@ class GeneratePromptRequest(BaseModel): None, description="Optional: tool IDs from frontend (takes precedence over database query)") sub_agent_ids: Optional[List[int]] = Field( None, description="Optional: sub-agent IDs from frontend (takes precedence over database query)") + knowledge_base_display_names: Optional[List[str]] = Field( + None, description="Optional: knowledge base display names from frontend (takes precedence over database query)") class GenerateTitleRequest(BaseModel): diff --git a/backend/prompts/utils/prompt_generate_en.yaml b/backend/prompts/utils/prompt_generate_en.yaml index 7f55becd3..d2f2291b7 100644 --- a/backend/prompts/utils/prompt_generate_en.yaml +++ b/backend/prompts/utils/prompt_generate_en.yaml @@ -244,7 +244,7 @@ USER_PROMPT: |- {% else %} You have no available tools. {% endif %} - + ### Available Assistants List: {% if assistant_description %} {{assistant_description}} @@ -252,6 +252,13 @@ USER_PROMPT: |- You have no available assistants {% endif %} + {% if knowledge_base_names %} + ### Knowledge Base Configuration Note: + When generating few-shot examples, if using the knowledge_base_search tool, you MUST use the following actual configured knowledge base names: + {{knowledge_base_names}} + Please use these names directly in examples, e.g.: knowledge_base_search(query="xxx", index_names=[{{knowledge_base_names}}]) + {% endif %} + AGENT_NAME_REGENERATE_SYSTEM_PROMPT: |- ### You are an [Agent Variable Name Refinement Expert] diff --git a/backend/prompts/utils/prompt_generate_zh.yaml b/backend/prompts/utils/prompt_generate_zh.yaml index d513bc860..0afc58052 100644 --- a/backend/prompts/utils/prompt_generate_zh.yaml +++ b/backend/prompts/utils/prompt_generate_zh.yaml @@ -249,6 +249,13 @@ USER_PROMPT: |- 你没有可用的助手 {% endif %} + {% if knowledge_base_names %} + ### 知识库配置说明: + 在生成 few-shot 示例时,如果使用 knowledge_base_search 工具,必须使用以下实际配置的知识库名称: + {{knowledge_base_names}} + 请将这些名称直接用于示例中,例如:knowledge_base_search(query="xxx", index_names=[{{knowledge_base_names}}]) + {% endif %} + AGENT_NAME_REGENERATE_SYSTEM_PROMPT: |- ### 你是【Agent变量名调整专家】 diff --git a/backend/services/prompt_service.py b/backend/services/prompt_service.py index 3706c3cc5..535b1ed25 100644 --- a/backend/services/prompt_service.py +++ b/backend/services/prompt_service.py @@ -12,7 +12,8 @@ from consts.exceptions import AppException from database.agent_db import search_agent_info_by_agent_id, query_all_agent_info_by_tenant_id, \ query_sub_agents_id_list -from database.tool_db import query_tools_by_ids +from database.knowledge_db import get_knowledge_name_map_by_index_names +from database.tool_db import query_tools_by_ids, query_tool_instances_by_id from services.agent_service import ( get_enable_tool_id_by_agent_id, _check_agent_name_duplicate, @@ -29,7 +30,7 @@ logger = logging.getLogger("prompt_service") -def gen_system_prompt_streamable(agent_id: int, model_id: int, task_description: str, user_id: str, tenant_id: str, language: str, tool_ids: Optional[List[int]] = None, sub_agent_ids: Optional[List[int]] = None): +def gen_system_prompt_streamable(agent_id: int, model_id: int, task_description: str, user_id: str, tenant_id: str, language: str, tool_ids: Optional[List[int]] = None, sub_agent_ids: Optional[List[int]] = None, knowledge_base_display_names: Optional[List[str]] = None): try: for system_prompt in generate_and_save_system_prompt_impl( agent_id=agent_id, @@ -39,7 +40,8 @@ def gen_system_prompt_streamable(agent_id: int, model_id: int, task_description: tenant_id=tenant_id, language=language, tool_ids=tool_ids, - sub_agent_ids=sub_agent_ids + sub_agent_ids=sub_agent_ids, + knowledge_base_display_names=knowledge_base_display_names ): # SSE format, each message ends with \n\n yield f"data: {json.dumps({'success': True, 'data': system_prompt}, ensure_ascii=False)}\n\n" @@ -63,7 +65,8 @@ def generate_and_save_system_prompt_impl(agent_id: int, tenant_id: str, language: str, tool_ids: Optional[List[int]] = None, - sub_agent_ids: Optional[List[int]] = None): + sub_agent_ids: Optional[List[int]] = None, + knowledge_base_display_names: Optional[List[str]] = None): # Get description of tool and agent from frontend-provided IDs # Frontend always provides tool_ids and sub_agent_ids (could be empty arrays) @@ -77,6 +80,18 @@ def generate_and_save_system_prompt_impl(agent_id: int, tool_info_list = get_enabled_tool_description_for_generate_prompt( tenant_id=tenant_id, agent_id=agent_id) + # Get knowledge base display names for few-shot examples + # Priority: frontend-provided > database query + if knowledge_base_display_names: + logger.debug(f"Using frontend-provided knowledge base display names: {knowledge_base_display_names}") + else: + knowledge_base_display_names = get_knowledge_base_display_names( + tool_info_list=tool_info_list, + agent_id=agent_id, + tenant_id=tenant_id + ) + logger.debug(f"Using database query for knowledge base display names: {knowledge_base_display_names}") + # Handle sub-agent IDs if sub_agent_ids and len(sub_agent_ids) > 0: sub_agent_info_list = [] @@ -114,7 +129,7 @@ def generate_and_save_system_prompt_impl(agent_id: int, # Collect results and yield non-name fields immediately, but hold name fields for duplicate checking for result_data in generate_system_prompt(sub_agent_info_list, task_description, tool_info_list, tenant_id, - model_id, language): + model_id, language, knowledge_base_display_names): result_type = result_data["type"] final_results[result_type] = result_data["content"] @@ -223,7 +238,7 @@ def generate_and_save_system_prompt_impl(agent_id: int, raise Exception("Failed to generate prompt content.") -def generate_system_prompt(sub_agent_info_list, task_description, tool_info_list, tenant_id: str, model_id: int, language: str = LANGUAGE["ZH"]): +def generate_system_prompt(sub_agent_info_list, task_description, tool_info_list, tenant_id: str, model_id: int, language: str = LANGUAGE["ZH"], knowledge_base_display_names: Optional[List[str]] = None): """Main function for generating system prompts""" prompt_for_generate = get_prompt_generate_prompt_template(language) @@ -233,7 +248,8 @@ def generate_system_prompt(sub_agent_info_list, task_description, tool_info_list sub_agent_info_list=sub_agent_info_list, task_description=task_description, tool_info_list=tool_info_list, - language=language + language=language, + knowledge_base_display_names=knowledge_base_display_names ) # Initialize state @@ -352,7 +368,7 @@ def _stream_results(produce_queue, latest, stop_flags, threads, error_holder): last_results[tag] = latest[tag] -def join_info_for_generate_system_prompt(prompt_for_generate, sub_agent_info_list, task_description, tool_info_list, language: str = LANGUAGE["ZH"]): +def join_info_for_generate_system_prompt(prompt_for_generate, sub_agent_info_list, task_description, tool_info_list, language: str = LANGUAGE["ZH"], knowledge_base_display_names: Optional[List[str]] = None): input_label = "Inputs" if language == 'en' else "接受输入" output_label = "Output type" if language == 'en' else "返回输出类型" @@ -361,12 +377,21 @@ def join_info_for_generate_system_prompt(prompt_for_generate, sub_agent_info_lis for tool in tool_info_list]) assistant_description = "\n".join( [f"- {sub_agent_info['name']}: {sub_agent_info['description']}" for sub_agent_info in sub_agent_info_list]) - # Generate content using template - content = Template(prompt_for_generate["USER_PROMPT"], undefined=StrictUndefined).render({ + + # Build template context + template_context = { "task_description": task_description, "tool_description": tool_description, "assistant_description": assistant_description - }) + } + + # Add knowledge base display names for few-shot examples if available + if knowledge_base_display_names: + kb_names_str = ", ".join(f'"{name}"' for name in knowledge_base_display_names) + template_context["knowledge_base_names"] = kb_names_str + + # Generate content using template + content = Template(prompt_for_generate["USER_PROMPT"], undefined=StrictUndefined).render(template_context) return content @@ -379,6 +404,68 @@ def get_enabled_tool_description_for_generate_prompt(agent_id: int, tenant_id: s return tool_info_list +def get_knowledge_base_display_names(tool_info_list: List[dict], agent_id: int, tenant_id: str) -> Optional[List[str]]: + """ + Extract knowledge base display names from tool configurations. + This is used to ensure few-shot examples use actual configured knowledge base names. + + Args: + tool_info_list: List of tool info dictionaries + agent_id: Agent ID for querying tool instances + tenant_id: Tenant ID for database queries + + Returns: + List of knowledge base display names if knowledge_base_search tool is configured, None otherwise + """ + # Check if knowledge_base_search tool is in the list + kb_tool_ids = [tool['tool_id'] for tool in tool_info_list if tool.get('name') == 'knowledge_base_search'] + if not kb_tool_ids: + logger.debug("No knowledge_base_search tool found in tool list") + return None + + # Get the index_names from ToolInstance for knowledge_base_search tool + all_index_names = [] + for kb_tool_id in kb_tool_ids: + try: + tool_instance = query_tool_instances_by_id( + agent_id=agent_id, + tool_id=kb_tool_id, + tenant_id=tenant_id + ) + if tool_instance and tool_instance.get('params', {}).get('index_names'): + index_names = tool_instance['params']['index_names'] + if isinstance(index_names, list): + all_index_names.extend(index_names) + elif isinstance(index_names, str): + # Handle JSON string format + try: + all_index_names.extend(json.loads(index_names)) + except json.JSONDecodeError: + logger.warning(f"Failed to parse index_names JSON: {index_names}") + except Exception as e: + logger.warning(f"Failed to get tool instance for tool_id {kb_tool_id}: {e}") + + if not all_index_names: + logger.debug("No index_names configured for knowledge_base_search tool") + return None + + # Remove duplicates while preserving order + unique_index_names = list(dict.fromkeys(all_index_names)) + + # Convert to display names + knowledge_name_map = get_knowledge_name_map_by_index_names(unique_index_names) + + # Return list of display names (knowledge_name) for each configured index_name + display_names = [] + for index_name in unique_index_names: + display_name = knowledge_name_map.get(index_name, index_name) + if display_name and display_name not in display_names: + display_names.append(display_name) + + logger.debug(f"Converted index_names {unique_index_names} to display_names: {display_names}") + return display_names if display_names else None + + def get_enabled_sub_agent_description_for_generate_prompt(agent_id: int, tenant_id: str): logger.info("Fetching sub-agents information") diff --git a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx index d09a06039..f42355bee 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx @@ -952,8 +952,16 @@ export default function ToolConfigModal({ {} as Record ); - // Update local state: Add tool to selected tools with updated params - const updatedTool = { ...toolToSave, initParams: currentParams }; + // Update local state: Add tool to selected tools with updated params and display_names + // Include display_names for knowledge base tools to pass to prompt generation + const updatedTool: typeof toolToSave = { + ...toolToSave, + initParams: currentParams, + // Store knowledge base display names for prompt generation + ...(toolRequiresKbSelection && selectedKbDisplayNames.length > 0 + ? { display_names: selectedKbDisplayNames } + : {}) + }; const currentTools = useAgentConfigStore.getState().editedAgent.tools; // Check if tool already exists, if so replace it, otherwise add it diff --git a/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx b/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx index 37687e1fb..286e3187c 100644 --- a/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx +++ b/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx @@ -476,6 +476,18 @@ export default function AgentGenerateDetail({ setIsGenerating(true); setActiveTab("few-shots"); + + // Extract knowledge base display names from selected tools + // This allows the backend to use frontend-configured display names without database lookup + const knowledgeBaseDisplayNames: string[] = []; + if (Array.isArray(editedAgent.tools)) { + for (const tool of editedAgent.tools) { + if (typeof tool === "object" && tool.display_names && Array.isArray(tool.display_names)) { + knowledgeBaseDisplayNames.push(...tool.display_names); + } + } + } + try { await generatePromptStream( { @@ -490,6 +502,8 @@ export default function AgentGenerateDetail({ : tool ) : [], + // Pass knowledge base display names from frontend-configured tools + knowledge_base_display_names: knowledgeBaseDisplayNames.length > 0 ? knowledgeBaseDisplayNames : undefined, }, (data) => { // Process streaming response data diff --git a/frontend/types/agentConfig.ts b/frontend/types/agentConfig.ts index 50c128146..7441d5efd 100644 --- a/frontend/types/agentConfig.ts +++ b/frontend/types/agentConfig.ts @@ -75,6 +75,12 @@ export interface Tool { usage?: string; inputs?: string; category?: string; + /** + * Knowledge base display names associated with this tool. + * This is populated when the tool (e.g., knowledge_base_search) has knowledge bases configured. + * Used to pass knowledge base names to prompt generation without requiring database lookup. + */ + display_names?: string[]; } export interface ToolParam { @@ -401,6 +407,13 @@ export interface GeneratePromptParams { model_id: string; tool_ids?: number[]; // Optional: tool IDs selected in frontend (takes precedence over database query) sub_agent_ids?: number[]; // Optional: sub-agent IDs selected in frontend (takes precedence over database query) + /** + * Optional: Knowledge base display names for few-shot examples. + * If provided, the backend will use these instead of querying the database. + * This allows the frontend to pass the latest configured knowledge base names + * without waiting for tool config to be saved first. + */ + knowledge_base_display_names?: string[]; } /** diff --git a/test/backend/services/test_prompt_service.py b/test/backend/services/test_prompt_service.py index 3b33f1a5e..90049d006 100644 --- a/test/backend/services/test_prompt_service.py +++ b/test/backend/services/test_prompt_service.py @@ -114,16 +114,18 @@ def mock_generator(*args, **kwargs): self.assertEqual(call_args[0][1], "Test task") # task_description self.assertEqual(call_args[0][2], [mock_tool1, mock_tool2]) # tool_info_list - @patch('backend.services.prompt_service.generate_system_prompt') @patch('backend.services.prompt_service.query_all_agent_info_by_tenant_id') - @patch('backend.services.prompt_service.get_enabled_sub_agent_description_for_generate_prompt') + @patch('backend.services.prompt_service.generate_system_prompt') @patch('backend.services.prompt_service.get_enabled_tool_description_for_generate_prompt') + @patch('backend.services.prompt_service.get_enabled_sub_agent_description_for_generate_prompt') + @patch('backend.services.prompt_service.get_knowledge_base_display_names') def test_generate_and_save_system_prompt_impl_create_mode( self, - mock_get_enabled_tools, + mock_get_kb_display_names, mock_get_enabled_sub_agents, - mock_query_all_agents, + mock_get_enabled_tools, mock_generate_system_prompt, + mock_query_all_agents, ): """Test generate_and_save_system_prompt_impl in create mode (agent_id=0)""" # Setup - Mock the generator to return the expected data structure @@ -146,6 +148,7 @@ def mock_generator(*args, **kwargs): enabled_sub_agents = [{"name": "db_agent", "description": "DB agent"}] mock_get_enabled_tools.return_value = enabled_tools mock_get_enabled_sub_agents.return_value = enabled_sub_agents + mock_get_kb_display_names.return_value = None # Execute - test as a generator with agent_id=0 (create mode) and empty tool/sub-agent IDs result_gen = generate_and_save_system_prompt_impl( @@ -170,7 +173,8 @@ def mock_generator(*args, **kwargs): enabled_tools, # tool_info_list from helper "tenant456", self.test_model_id, - "zh" + "zh", + None # knowledge_base_display_names ) @patch('backend.services.prompt_service._regenerate_agent_display_name_with_llm') @@ -648,13 +652,14 @@ def mock_llm_call(model_id, content, sys_prompt, callback, tenant_id): # Verify template loading mock_get_prompt_template.assert_called_once_with(mock_language) - # Verify template joining + # Verify template joining - now includes knowledge_base_display_names parameter mock_join_info.assert_called_once_with( prompt_for_generate=mock_prompt_config, sub_agent_info_list=mock_sub_agents, task_description=mock_task_description, tool_info_list=mock_tools, - language=mock_language + language=mock_language, + knowledge_base_display_names=None ) # Verify LLM calls - should be called 6 times for each prompt type @@ -1187,3 +1192,282 @@ def test_join_info_for_generate_system_prompt_empty_tools_and_agents(self, mock_ # Assert self.assertEqual(result, "Rendered content") + @patch('backend.services.prompt_service.Template') + def test_join_info_for_generate_system_prompt_with_knowledge_base_names(self, mock_template): + """Test join_info_for_generate_system_prompt with knowledge_base_display_names""" + # Setup + mock_prompt_for_generate = {"USER_PROMPT": "Test User Prompt"} + mock_sub_agents = [] + mock_task_description = "Test task" + mock_tools = [ + {"name": "knowledge_base_search", "description": "Search knowledge base", + "inputs": "{}", "output_type": "string"} + ] + + mock_template_instance = MagicMock() + mock_template.return_value = mock_template_instance + mock_template_instance.render.return_value = "Rendered content with KB names" + + # Execute with knowledge base display names + result = join_info_for_generate_system_prompt( + mock_prompt_for_generate, mock_sub_agents, mock_task_description, mock_tools, + knowledge_base_display_names=["redis", "kafka"] + ) + + # Assert + self.assertEqual(result, "Rendered content with KB names") + # Verify that knowledge_base_names was passed to template + template_vars = mock_template_instance.render.call_args[0][0] + self.assertIn("knowledge_base_names", template_vars) + self.assertEqual(template_vars["knowledge_base_names"], '"redis", "kafka"') + + @patch('backend.services.prompt_service.Template') + def test_join_info_for_generate_system_prompt_without_knowledge_base_names(self, mock_template): + """Test join_info_for_generate_system_prompt without knowledge_base_display_names""" + # Setup + mock_prompt_for_generate = {"USER_PROMPT": "Test User Prompt"} + mock_sub_agents = [] + mock_task_description = "Test task" + mock_tools = [ + {"name": "web_search", "description": "Web search", + "inputs": "{}", "output_type": "string"} + ] + + mock_template_instance = MagicMock() + mock_template.return_value = mock_template_instance + mock_template_instance.render.return_value = "Rendered content" + + # Execute without knowledge base display names + result = join_info_for_generate_system_prompt( + mock_prompt_for_generate, mock_sub_agents, mock_task_description, mock_tools + ) + + # Assert + template_vars = mock_template_instance.render.call_args[0][0] + # knowledge_base_names should not be in template vars when not provided + self.assertNotIn("knowledge_base_names", template_vars) + + @patch('backend.services.prompt_service.get_knowledge_name_map_by_index_names') + @patch('backend.services.prompt_service.query_tool_instances_by_id') + def test_get_knowledge_base_display_names_with_configured_kb( + self, + mock_query_tool_instance, + mock_get_knowledge_map, + ): + """Test get_knowledge_base_display_names with configured knowledge base""" + from backend.services.prompt_service import get_knowledge_base_display_names + + # Setup + tool_info_list = [ + {"tool_id": 1, "name": "knowledge_base_search"}, + {"tool_id": 2, "name": "web_search"}, + ] + + mock_query_tool_instance.return_value = { + "params": { + "index_names": ["index-1", "index-2"] + } + } + mock_get_knowledge_map.return_value = { + "index-1": "redis", + "index-2": "kafka" + } + + # Execute + result = get_knowledge_base_display_names( + tool_info_list=tool_info_list, + agent_id=123, + tenant_id="tenant-abc" + ) + + # Assert + self.assertEqual(result, ["redis", "kafka"]) + mock_query_tool_instance.assert_called_once_with( + agent_id=123, tool_id=1, tenant_id="tenant-abc" + ) + mock_get_knowledge_map.assert_called_once_with(["index-1", "index-2"]) + + @patch('backend.services.prompt_service.query_tool_instances_by_id') + def test_get_knowledge_base_display_names_no_kb_tool(self, mock_query_tool_instance): + """Test get_knowledge_base_display_names when no knowledge_base_search tool exists""" + from backend.services.prompt_service import get_knowledge_base_display_names + + # Setup - no knowledge_base_search tool + tool_info_list = [ + {"tool_id": 2, "name": "web_search"}, + ] + + # Execute + result = get_knowledge_base_display_names( + tool_info_list=tool_info_list, + agent_id=123, + tenant_id="tenant-abc" + ) + + # Assert + self.assertIsNone(result) + mock_query_tool_instance.assert_not_called() + + @patch('backend.services.prompt_service.get_knowledge_name_map_by_index_names') + @patch('backend.services.prompt_service.query_tool_instances_by_id') + def test_get_knowledge_base_display_names_empty_index_names( + self, + mock_query_tool_instance, + mock_get_knowledge_map, + ): + """Test get_knowledge_base_display_names when index_names is empty""" + from backend.services.prompt_service import get_knowledge_base_display_names + + # Setup + tool_info_list = [ + {"tool_id": 1, "name": "knowledge_base_search"}, + ] + + mock_query_tool_instance.return_value = { + "params": {} + } + + # Execute + result = get_knowledge_base_display_names( + tool_info_list=tool_info_list, + agent_id=123, + tenant_id="tenant-abc" + ) + + # Assert + self.assertIsNone(result) + mock_get_knowledge_map.assert_not_called() + + @patch('backend.services.prompt_service.get_knowledge_name_map_by_index_names') + @patch('backend.services.prompt_service.query_tool_instances_by_id') + def test_get_knowledge_base_display_names_with_json_string( + self, + mock_query_tool_instance, + mock_get_knowledge_map, + ): + """Test get_knowledge_base_display_names when index_names is a JSON string""" + from backend.services.prompt_service import get_knowledge_base_display_names + + # Setup + tool_info_list = [ + {"tool_id": 1, "name": "knowledge_base_search"}, + ] + + mock_query_tool_instance.return_value = { + "params": { + "index_names": '["index-1", "index-2"]' # JSON string format + } + } + mock_get_knowledge_map.return_value = { + "index-1": "redis", + "index-2": "kafka" + } + + # Execute + result = get_knowledge_base_display_names( + tool_info_list=tool_info_list, + agent_id=123, + tenant_id="tenant-abc" + ) + + # Assert + self.assertEqual(result, ["redis", "kafka"]) + + @patch('backend.services.prompt_service.get_knowledge_name_map_by_index_names') + @patch('backend.services.prompt_service.query_tool_instances_by_id') + def test_get_knowledge_base_display_names_multiple_tools( + self, + mock_query_tool_instance, + mock_get_knowledge_map, + ): + """Test get_knowledge_base_display_names with multiple knowledge_base_search tools""" + from backend.services.prompt_service import get_knowledge_base_display_names + + # Setup - two knowledge_base_search tools + tool_info_list = [ + {"tool_id": 1, "name": "knowledge_base_search"}, + {"tool_id": 2, "name": "knowledge_base_search"}, + ] + + mock_query_tool_instance.side_effect = [ + {"params": {"index_names": ["index-1"]}}, + {"params": {"index_names": ["index-2"]}}, + ] + mock_get_knowledge_map.return_value = { + "index-1": "redis", + "index-2": "kafka" + } + + # Execute + result = get_knowledge_base_display_names( + tool_info_list=tool_info_list, + agent_id=123, + tenant_id="tenant-abc" + ) + + # Assert + self.assertEqual(result, ["redis", "kafka"]) + self.assertEqual(mock_query_tool_instance.call_count, 2) + + @patch('backend.services.prompt_service.get_knowledge_name_map_by_index_names') + @patch('backend.services.prompt_service.query_tool_instances_by_id') + def test_get_knowledge_base_display_names_duplicate_index_names( + self, + mock_query_tool_instance, + mock_get_knowledge_map, + ): + """Test get_knowledge_base_display_names handles duplicate index_names""" + from backend.services.prompt_service import get_knowledge_base_display_names + + # Setup + tool_info_list = [ + {"tool_id": 1, "name": "knowledge_base_search"}, + ] + + mock_query_tool_instance.return_value = { + "params": {"index_names": ["index-1", "index-1", "index-2"]} # Duplicates + } + mock_get_knowledge_map.return_value = { + "index-1": "redis", + "index-2": "kafka" + } + + # Execute + result = get_knowledge_base_display_names( + tool_info_list=tool_info_list, + agent_id=123, + tenant_id="tenant-abc" + ) + + # Assert - should deduplicate while preserving order + self.assertEqual(result, ["redis", "kafka"]) + # Should be called with deduplicated list + mock_get_knowledge_map.assert_called_once_with(["index-1", "index-2"]) + + @patch('backend.services.prompt_service.generate_and_save_system_prompt_impl') + def test_gen_system_prompt_streamable_knowledge_base_flow(self, mock_generate_impl): + """Test gen_system_prompt_streamable with knowledge base configuration""" + # Setup + test_data = [ + {"type": "duty", "content": "Test duty", "is_complete": False}, + {"type": "few_shots", "content": 'index_names=["redis", "kafka"]', "is_complete": True}, + ] + mock_generate_impl.return_value = iter(test_data) + + # Execute + result_list = list(gen_system_prompt_streamable( + agent_id=123, + model_id=self.test_model_id, + task_description="Test task with knowledge base", + user_id="user123", + tenant_id="tenant456", + language="zh" + )) + + # Assert + self.assertEqual(len(result_list), 2) + # Verify success format + import json + parsed = json.loads(result_list[0].replace("data: ", "").replace("\n\n", "")) + self.assertTrue(parsed['success']) + From 9d2929e664780a477abde2ea94b0a4de26df94e4 Mon Sep 17 00:00:00 2001 From: zhizhi <928570418@qq.com> Date: Tue, 21 Apr 2026 14:22:22 +0800 Subject: [PATCH 3/5] =?UTF-8?q?=E2=9C=A8=20Enhance=20prompt=20generation?= =?UTF-8?q?=20with=20knowledge=20base=20display=20names=20part2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/backend/services/test_prompt_service.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/backend/services/test_prompt_service.py b/test/backend/services/test_prompt_service.py index 90049d006..58ed33fbb 100644 --- a/test/backend/services/test_prompt_service.py +++ b/test/backend/services/test_prompt_service.py @@ -569,6 +569,7 @@ def test_gen_system_prompt_streamable(self, mock_generate_impl): language="zh", tool_ids=None, sub_agent_ids=None, + knowledge_base_display_names=None, ) # Verify output format - should be SSE format From 65f7dec3f643a1ff85447528261c75b8ac02a2c8 Mon Sep 17 00:00:00 2001 From: zhizhi <928570418@qq.com> Date: Thu, 23 Apr 2026 16:02:23 +0800 Subject: [PATCH 4/5] =?UTF-8?q?=E2=9C=A8=20Add=20display=20name=20to=20ind?= =?UTF-8?q?ex=20name=20mapping=20for=20KnowledgeBaseSearchTool?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/agents/create_agent_info.py | 11 +- .../core/tools/knowledge_base_search_tool.py | 18 +- test/backend/agents/test_create_agent_info.py | 337 +++++++++++++++++- .../tools/test_knowledge_base_search_tool.py | 191 ++++++++++ 4 files changed, 543 insertions(+), 14 deletions(-) diff --git a/backend/agents/create_agent_info.py b/backend/agents/create_agent_info.py index 4bf092165..ca2f25ee8 100644 --- a/backend/agents/create_agent_info.py +++ b/backend/agents/create_agent_info.py @@ -358,11 +358,12 @@ async def create_agent_config( if "KnowledgeBaseSearchTool" == tool.class_name: index_names = tool.params.get("index_names") if index_names: - # Batch query to get display names (knowledge_name) for all index_names - knowledge_name_map = get_knowledge_name_map_by_index_names(index_names) + # Reuse the index_name -> display_name mapping from tool.metadata + # (already computed in create_tool_config_list to avoid redundant DB query) + index_name_to_display_map = tool.metadata.get("index_name_to_display_map", {}) if tool.metadata else {} for index_name in index_names: try: - display_name = knowledge_name_map.get(index_name, index_name) + display_name = index_name_to_display_map.get(index_name, index_name) message = ElasticSearchService().get_summary(index_name=index_name) summary = message.get("summary", "") knowledge_base_summary += f"**{display_name}**: {summary}\n\n" @@ -462,19 +463,23 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int ) # Build display_name to index_name mapping for LLM parameter conversion + # Also build reverse mapping (index_name -> display_name) for knowledge_base_summary index_names = param_dict.get("index_names", []) display_name_to_index_map = {} + index_name_to_display_map = {} if index_names: knowledge_name_map = get_knowledge_name_map_by_index_names(index_names) # Reverse the mapping: display_name (knowledge_name) -> index_name for idx_name, kb_name in knowledge_name_map.items(): display_name_to_index_map[kb_name] = idx_name + index_name_to_display_map[idx_name] = kb_name tool_config.metadata = { "vdb_core": get_vector_db_core(), "embedding_model": get_embedding_model(tenant_id=tenant_id), "rerank_model": rerank_model, "display_name_to_index_map": display_name_to_index_map, + "index_name_to_display_map": index_name_to_display_map, } elif tool_config.class_name in ["DifySearchTool", "DataMateSearchTool"]: rerank = param_dict.get("rerank", False) diff --git a/sdk/nexent/core/tools/knowledge_base_search_tool.py b/sdk/nexent/core/tools/knowledge_base_search_tool.py index dcb4c4cef..e3fb2916c 100644 --- a/sdk/nexent/core/tools/knowledge_base_search_tool.py +++ b/sdk/nexent/core/tools/knowledge_base_search_tool.py @@ -131,20 +131,20 @@ def _convert_to_index_names(self, names: List[str]) -> List[str]: Returns: List of actual index_names for ES queries """ - # Handle FieldInfo case (smolagents doesn't expand Field defaults) display_map = self.display_name_to_index_map if isinstance(display_map, FieldInfo): - display_map = display_map.default + if display_map.default_factory is not None: + display_map = display_map.default_factory() + else: + display_map = display_map.default if not display_map: return names converted_names = [] for name in names: - # If the name is in the map as a display_name, convert it to index_name if name in display_map: converted_names.append(display_map[name]) else: - # Otherwise, assume it's already an index_name converted_names.append(name) return converted_names @@ -177,9 +177,15 @@ def forward(self, query: str, index_names: Optional[List[str]] = None) -> str: effective_top_k = self.top_k is_rerank = self.rerank if isinstance(effective_top_k, FieldInfo): - effective_top_k = effective_top_k.default + if effective_top_k.default_factory is not None: + effective_top_k = effective_top_k.default_factory() + else: + effective_top_k = effective_top_k.default if isinstance(is_rerank, FieldInfo): - is_rerank = is_rerank.default + if is_rerank.default_factory is not None: + is_rerank = is_rerank.default_factory() + else: + is_rerank = is_rerank.default if is_rerank: effective_top_k = effective_top_k * RERANK_OVERSEARCH_MULTIPLIER diff --git a/test/backend/agents/test_create_agent_info.py b/test/backend/agents/test_create_agent_info.py index d59139630..b92f62571 100644 --- a/test/backend/agents/test_create_agent_info.py +++ b/test/backend/agents/test_create_agent_info.py @@ -1910,9 +1910,6 @@ async def test_create_agent_config_with_knowledge_base_summary_filtering(self): patch( "backend.agents.create_agent_info._get_skill_script_tools" ) as mock_get_skill_tools, - patch( - "backend.agents.create_agent_info.get_knowledge_name_map_by_index_names" - ) as mock_get_knowledge_name_map, ): mock_search_agent.return_value = { "name": "test_agent", @@ -1930,6 +1927,9 @@ async def test_create_agent_config_with_knowledge_base_summary_filtering(self): kb_tool_1.class_name = "KnowledgeBaseSearchTool" kb_tool_1.name = "kb_tool_1" kb_tool_1.params = {"index_names": ["idx_a", "idx_b"]} + kb_tool_1.metadata = { + "index_name_to_display_map": {"idx_a": "idx_a", "idx_b": "idx_b"} + } other_tool = Mock() other_tool.class_name = "OtherTool" @@ -1940,6 +1940,9 @@ async def test_create_agent_config_with_knowledge_base_summary_filtering(self): kb_tool_2.class_name = "KnowledgeBaseSearchTool" kb_tool_2.name = "kb_tool_2" kb_tool_2.params = {"index_names": ["idx_c"]} + kb_tool_2.metadata = { + "index_name_to_display_map": {"idx_c": "idx_c"} + } mock_create_tools.return_value = [kb_tool_1, other_tool, kb_tool_2] mock_get_template.return_value = {"system_prompt": "{{ knowledge_base_summary }}"} @@ -1955,8 +1958,6 @@ async def test_create_agent_config_with_knowledge_base_summary_filtering(self): mock_get_model_by_id.return_value = {"display_name": "test_model"} mock_get_skills.return_value = [] mock_get_skill_tools.return_value = [] - # Mock knowledge_name_map to return index_name as fallback - mock_get_knowledge_name_map.return_value = {"idx_a": "idx_a", "idx_b": "idx_b"} mock_es_instance = Mock() mock_es_instance.get_summary.side_effect = [ @@ -1980,6 +1981,214 @@ async def test_create_agent_config_with_knowledge_base_summary_filtering(self): # Ensure only the first KnowledgeBaseSearchTool is processed. assert "idx_c" not in str(mock_es_instance.get_summary.call_args_list) + @pytest.mark.asyncio + async def test_create_agent_config_uses_metadata_index_name_to_display_map(self): + """Test that create_agent_config uses index_name_to_display_map from tool.metadata. + + This test verifies the refactored behavior where create_agent_config + reuses the index_name -> display_name mapping from tool.metadata instead of + making redundant database queries. + """ + with ( + patch( + "backend.agents.create_agent_info.search_agent_info_by_agent_id" + ) as mock_search_agent, + patch( + "backend.agents.create_agent_info.query_sub_agents_id_list" + ) as mock_query_sub, + patch( + "backend.agents.create_agent_info.create_tool_config_list" + ) as mock_create_tools, + patch( + "backend.agents.create_agent_info.get_agent_prompt_template" + ) as mock_get_template, + patch( + "backend.agents.create_agent_info.tenant_config_manager" + ) as mock_tenant_config, + patch( + "backend.agents.create_agent_info.build_memory_context" + ) as mock_build_memory, + patch( + "backend.agents.create_agent_info.ElasticSearchService" + ) as mock_es_service, + patch( + "backend.agents.create_agent_info.prepare_prompt_templates" + ) as mock_prepare_templates, + patch( + "backend.agents.create_agent_info.get_model_by_model_id" + ) as mock_get_model_by_id, + patch( + "backend.agents.create_agent_info._get_skills_for_template" + ) as mock_get_skills, + patch( + "backend.agents.create_agent_info._get_skill_script_tools" + ) as mock_get_skill_tools, + patch( + "backend.agents.create_agent_info.get_knowledge_name_map_by_index_names" + ) as mock_get_knowledge_name_map, + ): + mock_search_agent.return_value = { + "name": "test_agent", + "description": "test description", + "duty_prompt": "test duty", + "constraint_prompt": "test constraint", + "few_shots_prompt": "test few shots", + "max_steps": 5, + "model_id": 123, + "provide_run_summary": True, + } + mock_query_sub.return_value = [] + + # Create a tool with index_name_to_display_map in metadata + kb_tool = Mock() + kb_tool.class_name = "KnowledgeBaseSearchTool" + kb_tool.name = "kb_tool" + kb_tool.params = {"index_names": ["idx1", "idx2"]} + # The tool.metadata contains the index_name -> display_name mapping + kb_tool.metadata = { + "index_name_to_display_map": { + "idx1": "Custom Name 1", + "idx2": "Custom Name 2" + } + } + + mock_create_tools.return_value = [kb_tool] + mock_get_template.return_value = {"system_prompt": "{{ knowledge_base_summary }}"} + mock_tenant_config.get_app_config.side_effect = ["TestApp", "Test Description"] + mock_build_memory.return_value = Mock( + user_config=Mock(memory_switch=False), + memory_config={}, + tenant_id="tenant_1", + user_id="user_1", + agent_id="agent_1", + ) + mock_prepare_templates.return_value = {"system_prompt": "populated_system_prompt"} + mock_get_model_by_id.return_value = {"display_name": "test_model"} + mock_get_skills.return_value = [] + mock_get_skill_tools.return_value = [] + # This should NOT be called when tool.metadata has index_name_to_display_map + mock_get_knowledge_name_map.return_value = {"idx1": "idx1", "idx2": "idx2"} + + mock_es_instance = Mock() + mock_es_instance.get_summary.side_effect = [ + {"summary": "Summary 1"}, + {"summary": "Summary 2"}, + ] + mock_es_service.return_value = mock_es_instance + + await create_agent_config("agent_1", "tenant_1", "user_1", "zh", "test query") + + # Verify ElasticSearchService was called for both indices + assert mock_es_instance.get_summary.call_count == 2 + + # Verify get_knowledge_name_map_by_index_names was NOT called + # because we're using the mapping from tool.metadata + mock_get_knowledge_name_map.assert_not_called() + + # Verify the system prompt uses the display names from metadata + mock_prepare_templates.assert_called_once() + system_prompt = mock_prepare_templates.call_args[1]["system_prompt"] + assert "**Custom Name 1**" in system_prompt + assert "**Custom Name 2**" in system_prompt + assert "idx1" not in system_prompt + assert "idx2" not in system_prompt + + @pytest.mark.asyncio + async def test_create_agent_config_metadata_without_index_name_to_display_map(self): + """Test that create_agent_config handles missing index_name_to_display_map gracefully. + + When tool.metadata exists but doesn't have index_name_to_display_map, + it should fall back to using index_name as display_name. + """ + with ( + patch( + "backend.agents.create_agent_info.search_agent_info_by_agent_id" + ) as mock_search_agent, + patch( + "backend.agents.create_agent_info.query_sub_agents_id_list" + ) as mock_query_sub, + patch( + "backend.agents.create_agent_info.create_tool_config_list" + ) as mock_create_tools, + patch( + "backend.agents.create_agent_info.get_agent_prompt_template" + ) as mock_get_template, + patch( + "backend.agents.create_agent_info.tenant_config_manager" + ) as mock_tenant_config, + patch( + "backend.agents.create_agent_info.build_memory_context" + ) as mock_build_memory, + patch( + "backend.agents.create_agent_info.ElasticSearchService" + ) as mock_es_service, + patch( + "backend.agents.create_agent_info.prepare_prompt_templates" + ) as mock_prepare_templates, + patch( + "backend.agents.create_agent_info.get_model_by_model_id" + ) as mock_get_model_by_id, + patch( + "backend.agents.create_agent_info._get_skills_for_template" + ) as mock_get_skills, + patch( + "backend.agents.create_agent_info._get_skill_script_tools" + ) as mock_get_skill_tools, + patch( + "backend.agents.create_agent_info.get_knowledge_name_map_by_index_names" + ) as mock_get_knowledge_name_map, + ): + mock_search_agent.return_value = { + "name": "test_agent", + "description": "test description", + "duty_prompt": "test duty", + "constraint_prompt": "test constraint", + "few_shots_prompt": "test few shots", + "max_steps": 5, + "model_id": 123, + "provide_run_summary": True, + } + mock_query_sub.return_value = [] + + # Create a tool with empty metadata (no index_name_to_display_map) + kb_tool = Mock() + kb_tool.class_name = "KnowledgeBaseSearchTool" + kb_tool.name = "kb_tool" + kb_tool.params = {"index_names": ["idx1", "idx2"]} + kb_tool.metadata = {} # Empty metadata + + mock_create_tools.return_value = [kb_tool] + mock_get_template.return_value = {"system_prompt": "{{ knowledge_base_summary }}"} + mock_tenant_config.get_app_config.side_effect = ["TestApp", "Test Description"] + mock_build_memory.return_value = Mock( + user_config=Mock(memory_switch=False), + memory_config={}, + tenant_id="tenant_1", + user_id="user_1", + agent_id="agent_1", + ) + mock_prepare_templates.return_value = {"system_prompt": "populated_system_prompt"} + mock_get_model_by_id.return_value = {"display_name": "test_model"} + mock_get_skills.return_value = [] + mock_get_skill_tools.return_value = [] + mock_get_knowledge_name_map.return_value = {} + + mock_es_instance = Mock() + mock_es_instance.get_summary.side_effect = [ + {"summary": "Summary 1"}, + {"summary": "Summary 2"}, + ] + mock_es_service.return_value = mock_es_instance + + await create_agent_config("agent_1", "tenant_1", "user_1", "zh", "test query") + + # When metadata is empty, it should fall back to using index_name + # as the display_name (no mapping available) + mock_prepare_templates.assert_called_once() + system_prompt = mock_prepare_templates.call_args[1]["system_prompt"] + assert "**idx1**" in system_prompt + assert "**idx2**" in system_prompt + @pytest.mark.parametrize( "language,expected_message", [ @@ -3451,6 +3660,124 @@ async def test_knowledge_base_with_partial_name_mapping(self): # Unfound indices will use index_name as fallback (which is not in get_knowledge_name_map result) assert "Knowledge Base 1" in result[0].metadata["display_name_to_index_map"] + @pytest.mark.asyncio + async def test_knowledge_base_with_index_name_to_display_map(self): + """Test that KnowledgeBaseSearchTool gets correct index_name_to_display_map from index_names. + + This test verifies the reverse mapping (index_name -> display_name) that was added + to avoid redundant database queries when building knowledge_base_summary. + """ + mock_tool_instance = MagicMock() + mock_tool_instance.class_name = "KnowledgeBaseSearchTool" + + with patch('backend.agents.create_agent_info.ToolConfig') as mock_tool_config, \ + patch('backend.agents.create_agent_info.discover_langchain_tools', return_value=[]), \ + patch('backend.agents.create_agent_info.search_tools_for_sub_agent') as mock_search_tools, \ + patch('backend.agents.create_agent_info.get_vector_db_core') as mock_get_vector_db_core, \ + patch('backend.agents.create_agent_info.get_embedding_model') as mock_embedding, \ + patch('backend.agents.create_agent_info.get_rerank_model') as mock_rerank, \ + patch('backend.agents.create_agent_info.get_knowledge_name_map_by_index_names') as mock_get_knowledge_map: + + mock_tool_config.return_value = mock_tool_instance + + mock_search_tools.return_value = [ + { + "class_name": "KnowledgeBaseSearchTool", + "name": "knowledge_search", + "description": "Knowledge search tool", + "inputs": "string", + "output_type": "string", + "params": [ + {"name": "index_names", "default": ["idx1", "idx2"]}, + {"name": "rerank", "default": False}, + ], + "source": "local", + "usage": None + } + ] + mock_get_vector_db_core.return_value = "vdb_core_instance" + mock_embedding.return_value = "embedding_instance" + mock_rerank.return_value = None + # Mock the knowledge name map: index_name -> knowledge_name (display_name) + mock_get_knowledge_map.return_value = { + "idx1": "Knowledge Base 1", + "idx2": "Knowledge Base 2" + } + + result = await create_tool_config_list("agent_1", "tenant_1", "user_1") + + assert len(result) == 1 + # Verify display_name_to_index_map (original mapping) + assert result[0].metadata["display_name_to_index_map"] == { + "Knowledge Base 1": "idx1", + "Knowledge Base 2": "idx2" + } + # Verify index_name_to_display_map (new reverse mapping) + assert result[0].metadata["index_name_to_display_map"] == { + "idx1": "Knowledge Base 1", + "idx2": "Knowledge Base 2" + } + # Both maps should be present + assert "display_name_to_index_map" in result[0].metadata + assert "index_name_to_display_map" in result[0].metadata + + @pytest.mark.asyncio + async def test_knowledge_base_with_partial_index_name_mapping(self): + """Test that KnowledgeBaseSearchTool handles partial index_name_to_display_map correctly. + + When some index_names are not found in the database, they should not be + added to the index_name_to_display_map. + """ + mock_tool_instance = MagicMock() + mock_tool_instance.class_name = "KnowledgeBaseSearchTool" + + with patch('backend.agents.create_agent_info.ToolConfig') as mock_tool_config, \ + patch('backend.agents.create_agent_info.discover_langchain_tools', return_value=[]), \ + patch('backend.agents.create_agent_info.search_tools_for_sub_agent') as mock_search_tools, \ + patch('backend.agents.create_agent_info.get_vector_db_core') as mock_get_vector_db_core, \ + patch('backend.agents.create_agent_info.get_embedding_model') as mock_embedding, \ + patch('backend.agents.create_agent_info.get_rerank_model') as mock_rerank, \ + patch('backend.agents.create_agent_info.get_knowledge_name_map_by_index_names') as mock_get_knowledge_map: + + mock_tool_config.return_value = mock_tool_instance + + mock_search_tools.return_value = [ + { + "class_name": "KnowledgeBaseSearchTool", + "name": "knowledge_search", + "description": "Knowledge search tool", + "inputs": "string", + "output_type": "string", + "params": [ + {"name": "index_names", "default": ["idx1", "idx2", "idx3"]}, + {"name": "rerank", "default": False}, + ], + "source": "local", + "usage": None + } + ] + mock_get_vector_db_core.return_value = "vdb_core_instance" + mock_embedding.return_value = "embedding_instance" + mock_rerank.return_value = None + # Only idx1 and idx2 are found, idx3 is not in the database + mock_get_knowledge_map.return_value = { + "idx1": "Knowledge Base 1", + "idx2": "Knowledge Base 2" + } + + result = await create_tool_config_list("agent_1", "tenant_1", "user_1") + + # Verify both mappings contain only found entries + assert "idx1" in result[0].metadata["index_name_to_display_map"] + assert "idx2" in result[0].metadata["index_name_to_display_map"] + # idx3 was not found, so it should not be in the map + assert "idx3" not in result[0].metadata["index_name_to_display_map"] + + # Verify reverse mapping also contains only found entries + assert "Knowledge Base 1" in result[0].metadata["display_name_to_index_map"] + assert "Knowledge Base 2" in result[0].metadata["display_name_to_index_map"] + assert "idx3" not in result[0].metadata["display_name_to_index_map"] + class TestFilterMcpServersAndTools: """Tests for filter_mcp_servers_and_tools function""" diff --git a/test/sdk/core/tools/test_knowledge_base_search_tool.py b/test/sdk/core/tools/test_knowledge_base_search_tool.py index 141ce5ca9..bcfeaddc4 100644 --- a/test/sdk/core/tools/test_knowledge_base_search_tool.py +++ b/test/sdk/core/tools/test_knowledge_base_search_tool.py @@ -1025,3 +1025,194 @@ def test_rerank_trims_to_top_k(self, mock_observer, mock_vdb_core, mock_embeddin search_results = json.loads(result) assert len(search_results) == 3 + + +class TestFieldInfoDefaultFactory: + """Tests for FieldInfo default_factory handling. + + smolagents Tool may not properly expand Field defaults, so the code + handles FieldInfo objects with both .default and .default_factory attributes. + These tests verify the correct handling of both cases. + """ + + def test_convert_to_index_names_with_fieldinfo_default_factory(self, mock_observer, mock_vdb_core, mock_embedding_model): + """Test _convert_to_index_names handles FieldInfo with default_factory correctly.""" + try: + from pydantic import FieldInfo + except ImportError: + from pydantic.fields import FieldInfo + + # Create a FieldInfo with default_factory only (Pydantic doesn't allow both) + field_info_with_factory = FieldInfo( + default_factory=lambda: {"Knowledge X": "es_index_x", "Knowledge Y": "es_index_y"} + ) + + tool = KnowledgeBaseSearchTool( + index_names=[], + search_mode="hybrid", + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + observer=mock_observer, + display_name_to_index_map=field_info_with_factory, + ) + + result = tool._convert_to_index_names(["Knowledge X", "Knowledge Y"]) + + # Should convert using the factory result + assert result == ["es_index_x", "es_index_y"] + + def test_convert_to_index_names_with_fieldinfo_default_only(self, mock_observer, mock_vdb_core, mock_embedding_model): + """Test _convert_to_index_names handles FieldInfo with only default correctly.""" + try: + from pydantic import FieldInfo + except ImportError: + from pydantic.fields import FieldInfo + + # Create a FieldInfo with default only (no factory) + field_info_with_default = FieldInfo( + default={"Knowledge A": "es_index_a"} + ) + + tool = KnowledgeBaseSearchTool( + index_names=[], + search_mode="hybrid", + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + observer=mock_observer, + display_name_to_index_map=field_info_with_default, + ) + + result = tool._convert_to_index_names(["Knowledge A"]) + + # Should convert using the default value + assert result == ["es_index_a"] + + def test_forward_with_fieldinfo_top_k_default_factory(self, mock_observer, mock_vdb_core, mock_embedding_model): + """Test forward handles FieldInfo top_k with default_factory correctly.""" + try: + from pydantic import FieldInfo + except ImportError: + from pydantic.fields import FieldInfo + + mock_results = create_mock_search_result(3) + mock_vdb_core.hybrid_search.return_value = mock_results + + # Create FieldInfo with default_factory only (Pydantic doesn't allow both) + field_info_top_k = FieldInfo( + default_factory=lambda: 5 + ) + + tool = KnowledgeBaseSearchTool( + index_names=["kb1"], + search_mode="hybrid", + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + observer=mock_observer, + display_name_to_index_map={}, + ) + # Override top_k with FieldInfo + tool.top_k = field_info_top_k + + result = tool.forward("test query") + + # Should use the factory result (5) for top_k + call_kwargs = mock_vdb_core.hybrid_search.call_args[1] + assert call_kwargs["top_k"] == 5 + + def test_forward_with_fieldinfo_rerank_default_factory(self, mock_observer, mock_vdb_core, mock_embedding_model): + """Test forward handles FieldInfo rerank with default_factory correctly.""" + try: + from pydantic import FieldInfo + except ImportError: + from pydantic.fields import FieldInfo + + mock_results = create_mock_search_result(10) + mock_vdb_core.hybrid_search.return_value = mock_results + + # Create FieldInfo with default_factory only (Pydantic doesn't allow both) + field_info_rerank = FieldInfo( + default_factory=lambda: True + ) + + tool = KnowledgeBaseSearchTool( + index_names=["kb1"], + search_mode="hybrid", + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + observer=mock_observer, + display_name_to_index_map={}, + ) + # Override rerank with FieldInfo + tool.rerank = field_info_rerank + + from sdk.nexent.core.utils.constants import RERANK_OVERSEARCH_MULTIPLIER + + result = tool.forward("test query") + + # Should use the factory result (True) and multiply top_k + call_kwargs = mock_vdb_core.hybrid_search.call_args[1] + # top_k from default is 3, multiplied by RERANK_OVERSEARCH_MULTIPLIER + assert call_kwargs["top_k"] == 3 * RERANK_OVERSEARCH_MULTIPLIER + + def test_forward_with_fieldinfo_top_k_default_only(self, mock_observer, mock_vdb_core, mock_embedding_model): + """Test forward handles FieldInfo top_k with only default correctly.""" + try: + from pydantic import FieldInfo + except ImportError: + from pydantic.fields import FieldInfo + + mock_results = create_mock_search_result(5) + mock_vdb_core.hybrid_search.return_value = mock_results + + # Create FieldInfo with default only (no factory) + field_info_top_k = FieldInfo(default=10) + + tool = KnowledgeBaseSearchTool( + index_names=["kb1"], + search_mode="hybrid", + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + observer=mock_observer, + display_name_to_index_map={}, + ) + # Override top_k with FieldInfo + tool.top_k = field_info_top_k + + result = tool.forward("test query") + + # Should use the default value (10) + call_kwargs = mock_vdb_core.hybrid_search.call_args[1] + assert call_kwargs["top_k"] == 10 + + def test_forward_with_fieldinfo_rerank_default_only(self, mock_observer, mock_vdb_core, mock_embedding_model): + """Test forward handles FieldInfo rerank with only default correctly.""" + try: + from pydantic import FieldInfo + except ImportError: + from pydantic.fields import FieldInfo + + mock_results = create_mock_search_result(5) + mock_vdb_core.hybrid_search.return_value = mock_results + + # Create FieldInfo with default only (no factory) + field_info_rerank = FieldInfo(default=True) + + tool = KnowledgeBaseSearchTool( + index_names=["kb1"], + search_mode="hybrid", + vdb_core=mock_vdb_core, + embedding_model=mock_embedding_model, + observer=mock_observer, + display_name_to_index_map={}, + ) + # Override rerank with FieldInfo + tool.rerank = field_info_rerank + + from sdk.nexent.core.utils.constants import RERANK_OVERSEARCH_MULTIPLIER + + result = tool.forward("test query") + + # Should use the default value (True) and multiply top_k + call_kwargs = mock_vdb_core.hybrid_search.call_args[1] + # top_k from default is 3, multiplied by RERANK_OVERSEARCH_MULTIPLIER + assert call_kwargs["top_k"] == 3 * RERANK_OVERSEARCH_MULTIPLIER From 66dd82a7fc5f2e8c81bc22b9bd49bfcbdb9ef839 Mon Sep 17 00:00:00 2001 From: zhizhi <928570418@qq.com> Date: Sat, 25 Apr 2026 14:49:01 +0800 Subject: [PATCH 5/5] =?UTF-8?q?=F0=9F=90=9B=20Bugfix:=20Enhance=20prompt?= =?UTF-8?q?=20generation=20with=20knowledge=20base=20display=20names=20par?= =?UTF-8?q?t2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/backend/services/test_prompt_service.py | 38 ++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/test/backend/services/test_prompt_service.py b/test/backend/services/test_prompt_service.py index 58ed33fbb..e6dc8aee3 100644 --- a/test/backend/services/test_prompt_service.py +++ b/test/backend/services/test_prompt_service.py @@ -1445,6 +1445,44 @@ def test_get_knowledge_base_display_names_duplicate_index_names( # Should be called with deduplicated list mock_get_knowledge_map.assert_called_once_with(["index-1", "index-2"]) + @patch('backend.services.prompt_service.get_knowledge_name_map_by_index_names') + @patch('backend.services.prompt_service.query_tool_instances_by_id') + def test_get_knowledge_base_display_names_query_tool_instance_exception( + self, + mock_query_tool_instance, + mock_get_knowledge_map, + ): + """Test get_knowledge_base_display_names handles query_tool_instances_by_id exception gracefully (lines 445-446)""" + from backend.services.prompt_service import get_knowledge_base_display_names + + # Setup - two knowledge_base_search tools + tool_info_list = [ + {"tool_id": 1, "name": "knowledge_base_search"}, + {"tool_id": 2, "name": "knowledge_base_search"}, + ] + + # First tool instance query fails with exception + mock_query_tool_instance.side_effect = [ + Exception("Database connection error"), + {"params": {"index_names": ["index-2"]}}, # Second tool succeeds + ] + mock_get_knowledge_map.return_value = { + "index-2": "kafka" + } + + # Execute - should handle exception gracefully and continue processing + result = get_knowledge_base_display_names( + tool_info_list=tool_info_list, + agent_id=123, + tenant_id="tenant-abc" + ) + + # Assert - should still return results from the tool that succeeded + self.assertEqual(result, ["kafka"]) + # Should have tried both tools + self.assertEqual(mock_query_tool_instance.call_count, 2) + mock_get_knowledge_map.assert_called_once_with(["index-2"]) + @patch('backend.services.prompt_service.generate_and_save_system_prompt_impl') def test_gen_system_prompt_streamable_knowledge_base_flow(self, mock_generate_impl): """Test gen_system_prompt_streamable with knowledge base configuration"""