Skip to content

Commit

Permalink
mostly complete.
Browse files Browse the repository at this point in the history
  • Loading branch information
emrgnt-cmplxty committed Mar 14, 2024
1 parent 0ada9c9 commit 542c74d
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 18 deletions.
3 changes: 2 additions & 1 deletion r2r/llms/openai/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
from dataclasses import dataclass
from typing import Union

from openai.types.chat import ChatCompletion, ChatCompletionChunk

Expand Down Expand Up @@ -81,7 +82,7 @@ def _get_completion(
messages: list[dict],
generation_config: GenerationConfig,
**kwargs,
) -> ChatCompletion:
) -> Union[ChatCompletion, ChatCompletionChunk]:
"""Get a completion from the OpenAI API based on the provided messages."""

# Create a dictionary with the default arguments
Expand Down
39 changes: 23 additions & 16 deletions r2r/main/app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
import logging
from dataclasses import asdict
from pathlib import Path
from typing import cast, AsyncGenerator, Generator, Optional, Union
from typing import AsyncGenerator, Generator, Optional, Union, cast

from fastapi import (
BackgroundTasks,
Expand All @@ -16,13 +17,12 @@
from r2r.core import (
EmbeddingPipeline,
EvalPipeline,
IngestionPipeline,
GenerationConfig,
IngestionPipeline,
LoggingDatabaseConnection,
RAGPipeline,
RAGPipelineOutput,
)
from dataclasses import asdict

from r2r.main.utils import (
apply_cors,
configure_logging,
Expand Down Expand Up @@ -172,24 +172,28 @@ async def rag_completion(
try:
stream = query.generation_config.stream
if not stream:
rag_completion = rag_pipeline.run(
untyped_completion = rag_pipeline.run(
query.query,
query.filters,
query.limit,
generation_config=query.generation_config,
)

# Tell the type checker that rag_completion is a RAGPipelineOutput
rag_completion = cast(RAGPipelineOutput, untyped_completion)
if not rag_completion.completion:
raise ValueError("No completion found in RAGPipelineOutput.")

completion_text = rag_completion.completion.choices[
0
].message.content
rag_run_id = rag_pipeline.pipeline_run_info["run_id"]

# TODO - Run with task manager for Cloud deployments
background_tasks.add_task(
eval_pipeline.run,
query.query,
rag_completion.context,
completion_text,
rag_run_id,
rag_completion.context or "",
completion_text or "",
rag_pipeline.pipeline_run_info["run_id"], # type: ignore
**query.settings.rag_settings.dict(),
)
return rag_completion
Expand All @@ -214,12 +218,15 @@ async def _stream_rag_completion(
raise ValueError(
"Must pass `stream` as True to stream completions."
)
completion_generator = cast(Generator[str, None, None], rag_pipeline.run(
query.query,
query.filters,
query.limit,
generation_config=gen_config,
))
completion_generator = cast(
Generator[str, None, None],
rag_pipeline.run(
query.query,
query.filters,
query.limit,
generation_config=gen_config,
),
)

for item in completion_generator:
yield item
Expand Down
4 changes: 4 additions & 0 deletions r2r/main/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ def process_event(event: dict[str, Any], pipeline_type: str) -> dict[str, Any]:
id_match = re.search(r"'id': '([^']+)'", result)
text_match = re.search(r"'text': '([^']+)'", result)
metadata_match = re.search(r"'metadata': (\{[^}]+\})", result)
if not id_match or not text_match or not metadata_match:
raise ValueError(
f"Missing 'id', 'text', or 'metadata' in result: {result}"
)
metadata = metadata_match.group(1).replace("'", '"')
metadata_json = json.loads(metadata)

Expand Down
4 changes: 3 additions & 1 deletion r2r/pipelines/basic/prompt_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ class BasicPromptProvider(DefaultPromptProvider):
"""

def __init__(
self, system_prompt: Optional[str] = None, task_prompt: Optional[str] = None
self,
system_prompt: Optional[str] = None,
task_prompt: Optional[str] = None,
):
super().__init__()
self.add_prompt(
Expand Down

0 comments on commit 542c74d

Please sign in to comment.