Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
emrgnt-cmplxty committed Jun 18, 2024
1 parent bfa6692 commit 7449d1c
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 53 deletions.
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

60 changes: 9 additions & 51 deletions r2r/main/r2r_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,79 +149,37 @@ def rag(
message=message,
vector_settings=vector_search_settings,
kg_settings=kg_search_settings,
rag_generation_config=rag_generation_config,
)

if streaming:
return self._stream_rag_sync(
message=message,
vector_search_settings=vector_search_settings,
kg_search_settings=kg_search_settings,
rag_generation_config=rag_generation_config,
)
return self._stream_rag_sync(rag_request)
else:
try:
url = f"{self.base_url}/rag"
data = {
"message": message,
"search_filters": (
json.dumps(search_filters) if search_filters else None
),
"search_limit": search_limit,
"rag_generation_config": (
json.dumps(rag_generation_config)
if rag_generation_config
else None
),
"streaming": streaming,
}

response = requests.post(url, json=data)
response = requests.post(url, json=rag_request.dict())
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
raise e

async def _stream_rag(
self,
message: str,
search_filters: Optional[dict] = None,
search_limit: int = 10,
rag_generation_config: Optional[dict] = None,
self, rag_request: R2RRAGRequest
) -> AsyncGenerator[str, None]:
url = f"{self.base_url}/rag"
data = {
"message": message,
"search_filters": (
json.dumps(search_filters) if search_filters else None
),
"search_limit": search_limit,
"rag_generation_config": (
json.dumps(rag_generation_config)
if rag_generation_config
else None
),
"streaming": True,
}
async with httpx.AsyncClient() as client:
async with client.stream("POST", url, json=data) as response:
async with client.stream(
"POST", url, json=rag_request.dict()
) as response:
response.raise_for_status()
async for chunk in response.aiter_text():
yield chunk

def _stream_rag_sync(
self,
message: str,
vector_search_settings: VectorSearchSettings,
kg_search_settings: KGSearchSettings,
rag_generation_config: Optional[GenerationConfig] = None,
self, rag_request: R2RRAGRequest
) -> Generator[str, None, None]:
async def run_async_generator():
async for chunk in self._stream_rag(
message=message,
vector_search_settings=vector_search_settings,
kg_search_settings=kg_search_settings,
rag_generation_config=rag_generation_config,
):
async for chunk in self._stream_rag(rag_request):
yield chunk

loop = asyncio.new_event_loop()
Expand Down
1 change: 0 additions & 1 deletion r2r/providers/vector_dbs/pgvector/pgvector_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import logging
import os
import time
import uuid
from typing import Optional, Union

from sqlalchemy import exc, text
Expand Down

0 comments on commit 7449d1c

Please sign in to comment.