Skip to content

Commit

Permalink
File-based CDK: make full refresh concurrent (#34411)
Browse files Browse the repository at this point in the history
  • Loading branch information
clnoll committed Jan 30, 2024
1 parent d2171e4 commit eb31e4d
Show file tree
Hide file tree
Showing 24 changed files with 1,042 additions and 195 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from airbyte_cdk.sources.concurrent_source.concurrent_source import ConcurrentSource
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream
from airbyte_cdk.sources.streams.concurrent.adapters import StreamFacade
from airbyte_cdk.sources.streams.concurrent.abstract_stream_facade import AbstractStreamFacade


class ConcurrentSourceAdapter(AbstractSource, ABC):
Expand Down Expand Up @@ -58,6 +58,6 @@ def _select_abstract_streams(self, config: Mapping[str, Any], configured_catalog
f"The stream {configured_stream.stream.name} no longer exists in the configuration. "
f"Refresh the schema in replication settings and remove this stream from future sync attempts."
)
if isinstance(stream_instance, StreamFacade):
abstract_streams.append(stream_instance._abstract_stream)
if isinstance(stream_instance, AbstractStreamFacade):
abstract_streams.append(stream_instance.get_underlying_stream())
return abstract_streams
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from .abstract_file_based_availability_strategy import AbstractFileBasedAvailabilityStrategy
from .abstract_file_based_availability_strategy import (
AbstractFileBasedAvailabilityStrategy,
AbstractFileBasedAvailabilityStrategyWrapper,
)
from .default_file_based_availability_strategy import DefaultFileBasedAvailabilityStrategy

__all__ = ["AbstractFileBasedAvailabilityStrategy", "DefaultFileBasedAvailabilityStrategy"]
__all__ = ["AbstractFileBasedAvailabilityStrategy", "AbstractFileBasedAvailabilityStrategyWrapper", "DefaultFileBasedAvailabilityStrategy"]
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@

from airbyte_cdk.sources import Source
from airbyte_cdk.sources.streams.availability_strategy import AvailabilityStrategy
from airbyte_cdk.sources.streams.concurrent.availability_strategy import (
AbstractAvailabilityStrategy,
StreamAvailability,
StreamAvailable,
StreamUnavailable,
)
from airbyte_cdk.sources.streams.core import Stream

if TYPE_CHECKING:
Expand Down Expand Up @@ -35,3 +41,17 @@ def check_availability_and_parsability(
Returns (True, None) if successful, otherwise (False, <error message>).
"""
...


class AbstractFileBasedAvailabilityStrategyWrapper(AbstractAvailabilityStrategy):
def __init__(self, stream: "AbstractFileBasedStream"):
self.stream = stream

def check_availability(self, logger: logging.Logger) -> StreamAvailability:
is_available, reason = self.stream.availability_strategy.check_availability(self.stream, logger, None)
if is_available:
return StreamAvailable()
return StreamUnavailable(reason or "")

def check_availability_and_parsability(self, logger: logging.Logger) -> Tuple[bool, Optional[str]]:
return self.stream.availability_strategy.check_availability_and_parsability(self.stream, logger, None)
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,18 @@
from collections import Counter
from typing import Any, Iterator, List, Mapping, MutableMapping, Optional, Tuple, Type, Union

from airbyte_cdk.models import AirbyteMessage, AirbyteStateMessage, ConfiguredAirbyteCatalog, ConnectorSpecification
from airbyte_cdk.sources import AbstractSource
from airbyte_cdk.logger import AirbyteLogFormatter, init_logger
from airbyte_cdk.models import (
AirbyteMessage,
AirbyteStateMessage,
ConfiguredAirbyteCatalog,
ConnectorSpecification,
FailureType,
Level,
SyncMode,
)
from airbyte_cdk.sources.concurrent_source.concurrent_source import ConcurrentSource
from airbyte_cdk.sources.concurrent_source.concurrent_source_adapter import ConcurrentSourceAdapter
from airbyte_cdk.sources.file_based.availability_strategy import AbstractFileBasedAvailabilityStrategy, DefaultFileBasedAvailabilityStrategy
from airbyte_cdk.sources.file_based.config.abstract_file_based_spec import AbstractFileBasedSpec
from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig, ValidationPolicy
Expand All @@ -20,19 +30,33 @@
from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser
from airbyte_cdk.sources.file_based.schema_validation_policies import DEFAULT_SCHEMA_VALIDATION_POLICIES, AbstractSchemaValidationPolicy
from airbyte_cdk.sources.file_based.stream import AbstractFileBasedStream, DefaultFileBasedStream
from airbyte_cdk.sources.file_based.stream.concurrent.adapters import FileBasedStreamFacade
from airbyte_cdk.sources.file_based.stream.concurrent.cursor import FileBasedNoopCursor
from airbyte_cdk.sources.file_based.stream.cursor import AbstractFileBasedCursor
from airbyte_cdk.sources.file_based.stream.cursor.default_file_based_cursor import DefaultFileBasedCursor
from airbyte_cdk.sources.message.repository import InMemoryMessageRepository, MessageRepository
from airbyte_cdk.sources.source import TState
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.utils.analytics_message import create_analytics_message
from airbyte_cdk.utils.traced_exception import AirbyteTracedException
from pydantic.error_wrappers import ValidationError

DEFAULT_CONCURRENCY = 100
MAX_CONCURRENCY = 100
INITIAL_N_PARTITIONS = MAX_CONCURRENCY // 2


class FileBasedSource(ConcurrentSourceAdapter, ABC):
# We make each source override the concurrency level to give control over when they are upgraded.
_concurrency_level = None

class FileBasedSource(AbstractSource, ABC):
def __init__(
self,
stream_reader: AbstractFileBasedStreamReader,
spec_class: Type[AbstractFileBasedSpec],
catalog_path: Optional[str] = None,
catalog: Optional[ConfiguredAirbyteCatalog],
config: Optional[Mapping[str, Any]],
state: Optional[TState],
availability_strategy: Optional[AbstractFileBasedAvailabilityStrategy] = None,
discovery_policy: AbstractDiscoveryPolicy = DefaultDiscoveryPolicy(),
parsers: Mapping[Type[Any], FileTypeParser] = default_parsers,
Expand All @@ -41,15 +65,29 @@ def __init__(
):
self.stream_reader = stream_reader
self.spec_class = spec_class
self.config = config
self.catalog = catalog
self.state = state
self.availability_strategy = availability_strategy or DefaultFileBasedAvailabilityStrategy(stream_reader)
self.discovery_policy = discovery_policy
self.parsers = parsers
self.validation_policies = validation_policies
catalog = self.read_catalog(catalog_path) if catalog_path else None
self.stream_schemas = {s.stream.name: s.stream.json_schema for s in catalog.streams} if catalog else {}
self.cursor_cls = cursor_cls
self.logger = logging.getLogger(f"airbyte.{self.name}")
self.logger = init_logger(f"airbyte.{self.name}")
self.errors_collector: FileBasedErrorsCollector = FileBasedErrorsCollector()
self._message_repository: Optional[MessageRepository] = None
concurrent_source = ConcurrentSource.create(
MAX_CONCURRENCY, INITIAL_N_PARTITIONS, self.logger, self._slice_logger, self.message_repository
)
self._state = None
super().__init__(concurrent_source)

@property
def message_repository(self) -> MessageRepository:
if self._message_repository is None:
self._message_repository = InMemoryMessageRepository(Level(AirbyteLogFormatter.level_mapping[self.logger.level]))
return self._message_repository

def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) -> Tuple[bool, Optional[Any]]:
"""
Expand All @@ -61,7 +99,15 @@ def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) ->
Otherwise, the "error" object should describe what went wrong.
"""
streams = self.streams(config)
try:
streams = self.streams(config)
except Exception as config_exception:
raise AirbyteTracedException(
internal_message="Please check the logged errors for more information.",
message=FileBasedSourceError.CONFIG_VALIDATION_ERROR.value,
exception=AirbyteTracedException(exception=config_exception),
failure_type=FailureType.config_error,
)
if len(streams) == 0:
return (
False,
Expand All @@ -80,7 +126,7 @@ def check_connection(self, logger: logging.Logger, config: Mapping[str, Any]) ->
reason,
) = stream.availability_strategy.check_availability_and_parsability(stream, logger, self)
except Exception:
errors.append(f"Unable to connect to stream {stream} - {''.join(traceback.format_exc())}")
errors.append(f"Unable to connect to stream {stream.name} - {''.join(traceback.format_exc())}")
else:
if not stream_is_available and reason:
errors.append(reason)
Expand All @@ -91,10 +137,26 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]:
"""
Return a list of this source's streams.
"""
file_based_streams = self._get_file_based_streams(config)

configured_streams: List[Stream] = []

for stream in file_based_streams:
sync_mode = self._get_sync_mode_from_catalog(stream)
if sync_mode == SyncMode.full_refresh and hasattr(self, "_concurrency_level") and self._concurrency_level is not None:
configured_streams.append(
FileBasedStreamFacade.create_from_stream(stream, self, self.logger, None, FileBasedNoopCursor(stream.config))
)
else:
configured_streams.append(stream)

return configured_streams

def _get_file_based_streams(self, config: Mapping[str, Any]) -> List[AbstractFileBasedStream]:
try:
parsed_config = self._get_parsed_config(config)
self.stream_reader.config = parsed_config
streams: List[Stream] = []
streams: List[AbstractFileBasedStream] = []
for stream_config in parsed_config.streams:
self._validate_input_schema(stream_config)
streams.append(
Expand All @@ -115,6 +177,14 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]:
except ValidationError as exc:
raise ConfigValidationError(FileBasedSourceError.CONFIG_VALIDATION_ERROR) from exc

def _get_sync_mode_from_catalog(self, stream: Stream) -> Optional[SyncMode]:
if self.catalog:
for catalog_stream in self.catalog.streams:
if stream.name == catalog_stream.stream.name:
return catalog_stream.sync_mode
raise RuntimeError(f"No sync mode was found for {stream.name}.")
return None

def read(
self,
logger: logging.Logger,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from functools import partial
from io import IOBase
from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Set, Tuple
from uuid import uuid4

from airbyte_cdk.models import FailureType
from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat, CsvHeaderAutogenerated, CsvHeaderUserProvided, InferenceType
Expand Down Expand Up @@ -38,8 +39,10 @@ def read_data(

# Formats are configured individually per-stream so a unique dialect should be registered for each stream.
# We don't unregister the dialect because we are lazily parsing each csv file to generate records
# This will potentially be a problem if we ever process multiple streams concurrently
dialect_name = config.name + DIALECT_NAME
# Give each stream's dialect a unique name; otherwise, when we are doing a concurrent sync we can end up
# with a race condition where a thread attempts to use a dialect before a separate thread has finished
# registering it.
dialect_name = f"{config.name}_{str(uuid4())}_{DIALECT_NAME}"
csv.register_dialect(
dialect_name,
delimiter=config_format.delimiter,
Expand Down
Empty file.
Loading

0 comments on commit eb31e4d

Please sign in to comment.