-
-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: quivr core minimal chat (#2803)
# Description Minimal working example of `quivr-core` rag with minimal dependencies. --------- Co-authored-by: aminediro <aminedirhoussi@gmail.com>
- Loading branch information
Showing
23 changed files
with
2,417 additions
and
194 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -87,3 +87,5 @@ backend/modules/sync/controller/credentials.json | |
backend/.env.test | ||
|
||
**/*.egg-info | ||
|
||
.coverage |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from langchain_core.embeddings import DeterministicFakeEmbedding | ||
from langchain_core.language_models import FakeListChatModel | ||
|
||
from quivr_core import Brain | ||
from quivr_core.processor.default_parsers import DEFAULT_PARSERS | ||
from quivr_core.processor.pdf_processor import TikaParser | ||
|
||
if __name__ == "__main__": | ||
pdf_paths = ["../tests/processor/data/dummy.pdf"] | ||
brain = Brain.from_files( | ||
name="test_brain", | ||
file_paths=[], | ||
llm=FakeListChatModel(responses=["good"]), | ||
embedder=DeterministicFakeEmbedding(size=20), | ||
processors_mapping={ | ||
**DEFAULT_PARSERS, | ||
".pdf": TikaParser(), | ||
}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import tempfile | ||
|
||
from quivr_core import Brain | ||
|
||
if __name__ == "__main__": | ||
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt") as temp_file: | ||
temp_file.write("Gold is metal.") | ||
temp_file.flush() | ||
|
||
brain = Brain.from_files(name="test_brain", file_paths=[temp_file.name]) | ||
|
||
answer = brain.ask("Property of gold?") | ||
|
||
print("answer :", answer.answer) |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .brain import Brain | ||
|
||
__all__ = ["Brain"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
import asyncio | ||
import logging | ||
from pathlib import Path | ||
from typing import Mapping, Self | ||
from uuid import UUID, uuid4 | ||
|
||
from langchain_core.documents import Document | ||
from langchain_core.embeddings import Embeddings | ||
from langchain_core.language_models import BaseChatModel | ||
from langchain_core.vectorstores import VectorStore | ||
|
||
from quivr_core.config import RAGConfig | ||
from quivr_core.models import ParsedRAGResponse | ||
from quivr_core.processor.default_parsers import DEFAULT_PARSERS | ||
from quivr_core.processor.processor_base import ProcessorBase | ||
from quivr_core.quivr_rag import QuivrQARAG | ||
from quivr_core.storage.file import QuivrFile | ||
from quivr_core.storage.local_storage import TransparentStorage | ||
from quivr_core.storage.storage_base import StorageBase | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
async def _process_files( | ||
storage: StorageBase, | ||
skip_file_error: bool, | ||
processors_mapping: Mapping[str, ProcessorBase], | ||
) -> list[Document]: | ||
knowledge = [] | ||
for file in storage.get_files(): | ||
try: | ||
if file.file_extension: | ||
processor = processors_mapping[file.file_extension] | ||
docs = await processor.process_file(file) | ||
knowledge.extend(docs) | ||
else: | ||
logger.error(f"can't find processor for {file}") | ||
if skip_file_error: | ||
continue | ||
else: | ||
raise ValueError(f"can't parse {file}. can't find file extension") | ||
except KeyError as e: | ||
if skip_file_error: | ||
continue | ||
else: | ||
raise Exception(f"Can't parse {file}. No available processor") from e | ||
|
||
return knowledge | ||
|
||
|
||
class Brain: | ||
def __init__( | ||
self, | ||
*, | ||
name: str, | ||
id: UUID, | ||
vector_db: VectorStore, | ||
llm: BaseChatModel, | ||
embedder: Embeddings, | ||
storage: StorageBase, | ||
): | ||
self.id = id | ||
self.name = name | ||
self.storage = storage | ||
|
||
# Chat history | ||
self.chat_history: list[str] = [] | ||
|
||
# RAG dependencies: | ||
self.llm = llm | ||
self.vector_db = vector_db | ||
self.embedder = embedder | ||
|
||
@classmethod | ||
async def afrom_files( | ||
cls, | ||
*, | ||
name: str, | ||
file_paths: list[str | Path], | ||
vector_db: VectorStore | None = None, | ||
storage: StorageBase = TransparentStorage(), | ||
llm: BaseChatModel | None = None, | ||
embedder: Embeddings | None = None, | ||
processors_mapping: Mapping[str, ProcessorBase] = DEFAULT_PARSERS, | ||
skip_file_error: bool = False, | ||
): | ||
if llm is None: | ||
try: | ||
from langchain_openai import ChatOpenAI | ||
|
||
logger.debug("Loaded ChatOpenAI as default LLM for brain") | ||
|
||
llm = ChatOpenAI() | ||
|
||
except ImportError as e: | ||
raise ImportError( | ||
"Please provide a valid BaseLLM or install quivr-core['base'] package" | ||
) from e | ||
|
||
if embedder is None: | ||
try: | ||
from langchain_openai import OpenAIEmbeddings | ||
|
||
logger.debug("Loaded OpenAIEmbeddings as default LLM for brain") | ||
embedder = OpenAIEmbeddings() | ||
except ImportError as e: | ||
raise ImportError( | ||
"Please provide a valid Embedder or install quivr-core['base'] package for using the defaultone." | ||
) from e | ||
|
||
brain_id = uuid4() | ||
|
||
for path in file_paths: | ||
file = QuivrFile.from_path(brain_id, path) | ||
storage.upload_file(file) | ||
|
||
# Parse files | ||
docs = await _process_files( | ||
storage=storage, | ||
processors_mapping=processors_mapping, | ||
skip_file_error=skip_file_error, | ||
) | ||
|
||
# Building brain's vectordb | ||
if vector_db is None: | ||
try: | ||
from langchain_community.vectorstores import FAISS | ||
|
||
logger.debug("Using Faiss-CPU as vector store.") | ||
# TODO(@aminediro) : embedding call is not concurrent for all documents but waits | ||
# We can actually wait on all processing | ||
if len(docs) > 0: | ||
vector_db = await FAISS.afrom_documents( | ||
documents=docs, embedding=embedder | ||
) | ||
else: | ||
raise ValueError("can't initialize brain without documents") | ||
|
||
except ImportError as e: | ||
raise ImportError( | ||
"Please provide a valid vectore store or install quivr-core['base'] package for using the default one." | ||
) from e | ||
else: | ||
vector_db.add_documents(docs) | ||
|
||
return cls( | ||
id=brain_id, | ||
name=name, | ||
storage=storage, | ||
llm=llm, | ||
embedder=embedder, | ||
vector_db=vector_db, | ||
) | ||
|
||
@classmethod | ||
def from_files( | ||
cls, | ||
*, | ||
name: str, | ||
file_paths: list[str | Path], | ||
vector_db: VectorStore | None = None, | ||
storage: StorageBase = TransparentStorage(), | ||
llm: BaseChatModel | None = None, | ||
embedder: Embeddings | None = None, | ||
processors_mapping: Mapping[str, ProcessorBase] = DEFAULT_PARSERS, | ||
skip_file_error: bool = False, | ||
) -> Self: | ||
return asyncio.run( | ||
cls.afrom_files( | ||
name=name, | ||
file_paths=file_paths, | ||
vector_db=vector_db, | ||
storage=storage, | ||
llm=llm, | ||
embedder=embedder, | ||
processors_mapping=processors_mapping, | ||
skip_file_error=skip_file_error, | ||
) | ||
) | ||
|
||
# TODO(@aminediro) | ||
def add_file(self) -> None: | ||
# add it to storage | ||
# add it to vectorstore | ||
raise NotImplementedError | ||
|
||
def ask( | ||
self, question: str, rag_config: RAGConfig = RAGConfig() | ||
) -> ParsedRAGResponse: | ||
rag_pipeline = QuivrQARAG( | ||
rag_config=rag_config, llm=self.llm, vector_store=self.vector_db | ||
) | ||
|
||
# transformed_history = format_chat_history(history) | ||
parsed_response = rag_pipeline.answer(question, [], []) | ||
|
||
# Save answer to the chat history | ||
return parsed_response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from quivr_core.processor.processor_base import ProcessorBase | ||
from quivr_core.processor.txt_parser import TxtProcessor | ||
|
||
DEFAULT_PARSERS: dict[str, ProcessorBase] = { | ||
".txt": TxtProcessor(), | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import logging | ||
from typing import AsyncIterable | ||
|
||
import httpx | ||
from langchain_core.documents import Document | ||
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter | ||
|
||
from quivr_core.processor.processor_base import ProcessorBase | ||
from quivr_core.processor.splitter import SplitterConfig | ||
from quivr_core.storage.file import QuivrFile | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class TikaParser(ProcessorBase): | ||
supported_extensions = [".pdf"] | ||
|
||
def __init__( | ||
self, | ||
tika_url: str = "http://localhost:9998/tika", | ||
splitter: TextSplitter | None = None, | ||
splitter_config: SplitterConfig = SplitterConfig(), | ||
timeout: float = 5.0, | ||
max_retries: int = 3, | ||
) -> None: | ||
self.tika_url = tika_url | ||
self.max_retries = max_retries | ||
self._client = httpx.AsyncClient(timeout=timeout) | ||
|
||
self.splitter_config = splitter_config | ||
|
||
if splitter: | ||
self.text_splitter = splitter | ||
else: | ||
self.text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( | ||
chunk_size=splitter_config.chunk_size, | ||
chunk_overlap=splitter_config.chunk_overlap, | ||
) | ||
|
||
async def _send_parse_tika(self, f: AsyncIterable[bytes]) -> str: | ||
retry = 0 | ||
headers = {"Accept": "text/plain"} | ||
while retry < self.max_retries: | ||
try: | ||
resp = await self._client.put(self.tika_url, headers=headers, content=f) | ||
resp.raise_for_status() | ||
return resp.content.decode("utf-8") | ||
except Exception as e: | ||
retry += 1 | ||
logger.debug(f"tika url error :{e}. retrying for the {retry} time...") | ||
raise RuntimeError("can't send parse request to tika server") | ||
|
||
async def process_file(self, file: QuivrFile) -> list[Document]: | ||
assert file.file_extension in self.supported_extensions | ||
|
||
async with file.open() as f: | ||
txt = await self._send_parse_tika(f) | ||
document = Document(page_content=txt) | ||
|
||
# Use the default splitter | ||
docs = self.text_splitter.split_documents([document]) | ||
return docs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Generic, TypeVar | ||
|
||
from langchain_core.documents import Document | ||
|
||
from quivr_core.storage.file import QuivrFile | ||
|
||
|
||
class ProcessorBase(ABC): | ||
supported_extensions: list[str] | ||
|
||
@abstractmethod | ||
async def process_file(self, file: QuivrFile) -> list[Document]: | ||
pass | ||
|
||
|
||
P = TypeVar("P", bound=ProcessorBase) | ||
|
||
|
||
class ProcessorsMapping(Generic[P]): | ||
def __init__(self, mapping: dict[str, P]) -> None: | ||
# Create an empty list with items of type T | ||
self.ext_parser: dict[str, P] = mapping | ||
|
||
def add_parser(self, extension: str, parser: P): | ||
# TODO: deal with existing ext keys | ||
self.ext_parser[extension] = parser |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from pydantic import BaseModel | ||
|
||
|
||
class SplitterConfig(BaseModel): | ||
chunk_size: int = 400 | ||
chunk_overlap: int = 100 |
Oops, something went wrong.