diff --git a/poetry.lock b/poetry.lock index 4d6f3de4..fee627b4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3946,4 +3946,4 @@ local-embedding = ["sentence-transformers"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "b84fa623b6fa140d03748bf787f988e615ee91bfd7b114a2a32a764ee88bc2d9" \ No newline at end of file +content-hash = "b84fa623b6fa140d03748bf787f988e615ee91bfd7b114a2a32a764ee88bc2d9" diff --git a/r2r/main/r2r_client.py b/r2r/main/r2r_client.py index cac1f484..39c33b31 100644 --- a/r2r/main/r2r_client.py +++ b/r2r/main/r2r_client.py @@ -149,79 +149,37 @@ def rag( message=message, vector_settings=vector_search_settings, kg_settings=kg_search_settings, + rag_generation_config=rag_generation_config, ) if streaming: - return self._stream_rag_sync( - message=message, - vector_search_settings=vector_search_settings, - kg_search_settings=kg_search_settings, - rag_generation_config=rag_generation_config, - ) + return self._stream_rag_sync(rag_request) else: try: url = f"{self.base_url}/rag" - data = { - "message": message, - "search_filters": ( - json.dumps(search_filters) if search_filters else None - ), - "search_limit": search_limit, - "rag_generation_config": ( - json.dumps(rag_generation_config) - if rag_generation_config - else None - ), - "streaming": streaming, - } - - response = requests.post(url, json=data) + response = requests.post(url, json=rag_request.dict()) response.raise_for_status() return response.json() except requests.exceptions.RequestException as e: raise e async def _stream_rag( - self, - message: str, - search_filters: Optional[dict] = None, - search_limit: int = 10, - rag_generation_config: Optional[dict] = None, + self, rag_request: R2RRAGRequest ) -> AsyncGenerator[str, None]: url = f"{self.base_url}/rag" - data = { - "message": message, - "search_filters": ( - json.dumps(search_filters) if search_filters else None - ), - "search_limit": search_limit, - "rag_generation_config": ( - json.dumps(rag_generation_config) - if rag_generation_config - else None - ), - "streaming": True, - } async with httpx.AsyncClient() as client: - async with client.stream("POST", url, json=data) as response: + async with client.stream( + "POST", url, json=rag_request.dict() + ) as response: response.raise_for_status() async for chunk in response.aiter_text(): yield chunk def _stream_rag_sync( - self, - message: str, - vector_search_settings: VectorSearchSettings, - kg_search_settings: KGSearchSettings, - rag_generation_config: Optional[GenerationConfig] = None, + self, rag_request: R2RRAGRequest ) -> Generator[str, None, None]: async def run_async_generator(): - async for chunk in self._stream_rag( - message=message, - vector_search_settings=vector_search_settings, - kg_search_settings=kg_search_settings, - rag_generation_config=rag_generation_config, - ): + async for chunk in self._stream_rag(rag_request): yield chunk loop = asyncio.new_event_loop() diff --git a/r2r/providers/vector_dbs/pgvector/pgvector_db.py b/r2r/providers/vector_dbs/pgvector/pgvector_db.py index 1961de62..f18a2f04 100644 --- a/r2r/providers/vector_dbs/pgvector/pgvector_db.py +++ b/r2r/providers/vector_dbs/pgvector/pgvector_db.py @@ -2,7 +2,6 @@ import logging import os import time -import uuid from typing import Optional, Union from sqlalchemy import exc, text