From 55d5345bff4ef5987412e85c0767720025fa60e2 Mon Sep 17 00:00:00 2001 From: Joe Reuter Date: Wed, 13 Dec 2023 12:23:39 +0100 Subject: [PATCH] Vector DB CDK: Refactor to improve readability (#33255) Co-authored-by: flash1293 --- .../destinations/vector_db_based/embedder.py | 41 +++++++++++-------- .../destinations/vector_db_based/writer.py | 40 ++++++++++-------- .../vector_db_based/embedder_test.py | 18 ++++---- .../vector_db_based/writer_test.py | 10 ++--- 4 files changed, 62 insertions(+), 47 deletions(-) diff --git a/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/embedder.py b/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/embedder.py index a1f89b05648b6..7fb880fadaaea 100644 --- a/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/embedder.py +++ b/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/embedder.py @@ -4,6 +4,7 @@ import os from abc import ABC, abstractmethod +from dataclasses import dataclass from typing import List, Optional, Union, cast from airbyte_cdk.destinations.vector_db_based.config import ( @@ -15,8 +16,8 @@ OpenAIEmbeddingConfigModel, ProcessingConfigModel, ) -from airbyte_cdk.destinations.vector_db_based.document_processor import Chunk from airbyte_cdk.destinations.vector_db_based.utils import create_chunks, format_exception +from airbyte_cdk.models import AirbyteRecordMessage from airbyte_cdk.utils.traced_exception import AirbyteTracedException, FailureType from langchain.embeddings.cohere import CohereEmbeddings from langchain.embeddings.fake import FakeEmbeddings @@ -24,6 +25,12 @@ from langchain.embeddings.openai import OpenAIEmbeddings +@dataclass +class Document: + page_content: str + record: AirbyteRecordMessage + + class Embedder(ABC): """ Embedder is an abstract class that defines the interface for embedding text. @@ -41,7 +48,7 @@ def check(self) -> Optional[str]: pass @abstractmethod - def embed_chunks(self, chunks: List[Chunk]) -> List[Optional[List[float]]]: + def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]: """ Embed the text of each chunk and return the resulting embedding vectors. If a chunk cannot be embedded or is configured to not be embedded, return None for that chunk. @@ -72,7 +79,7 @@ def check(self) -> Optional[str]: return format_exception(e) return None - def embed_chunks(self, chunks: List[Chunk]) -> List[Optional[List[float]]]: + def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]: """ Embed the text of each chunk and return the resulting embedding vectors. @@ -80,9 +87,9 @@ def embed_chunks(self, chunks: List[Chunk]) -> List[Optional[List[float]]]: It's still possible to run into the rate limit between each embed call because the available token budget hasn't recovered between the calls, but the built-in retry mechanism of the OpenAI client handles that. """ - # Each chunk can hold at most self.chunk_size tokens, so tokens-per-minute by maximum tokens per chunk is the number of chunks that can be embedded at once without exhausting the limit in a single request + # Each chunk can hold at most self.chunk_size tokens, so tokens-per-minute by maximum tokens per chunk is the number of documents that can be embedded at once without exhausting the limit in a single request embedding_batch_size = OPEN_AI_TOKEN_LIMIT // self.chunk_size - batches = create_chunks(chunks, batch_size=embedding_batch_size) + batches = create_chunks(documents, batch_size=embedding_batch_size) embeddings: List[Optional[List[float]]] = [] for batch in batches: embeddings.extend(self.embeddings.embed_documents([chunk.page_content for chunk in batch])) @@ -121,8 +128,8 @@ def check(self) -> Optional[str]: return format_exception(e) return None - def embed_chunks(self, chunks: List[Chunk]) -> List[Optional[List[float]]]: - return cast(List[Optional[List[float]]], self.embeddings.embed_documents([chunk.page_content or "" for chunk in chunks])) + def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]: + return cast(List[Optional[List[float]]], self.embeddings.embed_documents([document.page_content for document in documents])) @property def embedding_dimensions(self) -> int: @@ -142,8 +149,8 @@ def check(self) -> Optional[str]: return format_exception(e) return None - def embed_chunks(self, chunks: List[Chunk]) -> List[Optional[List[float]]]: - return cast(List[Optional[List[float]]], self.embeddings.embed_documents([chunk.page_content or "" for chunk in chunks])) + def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]: + return cast(List[Optional[List[float]]], self.embeddings.embed_documents([document.page_content for document in documents])) @property def embedding_dimensions(self) -> int: @@ -173,8 +180,8 @@ def check(self) -> Optional[str]: return format_exception(e) return None - def embed_chunks(self, chunks: List[Chunk]) -> List[Optional[List[float]]]: - return cast(List[Optional[List[float]]], self.embeddings.embed_documents([chunk.page_content or "" for chunk in chunks])) + def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]: + return cast(List[Optional[List[float]]], self.embeddings.embed_documents([document.page_content for document in documents])) @property def embedding_dimensions(self) -> int: @@ -190,32 +197,32 @@ def __init__(self, config: FromFieldEmbeddingConfigModel): def check(self) -> Optional[str]: return None - def embed_chunks(self, chunks: List[Chunk]) -> List[Optional[List[float]]]: + def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]: """ From each chunk, pull the embedding from the field specified in the config. Check that the field exists, is a list of numbers and is the correct size. If not, raise an AirbyteTracedException explaining the problem. """ embeddings: List[Optional[List[float]]] = [] - for chunk in chunks: - data = chunk.record.data + for document in documents: + data = document.record.data if self.config.field_name not in data: raise AirbyteTracedException( internal_message="Embedding vector field not found", failure_type=FailureType.config_error, - message=f"Record {str(data)[:250]}... in stream {chunk.record.stream} does not contain embedding vector field {self.config.field_name}. Please check your embedding configuration, the embedding vector field has to be set correctly on every record.", + message=f"Record {str(data)[:250]}... in stream {document.record.stream} does not contain embedding vector field {self.config.field_name}. Please check your embedding configuration, the embedding vector field has to be set correctly on every record.", ) field = data[self.config.field_name] if not isinstance(field, list) or not all(isinstance(x, (int, float)) for x in field): raise AirbyteTracedException( internal_message="Embedding vector field not a list of numbers", failure_type=FailureType.config_error, - message=f"Record {str(data)[:250]}... in stream {chunk.record.stream} does contain embedding vector field {self.config.field_name}, but it is not a list of numbers. Please check your embedding configuration, the embedding vector field has to be a list of numbers of length {self.config.dimensions} on every record.", + message=f"Record {str(data)[:250]}... in stream {document.record.stream} does contain embedding vector field {self.config.field_name}, but it is not a list of numbers. Please check your embedding configuration, the embedding vector field has to be a list of numbers of length {self.config.dimensions} on every record.", ) if len(field) != self.config.dimensions: raise AirbyteTracedException( internal_message="Embedding vector field has wrong length", failure_type=FailureType.config_error, - message=f"Record {str(data)[:250]}... in stream {chunk.record.stream} does contain embedding vector field {self.config.field_name}, but it has length {len(field)} instead of the configured {self.config.dimensions}. Please check your embedding configuration, the embedding vector field has to be a list of numbers of length {self.config.dimensions} on every record.", + message=f"Record {str(data)[:250]}... in stream {document.record.stream} does contain embedding vector field {self.config.field_name}, but it has length {len(field)} instead of the configured {self.config.dimensions}. Please check your embedding configuration, the embedding vector field has to be a list of numbers of length {self.config.dimensions} on every record.", ) embeddings.append(field) diff --git a/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/writer.py b/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/writer.py index e8d58abb4ad6a..0f764c366b54b 100644 --- a/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/writer.py +++ b/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/writer.py @@ -8,7 +8,7 @@ from airbyte_cdk.destinations.vector_db_based.config import ProcessingConfigModel from airbyte_cdk.destinations.vector_db_based.document_processor import Chunk, DocumentProcessor -from airbyte_cdk.destinations.vector_db_based.embedder import Embedder +from airbyte_cdk.destinations.vector_db_based.embedder import Document, Embedder from airbyte_cdk.destinations.vector_db_based.indexer import Indexer from airbyte_cdk.models import AirbyteMessage, ConfiguredAirbyteCatalog, Type @@ -16,14 +16,14 @@ class Writer: """ The Writer class is orchestrating the document processor, the embedder and the indexer: - * Incoming records are passed through the document processor to generate documents - * One the configured batch size is reached, the documents are passed to the embedder to generate embeddings - * The embedder embeds the documents - * The indexer deletes old documents by the associated record id before indexing the new ones + * Incoming records are passed through the document processor to generate chunks + * One the configured batch size is reached, the chunks are passed to the embedder to generate embeddings + * The embedder embeds the chunks + * The indexer deletes old chunks by the associated record id before indexing the new ones The destination connector is responsible to create a writer instance and pass the input messages iterable to the write method. The batch size can be configured by the destination connector to give the freedom of either letting the user configure it or hardcoding it to a sensible value depending on the destination. - The omit_raw_text parameter can be used to omit the raw text from the documents. This can be useful if the raw text is very large and not needed for the destination. + The omit_raw_text parameter can be used to omit the raw text from the chunks. This can be useful if the raw text is very large and not needed for the destination. """ def __init__( @@ -37,21 +37,29 @@ def __init__( self._init_batch() def _init_batch(self) -> None: - self.documents: Dict[Tuple[str, str], List[Chunk]] = defaultdict(list) + self.chunks: Dict[Tuple[str, str], List[Chunk]] = defaultdict(list) self.ids_to_delete: Dict[Tuple[str, str], List[str]] = defaultdict(list) - self.number_of_documents = 0 + self.number_of_chunks = 0 + + def _convert_to_document(self, chunk: Chunk) -> Document: + """ + Convert a chunk to a document for the embedder. + """ + if chunk.page_content is None: + raise ValueError("Cannot embed a chunk without page content") + return Document(page_content=chunk.page_content, record=chunk.record) def _process_batch(self) -> None: for (namespace, stream), ids in self.ids_to_delete.items(): self.indexer.delete(ids, namespace, stream) - for (namespace, stream), documents in self.documents.items(): - embeddings = self.embedder.embed_chunks(documents) - for i, document in enumerate(documents): + for (namespace, stream), chunks in self.chunks.items(): + embeddings = self.embedder.embed_documents([self._convert_to_document(chunk) for chunk in chunks]) + for i, document in enumerate(chunks): document.embedding = embeddings[i] if self.omit_raw_text: document.page_content = None - self.indexer.index(documents, namespace, stream) + self.indexer.index(chunks, namespace, stream) self._init_batch() @@ -65,12 +73,12 @@ def write(self, configured_catalog: ConfiguredAirbyteCatalog, input_messages: It self._process_batch() yield message elif message.type == Type.RECORD: - record_documents, record_id_to_delete = self.processor.process(message.record) - self.documents[(message.record.namespace, message.record.stream)].extend(record_documents) + record_chunks, record_id_to_delete = self.processor.process(message.record) + self.chunks[(message.record.namespace, message.record.stream)].extend(record_chunks) if record_id_to_delete is not None: self.ids_to_delete[(message.record.namespace, message.record.stream)].append(record_id_to_delete) - self.number_of_documents += len(record_documents) - if self.number_of_documents >= self.batch_size: + self.number_of_chunks += len(record_chunks) + if self.number_of_chunks >= self.batch_size: self._process_batch() self._process_batch() diff --git a/airbyte-cdk/python/unit_tests/destinations/vector_db_based/embedder_test.py b/airbyte-cdk/python/unit_tests/destinations/vector_db_based/embedder_test.py index 088b4f85fb8ea..3cf8e4114e5bf 100644 --- a/airbyte-cdk/python/unit_tests/destinations/vector_db_based/embedder_test.py +++ b/airbyte-cdk/python/unit_tests/destinations/vector_db_based/embedder_test.py @@ -13,12 +13,12 @@ OpenAICompatibleEmbeddingConfigModel, OpenAIEmbeddingConfigModel, ) -from airbyte_cdk.destinations.vector_db_based.document_processor import Chunk from airbyte_cdk.destinations.vector_db_based.embedder import ( COHERE_VECTOR_SIZE, OPEN_AI_VECTOR_SIZE, AzureOpenAIEmbedder, CohereEmbedder, + Document, FakeEmbedder, FromFieldEmbedder, OpenAICompatibleEmbedder, @@ -82,10 +82,10 @@ def test_embedder(embedder_class, args, dimensions): mock_embedding_instance.embed_documents.return_value = [[0] * dimensions] * 2 chunks = [ - Chunk(page_content="a", metadata={}, record=AirbyteRecordMessage(stream="mystream", data={}, emitted_at=0)), - Chunk(page_content="b", metadata={}, record=AirbyteRecordMessage(stream="mystream", data={}, emitted_at=0)), + Document(page_content="a", record=AirbyteRecordMessage(stream="mystream", data={}, emitted_at=0)), + Document(page_content="b", record=AirbyteRecordMessage(stream="mystream", data={}, emitted_at=0)), ] - assert embedder.embed_chunks(chunks) == mock_embedding_instance.embed_documents.return_value + assert embedder.embed_documents(chunks) == mock_embedding_instance.embed_documents.return_value mock_embedding_instance.embed_documents.assert_called_with(["a", "b"]) @@ -102,12 +102,12 @@ def test_embedder(embedder_class, args, dimensions): ) def test_from_field_embedder(field_name, dimensions, metadata, expected_embedding, expected_error): embedder = FromFieldEmbedder(FromFieldEmbeddingConfigModel(mode="from_field", dimensions=dimensions, field_name=field_name)) - chunks = [Chunk(page_content="a", metadata=metadata, record=AirbyteRecordMessage(stream="mystream", data=metadata, emitted_at=0))] + chunks = [Document(page_content="a", record=AirbyteRecordMessage(stream="mystream", data=metadata, emitted_at=0))] if expected_error: with pytest.raises(AirbyteTracedException): - embedder.embed_chunks(chunks) + embedder.embed_documents(chunks) else: - assert embedder.embed_chunks(chunks) == [expected_embedding] + assert embedder.embed_documents(chunks) == [expected_embedding] def test_openai_chunking(): @@ -119,7 +119,7 @@ def test_openai_chunking(): mock_embedding_instance.embed_documents.side_effect = lambda texts: [[0] * OPEN_AI_VECTOR_SIZE] * len(texts) chunks = [ - Chunk(page_content="a", metadata={}, record=AirbyteRecordMessage(stream="mystream", data={}, emitted_at=0)) for _ in range(1005) + Document(page_content="a", record=AirbyteRecordMessage(stream="mystream", data={}, emitted_at=0)) for _ in range(1005) ] - assert embedder.embed_chunks(chunks) == [[0] * OPEN_AI_VECTOR_SIZE] * 1005 + assert embedder.embed_documents(chunks) == [[0] * OPEN_AI_VECTOR_SIZE] * 1005 mock_embedding_instance.embed_documents.assert_has_calls([call(["a"] * 1000), call(["a"] * 5)]) diff --git a/airbyte-cdk/python/unit_tests/destinations/vector_db_based/writer_test.py b/airbyte-cdk/python/unit_tests/destinations/vector_db_based/writer_test.py index dff570d6e698e..c906d0f3e9b57 100644 --- a/airbyte-cdk/python/unit_tests/destinations/vector_db_based/writer_test.py +++ b/airbyte-cdk/python/unit_tests/destinations/vector_db_based/writer_test.py @@ -48,8 +48,8 @@ def generate_stream(name: str = "example_stream", namespace: Optional[str] = Non def generate_mock_embedder(): mock_embedder = MagicMock() - mock_embedder.embed_chunks.return_value = [[0] * 1536] * (BATCH_SIZE + 5 + 5) - mock_embedder.embed_chunks.side_effect = lambda chunks: [[0] * 1536] * len(chunks) + mock_embedder.embed_documents.return_value = [[0] * 1536] * (BATCH_SIZE + 5 + 5) + mock_embedder.embed_documents.side_effect = lambda chunks: [[0] * 1536] * len(chunks) return mock_embedder @@ -88,7 +88,7 @@ def test_write(omit_raw_text: bool): # 1 batches due to max batch size reached and 1 batch due to state message assert mock_indexer.index.call_count == 2 assert mock_indexer.delete.call_count == 2 - assert mock_embedder.embed_chunks.call_count == 2 + assert mock_embedder.embed_documents.call_count == 2 if omit_raw_text: for call_args in mock_indexer.index.call_args_list: @@ -110,7 +110,7 @@ def test_write(omit_raw_text: bool): # 1 batch due to end of message stream assert mock_indexer.index.call_count == 3 assert mock_indexer.delete.call_count == 3 - assert mock_embedder.embed_chunks.call_count == 3 + assert mock_embedder.embed_documents.call_count == 3 mock_indexer.post_sync.assert_called() @@ -169,4 +169,4 @@ def test_write_stream_namespace_split(): call(ANY, None, "example_stream2"), ] ) - assert mock_embedder.embed_chunks.call_count == 4 + assert mock_embedder.embed_documents.call_count == 4