Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/add jina reranker rebased #312

Merged
merged 2 commits into from
Apr 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions config.json
Original file line number Diff line number Diff line change
@@ -1,21 +1,10 @@
{
"embedding": {
"provider": "openai",
"model": "text-embedding-3-small",
"dimension": 1536,
"batch_size": 32
},
"evals": {
"provider": "parea",
"frequency": 1.0
},
"language_model": {
"provider": "litellm"
},
"logging_database": {
"vector_database": {
"provider": "local",
"collection_name": "demo_logs",
"level": "INFO"
"collection_name": "demo_vecs"
},
"ingestion":{
"provider": "local",
Expand All @@ -25,12 +14,23 @@
"chunk_overlap": 20
}
},
"vector_database": {
"provider": "local",
"collection_name": "demo_vecs"
"embedding": {
"provider": "openai",
"search_model": "text-embedding-3-small",
"search_dimension": 512,
"batch_size": 32
},
"evals": {
"provider": "parea",
"frequency": 1.0
},
"app": {
"max_logs": 100,
"max_file_size_in_mb": 100
},
"logging_database": {
"provider": "local",
"collection_name": "demo_logs",
"level": "INFO"
}
}
12 changes: 5 additions & 7 deletions docs/pages/deep-dive/embedding.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ The `BasicEmbeddingPipeline` is a simple implementation of the `EmbeddingPipelin

### Initialization

The `BasicEmbeddingPipeline` is initialized with the `embedding_model`, `embeddings_provider`, `db`, `text_splitter`, an optional `logging_connection`, `embedding_batch_size`, and `id_prefix`.
The `BasicEmbeddingPipeline` is initialized with an `embedding_provider`, `vector_db_provider`, `text_splitter`, an optional `logging_connection`, `embedding_batch_size`, and `id_prefix`.

### Text Extraction

Expand Down Expand Up @@ -70,18 +70,16 @@ import textstat
class CustomEmbeddingPipeline(BasicEmbeddingPipeline):
def __init__(
self,
embedding_model: str,
embeddings_provider: OpenAIEmbeddingProvider,
db: VectorDBProvider,
embedding_provider: OpenAIEmbeddingProvider,
vector_db_provider: VectorDBProvider,
text_splitter: TextSplitter,
logging_connection: Optional[LoggingDatabaseConnection] = None,
embedding_batch_size: int = 1,
id_prefix: str = "demo",
):
super().__init__(
embedding_model,
embeddings_provider,
db,
embedding_provider,
vector_db_provider,
text_splitter,
logging_connection,
embedding_batch_size,
Expand Down
14 changes: 7 additions & 7 deletions docs/pages/deep-dive/factory.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ from r2r.main import E2EPipelineFactory, R2RConfig

app = E2EPipelineFactory.create_pipeline(
config=R2RConfig.load_config(),
db=None,
embeddings_provider=None,
llm=None,
vector_db_provider=None,
embedding_provider=None,
llm_provider=None,
text_splitter=None,
adapters=None,
ingestion_pipeline_impl=BasicIngestionPipeline,
Expand All @@ -37,7 +37,7 @@ This method retrieves the appropriate vector database based on the `database_con
- `pgvector`: Returns an instance of `PGVectorDB`.
- `local`: Returns an instance of `LocalVectorDB`.

### `get_embeddings_provider(embedding_config: dict[str, Any])`
### `get_embedding_provider(embedding_config: dict[str, Any])`

This method retrieves the appropriate embeddings provider based on the `embedding_config`. It supports the following embeddings providers:
- `openai`: Returns an instance of `OpenAIEmbeddingProvider`.
Expand All @@ -58,10 +58,10 @@ This method retrieves the appropriate text splitter based on the `text_splitter_
This method creates the end-to-end pipeline by assembling various components based on the provided `config` and optional parameters. It performs the following steps:

1. Sets up logging based on the `logging_database` configuration.
2. Retrieves the embeddings provider using `get_embeddings_provider` or uses the provided `embeddings_provider`.
3. Retrieves the vector database using `get_vector_db` or uses the provided `db`.
2. Retrieves the embeddings provider using `get_embedding_provider` or uses the provided `embedding_provider`.
3. Retrieves the vector database using `get_vector_db_provider` or uses the provided `vector_db_provider`.
4. Initializes the vector database collection with the specified `collection_name` and `embedding_dimension`.
5. Retrieves the language model using `get_llm` or uses the provided `llm`.
5. Retrieves the language model using `get_llm_provider` or uses the provided `llm`.
6. Creates a `LoggingDatabaseConnection` instance for logging purposes.
7. Creates an instance of the `rag_pipeline_impl` (default: `QnARAGPipeline`) with the language model, vector database, embedding model, embeddings provider, and logging connection.
8. Retrieves the text splitter using `get_text_splitter` or uses the provided `text_splitter`.
Expand Down
30 changes: 13 additions & 17 deletions docs/pages/deep-dive/rag.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@ The `QnARAGPipeline` is a simple implementation of the `RAGPipeline` abstract ba
### Initialization

The `QnARAGPipeline` is initialized with the following parameters:
- `llm`: An instance of `LLMProvider` for generating completions.
- `db`: An instance of `VectorDBProvider` for searching and retrieving documents.
- `embedding_model`: The name of the embedding model to use.
- `embeddings_provider`: An instance of `OpenAIEmbeddingProvider` for generating embeddings.
- `llm_provider`: An instance of `LLMProvider` for generating completions.
- `vector_db_provider`: An instance of `VectorDBProvider` for searching and retrieving documents.
- `embedding_provider`: An instance of `OpenAIEmbeddingProvider` for generating embeddings.
- `prompt_provider` (optional): An instance of `PromptProvider` for providing prompts (default is `BasicPromptProvider`).
- `logging_connection` (optional): An instance of `LoggingDatabaseConnection` for logging.

Expand Down Expand Up @@ -40,10 +39,9 @@ To create a custom RAG pipeline, you can subclass the `RAGPipeline` abstract bas
class CustomRAGPipeline(RAGPipeline):
def __init__(
self,
llm: LLMProvider,
db: VectorDBProvider,
embedding_model: str,
embeddings_provider: OpenAIEmbeddingProvider,
llm_provider: LLMProvider,
vector_db_provider: VectorDBProvider,
embedding_provider: OpenAIEmbeddingProvider,
prompt_provider: Optional[PromptProvider] = None,
logging_connection: Optional[LoggingDatabaseConnection] = None,
) -> None:
Expand All @@ -52,13 +50,12 @@ class CustomRAGPipeline(RAGPipeline):
self.prompt_provider = prompt_provider

super().__init__(
llm,
llm_provider=llm_provider,
prompt_provider=prompt_provider,
embedding_provider=embedding_provider,
vector_db_provider=vector_db_provider,
logging_connection=logging_connection,
)
self.embedding_model = embedding_model
self.embeddings_provider = embeddings_provider
self.db = db
self.pipeline_run_info = None

def transform_query(self, query: str) -> str:
Expand All @@ -75,21 +72,20 @@ class CustomRAGPipeline(RAGPipeline):
**kwargs,
) -> list[VectorSearchResult]:
# Custom document retrieval logic
results = self.db.search(
query_vector=self.embeddings_provider.get_embedding(
results = self.vector_db_provider.search(
query_vector=self.embedding_provider.get_embedding(
transformed_query,
self.embedding_model,
),
filters=filters,
limit=limit,
)
return results

def rerank_results(
self, results: list[VectorSearchResult]
self, transformed_query: str, results: list[VectorSearchResult], limit
) -> list[VectorSearchResult]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rerank_results function should check if limit is a positive integer.

Suggested change
) -> list[VectorSearchResult]:
def rerank_results(
self, transformed_query: str, results: list[VectorSearchResult], limit: int
) -> list[VectorSearchResult]:
if not isinstance(limit, int) or limit <= 0:
raise ValueError("'limit' must be a positive integer")
return list(reversed(results))[0:limit]

# Custom result reranking logic
return list(reversed(results))
return list(reversed(results))[0:limit]

def _format_results(self, results: list[VectorSearchResult]) -> str:
# Custom result formatting logic
Expand Down
18 changes: 12 additions & 6 deletions r2r/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,17 @@ def add_entries(
def search(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The search function should check if search_limit and rerank_limit are positive integers.

Suggested change
def search(
def search(
self,
query: str,
search_limit: Optional[int] = 25,
rerank_limit: Optional[int] = 15,
filters: Optional[Dict[str, Any]] = None,
settings: Optional[Dict[str, Any]] = None,
):
if not isinstance(search_limit, int) or search_limit <= 0:
raise ValueError("'search_limit' must be a positive integer")
if not isinstance(rerank_limit, int) or rerank_limit <= 0:
raise ValueError("'rerank_limit' must be a positive integer")
json_data = {
"query": query,
"filters": filters or {},
"search_limit": search_limit,
"rerank_limit": rerank_limit,
"settings": settings or {},
}
response = requests.post(url, json=json_data)
return response.json()

self,
query: str,
limit: Optional[int] = 10,
search_limit: Optional[int] = 25,
rerank_limit: Optional[int] = 15,
filters: Optional[Dict[str, Any]] = None,
settings: Optional[Dict[str, Any]] = None,
):
url = f"{self.base_url}/search/"
json_data = {
"query": query,
"filters": filters or {},
"limit": limit,
"search_limit": search_limit,
"rerank_limit": rerank_limit,
"settings": settings or {},
}
response = requests.post(url, json=json_data)
Expand All @@ -96,7 +98,8 @@ def search(
def rag_completion(
self,
query: str,
limit: Optional[int] = 10,
search_limit: Optional[int] = 25,
rerank_limit: Optional[int] = 15,
filters: Optional[Dict[str, Any]] = None,
settings: Optional[Dict[str, Any]] = None,
generation_config: Optional[Dict[str, Any]] = None,
Expand All @@ -114,7 +117,8 @@ def rag_completion(
json_data = {
"query": query,
"filters": filters or {},
"limit": limit,
"search_limit": search_limit,
"rerank_limit": rerank_limit,
"settings": settings or {},
"generation_config": generation_config or {},
}
Expand Down Expand Up @@ -143,7 +147,8 @@ def eval(
async def stream_rag_completion(
self,
query: str,
limit: Optional[int] = 10,
search_limit: Optional[int] = 25,
rerank_limit: Optional[int] = 15,
filters: Optional[Dict[str, Any]] = None,
settings: Optional[Dict[str, Any]] = None,
generation_config: Optional[Dict[str, Any]] = None,
Expand All @@ -165,7 +170,8 @@ async def stream_rag_completion(
json_data = {
"query": query,
"filters": filters or {},
"limit": limit,
"search_limit": search_limit,
"rerank_limit": rerank_limit,
"settings": settings or {},
"generation_config": generation_config or {},
}
Expand Down
3 changes: 2 additions & 1 deletion r2r/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .pipelines.ingestion import IngestionPipeline
from .pipelines.rag import RAGPipeline
from .pipelines.scraping import ScraperPipeline
from .providers.embedding import EmbeddingProvider
from .providers.embedding import EmbeddingProvider, PipelineStage
from .providers.eval import EvalProvider
from .providers.llm import GenerationConfig, LLMConfig, LLMProvider
from .providers.logging import LoggingDatabaseConnection, log_execution_to_db
Expand All @@ -30,6 +30,7 @@
"PromptProvider",
"EvalProvider",
"EmbeddingProvider",
"PipelineStage",
"GenerationConfig",
"LLMConfig",
"LLMProvider",
Expand Down
10 changes: 4 additions & 6 deletions r2r/core/pipelines/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,14 @@
class EmbeddingPipeline(Pipeline):
def __init__(
self,
embedding_model: str,
embeddings_provider: EmbeddingProvider,
db: VectorDBProvider,
embedding_provider: EmbeddingProvider,
vector_db_provider: VectorDBProvider,
logging_connection: Optional[LoggingDatabaseConnection] = None,
*args,
**kwargs,
):
self.embedding_model = embedding_model
self.embeddings_provider = embeddings_provider
self.db = db
self.embedding_provider = embedding_provider
self.vector_db_provider = vector_db_provider
super().__init__(logging_connection=logging_connection, **kwargs)

def initialize_pipeline(self, *args, **kwargs) -> None:
Expand Down
35 changes: 25 additions & 10 deletions r2r/core/pipelines/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
from openai.types.chat import ChatCompletion

from ..abstractions.output import RAGPipelineOutput
from ..providers.embedding import EmbeddingProvider
from ..providers.llm import GenerationConfig, LLMProvider
from ..providers.logging import LoggingDatabaseConnection, log_execution_to_db
from ..providers.prompt import PromptProvider
from ..providers.vector_db import VectorDBProvider
from .pipeline import Pipeline

logger = logging.getLogger(__name__)
Expand All @@ -26,14 +28,19 @@ class RAGPipeline(Pipeline):

def __init__(
self,
llm: "LLMProvider",
prompt_provider: PromptProvider,
embedding_provider: EmbeddingProvider,
llm_provider: LLMProvider,
vector_db_provider: VectorDBProvider,
logging_connection: Optional[LoggingDatabaseConnection] = None,
*args,
**kwargs,
):
self.llm = llm
self.prompt_provider = prompt_provider
self.llm_provider = llm_provider
self.embedding_provider = embedding_provider
self.vector_db_provider = vector_db_provider

super().__init__(logging_connection=logging_connection, **kwargs)

def initialize_pipeline(
Expand Down Expand Up @@ -63,7 +70,7 @@ def transform_query(self, query: str) -> Any:
@abstractmethod
def search(
self,
transformed_query,
transformed_query: Any,
filters: dict[str, Any],
limit: int,
*args,
Expand All @@ -76,7 +83,9 @@ def search(
pass

@abstractmethod
def rerank_results(self, results: list) -> list:
def rerank_results(
self, transformed_query: Any, results: list, limit: int
) -> list:
"""
Reranks the retrieved results based on relevance or other criteria.
"""
Expand All @@ -94,8 +103,7 @@ def construct_context(
self,
results: list,
) -> str:
reranked_results = self.rerank_results(results)
return self._format_results(reranked_results)
return self._format_results(results)

@log_execution_to_db
def construct_prompt(self, inputs: dict[str, str]) -> str:
Expand Down Expand Up @@ -141,14 +149,16 @@ def generate_completion(
]

if not generation_config.stream:
return self.llm.get_completion(messages, generation_config)
return self.llm_provider.get_completion(
messages, generation_config
)

return self._stream_generate_completion(messages, generation_config)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The run function should check if search_limit and rerank_limit are positive integers.

Suggested change
def run(
self,
query,
filters={},
search_limit=25,
rerank_limit=15,
search_only=False,
generation_config: Optional[GenerationConfig] = None,
*args,
**kwargs,
):
if not isinstance(search_limit, int) or search_limit <= 0:
raise ValueError("'search_limit' must be a positive integer")
if not isinstance(rerank_limit, int) or rerank_limit <= 0:
raise ValueError("'rerank_limit' must be a positive integer")
self.initialize_pipeline(query, search_only)
transformed_query = self.transform_query(query)
search_results = self.search(transformed_query, filters, search_limit)
search_results = self.rerank_results(
transformed_query, search_results, rerank_limit
)
if search_only:
return RAGPipelineOutput(search_results, None, None)
elif not generation_config:
return self.generate_completion(transformed_query, search_results)
else:
return self.generate_completion(
transformed_query, search_results, generation_config
)

def _stream_generate_completion(
self, messages: list[dict], generation_config: GenerationConfig
) -> Generator[str, None, None]:
for result in self.llm.get_completion_stream(
for result in self.llm_provider.get_completion_stream(
messages, generation_config
):
yield result.choices[0].delta.content or "" # type: ignore
Expand All @@ -157,7 +167,8 @@ def run(
self,
query,
filters={},
limit=10,
search_limit=25,
rerank_limit=15,
search_only=False,
generation_config: Optional[GenerationConfig] = None,
*args,
Expand All @@ -169,7 +180,11 @@ def run(
self.initialize_pipeline(query, search_only)

transformed_query = self.transform_query(query)
search_results = self.search(transformed_query, filters, limit)
search_results = self.search(transformed_query, filters, search_limit)
search_results = self.rerank_results(
transformed_query, search_results, rerank_limit
)

if search_only:
return RAGPipelineOutput(search_results, None, None)
elif not generation_config:
Expand Down
Loading