Skip to content

Commit

Permalink
make sure we can configure kg
Browse files Browse the repository at this point in the history
  • Loading branch information
emrgnt-cmplxty committed Jun 17, 2024
1 parent fcfb346 commit 897dc3f
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 24 deletions.
1 change: 1 addition & 0 deletions r2r/core/abstractions/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def dict(self) -> dict:


class VectorSearchSettings(BaseModel):
use_vector_search: bool = True
search_filters: Optional[dict[str, Any]] = Field(default_factory=dict)
search_limit: int = 10
do_hybrid_search: bool = False
Expand Down
20 changes: 13 additions & 7 deletions r2r/core/pipeline/search_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,13 @@ async def run(
**kwargs: Any,
):
self.state = state or AsyncState()

do_vector_search = (
self.vector_search_pipeline is not None
and vector_search_settings.use_vector_search
)
do_kg = (
self.kg_search_pipeline is not None and kg_search_settings.use_kg
)
async with manage_run(run_manager, self.pipeline_type):
await run_manager.log_run_info(
key="pipeline_type",
Expand All @@ -56,9 +62,9 @@ async def run(

async def enqueue_requests():
async for message in input:
if self.vector_search_pipeline:
if do_vector_search:
await vector_search_queue.put(message)
if self.kg_search_pipeline:
if do_kg:
await kg_queue.put(message)

await vector_search_queue.put(None)
Expand All @@ -68,7 +74,7 @@ async def enqueue_requests():
enqueue_task = asyncio.create_task(enqueue_requests())

# Start the embedding and KG pipelines in parallel
if self.vector_search_pipeline:
if do_vector_search:
vector_search_task = asyncio.create_task(
self.vector_search_pipeline.run(
dequeue_requests(vector_search_queue),
Expand All @@ -79,7 +85,7 @@ async def enqueue_requests():
)
)

if self.kg_search_pipeline:
if do_kg:
kg_task = asyncio.create_task(
self.kg_search_pipeline.run(
dequeue_requests(kg_queue),
Expand All @@ -93,9 +99,9 @@ async def enqueue_requests():
await enqueue_task

vector_search_results = (
await vector_search_task if self.vector_search_pipeline else None
await vector_search_task if do_vector_search else None
)
kg_results = await kg_task if self.kg_search_pipeline else None
kg_results = await kg_task if do_kg else None

return AggregateSearchResult(
vector_search_results=vector_search_results,
Expand Down
4 changes: 2 additions & 2 deletions r2r/main/r2r_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,8 +1018,8 @@ async def stream_response():
input=to_async_generator([message]),
streaming=False,
run_manager=self.run_manager,
vector_settings=vector_search_settings,
kg_settings=kg_search_settings,
vector_search_settings=vector_search_settings,
kg_search_settings=kg_search_settings,
rag_generation_config=rag_generation_config,
)

Expand Down
32 changes: 17 additions & 15 deletions r2r/pipes/search_rag_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,19 +105,21 @@ async def _collect_context(
total_results: int,
) -> Tuple[str, int]:
context = f"Query:\n{query}\n\n"
context += f"Vector Search Results({iteration}):\n"
it = total_results + 1
for result in results.vector_search_results:
context += f"[{it}]: {result.metadata['text']}\n\n"
it += 1
total_results = (
it - 1
) # Update total_results based on the last index used
context += f"Knowledge Graph Search Results({iteration}):\n"
for result in results.kg_search_results:
context += f"[{it}]: {result}\n\n"
it += 1
total_results = (
it - 1
) # Update total_results based on the last index used
if results.vector_search_results:
context += f"Vector Search Results({iteration}):\n"
it = total_results + 1
for result in results.vector_search_results:
context += f"[{it}]: {result.metadata['text']}\n\n"
it += 1
total_results = (
it - 1
) # Update total_results based on the last index used
if results.kg_search_results:
context += f"Knowledge Graph Search Results({iteration}):\n"
for result in results.kg_search_results:
context += f"[{it}]: {result}\n\n"
it += 1
total_results = (
it - 1
) # Update total_results based on the last index used
return context, total_results

0 comments on commit 897dc3f

Please sign in to comment.