diff --git a/airbyte-lib/airbyte_lib/_file_writers/base.py b/airbyte-lib/airbyte_lib/_file_writers/base.py index 3f16953f12f54..ece3bb5512eff 100644 --- a/airbyte-lib/airbyte_lib/_file_writers/base.py +++ b/airbyte-lib/airbyte_lib/_file_writers/base.py @@ -51,7 +51,7 @@ def _write_batch( self, stream_name: str, batch_id: str, - record_batch: pa.Table | pa.RecordBatch, + record_batch: pa.Table, ) -> FileWriterBatchHandle: """Process a record batch. @@ -64,7 +64,7 @@ def write_batch( self, stream_name: str, batch_id: str, - record_batch: pa.Table | pa.RecordBatch, + record_batch: pa.Table, ) -> FileWriterBatchHandle: """Write a batch of records to the cache. diff --git a/airbyte-lib/airbyte_lib/_file_writers/parquet.py b/airbyte-lib/airbyte_lib/_file_writers/parquet.py index aeb2113f2a285..fbe02776eb4e0 100644 --- a/airbyte-lib/airbyte_lib/_file_writers/parquet.py +++ b/airbyte-lib/airbyte_lib/_file_writers/parquet.py @@ -11,7 +11,12 @@ from overrides import overrides from pyarrow import parquet -from .base import FileWriterBase, FileWriterBatchHandle, FileWriterConfigBase +from airbyte_lib import exceptions as exc +from airbyte_lib._file_writers.base import ( + FileWriterBase, + FileWriterBatchHandle, + FileWriterConfigBase, +) class ParquetWriterConfig(FileWriterConfigBase): @@ -37,12 +42,24 @@ def get_new_cache_file_path( target_dir.mkdir(parents=True, exist_ok=True) return target_dir / f"{stream_name}_{batch_id}.parquet" + def _get_missing_columns( + self, + stream_name: str, + record_batch: pa.Table, + ) -> list[str]: + """Return a list of columns that are missing in the batch.""" + if not self._catalog_manager: + raise exc.AirbyteLibInternalError(message="Catalog manager should exist but does not.") + stream = self._catalog_manager.get_stream_config(stream_name) + stream_property_names = stream.stream.json_schema["properties"].keys() + return [col for col in stream_property_names if col not in record_batch.schema.names] + @overrides def _write_batch( self, stream_name: str, batch_id: str, - record_batch: pa.Table | pa.RecordBatch, + record_batch: pa.Table, ) -> FileWriterBatchHandle: """Process a record batch. @@ -51,8 +68,15 @@ def _write_batch( _ = batch_id # unused output_file_path = self.get_new_cache_file_path(stream_name) - with parquet.ParquetWriter(output_file_path, record_batch.schema) as writer: - writer.write_table(cast(pa.Table, record_batch)) + missing_columns = self._get_missing_columns(stream_name, record_batch) + if missing_columns: + # We need to append columns with the missing column name(s) and a null type + null_array = cast(pa.Array, pa.array([None] * len(record_batch), type=pa.null())) + for col in missing_columns: + record_batch = record_batch.append_column(col, null_array) + + with parquet.ParquetWriter(output_file_path, schema=record_batch.schema) as writer: + writer.write_table(record_batch) batch_handle = FileWriterBatchHandle() batch_handle.files.append(output_file_path) diff --git a/airbyte-lib/airbyte_lib/_processors.py b/airbyte-lib/airbyte_lib/_processors.py index 45f18ba76b3ab..a483b13459499 100644 --- a/airbyte-lib/airbyte_lib/_processors.py +++ b/airbyte-lib/airbyte_lib/_processors.py @@ -27,9 +27,11 @@ AirbyteStateType, AirbyteStreamState, ConfiguredAirbyteCatalog, + ConfiguredAirbyteStream, Type, ) +from airbyte_lib import exceptions as exc from airbyte_lib._util import protocol_util # Internal utility functions from airbyte_lib.progress import progress @@ -37,6 +39,7 @@ if TYPE_CHECKING: from collections.abc import Generator, Iterable, Iterator + from airbyte_lib.caches._catalog_manager import CatalogManager from airbyte_lib.config import CacheConfigBase @@ -60,6 +63,8 @@ class RecordProcessor(abc.ABC): def __init__( self, config: CacheConfigBase | dict | None, + *, + catalog_manager: CatalogManager | None = None, ) -> None: if isinstance(config, dict): config = self.config_class(**config) @@ -72,8 +77,6 @@ def __init__( ) raise TypeError(err_msg) - self.source_catalog: ConfiguredAirbyteCatalog | None = None - self._pending_batches: dict[str, dict[str, Any]] = defaultdict(lambda: {}, {}) self._finalized_batches: dict[str, dict[str, Any]] = defaultdict(lambda: {}, {}) @@ -83,22 +86,25 @@ def __init__( list[AirbyteStateMessage], ] = defaultdict(list, {}) + self._catalog_manager: CatalogManager | None = catalog_manager self._setup() def register_source( self, source_name: str, incoming_source_catalog: ConfiguredAirbyteCatalog, + stream_names: set[str], ) -> None: - """Register the source name and catalog. - - For now, only one source at a time is supported. - If this method is called multiple times, the last call will overwrite the previous one. - - TODO: Expand this to handle multiple sources. - """ - _ = source_name - self.source_catalog = incoming_source_catalog + """Register the source name and catalog.""" + if not self._catalog_manager: + raise exc.AirbyteLibInternalError( + message="Catalog manager should exist but does not.", + ) + self._catalog_manager.register_source( + source_name, + incoming_source_catalog=incoming_source_catalog, + incoming_stream_names=stream_names, + ) @property def _streams_with_data(self) -> set[str]: @@ -215,7 +221,7 @@ def _write_batch( self, stream_name: str, batch_id: str, - record_batch: pa.Table | pa.RecordBatch, + record_batch: pa.Table, ) -> BatchHandle: """Process a single batch. @@ -319,3 +325,24 @@ def _teardown(self) -> None: def __del__(self) -> None: """Teardown temporary resources when instance is unloaded from memory.""" self._teardown() + + @final + def _get_stream_config( + self, + stream_name: str, + ) -> ConfiguredAirbyteStream: + """Return the column definitions for the given stream.""" + if not self._catalog_manager: + raise exc.AirbyteLibInternalError( + message="Catalog manager should exist but does not.", + ) + + return self._catalog_manager.get_stream_config(stream_name) + + @final + def _get_stream_json_schema( + self, + stream_name: str, + ) -> dict[str, Any]: + """Return the column definitions for the given stream.""" + return self._get_stream_config(stream_name).stream.json_schema diff --git a/airbyte-lib/airbyte_lib/caches/_catalog_manager.py b/airbyte-lib/airbyte_lib/caches/_catalog_manager.py index 1a0b322f1a290..978c268af87f3 100644 --- a/airbyte-lib/airbyte_lib/caches/_catalog_manager.py +++ b/airbyte-lib/airbyte_lib/caches/_catalog_manager.py @@ -46,8 +46,23 @@ def __init__( ) -> None: self._engine: Engine = engine self._table_name_resolver = table_name_resolver - self.source_catalog: ConfiguredAirbyteCatalog | None = None + self._source_catalog: ConfiguredAirbyteCatalog | None = None self._load_catalog_from_internal_table() + assert self._source_catalog is not None + + @property + def source_catalog(self) -> ConfiguredAirbyteCatalog: + """Return the source catalog. + + Raises: + AirbyteLibInternalError: If the source catalog is not set. + """ + if not self._source_catalog: + raise exc.AirbyteLibInternalError( + message="Source catalog should be initialized but is not.", + ) + + return self._source_catalog def _ensure_internal_tables(self) -> None: engine = self._engine @@ -57,34 +72,70 @@ def register_source( self, source_name: str, incoming_source_catalog: ConfiguredAirbyteCatalog, + incoming_stream_names: set[str], ) -> None: - if not self.source_catalog: - self.source_catalog = incoming_source_catalog - else: - # merge in the new streams, keyed by name - new_streams = {stream.stream.name: stream for stream in incoming_source_catalog.streams} - for stream in self.source_catalog.streams: - if stream.stream.name not in new_streams: - new_streams[stream.stream.name] = stream - self.source_catalog = ConfiguredAirbyteCatalog( - streams=list(new_streams.values()), + """Register a source and its streams in the cache.""" + self._update_catalog( + incoming_source_catalog=incoming_source_catalog, + incoming_stream_names=incoming_stream_names, + ) + self._save_catalog_to_internal_table( + source_name=source_name, + incoming_source_catalog=incoming_source_catalog, + incoming_stream_names=incoming_stream_names, + ) + + def _update_catalog( + self, + incoming_source_catalog: ConfiguredAirbyteCatalog, + incoming_stream_names: set[str], + ) -> None: + if not self._source_catalog: + self._source_catalog = ConfiguredAirbyteCatalog( + streams=[ + stream + for stream in incoming_source_catalog.streams + if stream.stream.name in incoming_stream_names + ], ) + assert len(self._source_catalog.streams) == len(incoming_stream_names) + return + + # Keep existing streams untouched if not incoming + unchanged_streams: list[ConfiguredAirbyteStream] = [ + stream + for stream in self._source_catalog.streams + if stream.stream.name not in incoming_stream_names + ] + new_streams: list[ConfiguredAirbyteStream] = [ + stream + for stream in incoming_source_catalog.streams + if stream.stream.name in incoming_stream_names + ] + self._source_catalog = ConfiguredAirbyteCatalog(streams=unchanged_streams + new_streams) + def _save_catalog_to_internal_table( + self, + source_name: str, + incoming_source_catalog: ConfiguredAirbyteCatalog, + incoming_stream_names: set[str], + ) -> None: self._ensure_internal_tables() engine = self._engine with Session(engine) as session: - # delete all existing streams from the db - session.query(CachedStream).filter( - CachedStream.table_name.in_( - [ - self._table_name_resolver(stream.stream.name) - for stream in self.source_catalog.streams - ] - ) - ).delete() + # Delete and replace existing stream entries from the catalog cache + table_name_entries_to_delete = [ + self._table_name_resolver(incoming_stream_name) + for incoming_stream_name in incoming_stream_names + ] + result = ( + session.query(CachedStream) + .filter(CachedStream.table_name.in_(table_name_entries_to_delete)) + .delete() + ) + _ = result session.commit() - # add the new ones - streams = [ + insert_streams = [ CachedStream( source_name=source_name, stream_name=stream.stream.name, @@ -93,8 +144,7 @@ def register_source( ) for stream in incoming_source_catalog.streams ] - session.add_all(streams) - + session.add_all(insert_streams) session.commit() def get_stream_config( @@ -113,6 +163,11 @@ def get_stream_config( if not matching_streams: raise exc.AirbyteStreamNotFoundError( stream_name=stream_name, + context={ + "available_streams": [ + stream.stream.name for stream in self.source_catalog.streams + ], + }, ) if len(matching_streams) > 1: @@ -133,10 +188,13 @@ def _load_catalog_from_internal_table(self) -> None: streams: list[CachedStream] = session.query(CachedStream).all() if not streams: # no streams means the cache is pristine + if not self._source_catalog: + self._source_catalog = ConfiguredAirbyteCatalog(streams=[]) + return # load the catalog - self.source_catalog = ConfiguredAirbyteCatalog( + self._source_catalog = ConfiguredAirbyteCatalog( streams=[ ConfiguredAirbyteStream( stream=AirbyteStream( diff --git a/airbyte-lib/airbyte_lib/caches/base.py b/airbyte-lib/airbyte_lib/caches/base.py index 3d40ecf69fd1d..d58af67e55ab4 100644 --- a/airbyte-lib/airbyte_lib/caches/base.py +++ b/airbyte-lib/airbyte_lib/caches/base.py @@ -7,7 +7,7 @@ import enum from contextlib import contextmanager from functools import cached_property -from typing import TYPE_CHECKING, Any, cast, final +from typing import TYPE_CHECKING, cast, final import pandas as pd import pyarrow as pa @@ -38,7 +38,6 @@ from airbyte_protocol.models import ( ConfiguredAirbyteCatalog, - ConfiguredAirbyteStream, ) from airbyte_lib.datasets._base import DatasetBase @@ -113,18 +112,19 @@ def __init__( self, config: SQLCacheConfigBase | None = None, file_writer: FileWriterBase | None = None, - **kwargs: dict[str, Any], # Added for future proofing purposes. ) -> None: self.config: SQLCacheConfigBase self._engine: Engine | None = None self._connection_to_reuse: Connection | None = None - super().__init__(config, **kwargs) + super().__init__(config) self._ensure_schema_exists() - self._catalog_manager: CatalogManager = CatalogManager( - self.get_sql_engine(), lambda stream_name: self.get_sql_table_name(stream_name) + self._catalog_manager = CatalogManager( + engine=self.get_sql_engine(), + table_name_resolver=lambda stream_name: self.get_sql_table_name(stream_name), + ) + self.file_writer = file_writer or self.file_writer_class( + config, catalog_manager=self._catalog_manager ) - - self.file_writer = file_writer or self.file_writer_class(config) self.type_converter = self.type_converter_class() def __getitem__(self, stream: str) -> DatasetBase: @@ -447,28 +447,12 @@ def _get_sql_column_definitions( # columns["_airbyte_loaded_at"] = sqlalchemy.TIMESTAMP() return columns - @final - def _get_stream_config( - self, - stream_name: str, - ) -> ConfiguredAirbyteStream: - """Return the column definitions for the given stream.""" - return self._catalog_manager.get_stream_config(stream_name) - - @final - def _get_stream_json_schema( - self, - stream_name: str, - ) -> dict[str, Any]: - """Return the column definitions for the given stream.""" - return self._get_stream_config(stream_name).stream.json_schema - @overrides def _write_batch( self, stream_name: str, batch_id: str, - record_batch: pa.Table | pa.RecordBatch, + record_batch: pa.Table, ) -> FileWriterBatchHandle: """Process a record batch. @@ -756,15 +740,27 @@ def register_source( self, source_name: str, incoming_source_catalog: ConfiguredAirbyteCatalog, + stream_names: set[str], ) -> None: + """Register the source with the cache. + + We use stream_names to determine which streams will receive data, and + we only register the stream if is expected to receive data. + + This method is called by the source when it is initialized. + """ self._ensure_schema_exists() - self._catalog_manager.register_source(source_name, incoming_source_catalog) + super().register_source( + source_name, + incoming_source_catalog, + stream_names=stream_names, + ) @property @overrides def _streams_with_data(self) -> set[str]: """Return a list of known streams.""" - if not self._catalog_manager.source_catalog: + if not self._catalog_manager: raise exc.AirbyteLibInternalError( message="Cannot get streams with data without a catalog.", ) diff --git a/airbyte-lib/airbyte_lib/source.py b/airbyte-lib/airbyte_lib/source.py index 274c0c2652618..3cfaaf649b11d 100644 --- a/airbyte-lib/airbyte_lib/source.py +++ b/airbyte-lib/airbyte_lib/source.py @@ -110,6 +110,16 @@ def set_streams(self, streams: list[str]) -> None: ) self._selected_stream_names = streams + def get_selected_streams(self) -> list[str]: + """Get the selected streams. + + If no streams are selected, return all available streams. + """ + if self._selected_stream_names: + return self._selected_stream_names + + return self.get_available_streams() + def set_config( self, config: dict[str, Any], @@ -274,8 +284,20 @@ def get_records(self, stream: str) -> LazyDataset: }, ) from KeyError(stream) - iterator: Iterator[dict[str, Any]] = protocol_util.airbyte_messages_to_record_dicts( - self._read_with_catalog(streaming_cache_info, configured_catalog), + configured_stream = configured_catalog.streams[0] + col_list = configured_stream.stream.json_schema["properties"].keys() + + def _with_missing_columns(records: Iterable[dict[str, Any]]) -> Iterator[dict[str, Any]]: + """Add missing columns to the record with null values.""" + for record in records: + appended_columns = set(col_list) - set(record.keys()) + appended_dict = {col: None for col in appended_columns} + yield {**record, **appended_dict} + + iterator: Iterator[dict[str, Any]] = _with_missing_columns( + protocol_util.airbyte_messages_to_record_dicts( + self._read_with_catalog(streaming_cache_info, configured_catalog), + ) ) return LazyDataset(iterator) @@ -423,7 +445,9 @@ def read(self, cache: SQLCacheBase | None = None) -> ReadResult: cache = get_default_cache() cache.register_source( - source_name=self.name, incoming_source_catalog=self.configured_catalog + source_name=self.name, + incoming_source_catalog=self.configured_catalog, + stream_names=set(self.get_selected_streams()), ) cache.process_airbyte_messages(self._tally_records(self._read(cache.get_telemetry_info()))) diff --git a/airbyte-lib/docs/generated/airbyte_lib.html b/airbyte-lib/docs/generated/airbyte_lib.html index 59867359c01e8..93c5dc21a6514 100644 --- a/airbyte-lib/docs/generated/airbyte_lib.html +++ b/airbyte-lib/docs/generated/airbyte_lib.html @@ -131,7 +131,6 @@
Inherited Members
airbyte_lib._processors.RecordProcessor
skip_finalize_step
-
source_catalog
process_stdin
process_input_stream
process_airbyte_messages
@@ -482,6 +481,23 @@
Inherited Members
+ +
+
+ + def + get_selected_streams(self) -> list[str]: + + +
+ + +

Get the selected streams.

+ +

If no streams are selected, return all available streams.

+
+ +
diff --git a/airbyte-lib/docs/generated/airbyte_lib/caches.html b/airbyte-lib/docs/generated/airbyte_lib/caches.html index b39a5230e6582..45ec953c35713 100644 --- a/airbyte-lib/docs/generated/airbyte_lib/caches.html +++ b/airbyte-lib/docs/generated/airbyte_lib/caches.html @@ -60,7 +60,6 @@
Inherited Members
airbyte_lib._processors.RecordProcessor
skip_finalize_step
-
source_catalog
process_stdin
process_input_stream
process_airbyte_messages
@@ -274,7 +273,6 @@
Inherited Members
airbyte_lib._processors.RecordProcessor
skip_finalize_step
-
source_catalog
process_stdin
process_input_stream
process_airbyte_messages
@@ -668,18 +666,18 @@
Inherited Members
@overrides
def - register_source( self, source_name: str, incoming_source_catalog: airbyte_protocol.models.airbyte_protocol.ConfiguredAirbyteCatalog) -> None: + register_source( self, source_name: str, incoming_source_catalog: airbyte_protocol.models.airbyte_protocol.ConfiguredAirbyteCatalog, stream_names: set[str]) -> None:
-

Register the source name and catalog.

+

Register the source with the cache.

-

For now, only one source at a time is supported. -If this method is called multiple times, the last call will overwrite the previous one.

+

We use stream_names to determine which streams will receive data, and +we only register the stream if is expected to receive data.

-

TODO: Expand this to handle multiple sources.

+

This method is called by the source when it is initialized.

@@ -703,7 +701,6 @@
Inherited Members
airbyte_lib._processors.RecordProcessor
skip_finalize_step
-
source_catalog
process_stdin
process_input_stream
process_airbyte_messages
@@ -949,7 +946,6 @@
Inherited Members
airbyte_lib._processors.RecordProcessor
skip_finalize_step
-
source_catalog
process_stdin
process_input_stream
process_airbyte_messages
diff --git a/airbyte-lib/tests/integration_tests/fixtures/source-test/source_test/run.py b/airbyte-lib/tests/integration_tests/fixtures/source-test/source_test/run.py index b200e4a84f109..d17502ba4c42f 100644 --- a/airbyte-lib/tests/integration_tests/fixtures/source-test/source_test/run.py +++ b/airbyte-lib/tests/integration_tests/fixtures/source-test/source_test/run.py @@ -31,6 +31,7 @@ "properties": { "column1": {"type": "string"}, "column2": {"type": "number"}, + "empty_column": {"type": "string"}, }, }, }, diff --git a/airbyte-lib/tests/integration_tests/test_integration.py b/airbyte-lib/tests/integration_tests/test_integration.py index a122df84899e3..8b45a37a3cc8e 100644 --- a/airbyte-lib/tests/integration_tests/test_integration.py +++ b/airbyte-lib/tests/integration_tests/test_integration.py @@ -63,7 +63,7 @@ def expected_test_stream_data() -> dict[str, list[dict[str, str | int]]]: {"column1": "value2", "column2": 2}, ], "stream2": [ - {"column1": "value1", "column2": 1}, + {"column1": "value1", "column2": 1, "empty_column": None}, ], } @@ -219,7 +219,7 @@ def test_sync_to_duckdb(expected_test_stream_data: dict[str, list[dict[str, str def test_read_result_mapping(): source = ab.get_connector("source-test", config={"apiKey": "test"}) - result: ReadResult = source.read() + result: ReadResult = source.read(ab.new_local_cache()) assert len(result) == 2 assert isinstance(result, Mapping) assert "stream1" in result @@ -230,7 +230,7 @@ def test_read_result_mapping(): def test_dataset_list_and_len(expected_test_stream_data): source = ab.get_connector("source-test", config={"apiKey": "test"}) - result: ReadResult = source.read() + result: ReadResult = source.read(ab.new_local_cache()) stream_1 = result["stream1"] assert len(stream_1) == 2 assert len(list(stream_1)) == 2 @@ -584,7 +584,7 @@ def test_tracking(mock_datetime: Mock, mock_requests: Mock, raises: bool, api_ke mock_requests.post = mock_post source = ab.get_connector("source-test", config={"apiKey": api_key}) - cache = ab.get_default_cache() + cache = ab.new_local_cache() if request_call_fails: mock_post.side_effect = Exception("test exception") diff --git a/airbyte-lib/tests/unit_tests/test_writers.py b/airbyte-lib/tests/unit_tests/test_writers.py index 5d0432606b136..2578ae10b4835 100644 --- a/airbyte-lib/tests/unit_tests/test_writers.py +++ b/airbyte-lib/tests/unit_tests/test_writers.py @@ -30,9 +30,7 @@ def test_parquet_writer_has_config(): def test_parquet_writer_has_source_catalog(): config = ParquetWriterConfig(cache_dir='test_path') writer = ParquetWriter(config) - assert hasattr(writer, 'source_catalog') def test_parquet_writer_source_catalog_is_none(): config = ParquetWriterConfig(cache_dir='test_path') writer = ParquetWriter(config) - assert writer.source_catalog is None