diff --git a/r2r/core/abstractions/search.py b/r2r/core/abstractions/search.py index f9dea53c..2b558fd0 100644 --- a/r2r/core/abstractions/search.py +++ b/r2r/core/abstractions/search.py @@ -68,6 +68,7 @@ def dict(self) -> dict: class VectorSearchSettings(BaseModel): + use_vector_search: bool = True search_filters: Optional[dict[str, Any]] = Field(default_factory=dict) search_limit: int = 10 do_hybrid_search: bool = False diff --git a/r2r/core/pipeline/search_pipeline.py b/r2r/core/pipeline/search_pipeline.py index af443292..e30dd322 100644 --- a/r2r/core/pipeline/search_pipeline.py +++ b/r2r/core/pipeline/search_pipeline.py @@ -43,7 +43,13 @@ async def run( **kwargs: Any, ): self.state = state or AsyncState() - + do_vector_search = ( + self.vector_search_pipeline is not None + and vector_search_settings.use_vector_search + ) + do_kg = ( + self.kg_search_pipeline is not None and kg_search_settings.use_kg + ) async with manage_run(run_manager, self.pipeline_type): await run_manager.log_run_info( key="pipeline_type", @@ -56,9 +62,9 @@ async def run( async def enqueue_requests(): async for message in input: - if self.vector_search_pipeline: + if do_vector_search: await vector_search_queue.put(message) - if self.kg_search_pipeline: + if do_kg: await kg_queue.put(message) await vector_search_queue.put(None) @@ -68,7 +74,7 @@ async def enqueue_requests(): enqueue_task = asyncio.create_task(enqueue_requests()) # Start the embedding and KG pipelines in parallel - if self.vector_search_pipeline: + if do_vector_search: vector_search_task = asyncio.create_task( self.vector_search_pipeline.run( dequeue_requests(vector_search_queue), @@ -79,7 +85,7 @@ async def enqueue_requests(): ) ) - if self.kg_search_pipeline: + if do_kg: kg_task = asyncio.create_task( self.kg_search_pipeline.run( dequeue_requests(kg_queue), @@ -93,9 +99,9 @@ async def enqueue_requests(): await enqueue_task vector_search_results = ( - await vector_search_task if self.vector_search_pipeline else None + await vector_search_task if do_vector_search else None ) - kg_results = await kg_task if self.kg_search_pipeline else None + kg_results = await kg_task if do_kg else None return AggregateSearchResult( vector_search_results=vector_search_results, diff --git a/r2r/main/r2r_app.py b/r2r/main/r2r_app.py index 84650743..0e620d15 100644 --- a/r2r/main/r2r_app.py +++ b/r2r/main/r2r_app.py @@ -1018,8 +1018,8 @@ async def stream_response(): input=to_async_generator([message]), streaming=False, run_manager=self.run_manager, - vector_settings=vector_search_settings, - kg_settings=kg_search_settings, + vector_search_settings=vector_search_settings, + kg_search_settings=kg_search_settings, rag_generation_config=rag_generation_config, ) diff --git a/r2r/pipes/search_rag_pipe.py b/r2r/pipes/search_rag_pipe.py index 5d3c6996..5834d4a4 100644 --- a/r2r/pipes/search_rag_pipe.py +++ b/r2r/pipes/search_rag_pipe.py @@ -105,19 +105,21 @@ async def _collect_context( total_results: int, ) -> Tuple[str, int]: context = f"Query:\n{query}\n\n" - context += f"Vector Search Results({iteration}):\n" - it = total_results + 1 - for result in results.vector_search_results: - context += f"[{it}]: {result.metadata['text']}\n\n" - it += 1 - total_results = ( - it - 1 - ) # Update total_results based on the last index used - context += f"Knowledge Graph Search Results({iteration}):\n" - for result in results.kg_search_results: - context += f"[{it}]: {result}\n\n" - it += 1 - total_results = ( - it - 1 - ) # Update total_results based on the last index used + if results.vector_search_results: + context += f"Vector Search Results({iteration}):\n" + it = total_results + 1 + for result in results.vector_search_results: + context += f"[{it}]: {result.metadata['text']}\n\n" + it += 1 + total_results = ( + it - 1 + ) # Update total_results based on the last index used + if results.kg_search_results: + context += f"Knowledge Graph Search Results({iteration}):\n" + for result in results.kg_search_results: + context += f"[{it}]: {result}\n\n" + it += 1 + total_results = ( + it - 1 + ) # Update total_results based on the last index used return context, total_results