Skip to content

Commit

Permalink
Feature/add rag pipeline ver (#347)
Browse files Browse the repository at this point in the history
* fix pipeline

* add endpoint to fetch rag pipeline version

* clean and sort
  • Loading branch information
emrgnt-cmplxty committed May 2, 2024
1 parent e44ae64 commit f4969a9
Show file tree
Hide file tree
Showing 10 changed files with 53 additions and 23 deletions.
8 changes: 5 additions & 3 deletions r2r/core/pipelines/async_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from typing import Any, Optional

from ..providers.embedding import EmbeddingProvider
from ..utils.logging import LoggingDatabaseConnection
from ..providers.vector_db import VectorDBProvider, VectorEntry
from ..utils import generate_run_id
from ..utils.logging import LoggingDatabaseConnection
from .async_pipeline import AsyncPipeline

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -53,9 +53,11 @@ async def embed_chunks(self, chunks: list[Any]) -> list[list[float]]:
pass

@abstractmethod
async def store_chunks(self, chunks: list[VectorEntry], *args, **kwargs) -> None:
async def store_chunks(
self, chunks: list[VectorEntry], *args, **kwargs
) -> None:
pass

@abstractmethod
@abstractmethod
async def run(self, document: Any, **kwargs):
pass
8 changes: 6 additions & 2 deletions r2r/core/providers/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def get_embedding(
):
pass

async def async_get_embedding(self, text: str, stage: PipelineStage = PipelineStage.SEARCH):
async def async_get_embedding(
self, text: str, stage: PipelineStage = PipelineStage.SEARCH
):
return self.get_embedding(text, stage)

@abstractmethod
Expand All @@ -59,7 +61,9 @@ def get_embeddings(
):
pass

async def async_get_embeddings(self, texts: list[str], stage: PipelineStage = PipelineStage.SEARCH):
async def async_get_embeddings(
self, texts: list[str], stage: PipelineStage = PipelineStage.SEARCH
):
return self.get_embeddings(texts, stage)

@abstractmethod
Expand Down
5 changes: 4 additions & 1 deletion r2r/embeddings/dummy/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ def rerank(
return texts[:limit]

def tokenize_string(
self, text: str, model: str, stage: EmbeddingProvider.PipelineStage,
self,
text: str,
model: str,
stage: EmbeddingProvider.PipelineStage,
) -> list[int]:
"""Tokenizes the input string."""
return [0]
10 changes: 7 additions & 3 deletions r2r/embeddings/openai/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def get_embedding(
async def async_get_embedding(
self,
text: str,
stage: EmbeddingProvider.PipelineStage = EmbeddingProvider.PipelineStage.SEARCH
stage: EmbeddingProvider.PipelineStage = EmbeddingProvider.PipelineStage.SEARCH,
) -> list[float]:
if stage != EmbeddingProvider.PipelineStage.SEARCH:
raise ValueError(
Expand All @@ -111,7 +111,9 @@ async def async_get_embedding(
input=[text],
model=self.search_model,
dimensions=self.search_dimension
or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[self.search_model][-1],
or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[self.search_model][
-1
],
)
return response.data[0].embedding

Expand Down Expand Up @@ -151,7 +153,9 @@ async def async_get_embeddings(
input=texts,
model=self.search_model,
dimensions=self.search_dimension
or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[self.search_model][-1],
or OpenAIEmbeddingProvider.MODEL_TO_DIMENSIONS[self.search_model][
-1
],
)
return [ele.embedding for ele in response.data]

Expand Down
2 changes: 1 addition & 1 deletion r2r/examples/clients/run_test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from r2r.core.utils import generate_id_from_label

# Initialize the client with the base URL of your API
base_url = "http://localhost:8000"
base_url = "http://localhost:8001"
client = R2RClient(base_url)


Expand Down
9 changes: 9 additions & 0 deletions r2r/main/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,4 +499,13 @@ async def logs_summary(filter: LogFilterModel = Depends()):
logging.error(f":logs_summary: [Error](error={str(e)})")
raise HTTPException(status_code=500, detail=str(e))

@app.get("/get_rag_pipeline_var/")
async def get_rag_pipeline():
try:
rag_pipeline = os.getenv("RAG_PIPELINE", None)
return {"rag_pipeline": rag_pipeline}
except Exception as e:
logging.error(f":rag_pipeline: [Error](error={str(e)})")
raise HTTPException(status_code=500, detail=str(e))

return app
1 change: 1 addition & 0 deletions r2r/main/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def get_embedding_provider(embedding_config: dict[str, Any]):
return SentenceTransformerEmbeddingProvider(embedding_config)
elif embedding_config.provider == "dummy":
from r2r.embeddings import DummyEmbeddingProvider

return DummyEmbeddingProvider(embedding_config)
else:
raise ValueError(
Expand Down
17 changes: 10 additions & 7 deletions r2r/pipelines/core/async_embedding.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""
A simple example to demonstrate the usage of `BasicEmbeddingPipeline`.
"""

import asyncio
import copy
import logging
from typing import Any, Optional, Tuple, Generator
from typing import Any, Generator, Optional, Tuple

from r2r.core import (
DocumentPage,
AsyncEmbeddingPipeline,
DocumentPage,
EmbeddingProvider,
LoggingDatabaseConnection,
VectorDBProvider,
Expand Down Expand Up @@ -66,9 +67,7 @@ def ingress(self, document: DocumentPage) -> dict:
"text": document.text,
}

def initialize_pipeline(
self, *args, **kwargs
) -> None:
def initialize_pipeline(self, *args, **kwargs) -> None:
super().initialize_pipeline(*args, **kwargs)

def transform_text(self, text: str) -> str:
Expand Down Expand Up @@ -141,7 +140,9 @@ async def run(
self.ingress(document)

chunks = (
self.chunk_text(document.text) if do_chunking else [document.text]
self.chunk_text(document.text)
if do_chunking
else [document.text]
)

for chunk_iter, chunk in enumerate(chunks):
Expand All @@ -165,7 +166,9 @@ async def run(

await asyncio.gather(*tasks)

async def _process_batches(self, batch_data: list[Tuple[str, str, dict]], do_upsert: bool):
async def _process_batches(
self, batch_data: list[Tuple[str, str, dict]], do_upsert: bool
):
logger.debug(f"Parsing batch of size {len(batch_data)}.")

entries = []
Expand Down
10 changes: 5 additions & 5 deletions r2r/pipelines/core/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import copy
import logging
from typing import Any, Optional, Tuple, Generator
from typing import Any, Generator, Optional, Tuple

from r2r.core import (
DocumentPage,
Expand Down Expand Up @@ -66,9 +66,7 @@ def ingress(self, document: DocumentPage) -> dict:
"text": document.text,
}

def initialize_pipeline(
self, *args, **kwargs
) -> None:
def initialize_pipeline(self, *args, **kwargs) -> None:
super().initialize_pipeline(*args, **kwargs)

def transform_text(self, text: str) -> str:
Expand Down Expand Up @@ -143,7 +141,9 @@ def run(
self.ingress(document)

chunks = (
self.chunk_text(document.text) if do_chunking else [document.text]
self.chunk_text(document.text)
if do_chunking
else [document.text]
)
for chunk_iter, chunk in enumerate(chunks):
batch_data.append(
Expand Down
6 changes: 5 additions & 1 deletion r2r/pipelines/hyde/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,11 @@ def run_stream(
]

return self._return_stream(
search_results, context, prompt, generation_config, metadata={"answers": answers},
search_results,
context,
prompt,
generation_config,
metadata={"answers": answers},
)

def _construct_joined_context(
Expand Down

0 comments on commit f4969a9

Please sign in to comment.