Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chunking PR #400

Open
wants to merge 2 commits into
base: chunking-embedding-dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions ragna/chunking_models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
__all__ = [
"GenericChunkingModel",
"NLTKChunkingModel",
]

from ._generic_chunking_model import GenericChunkingModel
from ._nltk_chunking_model import NLTKChunkingModel
54 changes: 54 additions & 0 deletions ragna/chunking_models/_generic_chunking_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from ragna.core import Document, Chunk, ChunkingModel

import functools

from typing import TYPE_CHECKING, TypeVar, Iterable, Iterator, Deque

from collections import deque

if TYPE_CHECKING:
import tiktoken

T = TypeVar("T")


# The function is adapted from more_itertools.windowed to allow a ragged last window
# https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.windowed
def _windowed_ragged(
iterable: Iterable[T], *, n: int, step: int
) -> Iterator[tuple[T, ...]]:
window: Deque[T] = deque(maxlen=n)
i = n
for _ in map(window.append, iterable):
i -= 1
if not i:
i = step
yield tuple(window)

if len(window) < n:
yield tuple(window)
elif 0 < i < min(step, n):
yield tuple(window)[i:]

class GenericChunkingModel(ChunkingModel):
def chunk_documents(self, documents: list[Document], chunk_size: int = 500, chunk_overlap: int = 250) -> list[Chunk]:
chunks = []
for document in documents:
for window in _windowed_ragged(
(
(tokens, page.number)
for page in document.extract_pages()
for tokens in self.tokenizer.encode(page.text)
),
n=chunk_size,
step=chunk_size - chunk_overlap,
):
tokens, page_numbers = zip(*window)
chunks.append(Chunk(
text=self.tokenizer.decode(tokens), # type: ignore[arg-type]
document_id=document.id,
page_numbers=list(filter(lambda n: n is not None, page_numbers)) or None,
num_tokens=len(tokens),
))

return chunks
20 changes: 20 additions & 0 deletions ragna/chunking_models/_nltk_chunking_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from ragna.core import Document, Chunk, ChunkingModel

class NLTKChunkingModel(ChunkingModel):
def __init__(self):
super().__init__()

# our text splitter goes here
from langchain.text_splitter import NLTKTextSplitter
self.text_splitter = NLTKTextSplitter()

def chunk_documents(self, documents: list[Document]) -> list[Chunk]:
# This is not perfect, but it's the only way I could get this to somewhat work
chunks = []
for document in documents:
pages = list(document.extract_pages())
text = "".join([page.text for page in pages])

chunks += self.generate_chunks_from_text(self.text_splitter.split_text(text), document.id)

return chunks
2 changes: 2 additions & 0 deletions ragna/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"Assistant",
"Chat",
"Chunk",
"ChunkingModel",
"Component",
"Document",
"DocumentHandler",
Expand Down Expand Up @@ -52,6 +53,7 @@
from ._components import (
Assistant,
Component,
ChunkingModel,
Embedding,
EmbeddingModel,
Message,
Expand Down
70 changes: 0 additions & 70 deletions ragna/core/_compat.py

This file was deleted.

19 changes: 18 additions & 1 deletion ragna/core/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
AsyncIterable,
AsyncIterator,
Iterator,
Iterable,
Optional,
Type,
Union,
Expand All @@ -23,9 +24,11 @@
import pydantic
import pydantic.utils

from ._document import Chunk, Document
from ._document import Chunk, Document, Page
from ._utils import RequirementsMixin, merge_models

from uuid import UUID


class Component(RequirementsMixin):
"""Base class for RAG components.
Expand Down Expand Up @@ -92,6 +95,20 @@ def _protocol_model(cls) -> Type[pydantic.BaseModel]:
return merge_models(cls.display_name(), *cls._protocol_models().values())


class ChunkingModel(Component, ABC):
def __init__(self):
import tiktoken
self.tokenizer = tiktoken.get_encoding("cl100k_base")

@abstractmethod
def chunk_documents(self, documents: list[Document]) -> list[Chunk]:
raise NotImplementedError

def generate_chunks_from_text(self, chunks: list[str], document_id: UUID) -> list[Chunk]:
return [Chunk(page_numbers=[1], text=chunks[i], document_id=document_id,
num_tokens=len(self.tokenizer.encode(chunks[i]))) for i in range(len(chunks))]


@dataclass
class Embedding:
values: list[float]
Expand Down
52 changes: 32 additions & 20 deletions ragna/core/_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
import pydantic
from starlette.concurrency import iterate_in_threadpool, run_in_threadpool

from ._compat import chunk_pages
from ._components import (
Assistant,
Chunk,
ChunkingModel,
Component,
Embedding,
EmbeddingModel,
Expand Down Expand Up @@ -91,6 +91,7 @@ def chat(
source_storage: Union[Type[SourceStorage], SourceStorage],
assistant: Union[Type[Assistant], Assistant],
embedding_model: Optional[Union[Type[EmbeddingModel], EmbeddingModel]] = None,
chunking_model: Optional[Union[Type[ChunkingModel], ChunkingModel]] = None,
**params: Any,
) -> Chat:
"""Create a new [ragna.core.Chat][].
Expand All @@ -100,6 +101,8 @@ def chat(
[ragna.core.LocalDocument.from_path][] is invoked on it.
source_storage: Source storage to use.
assistant: Assistant to use.
embedding_model: Embedding model to use
chunking_model: Chunking model to use (Token Based, NLTK, Spacy)
**params: Additional parameters passed to the source storage and assistant.
"""
return Chat(
Expand All @@ -108,6 +111,7 @@ def chat(
source_storage=source_storage,
assistant=assistant,
embedding_model=embedding_model,
chunking_model=chunking_model,
**params,
)

Expand Down Expand Up @@ -153,6 +157,8 @@ class Chat:
[ragna.core.LocalDocument.from_path][] is invoked on it.
source_storage: Source storage to use.
assistant: Assistant to use.
embedding_model: Embedding model to use. Required for source storages that take embeddings
chunking_model: Chunking model to use. Required for source storages that take embeddings or chunks
**params: Additional parameters passed to the source storage and assistant.
"""

Expand All @@ -164,22 +170,32 @@ def __init__(
source_storage: Union[Type[SourceStorage], SourceStorage],
assistant: Union[Type[Assistant], Assistant],
embedding_model: Optional[Union[Type[EmbeddingModel], EmbeddingModel]],
chunking_model: Optional[Union[Type[ChunkingModel], ChunkingModel]],
**params: Any,
) -> None:
self._rag = rag

self.documents = self._parse_documents(documents)

if embedding_model is None and issubclass(
if (embedding_model is None or chunking_model is None) and issubclass(
source_storage.__ragna_input_type__, Embedding
):
raise RagnaException
elif embedding_model is not None:
embedding_model = cast(
EmbeddingModel, self._rag._load_component(embedding_model)
)
else:
if embedding_model is not None:
embedding_model = cast(
EmbeddingModel, self._rag._load_component(embedding_model)
)

if chunking_model is not None:
chunking_model = cast(
ChunkingModel, self._rag._load_component(chunking_model)
)

self.embedding_model = embedding_model

self.chunking_model = chunking_model

self.source_storage = cast(
SourceStorage, self._rag._load_component(source_storage)
)
Expand Down Expand Up @@ -225,20 +241,14 @@ async def prepare(self) -> Message:
detail=RagnaException.EVENT,
)

chunks = [
chunk
for document in self.documents
for chunk in chunk_pages(
document.extract_pages(),
document_id=document.id,
chunk_size=self.params["chunk_size"],
chunk_overlap=self.params["chunk_overlap"],
)
]

input: Union[list[Document], list[Embedding]] = self.documents
# I vaguely recall you mentioning 3 distinct cases, in which the source_storage may take any one of
# Document, Embedding or Chunk. I have accounted for that here
input: Union[list[Document], list[Embedding], list[Chunk]] = self.documents
if not issubclass(self.source_storage.__ragna_input_type__, Document):
input = cast(EmbeddingModel, self.embedding_model).embed_chunks(chunks)
input = cast(ChunkingModel, self.chunking_model).chunk_documents(input)
if not issubclass(self.source_storage.__ragna_input_type__, Chunk):
input = cast(EmbeddingModel, self.embedding_model).embed_chunks(input)

await self._run(self.source_storage.store, input)

self._prepared = True
Expand Down Expand Up @@ -271,7 +281,9 @@ async def answer(self, prompt: str, *, stream: bool = False) -> Message:
self._messages.append(Message(content=prompt, role=MessageRole.USER))

input: Union[str, list[float]] = prompt
if not issubclass(self.source_storage.__ragna_input_type__, Document):
# Both Chunk and Document would take a string prompt as input
if (not issubclass(self.source_storage.__ragna_input_type__, Document)
and not issubclass(self.source_storage.__ragna_input_type__, Chunk)):
input = self._embed_text(prompt)
sources = await self._run(self.source_storage.retrieve, self.documents, input)

Expand Down
Loading