Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/prompt provider and type reno #174

Merged
merged 3 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ line-length = 79

[tool.mypy]
ignore_missing_imports = true
exclude = 'playground/.*|deprecated/.*|dump/.*|docs/source'
exclude = 'playground/.*|deprecated/.*|dump/.*|docs/source|vecs/*'

[[tool.mypy.overrides]]
module = "yaml"
Expand Down
4 changes: 1 addition & 3 deletions r2r/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ def upload_and_process_file(
):
url = f"{self.base_url}/upload_and_process_file/"
with open(file_path, "rb") as file:
import json

files = {
"file": (file_path.split("/")[-1], file, "application/pdf")
}
Expand Down Expand Up @@ -155,7 +153,7 @@ async def stream_rag_completion(

async with httpx.AsyncClient() as client:
async with client.stream(
"POST", url, headers=headers, data=json.dumps(json_data)
"POST", url, headers=headers, json=json_data
) as response:
async for chunk in response.aiter_bytes():
yield chunk.decode()
Expand Down
8 changes: 5 additions & 3 deletions r2r/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .abstractions.completion import Completion, RAGCompletion
from .abstractions.document import BasicDocument
from .abstractions.output import RAGPipelineOutput
from .pipelines.embedding import EmbeddingPipeline
from .pipelines.eval import EvalPipeline
from .pipelines.ingestion import IngestionPipeline
Expand All @@ -9,6 +9,7 @@
from .providers.eval import EvalProvider
from .providers.llm import GenerationConfig, LLMConfig, LLMProvider
from .providers.logging import LoggingDatabaseConnection, log_execution_to_db
from .providers.prompt import DefaultPromptProvider, PromptProvider
from .providers.vector_db import (
VectorDBProvider,
VectorEntry,
Expand All @@ -17,14 +18,15 @@

__all__ = [
"BasicDocument",
"Completion",
"RAGCompletion",
"DefaultPromptProvider",
"RAGPipelineOutput",
"EmbeddingPipeline",
"EvalPipeline",
"IngestionPipeline",
"RAGPipeline",
"LoggingDatabaseConnection",
"log_execution_to_db",
"PromptProvider",
"EvalProvider",
"DatasetConfig",
"DatasetProvider",
Expand Down
25 changes: 0 additions & 25 deletions r2r/core/abstractions/completion.py

This file was deleted.

27 changes: 27 additions & 0 deletions r2r/core/abstractions/output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Optional

from openai.types.chat import ChatCompletion


class RAGPipelineOutput:
def __init__(
self,
search_results: list,
context: Optional[str] = None,
completion: Optional[ChatCompletion] = None,
):
self.search_results = search_results
self.context = context
self.completion = completion

def to_dict(self):
return {
"search_results": self.search_results,
"context": self.context,
"completion": self.completion.to_dict()
if self.completion
else None,
}

def __repr__(self):
return f"RAGPipelineOutput(search_results={self.search_results}, context={self.context}, completion={self.completion})"
50 changes: 18 additions & 32 deletions r2r/core/pipelines/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,16 @@
from abc import abstractmethod
from typing import Any, Generator, Optional, Union

from ..abstractions.completion import RAGCompletion
from openai.types.chat import ChatCompletion

from ..abstractions.output import RAGPipelineOutput
from ..providers.llm import GenerationConfig, LLMProvider
from ..providers.logging import LoggingDatabaseConnection, log_execution_to_db
from ..providers.prompt import PromptProvider
from .pipeline import Pipeline

logger = logging.getLogger(__name__)

DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant."
DEFAULT_TASK_PROMPT = """
## Task:
Answer the query given immediately below given the context which follows later.

### Query:
{query}

### Context:
{context}

### Query:
{query}

## Response:
"""


class RAGPipeline(Pipeline):
SEARCH_STREAM_MARKER = "search"
Expand All @@ -40,15 +26,13 @@ class RAGPipeline(Pipeline):
def __init__(
self,
llm: "LLMProvider",
system_prompt: Optional[str] = None,
task_prompt: Optional[str] = None,
prompt_provider: PromptProvider,
logging_connection: Optional[LoggingDatabaseConnection] = None,
*args,
**kwargs,
):
self.llm = llm
self.system_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
self.task_prompt = task_prompt or DEFAULT_TASK_PROMPT
self.prompt_provider = prompt_provider
super().__init__(logging_connection=logging_connection, **kwargs)

def initialize_pipeline(
Expand Down Expand Up @@ -117,40 +101,42 @@ def construct_prompt(self, inputs: dict[str, str]) -> str:
"""
Constructs a prompt for generation based on the reranked chunks.
"""
return self.task_prompt.format(**inputs)
return self.prompt_provider.get_prompt("task_prompt", inputs).format(
**inputs
)

@log_execution_to_db
def generate_completion(
self,
prompt: str,
generation_config: GenerationConfig,
) -> Union[Generator[str, None, None], RAGCompletion]:
) -> Union[Generator[str, None, None], ChatCompletion]:
"""
Generates a completion based on the prompt.
"""
self._check_pipeline_initialized()
messages = [
{
"role": "system",
"content": self.system_prompt,
"content": self.prompt_provider.get_prompt("system_prompt"),
},
{
"role": "user",
"content": prompt,
},
]
if not generation_config.stream:
return self.llm.get_chat_completion(messages, generation_config)
return self.llm.get_completion(messages, generation_config)

return self._stream_generate_completion(messages, generation_config)

def _stream_generate_completion(
self, messages: dict, generation_config: GenerationConfig
self, messages: list[dict], generation_config: GenerationConfig
) -> Generator[str, None, None]:
for result in self.llm.get_chat_completion(
for result in self.llm.get_completion_stream(
messages, generation_config
):
yield result.choices[0].delta.content or ""
yield result.choices[0].delta.content or "" # type: ignore

def run(
self,
Expand All @@ -161,7 +147,7 @@ def run(
generation_config: Optional[GenerationConfig] = None,
*args,
**kwargs,
) -> Union[Generator[str, None, None], RAGCompletion]:
) -> Union[Generator[str, None, None], RAGPipelineOutput]:
"""
Runs the completion pipeline.
"""
Expand All @@ -170,7 +156,7 @@ def run(
transformed_query = self.transform_query(query)
search_results = self.search(transformed_query, filters, limit)
if search_only:
return RAGCompletion(search_results, None, None)
return RAGPipelineOutput(search_results, None, None)
elif not generation_config:
raise ValueError(
"GenerationConfig is required for completion generation."
Expand All @@ -188,7 +174,7 @@ def run(

if not generation_config.stream:
completion = self.generate_completion(prompt, generation_config)
return RAGCompletion(search_results, context, completion)
return RAGPipelineOutput(search_results, context, completion)

return self._stream_run(
search_results, context, prompt, generation_config
Expand Down
13 changes: 6 additions & 7 deletions r2r/core/providers/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from dataclasses import dataclass, field, fields
from typing import Optional

from openai.types import Completion
from openai.types.chat import ChatCompletion
from openai.types.chat import ChatCompletion, ChatCompletionChunk


@dataclass
Expand Down Expand Up @@ -46,7 +45,7 @@ def __init__(
pass

@abstractmethod
def get_chat_completion(
def get_completion(
self,
messages: list[dict],
generation_config: GenerationConfig,
Expand All @@ -56,11 +55,11 @@ def get_chat_completion(
pass

@abstractmethod
def get_instruct_completion(
def get_completion_stream(
self,
prompt: str,
messages: list[dict],
generation_config: GenerationConfig,
**kwargs,
) -> Completion:
"""Abstract method to get an instruction completion from the provider."""
) -> ChatCompletionChunk:
"""Abstract method to get a completion stream from the provider."""
pass
42 changes: 42 additions & 0 deletions r2r/core/providers/prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from abc import ABC, abstractmethod
from typing import Any, Optional


class PromptProvider(ABC):
@abstractmethod
def add_prompt(self, prompt_name: str, prompt: str) -> None:
pass

@abstractmethod
def get_prompt(
self, prompt_name: str, inputs: Optional[dict[str, Any]] = None
) -> str:
pass

@abstractmethod
def get_all_prompts(self) -> dict[str, str]:
pass


class DefaultPromptProvider(PromptProvider):
def __init__(self) -> None:
self.prompts: dict[str, str] = {}

def add_prompt(self, prompt_name: str, prompt: str) -> None:
self.prompts[prompt_name] = prompt

def get_prompt(
self, prompt_name: str, inputs: Optional[dict[str, Any]] = None
) -> str:
prompt = self.prompts.get(prompt_name)
if prompt is None:
raise ValueError(f"Prompt '{prompt_name}' not found.")
return prompt.format(**(inputs or {}))

def set_prompt(self, prompt_name: str, prompt: str) -> None:
if prompt_name not in self.prompts:
raise ValueError(f"Prompt '{prompt_name}' not found.")
self.prompts[prompt_name] = prompt

def get_all_prompts(self) -> dict[str, str]:
return self.prompts.copy()
17 changes: 9 additions & 8 deletions r2r/examples/academy/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
GenerationConfig,
LLMProvider,
LoggingDatabaseConnection,
RAGCompletion,
PromptProvider,
RAGPipelineOutput,
VectorDBProvider,
VectorSearchResult,
log_execution_to_db,
)
from r2r.embeddings import OpenAIEmbeddingProvider
from r2r.main import E2EPipelineFactory, load_config
from r2r.pipelines import BasicRAGPipeline
from r2r.pipelines import BasicPromptProvider, BasicRAGPipeline

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -43,8 +44,9 @@ def __init__(
embedding_model: str,
embeddings_provider: OpenAIEmbeddingProvider,
logging_connection: Optional[LoggingDatabaseConnection] = None,
system_prompt: Optional[str] = DEFAULT_SYSTEM_PROMPT,
task_prompt: Optional[str] = DEFAULT_TASK_PROMPT,
prompt_provider: Optional[PromptProvider] = BasicPromptProvider(
DEFAULT_SYSTEM_PROMPT, DEFAULT_TASK_PROMPT
),
) -> None:
logger.debug(f"Initalizing `SyntheticRAGPipeline`")

Expand All @@ -54,8 +56,7 @@ def __init__(
embedding_model,
embeddings_provider,
logging_connection=logging_connection,
system_prompt=system_prompt,
task_prompt=task_prompt,
prompt_provider=prompt_provider,
)

def transform_query(self, query: str, generation_config: GenerationConfig) -> list[str]: # type: ignore
Expand Down Expand Up @@ -138,14 +139,14 @@ def run(
for transformed_query in transformed_queries
]
if search_only:
return RAGCompletion(search_results, None, None)
return RAGPipelineOutput(search_results, None, None)

context = self.construct_context(search_results)
prompt = self.construct_prompt({"query": query, "context": context})

if not generation_config.stream:
completion = self.generate_completion(prompt, generation_config)
return RAGCompletion(search_results, context, completion)
return RAGPipelineOutput(search_results, context, completion)

return self._stream_run(
search_results, context, prompt, generation_config
Expand Down
Loading