Skip to content

Commit

Permalink
Vector DB CDK: Refactor to improve readability (#33255)
Browse files Browse the repository at this point in the history
Co-authored-by: flash1293 <flash1293@users.noreply.github.com>
  • Loading branch information
Joe Reuter and flash1293 committed Dec 13, 2023
1 parent c1e428f commit 55d5345
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional, Union, cast

from airbyte_cdk.destinations.vector_db_based.config import (
Expand All @@ -15,15 +16,21 @@
OpenAIEmbeddingConfigModel,
ProcessingConfigModel,
)
from airbyte_cdk.destinations.vector_db_based.document_processor import Chunk
from airbyte_cdk.destinations.vector_db_based.utils import create_chunks, format_exception
from airbyte_cdk.models import AirbyteRecordMessage
from airbyte_cdk.utils.traced_exception import AirbyteTracedException, FailureType
from langchain.embeddings.cohere import CohereEmbeddings
from langchain.embeddings.fake import FakeEmbeddings
from langchain.embeddings.localai import LocalAIEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings


@dataclass
class Document:
page_content: str
record: AirbyteRecordMessage


class Embedder(ABC):
"""
Embedder is an abstract class that defines the interface for embedding text.
Expand All @@ -41,7 +48,7 @@ def check(self) -> Optional[str]:
pass

@abstractmethod
def embed_chunks(self, chunks: List[Chunk]) -> List[Optional[List[float]]]:
def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]:
"""
Embed the text of each chunk and return the resulting embedding vectors.
If a chunk cannot be embedded or is configured to not be embedded, return None for that chunk.
Expand Down Expand Up @@ -72,17 +79,17 @@ def check(self) -> Optional[str]:
return format_exception(e)
return None

def embed_chunks(self, chunks: List[Chunk]) -> List[Optional[List[float]]]:
def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]:
"""
Embed the text of each chunk and return the resulting embedding vectors.
As the OpenAI API will fail if more than the per-minute limit worth of tokens is sent at once, we split the request into batches and embed each batch separately.
It's still possible to run into the rate limit between each embed call because the available token budget hasn't recovered between the calls,
but the built-in retry mechanism of the OpenAI client handles that.
"""
# Each chunk can hold at most self.chunk_size tokens, so tokens-per-minute by maximum tokens per chunk is the number of chunks that can be embedded at once without exhausting the limit in a single request
# Each chunk can hold at most self.chunk_size tokens, so tokens-per-minute by maximum tokens per chunk is the number of documents that can be embedded at once without exhausting the limit in a single request
embedding_batch_size = OPEN_AI_TOKEN_LIMIT // self.chunk_size
batches = create_chunks(chunks, batch_size=embedding_batch_size)
batches = create_chunks(documents, batch_size=embedding_batch_size)
embeddings: List[Optional[List[float]]] = []
for batch in batches:
embeddings.extend(self.embeddings.embed_documents([chunk.page_content for chunk in batch]))
Expand Down Expand Up @@ -121,8 +128,8 @@ def check(self) -> Optional[str]:
return format_exception(e)
return None

def embed_chunks(self, chunks: List[Chunk]) -> List[Optional[List[float]]]:
return cast(List[Optional[List[float]]], self.embeddings.embed_documents([chunk.page_content or "" for chunk in chunks]))
def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]:
return cast(List[Optional[List[float]]], self.embeddings.embed_documents([document.page_content for document in documents]))

@property
def embedding_dimensions(self) -> int:
Expand All @@ -142,8 +149,8 @@ def check(self) -> Optional[str]:
return format_exception(e)
return None

def embed_chunks(self, chunks: List[Chunk]) -> List[Optional[List[float]]]:
return cast(List[Optional[List[float]]], self.embeddings.embed_documents([chunk.page_content or "" for chunk in chunks]))
def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]:
return cast(List[Optional[List[float]]], self.embeddings.embed_documents([document.page_content for document in documents]))

@property
def embedding_dimensions(self) -> int:
Expand Down Expand Up @@ -173,8 +180,8 @@ def check(self) -> Optional[str]:
return format_exception(e)
return None

def embed_chunks(self, chunks: List[Chunk]) -> List[Optional[List[float]]]:
return cast(List[Optional[List[float]]], self.embeddings.embed_documents([chunk.page_content or "" for chunk in chunks]))
def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]:
return cast(List[Optional[List[float]]], self.embeddings.embed_documents([document.page_content for document in documents]))

@property
def embedding_dimensions(self) -> int:
Expand All @@ -190,32 +197,32 @@ def __init__(self, config: FromFieldEmbeddingConfigModel):
def check(self) -> Optional[str]:
return None

def embed_chunks(self, chunks: List[Chunk]) -> List[Optional[List[float]]]:
def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]:
"""
From each chunk, pull the embedding from the field specified in the config.
Check that the field exists, is a list of numbers and is the correct size. If not, raise an AirbyteTracedException explaining the problem.
"""
embeddings: List[Optional[List[float]]] = []
for chunk in chunks:
data = chunk.record.data
for document in documents:
data = document.record.data
if self.config.field_name not in data:
raise AirbyteTracedException(
internal_message="Embedding vector field not found",
failure_type=FailureType.config_error,
message=f"Record {str(data)[:250]}... in stream {chunk.record.stream} does not contain embedding vector field {self.config.field_name}. Please check your embedding configuration, the embedding vector field has to be set correctly on every record.",
message=f"Record {str(data)[:250]}... in stream {document.record.stream} does not contain embedding vector field {self.config.field_name}. Please check your embedding configuration, the embedding vector field has to be set correctly on every record.",
)
field = data[self.config.field_name]
if not isinstance(field, list) or not all(isinstance(x, (int, float)) for x in field):
raise AirbyteTracedException(
internal_message="Embedding vector field not a list of numbers",
failure_type=FailureType.config_error,
message=f"Record {str(data)[:250]}... in stream {chunk.record.stream} does contain embedding vector field {self.config.field_name}, but it is not a list of numbers. Please check your embedding configuration, the embedding vector field has to be a list of numbers of length {self.config.dimensions} on every record.",
message=f"Record {str(data)[:250]}... in stream {document.record.stream} does contain embedding vector field {self.config.field_name}, but it is not a list of numbers. Please check your embedding configuration, the embedding vector field has to be a list of numbers of length {self.config.dimensions} on every record.",
)
if len(field) != self.config.dimensions:
raise AirbyteTracedException(
internal_message="Embedding vector field has wrong length",
failure_type=FailureType.config_error,
message=f"Record {str(data)[:250]}... in stream {chunk.record.stream} does contain embedding vector field {self.config.field_name}, but it has length {len(field)} instead of the configured {self.config.dimensions}. Please check your embedding configuration, the embedding vector field has to be a list of numbers of length {self.config.dimensions} on every record.",
message=f"Record {str(data)[:250]}... in stream {document.record.stream} does contain embedding vector field {self.config.field_name}, but it has length {len(field)} instead of the configured {self.config.dimensions}. Please check your embedding configuration, the embedding vector field has to be a list of numbers of length {self.config.dimensions} on every record.",
)
embeddings.append(field)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,22 @@

from airbyte_cdk.destinations.vector_db_based.config import ProcessingConfigModel
from airbyte_cdk.destinations.vector_db_based.document_processor import Chunk, DocumentProcessor
from airbyte_cdk.destinations.vector_db_based.embedder import Embedder
from airbyte_cdk.destinations.vector_db_based.embedder import Document, Embedder
from airbyte_cdk.destinations.vector_db_based.indexer import Indexer
from airbyte_cdk.models import AirbyteMessage, ConfiguredAirbyteCatalog, Type


class Writer:
"""
The Writer class is orchestrating the document processor, the embedder and the indexer:
* Incoming records are passed through the document processor to generate documents
* One the configured batch size is reached, the documents are passed to the embedder to generate embeddings
* The embedder embeds the documents
* The indexer deletes old documents by the associated record id before indexing the new ones
* Incoming records are passed through the document processor to generate chunks
* One the configured batch size is reached, the chunks are passed to the embedder to generate embeddings
* The embedder embeds the chunks
* The indexer deletes old chunks by the associated record id before indexing the new ones
The destination connector is responsible to create a writer instance and pass the input messages iterable to the write method.
The batch size can be configured by the destination connector to give the freedom of either letting the user configure it or hardcoding it to a sensible value depending on the destination.
The omit_raw_text parameter can be used to omit the raw text from the documents. This can be useful if the raw text is very large and not needed for the destination.
The omit_raw_text parameter can be used to omit the raw text from the chunks. This can be useful if the raw text is very large and not needed for the destination.
"""

def __init__(
Expand All @@ -37,21 +37,29 @@ def __init__(
self._init_batch()

def _init_batch(self) -> None:
self.documents: Dict[Tuple[str, str], List[Chunk]] = defaultdict(list)
self.chunks: Dict[Tuple[str, str], List[Chunk]] = defaultdict(list)
self.ids_to_delete: Dict[Tuple[str, str], List[str]] = defaultdict(list)
self.number_of_documents = 0
self.number_of_chunks = 0

def _convert_to_document(self, chunk: Chunk) -> Document:
"""
Convert a chunk to a document for the embedder.
"""
if chunk.page_content is None:
raise ValueError("Cannot embed a chunk without page content")
return Document(page_content=chunk.page_content, record=chunk.record)

def _process_batch(self) -> None:
for (namespace, stream), ids in self.ids_to_delete.items():
self.indexer.delete(ids, namespace, stream)

for (namespace, stream), documents in self.documents.items():
embeddings = self.embedder.embed_chunks(documents)
for i, document in enumerate(documents):
for (namespace, stream), chunks in self.chunks.items():
embeddings = self.embedder.embed_documents([self._convert_to_document(chunk) for chunk in chunks])
for i, document in enumerate(chunks):
document.embedding = embeddings[i]
if self.omit_raw_text:
document.page_content = None
self.indexer.index(documents, namespace, stream)
self.indexer.index(chunks, namespace, stream)

self._init_batch()

Expand All @@ -65,12 +73,12 @@ def write(self, configured_catalog: ConfiguredAirbyteCatalog, input_messages: It
self._process_batch()
yield message
elif message.type == Type.RECORD:
record_documents, record_id_to_delete = self.processor.process(message.record)
self.documents[(message.record.namespace, message.record.stream)].extend(record_documents)
record_chunks, record_id_to_delete = self.processor.process(message.record)
self.chunks[(message.record.namespace, message.record.stream)].extend(record_chunks)
if record_id_to_delete is not None:
self.ids_to_delete[(message.record.namespace, message.record.stream)].append(record_id_to_delete)
self.number_of_documents += len(record_documents)
if self.number_of_documents >= self.batch_size:
self.number_of_chunks += len(record_chunks)
if self.number_of_chunks >= self.batch_size:
self._process_batch()

self._process_batch()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
OpenAICompatibleEmbeddingConfigModel,
OpenAIEmbeddingConfigModel,
)
from airbyte_cdk.destinations.vector_db_based.document_processor import Chunk
from airbyte_cdk.destinations.vector_db_based.embedder import (
COHERE_VECTOR_SIZE,
OPEN_AI_VECTOR_SIZE,
AzureOpenAIEmbedder,
CohereEmbedder,
Document,
FakeEmbedder,
FromFieldEmbedder,
OpenAICompatibleEmbedder,
Expand Down Expand Up @@ -82,10 +82,10 @@ def test_embedder(embedder_class, args, dimensions):
mock_embedding_instance.embed_documents.return_value = [[0] * dimensions] * 2

chunks = [
Chunk(page_content="a", metadata={}, record=AirbyteRecordMessage(stream="mystream", data={}, emitted_at=0)),
Chunk(page_content="b", metadata={}, record=AirbyteRecordMessage(stream="mystream", data={}, emitted_at=0)),
Document(page_content="a", record=AirbyteRecordMessage(stream="mystream", data={}, emitted_at=0)),
Document(page_content="b", record=AirbyteRecordMessage(stream="mystream", data={}, emitted_at=0)),
]
assert embedder.embed_chunks(chunks) == mock_embedding_instance.embed_documents.return_value
assert embedder.embed_documents(chunks) == mock_embedding_instance.embed_documents.return_value
mock_embedding_instance.embed_documents.assert_called_with(["a", "b"])


Expand All @@ -102,12 +102,12 @@ def test_embedder(embedder_class, args, dimensions):
)
def test_from_field_embedder(field_name, dimensions, metadata, expected_embedding, expected_error):
embedder = FromFieldEmbedder(FromFieldEmbeddingConfigModel(mode="from_field", dimensions=dimensions, field_name=field_name))
chunks = [Chunk(page_content="a", metadata=metadata, record=AirbyteRecordMessage(stream="mystream", data=metadata, emitted_at=0))]
chunks = [Document(page_content="a", record=AirbyteRecordMessage(stream="mystream", data=metadata, emitted_at=0))]
if expected_error:
with pytest.raises(AirbyteTracedException):
embedder.embed_chunks(chunks)
embedder.embed_documents(chunks)
else:
assert embedder.embed_chunks(chunks) == [expected_embedding]
assert embedder.embed_documents(chunks) == [expected_embedding]


def test_openai_chunking():
Expand All @@ -119,7 +119,7 @@ def test_openai_chunking():
mock_embedding_instance.embed_documents.side_effect = lambda texts: [[0] * OPEN_AI_VECTOR_SIZE] * len(texts)

chunks = [
Chunk(page_content="a", metadata={}, record=AirbyteRecordMessage(stream="mystream", data={}, emitted_at=0)) for _ in range(1005)
Document(page_content="a", record=AirbyteRecordMessage(stream="mystream", data={}, emitted_at=0)) for _ in range(1005)
]
assert embedder.embed_chunks(chunks) == [[0] * OPEN_AI_VECTOR_SIZE] * 1005
assert embedder.embed_documents(chunks) == [[0] * OPEN_AI_VECTOR_SIZE] * 1005
mock_embedding_instance.embed_documents.assert_has_calls([call(["a"] * 1000), call(["a"] * 5)])
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def generate_stream(name: str = "example_stream", namespace: Optional[str] = Non

def generate_mock_embedder():
mock_embedder = MagicMock()
mock_embedder.embed_chunks.return_value = [[0] * 1536] * (BATCH_SIZE + 5 + 5)
mock_embedder.embed_chunks.side_effect = lambda chunks: [[0] * 1536] * len(chunks)
mock_embedder.embed_documents.return_value = [[0] * 1536] * (BATCH_SIZE + 5 + 5)
mock_embedder.embed_documents.side_effect = lambda chunks: [[0] * 1536] * len(chunks)

return mock_embedder

Expand Down Expand Up @@ -88,7 +88,7 @@ def test_write(omit_raw_text: bool):
# 1 batches due to max batch size reached and 1 batch due to state message
assert mock_indexer.index.call_count == 2
assert mock_indexer.delete.call_count == 2
assert mock_embedder.embed_chunks.call_count == 2
assert mock_embedder.embed_documents.call_count == 2

if omit_raw_text:
for call_args in mock_indexer.index.call_args_list:
Expand All @@ -110,7 +110,7 @@ def test_write(omit_raw_text: bool):
# 1 batch due to end of message stream
assert mock_indexer.index.call_count == 3
assert mock_indexer.delete.call_count == 3
assert mock_embedder.embed_chunks.call_count == 3
assert mock_embedder.embed_documents.call_count == 3

mock_indexer.post_sync.assert_called()

Expand Down Expand Up @@ -169,4 +169,4 @@ def test_write_stream_namespace_split():
call(ANY, None, "example_stream2"),
]
)
assert mock_embedder.embed_chunks.call_count == 4
assert mock_embedder.embed_documents.call_count == 4

0 comments on commit 55d5345

Please sign in to comment.