Skip to content

Commit

Permalink
checkin
Browse files Browse the repository at this point in the history
  • Loading branch information
emrgnt-cmplxty committed Jun 18, 2024
1 parent eadc193 commit ab3cc84
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 23 deletions.
6 changes: 2 additions & 4 deletions r2r/core/pipeline/rag_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,15 @@ async def multi_query_generator(input):
yield (query, await task)

rag_results = await self.rag_pipeline.run(
input=multi_query_generator(
input
), # to_async_generator([(input, search_results)]),
input=multi_query_generator(input),
state=state,
streaming=streaming,
run_manager=run_manager,
rag_generation_config=rag_generation_config,
*args,
**kwargs,
)
print("rag_results = ", rag_results)
return rag_results

def add_pipe(
self,
Expand Down
61 changes: 54 additions & 7 deletions r2r/examples/quickstart.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
AnalysisTypes,
Document,
FilterCriteria,
KGSearchSettings,
R2RAppBuilder,
R2RClient,
R2RConfig,
VectorSearchSettings,
generate_id_from_label,
)
from r2r.core.abstractions.llm import GenerationConfig
Expand Down Expand Up @@ -247,24 +249,69 @@ def update_as_files(self, file_tuples: Optional[list[tuple]] = None):
print(f"Time taken to update files: {t1-t0:.2f} seconds")
print(response)

def search(self, query: str, do_hybrid_search: bool = False):
def search(
self,
query: str,
use_vector_search: bool = True,
search_filters: Optional[str] = None,
search_limit: int = 10,
do_hybrid_search: bool = False,
use_kg: bool = False,
agent_generation_config: Optional[str] = None,
):
search_filters_dict = {}
if search_filters:
search_filters_dict = dict(
item.split("=") for item in search_filters.split(",")
)

vector_settings = VectorSearchSettings(
use_vector_search=use_vector_search,
search_filters=search_filters_dict,
search_limit=search_limit,
do_hybrid_search=do_hybrid_search,
)

agent_gen_config = {}
if agent_generation_config:
agent_gen_config = dict(
item.split("=") for item in agent_generation_config.split(",")
)

kg_settings = KGSearchSettings(
use_kg=use_kg,
agent_generation_config=(
GenerationConfig(**agent_gen_config)
if agent_generation_config
else GenerationConfig(model="gpt-4o")
),
)

t0 = time.time()
if hasattr(self, "client"):
results = self.client.search(
query,
search_filters={"user_id": self.user_id},
do_hybrid_search=do_hybrid_search,
vector_search_settings=vector_settings,
kg_search_settings=kg_settings,
)
else:
results = self.r2r.search(
query,
search_filters={"user_id": self.user_id},
do_hybrid_search=do_hybrid_search,
query=query,
vector_search_settings=vector_settings,
kg_search_settings=kg_settings,
)

if "vector_search_results" in results["results"]:
print("Vector search results:")
for result in results["results"]["vector_search_results"]:
print(result)
if "kg_search_results" in results["results"]:
print(
"KG search results:", results["results"]["kg_search_results"]
)

t1 = time.time()
print(f"Time taken to search: {t1-t0:.2f} seconds")
print("Results:", results)

def rag(
self,
Expand Down
4 changes: 2 additions & 2 deletions r2r/main/r2r_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,8 +928,8 @@ async def asearch(

results = await self.search_pipeline.run(
input=to_async_generator([query]),
vector_settings=vector_search_settings,
kg_settings=kg_search_settings,
vector_search_settings=vector_search_settings,
kg_search_settings=kg_search_settings,
run_manager=self.run_manager,
)

Expand Down
19 changes: 9 additions & 10 deletions r2r/main/r2r_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,22 +128,21 @@ def update_files(
response.raise_for_status()
return response.json()

async def search(
def search(
self,
query: str,
vector_settings: VectorSearchSettings,
kg_settings: KGSearchSettings,
vector_search_settings: VectorSearchSettings,
kg_search_settings: KGSearchSettings,
) -> dict:
url = f"{self.base_url}/search_app"
url = f"{self.base_url}/search"
search_request = SearchRequest(
query=query,
vector_settings=vector_settings,
kg_settings=kg_settings,
vector_settings=vector_search_settings,
kg_settings=kg_search_settings,
)
async with httpx.AsyncClient() as client:
response = await client.post(url, json=search_request.dict())
response.raise_for_status()
return response.json()
response = requests.post(url, json=search_request.dict())
response.raise_for_status()
return response.json()

def rag(
self,
Expand Down

0 comments on commit ab3cc84

Please sign in to comment.