-
Notifications
You must be signed in to change notification settings - Fork 248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/add jina reranker rebased #312
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -76,15 +76,17 @@ def add_entries( | |||||||||||||||||||||||||||||||||||||||||||||
def search( | ||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||||||||
query: str, | ||||||||||||||||||||||||||||||||||||||||||||||
limit: Optional[int] = 10, | ||||||||||||||||||||||||||||||||||||||||||||||
search_limit: Optional[int] = 25, | ||||||||||||||||||||||||||||||||||||||||||||||
rerank_limit: Optional[int] = 15, | ||||||||||||||||||||||||||||||||||||||||||||||
filters: Optional[Dict[str, Any]] = None, | ||||||||||||||||||||||||||||||||||||||||||||||
settings: Optional[Dict[str, Any]] = None, | ||||||||||||||||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||||||||||||||||
url = f"{self.base_url}/search/" | ||||||||||||||||||||||||||||||||||||||||||||||
json_data = { | ||||||||||||||||||||||||||||||||||||||||||||||
"query": query, | ||||||||||||||||||||||||||||||||||||||||||||||
"filters": filters or {}, | ||||||||||||||||||||||||||||||||||||||||||||||
"limit": limit, | ||||||||||||||||||||||||||||||||||||||||||||||
"search_limit": search_limit, | ||||||||||||||||||||||||||||||||||||||||||||||
"rerank_limit": rerank_limit, | ||||||||||||||||||||||||||||||||||||||||||||||
"settings": settings or {}, | ||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||
response = requests.post(url, json=json_data) | ||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -96,7 +98,8 @@ def search( | |||||||||||||||||||||||||||||||||||||||||||||
def rag_completion( | ||||||||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||||||||
query: str, | ||||||||||||||||||||||||||||||||||||||||||||||
limit: Optional[int] = 10, | ||||||||||||||||||||||||||||||||||||||||||||||
search_limit: Optional[int] = 25, | ||||||||||||||||||||||||||||||||||||||||||||||
rerank_limit: Optional[int] = 15, | ||||||||||||||||||||||||||||||||||||||||||||||
filters: Optional[Dict[str, Any]] = None, | ||||||||||||||||||||||||||||||||||||||||||||||
settings: Optional[Dict[str, Any]] = None, | ||||||||||||||||||||||||||||||||||||||||||||||
generation_config: Optional[Dict[str, Any]] = None, | ||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -114,7 +117,8 @@ def rag_completion( | |||||||||||||||||||||||||||||||||||||||||||||
json_data = { | ||||||||||||||||||||||||||||||||||||||||||||||
"query": query, | ||||||||||||||||||||||||||||||||||||||||||||||
"filters": filters or {}, | ||||||||||||||||||||||||||||||||||||||||||||||
"limit": limit, | ||||||||||||||||||||||||||||||||||||||||||||||
"search_limit": search_limit, | ||||||||||||||||||||||||||||||||||||||||||||||
"rerank_limit": rerank_limit, | ||||||||||||||||||||||||||||||||||||||||||||||
"settings": settings or {}, | ||||||||||||||||||||||||||||||||||||||||||||||
"generation_config": generation_config or {}, | ||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -143,7 +147,8 @@ def eval( | |||||||||||||||||||||||||||||||||||||||||||||
async def stream_rag_completion( | ||||||||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||||||||
query: str, | ||||||||||||||||||||||||||||||||||||||||||||||
limit: Optional[int] = 10, | ||||||||||||||||||||||||||||||||||||||||||||||
search_limit: Optional[int] = 25, | ||||||||||||||||||||||||||||||||||||||||||||||
rerank_limit: Optional[int] = 15, | ||||||||||||||||||||||||||||||||||||||||||||||
filters: Optional[Dict[str, Any]] = None, | ||||||||||||||||||||||||||||||||||||||||||||||
settings: Optional[Dict[str, Any]] = None, | ||||||||||||||||||||||||||||||||||||||||||||||
generation_config: Optional[Dict[str, Any]] = None, | ||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -165,7 +170,8 @@ async def stream_rag_completion( | |||||||||||||||||||||||||||||||||||||||||||||
json_data = { | ||||||||||||||||||||||||||||||||||||||||||||||
"query": query, | ||||||||||||||||||||||||||||||||||||||||||||||
"filters": filters or {}, | ||||||||||||||||||||||||||||||||||||||||||||||
"limit": limit, | ||||||||||||||||||||||||||||||||||||||||||||||
"search_limit": search_limit, | ||||||||||||||||||||||||||||||||||||||||||||||
"rerank_limit": rerank_limit, | ||||||||||||||||||||||||||||||||||||||||||||||
"settings": settings or {}, | ||||||||||||||||||||||||||||||||||||||||||||||
"generation_config": generation_config or {}, | ||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -11,9 +11,11 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from openai.types.chat import ChatCompletion | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from ..abstractions.output import RAGPipelineOutput | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from ..providers.embedding import EmbeddingProvider | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from ..providers.llm import GenerationConfig, LLMProvider | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from ..providers.logging import LoggingDatabaseConnection, log_execution_to_db | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from ..providers.prompt import PromptProvider | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from ..providers.vector_db import VectorDBProvider | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from .pipeline import Pipeline | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
logger = logging.getLogger(__name__) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -26,14 +28,19 @@ class RAGPipeline(Pipeline): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
llm: "LLMProvider", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
prompt_provider: PromptProvider, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
embedding_provider: EmbeddingProvider, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
llm_provider: LLMProvider, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
vector_db_provider: VectorDBProvider, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
logging_connection: Optional[LoggingDatabaseConnection] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
*args, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
**kwargs, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.llm = llm | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.prompt_provider = prompt_provider | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.llm_provider = llm_provider | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.embedding_provider = embedding_provider | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.vector_db_provider = vector_db_provider | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
super().__init__(logging_connection=logging_connection, **kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def initialize_pipeline( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -63,7 +70,7 @@ def transform_query(self, query: str) -> Any: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@abstractmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def search( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
transformed_query, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
transformed_query: Any, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
filters: dict[str, Any], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
limit: int, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
*args, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -76,7 +83,9 @@ def search( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
pass | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@abstractmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def rerank_results(self, results: list) -> list: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def rerank_results( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self, transformed_query: Any, results: list, limit: int | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) -> list: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Reranks the retrieved results based on relevance or other criteria. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -94,8 +103,7 @@ def construct_context( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
results: list, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) -> str: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
reranked_results = self.rerank_results(results) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return self._format_results(reranked_results) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return self._format_results(results) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@log_execution_to_db | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def construct_prompt(self, inputs: dict[str, str]) -> str: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -141,14 +149,16 @@ def generate_completion( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if not generation_config.stream: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return self.llm.get_completion(messages, generation_config) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return self.llm_provider.get_completion( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
messages, generation_config | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return self._stream_generate_completion(messages, generation_config) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def _stream_generate_completion( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self, messages: list[dict], generation_config: GenerationConfig | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) -> Generator[str, None, None]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for result in self.llm.get_completion_stream( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for result in self.llm_provider.get_completion_stream( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
messages, generation_config | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
yield result.choices[0].delta.content or "" # type: ignore | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -157,7 +167,8 @@ def run( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
query, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
filters={}, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
limit=10, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
search_limit=25, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
rerank_limit=15, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
search_only=False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
generation_config: Optional[GenerationConfig] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
*args, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -169,7 +180,11 @@ def run( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.initialize_pipeline(query, search_only) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
transformed_query = self.transform_query(query) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
search_results = self.search(transformed_query, filters, limit) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
search_results = self.search(transformed_query, filters, search_limit) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
search_results = self.rerank_results( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
transformed_query, search_results, rerank_limit | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if search_only: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return RAGPipelineOutput(search_results, None, None) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
elif not generation_config: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
rerank_results
function should check iflimit
is a positive integer.