Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

first commit for rd #251

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions r2r/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .abstractions.document import BasicDocument
from .abstractions.output import RAGPipelineOutput
from .pipelines.embedding import EmbeddingPipeline
from .pipelines.extraction import EntityExtractionPipeline
from .pipelines.eval import EvalPipeline
from .pipelines.ingestion import IngestionPipeline
from .pipelines.rag import RAGPipeline
Expand All @@ -20,6 +21,7 @@
"DefaultPromptProvider",
"RAGPipelineOutput",
"EmbeddingPipeline",
"EntityExtractionPipeline",
"EvalPipeline",
"IngestionPipeline",
"RAGPipeline",
Expand Down
37 changes: 37 additions & 0 deletions r2r/core/pipelines/extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from abc import abstractmethod
from typing import Iterator, Optional

from ..providers.prompt import PromptProvider
from ..providers.llm import LLMProvider
from ..providers.logging import LoggingDatabaseConnection, log_execution_to_db
from r2r.core import BasicDocument, GenerationConfig
from r2r.pipelines import Pipeline
Copy link

@taahan0810 taahan0810 Mar 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the Pipeline class is in .pipeline

Should it be

from .pipeline import Pipeline


class EntityExtractionPipeline(Pipeline):
def __init__(
self,
llm: LLMProvider,
prompt_provider: PromptProvider,
logging_connection: Optional[LoggingDatabaseConnection] = None,
*args,
**kwargs,
):
self.llm = llm
self.prompt_provider = prompt_provider
super().__init__(logging_connection=logging_connection, **kwargs)

@abstractmethod
def preprocess_text(self, text: str) -> str:
pass

@abstractmethod
def extract_entities(self, text: str, generation_config: GenerationConfig) -> list[str]:
pass

@abstractmethod
def postprocess_entities(self, entities: list[str]) -> list[str]:
pass

@abstractmethod
def run(self, documents: Iterator[BasicDocument]) -> Iterator[BasicDocument]:
pass
2 changes: 1 addition & 1 deletion r2r/core/pipelines/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class RAGPipeline(Pipeline):

def __init__(
self,
llm: "LLMProvider",
llm: LLMProvider,
prompt_provider: PromptProvider,
logging_connection: Optional[LoggingDatabaseConnection] = None,
*args,
Expand Down
53 changes: 53 additions & 0 deletions r2r/pipelines/basic/extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Iterator

from r2r.core import BasicDocument, EntityExtractionPipeline, GenerationConfig
from r2r.pipelines import BasicPromptProvider

class BasicEntityExtractionPipeline(EntityExtractionPipeline):
BASIC_SYSTEM_PROMPT = "You are a helpful assistant."
BASIC_TASK_PROMPT = """
## Task:
Extract the named entities from the following text document, and return them in a comma-separated list.

## Response:
"""
def __init__(self, llm, logging_connection=None, *args, **kwargs):
super().__init__(prompt_provider=BasicPromptProvider(BasicEntityExtractionPipeline.BASIC_SYSTEM_PROMPT, BasicEntityExtractionPipeline.BASIC_TASK_PROMPT), logging_connection=logging_connection, **kwargs)
self.llm = llm

def preprocess_text(self, text: str) -> str:
# Optional - Implement text preprocessing logic here
return text

def extract_entities(self, text: str, generation_config: GenerationConfig) -> list[str]:
# entities = self.com
self._check_pipeline_initialized()
messages = [
{
"role": "system",
"content": self.prompt_provider.get_prompt("system_prompt"),

},
{
"role": "user",
"content": self.prompt_provider.get_prompt("task_prompt"),
},
]
entities_list = self.llm.get_completion(text, generation_config)
if not "," in entities_list:
entities = []
else:
entities = entities_list.split(",")
return entities

def postprocess_entities(self, entities: list[str]) -> list[str]:
# Implement entity postprocessing logic here
return [entity.upper() for entity in entities]

def run(self, documents: Iterator[BasicDocument]) -> Iterator[BasicDocument]:
for document in documents:
preprocessed_text = self.preprocess_text(document.text)
entities = self.extract_entities(preprocessed_text)
postprocessed_entities = self.postprocess_entities(entities)
document.metadata["entities"] = postprocessed_entities
yield document