diff --git a/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/db/factory/DSLContextFactory.kt b/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/db/factory/DSLContextFactory.kt index 1d95b235a8f5c..2ecf74e0fd94c 100644 --- a/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/db/factory/DSLContextFactory.kt +++ b/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/db/factory/DSLContextFactory.kt @@ -67,7 +67,7 @@ object DSLContextFactory { driverClassName: String, jdbcConnectionString: String?, dialect: SQLDialect?, - connectionProperties: Map?, + connectionProperties: Map?, connectionTimeout: Duration? ): DSLContext { return DSL.using( diff --git a/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/db/factory/DataSourceFactory.kt b/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/db/factory/DataSourceFactory.kt index 507a4f366bdb6..0b3625d18dd29 100644 --- a/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/db/factory/DataSourceFactory.kt +++ b/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/db/factory/DataSourceFactory.kt @@ -50,7 +50,7 @@ object DataSourceFactory { password: String?, driverClassName: String, jdbcConnectionString: String?, - connectionProperties: Map?, + connectionProperties: Map?, connectionTimeout: Duration? ): DataSource { return DataSourceBuilder(username, password, driverClassName, jdbcConnectionString) @@ -100,7 +100,7 @@ object DataSourceFactory { port: Int, database: String?, driverClassName: String, - connectionProperties: Map? + connectionProperties: Map? ): DataSource { return DataSourceBuilder(username, password, driverClassName, host, port, database) .withConnectionProperties(connectionProperties) @@ -152,7 +152,7 @@ object DataSourceFactory { private var password: String?, private var driverClassName: String ) { - private var connectionProperties: Map = java.util.Map.of() + private var connectionProperties: Map = java.util.Map.of() private var database: String? = null private var host: String? = null private var jdbcUrl: String? = null @@ -185,7 +185,7 @@ object DataSourceFactory { } fun withConnectionProperties( - connectionProperties: Map? + connectionProperties: Map? ): DataSourceBuilder { if (connectionProperties != null) { this.connectionProperties = connectionProperties diff --git a/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/db/jdbc/JdbcSourceOperations.kt b/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/db/jdbc/JdbcSourceOperations.kt index 01e976ee8d71d..c98a13b9d6042 100644 --- a/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/db/jdbc/JdbcSourceOperations.kt +++ b/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/db/jdbc/JdbcSourceOperations.kt @@ -16,7 +16,7 @@ import org.slf4j.LoggerFactory /** Implementation of source operations with standard JDBC types. */ class JdbcSourceOperations : - AbstractJdbcCompatibleSourceOperations(), SourceOperations { + AbstractJdbcCompatibleSourceOperations(), SourceOperations { protected fun safeGetJdbcType(columnTypeInt: Int): JDBCType { return try { JDBCType.valueOf(columnTypeInt) @@ -65,7 +65,7 @@ class JdbcSourceOperations : preparedStatement: PreparedStatement, parameterIndex: Int, cursorFieldType: JDBCType?, - value: String + value: String? ) { when (cursorFieldType) { JDBCType.TIMESTAMP -> setTimestamp(preparedStatement, parameterIndex, value) @@ -80,12 +80,12 @@ class JdbcSourceOperations : JDBCType.TINYINT, JDBCType.SMALLINT -> setShortInt(preparedStatement, parameterIndex, value!!) JDBCType.INTEGER -> setInteger(preparedStatement, parameterIndex, value!!) - JDBCType.BIGINT -> setBigInteger(preparedStatement, parameterIndex, value) + JDBCType.BIGINT -> setBigInteger(preparedStatement, parameterIndex, value!!) JDBCType.FLOAT, JDBCType.DOUBLE -> setDouble(preparedStatement, parameterIndex, value!!) JDBCType.REAL -> setReal(preparedStatement, parameterIndex, value!!) JDBCType.NUMERIC, - JDBCType.DECIMAL -> setDecimal(preparedStatement, parameterIndex, value) + JDBCType.DECIMAL -> setDecimal(preparedStatement, parameterIndex, value!!) JDBCType.CHAR, JDBCType.NCHAR, JDBCType.NVARCHAR, @@ -147,7 +147,7 @@ class JdbcSourceOperations : return JdbcUtils.ALLOWED_CURSOR_TYPES.contains(type) } - override fun getAirbyteType(jdbcType: JDBCType?): JsonSchemaType { + override fun getAirbyteType(jdbcType: JDBCType): JsonSchemaType { return when (jdbcType) { JDBCType.BIT, JDBCType.BOOLEAN -> JsonSchemaType.BOOLEAN diff --git a/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/integrations/JdbcConnector.kt b/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/integrations/JdbcConnector.kt index 297925119c870..e8ff27ccad66a 100644 --- a/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/integrations/JdbcConnector.kt +++ b/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/integrations/JdbcConnector.kt @@ -11,7 +11,7 @@ import java.util.* abstract class JdbcConnector protected constructor(@JvmField protected val driverClassName: String) : BaseConnector() { - protected fun getConnectionTimeout(connectionProperties: Map): Duration { + protected fun getConnectionTimeout(connectionProperties: Map): Duration { return getConnectionTimeout(connectionProperties, driverClassName) } @@ -37,7 +37,7 @@ protected constructor(@JvmField protected val driverClassName: String) : BaseCon * @return DataSourceBuilder class used to create dynamic fields for DataSource */ fun getConnectionTimeout( - connectionProperties: Map, + connectionProperties: Map, driverClassName: String? ): Duration { val parsedConnectionTimeout = diff --git a/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/integrations/base/Source.kt b/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/integrations/base/Source.kt index 9a39069444ecd..a1943109bae2b 100644 --- a/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/integrations/base/Source.kt +++ b/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/integrations/base/Source.kt @@ -42,7 +42,7 @@ interface Source : Integration { @Throws(Exception::class) fun read( config: JsonNode, - catalog: ConfiguredAirbyteCatalog?, + catalog: ConfiguredAirbyteCatalog, state: JsonNode? ): AutoCloseableIterator @@ -65,7 +65,7 @@ interface Source : Integration { @Throws(Exception::class) fun readStreams( config: JsonNode, - catalog: ConfiguredAirbyteCatalog?, + catalog: ConfiguredAirbyteCatalog, state: JsonNode? ): Collection>? { return List.of(read(config, catalog, state)) diff --git a/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/integrations/base/spec_modification/SpecModifyingSource.kt b/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/integrations/base/spec_modification/SpecModifyingSource.kt index 4a6306d28f132..07231c743c575 100644 --- a/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/integrations/base/spec_modification/SpecModifyingSource.kt +++ b/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/integrations/base/spec_modification/SpecModifyingSource.kt @@ -35,7 +35,7 @@ abstract class SpecModifyingSource(private val source: Source) : Source { @Throws(Exception::class) override fun read( config: JsonNode, - catalog: ConfiguredAirbyteCatalog?, + catalog: ConfiguredAirbyteCatalog, state: JsonNode? ): AutoCloseableIterator { return source.read(config, catalog, state) @@ -44,7 +44,7 @@ abstract class SpecModifyingSource(private val source: Source) : Source { @Throws(Exception::class) override fun readStreams( config: JsonNode, - catalog: ConfiguredAirbyteCatalog?, + catalog: ConfiguredAirbyteCatalog, state: JsonNode? ): Collection>? { return source.readStreams(config, catalog, state) diff --git a/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/integrations/base/ssh/SshWrappedSource.kt b/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/integrations/base/ssh/SshWrappedSource.kt index cd77066827d9e..db045767eb8e7 100644 --- a/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/integrations/base/ssh/SshWrappedSource.kt +++ b/airbyte-cdk/java/airbyte-cdk/core/src/main/kotlin/io/airbyte/cdk/integrations/base/ssh/SshWrappedSource.kt @@ -76,7 +76,7 @@ class SshWrappedSource : Source { @Throws(Exception::class) override fun read( config: JsonNode, - catalog: ConfiguredAirbyteCatalog?, + catalog: ConfiguredAirbyteCatalog, state: JsonNode? ): AutoCloseableIterator { val tunnel: SshTunnel = SshTunnel.Companion.getInstance(config, hostKey, portKey) @@ -97,7 +97,7 @@ class SshWrappedSource : Source { @Throws(Exception::class) override fun readStreams( config: JsonNode, - catalog: ConfiguredAirbyteCatalog?, + catalog: ConfiguredAirbyteCatalog, state: JsonNode? ): Collection>? { val tunnel: SshTunnel = SshTunnel.Companion.getInstance(config, hostKey, portKey) diff --git a/airbyte-cdk/java/airbyte-cdk/core/src/main/resources/version.properties b/airbyte-cdk/java/airbyte-cdk/core/src/main/resources/version.properties index 3a7b1b0955713..c7be3358f5502 100644 --- a/airbyte-cdk/java/airbyte-cdk/core/src/main/resources/version.properties +++ b/airbyte-cdk/java/airbyte-cdk/core/src/main/resources/version.properties @@ -1 +1 @@ -version=0.28.10 +version=0.28.11 diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/build.gradle b/airbyte-cdk/java/airbyte-cdk/db-sources/build.gradle index 5ac716385f162..3f21973ce07a6 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/build.gradle +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/build.gradle @@ -11,6 +11,11 @@ java { } } +compileKotlin.compilerOptions.allWarningsAsErrors = false +compileTestFixturesKotlin.compilerOptions.allWarningsAsErrors = false +compileTestKotlin.compilerOptions.allWarningsAsErrors = false + + // Convert yaml to java: relationaldb.models jsonSchema2Pojo { sourceType = SourceType.YAMLSCHEMA @@ -53,4 +58,5 @@ dependencies { testImplementation testFixtures(project(':airbyte-cdk:java:airbyte-cdk:datastore-postgres')) testImplementation 'uk.org.webcompere:system-stubs-jupiter:2.0.1' + testImplementation 'org.mockito.kotlin:mockito-kotlin:5.2.1' } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/AirbyteDebeziumHandler.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/AirbyteDebeziumHandler.kt index 24e7eebb2aa37..2d619a02831d4 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/AirbyteDebeziumHandler.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/AirbyteDebeziumHandler.kt @@ -14,34 +14,43 @@ import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream import io.airbyte.protocol.models.v0.SyncMode import io.debezium.engine.ChangeEvent import io.debezium.engine.DebeziumEngine -import org.slf4j.Logger -import org.slf4j.LoggerFactory import java.time.Duration import java.time.Instant import java.time.temporal.ChronoUnit import java.util.* import java.util.concurrent.LinkedBlockingQueue +import org.slf4j.Logger +import org.slf4j.LoggerFactory /** * This class acts as the bridge between Airbyte DB connectors and debezium. If a DB connector wants * to use debezium for CDC, it should use this class */ -class AirbyteDebeziumHandler(private val config: JsonNode, - private val targetPosition: CdcTargetPosition, - private val trackSchemaHistory: Boolean, - private val firstRecordWaitTime: Duration, - private val subsequentRecordWaitTime: Duration, - private val queueSize: Int, - private val addDbNameToOffsetState: Boolean) { - internal inner class CapacityReportingBlockingQueue(capacity: Int) : LinkedBlockingQueue(capacity) { +class AirbyteDebeziumHandler( + private val config: JsonNode, + private val targetPosition: CdcTargetPosition, + private val trackSchemaHistory: Boolean, + private val firstRecordWaitTime: Duration, + private val subsequentRecordWaitTime: Duration, + private val queueSize: Int, + private val addDbNameToOffsetState: Boolean +) { + internal inner class CapacityReportingBlockingQueue(capacity: Int) : + LinkedBlockingQueue(capacity) { private var lastReport: Instant? = null private fun reportQueueUtilization() { - if (lastReport == null || Duration.between(lastReport, Instant.now()).compareTo(Companion.REPORT_DURATION) > 0) { - LOGGER.info("CDC events queue size: {}. remaining {}", this.size, this.remainingCapacity()) - synchronized(this) { - lastReport = Instant.now() - } + if ( + lastReport == null || + Duration.between(lastReport, Instant.now()) + .compareTo(Companion.REPORT_DURATION) > 0 + ) { + LOGGER.info( + "CDC events queue size: {}. remaining {}", + this.size, + this.remainingCapacity() + ) + synchronized(this) { lastReport = Instant.now() } } } @@ -55,44 +64,62 @@ class AirbyteDebeziumHandler(private val config: JsonNode, reportQueueUtilization() return super.poll() } - - companion object { - private val REPORT_DURATION: Duration = Duration.of(10, ChronoUnit.SECONDS) - } } - fun getIncrementalIterators(debeziumPropertiesManager: DebeziumPropertiesManager, - eventConverter: DebeziumEventConverter, - cdcSavedInfoFetcher: CdcSavedInfoFetcher, - cdcStateHandler: CdcStateHandler): AutoCloseableIterator { + fun getIncrementalIterators( + debeziumPropertiesManager: DebeziumPropertiesManager, + eventConverter: DebeziumEventConverter, + cdcSavedInfoFetcher: CdcSavedInfoFetcher, + cdcStateHandler: CdcStateHandler + ): AutoCloseableIterator { LOGGER.info("Using CDC: {}", true) - LOGGER.info("Using DBZ version: {}", DebeziumEngine::class.java.getPackage().implementationVersion) - val offsetManager: AirbyteFileOffsetBackingStore = AirbyteFileOffsetBackingStore.Companion.initializeState( + LOGGER.info( + "Using DBZ version: {}", + DebeziumEngine::class.java.getPackage().implementationVersion + ) + val offsetManager: AirbyteFileOffsetBackingStore = + AirbyteFileOffsetBackingStore.Companion.initializeState( cdcSavedInfoFetcher.savedOffset, - if (addDbNameToOffsetState) Optional.ofNullable(config[JdbcUtils.DATABASE_KEY].asText()) else Optional.empty()) - val schemaHistoryManager: Optional = if (trackSchemaHistory - ) Optional.of(AirbyteSchemaHistoryStorage.Companion.initializeDBHistory( - cdcSavedInfoFetcher.savedSchemaHistory, cdcStateHandler.compressSchemaHistoryForState())) - else Optional.empty() + if (addDbNameToOffsetState) + Optional.ofNullable(config[JdbcUtils.DATABASE_KEY].asText()) + else Optional.empty() + ) + val schemaHistoryManager: Optional = + if (trackSchemaHistory) + Optional.of( + AirbyteSchemaHistoryStorage.Companion.initializeDBHistory( + cdcSavedInfoFetcher.savedSchemaHistory, + cdcStateHandler.compressSchemaHistoryForState() + ) + ) + else Optional.empty() val publisher = DebeziumRecordPublisher(debeziumPropertiesManager) - val queue: CapacityReportingBlockingQueue> = CapacityReportingBlockingQueue>(queueSize) + val queue: CapacityReportingBlockingQueue> = + CapacityReportingBlockingQueue(queueSize) publisher.start(queue, offsetManager, schemaHistoryManager) // handle state machine around pub/sub logic. - val eventIterator: AutoCloseableIterator = DebeziumRecordIterator( + val eventIterator: AutoCloseableIterator = + DebeziumRecordIterator( queue, targetPosition, { publisher.hasClosed() }, DebeziumShutdownProcedure(queue, { publisher.close() }, { publisher.hasClosed() }), firstRecordWaitTime, - subsequentRecordWaitTime) + subsequentRecordWaitTime + ) - val syncCheckpointDuration = if (config.has(DebeziumIteratorConstants.SYNC_CHECKPOINT_DURATION_PROPERTY) - ) Duration.ofSeconds(config[DebeziumIteratorConstants.SYNC_CHECKPOINT_DURATION_PROPERTY].asLong()) - else DebeziumIteratorConstants.SYNC_CHECKPOINT_DURATION - val syncCheckpointRecords = if (config.has(DebeziumIteratorConstants.SYNC_CHECKPOINT_RECORDS_PROPERTY) - ) config[DebeziumIteratorConstants.SYNC_CHECKPOINT_RECORDS_PROPERTY].asLong() - else DebeziumIteratorConstants.SYNC_CHECKPOINT_RECORDS.toLong() - return AutoCloseableIterators.fromIterator(DebeziumStateDecoratingIterator( + val syncCheckpointDuration = + if (config.has(DebeziumIteratorConstants.SYNC_CHECKPOINT_DURATION_PROPERTY)) + Duration.ofSeconds( + config[DebeziumIteratorConstants.SYNC_CHECKPOINT_DURATION_PROPERTY].asLong() + ) + else DebeziumIteratorConstants.SYNC_CHECKPOINT_DURATION + val syncCheckpointRecords = + if (config.has(DebeziumIteratorConstants.SYNC_CHECKPOINT_RECORDS_PROPERTY)) + config[DebeziumIteratorConstants.SYNC_CHECKPOINT_RECORDS_PROPERTY].asLong() + else DebeziumIteratorConstants.SYNC_CHECKPOINT_RECORDS.toLong() + return AutoCloseableIterators.fromIterator( + DebeziumStateDecoratingIterator( eventIterator, cdcStateHandler, targetPosition, @@ -101,11 +128,14 @@ class AirbyteDebeziumHandler(private val config: JsonNode, trackSchemaHistory, schemaHistoryManager.orElse(null), syncCheckpointDuration, - syncCheckpointRecords)) + syncCheckpointRecords + ) + ) } companion object { private val LOGGER: Logger = LoggerFactory.getLogger(AirbyteDebeziumHandler::class.java) + private val REPORT_DURATION: Duration = Duration.of(10, ChronoUnit.SECONDS) /** * We use 10000 as capacity cause the default queue size and batch size of debezium is : @@ -115,8 +145,10 @@ class AirbyteDebeziumHandler(private val config: JsonNode, const val QUEUE_CAPACITY: Int = 10000 fun isAnyStreamIncrementalSyncMode(catalog: ConfiguredAirbyteCatalog): Boolean { - return catalog.streams.stream().map { obj: ConfiguredAirbyteStream -> obj.syncMode } - .anyMatch { syncMode: SyncMode -> syncMode == SyncMode.INCREMENTAL } + return catalog.streams + .stream() + .map { obj: ConfiguredAirbyteStream -> obj.syncMode } + .anyMatch { syncMode: SyncMode -> syncMode == SyncMode.INCREMENTAL } } } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/CdcMetadataInjector.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/CdcMetadataInjector.kt index fb2d5ce38cf01..5ccd4b3666002 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/CdcMetadataInjector.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/CdcMetadataInjector.kt @@ -17,12 +17,16 @@ interface CdcMetadataInjector { * https://debezium.io/documentation/reference/1.9/connectors/mysql.html#mysql-create-events * * @param event is the actual record which contains data and would be written to the destination - * @param source contains the metadata about the record and we need to extract that metadata and add - * it to the event before writing it to destination + * @param source contains the metadata about the record and we need to extract that metadata and + * add it to the event before writing it to destination */ fun addMetaData(event: ObjectNode?, source: JsonNode?) - fun addMetaDataToRowsFetchedOutsideDebezium(record: ObjectNode?, transactionTimestamp: String?, metadataToAdd: T) { + fun addMetaDataToRowsFetchedOutsideDebezium( + record: ObjectNode?, + transactionTimestamp: String?, + metadataToAdd: T + ) { throw RuntimeException("Not Supported") } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/CdcSavedInfoFetcher.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/CdcSavedInfoFetcher.kt index 80c025bba5873..abcc9e5915394 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/CdcSavedInfoFetcher.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/CdcSavedInfoFetcher.kt @@ -14,5 +14,5 @@ import java.util.* interface CdcSavedInfoFetcher { val savedOffset: JsonNode? - val savedSchemaHistory: AirbyteSchemaHistoryStorage.SchemaHistory?>? + val savedSchemaHistory: AirbyteSchemaHistoryStorage.SchemaHistory>? } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/CdcStateHandler.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/CdcStateHandler.kt index 2c48d0b8dd1e6..317c87e1cfbcc 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/CdcStateHandler.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/CdcStateHandler.kt @@ -11,7 +11,10 @@ import io.airbyte.protocol.models.v0.AirbyteMessage * which suits them. Also, it adds some utils to verify CDC event status. */ interface CdcStateHandler { - fun saveState(offset: Map?, dbHistory: AirbyteSchemaHistoryStorage.SchemaHistory?): AirbyteMessage? + fun saveState( + offset: Map?, + dbHistory: AirbyteSchemaHistoryStorage.SchemaHistory? + ): AirbyteMessage? fun saveStateAfterCompletionOfSnapshotOfNewStreams(): AirbyteMessage? @@ -21,10 +24,11 @@ interface CdcStateHandler { val isCdcCheckpointEnabled: Boolean /** - * This function is used as feature flag for sending state messages as checkpoints in CDC syncs. + * This function is used as feature flag for sending state messages as checkpoints in CDC + * syncs. * - * @return Returns `true` if checkpoint state messages are enabled for CDC syncs. Otherwise, it - * returns `false` + * @return Returns `true` if checkpoint state messages are enabled for CDC syncs. Otherwise, + * it returns `false` */ get() = false } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/CdcTargetPosition.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/CdcTargetPosition.kt index 889ffcfe73f3c..f0a8e12dad5a5 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/CdcTargetPosition.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/CdcTargetPosition.kt @@ -25,8 +25,8 @@ interface CdcTargetPosition { /** * Reads a position value (lsn) from a change event and compares it to target lsn * - * @param positionFromHeartbeat is the position extracted out of a heartbeat event (if the connector - * supports heartbeat) + * @param positionFromHeartbeat is the position extracted out of a heartbeat event (if the + * connector supports heartbeat) * @return true if heartbeat position is equal or greater than target position */ fun reachedTargetPosition(positionFromHeartbeat: T): Boolean { @@ -50,24 +50,28 @@ interface CdcTargetPosition { fun extractPositionFromHeartbeatOffset(sourceOffset: Map?): T /** - * This function checks if the event we are processing in the loop is already behind the offset so - * the process can safety save the state. + * This function checks if the event we are processing in the loop is already behind the offset + * so the process can safety save the state. * * @param offset DB CDC offset * @param event Event from the CDC load * @return Returns `true` when the event is ahead of the offset. Otherwise, it returns `false` */ - fun isEventAheadOffset(offset: Map?, event: ChangeEventWithMetadata?): Boolean { + fun isEventAheadOffset( + offset: Map?, + event: ChangeEventWithMetadata? + ): Boolean { return false } /** - * This function compares two offsets to make sure both are not pointing to the same position. The - * main purpose is to avoid sending same offset multiple times. + * This function compares two offsets to make sure both are not pointing to the same position. + * The main purpose is to avoid sending same offset multiple times. * * @param offsetA Offset to compare * @param offsetB Offset to compare - * @return Returns `true` if both offsets are at the same position. Otherwise, it returns `false` + * @return Returns `true` if both offsets are at the same position. Otherwise, it returns + * `false` */ fun isSameOffset(offsetA: Map?, offsetB: Map?): Boolean { return false diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/DebeziumIteratorConstants.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/DebeziumIteratorConstants.kt index 45dbcc81f6e59..f143e084e32a3 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/DebeziumIteratorConstants.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/DebeziumIteratorConstants.kt @@ -9,7 +9,8 @@ object DebeziumIteratorConstants { const val SYNC_CHECKPOINT_DURATION_PROPERTY: String = "sync_checkpoint_seconds" const val SYNC_CHECKPOINT_RECORDS_PROPERTY: String = "sync_checkpoint_records" - // TODO: Move these variables to a separate class IteratorConstants, as they will be used in state + // TODO: Move these variables to a separate class IteratorConstants, as they will be used in + // state // iterators for non debezium cases too. val SYNC_CHECKPOINT_DURATION: Duration = Duration.ofMinutes(15) const val SYNC_CHECKPOINT_RECORDS: Int = 10000 diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/AirbyteFileOffsetBackingStore.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/AirbyteFileOffsetBackingStore.kt index 275f59fe368bb..ce9fa8cd00358 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/AirbyteFileOffsetBackingStore.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/AirbyteFileOffsetBackingStore.kt @@ -6,11 +6,6 @@ package io.airbyte.cdk.integrations.debezium.internals import com.fasterxml.jackson.databind.JsonNode import com.google.common.base.Preconditions import io.airbyte.commons.json.Jsons -import org.apache.commons.io.FileUtils -import org.apache.kafka.connect.errors.ConnectException -import org.apache.kafka.connect.util.SafeObjectInputStream -import org.slf4j.Logger -import org.slf4j.LoggerFactory import java.io.EOFException import java.io.IOException import java.io.ObjectOutputStream @@ -23,6 +18,11 @@ import java.util.* import java.util.function.BiFunction import java.util.function.Function import java.util.stream.Collectors +import org.apache.commons.io.FileUtils +import org.apache.kafka.connect.errors.ConnectException +import org.apache.kafka.connect.util.SafeObjectInputStream +import org.slf4j.Logger +import org.slf4j.LoggerFactory /** * This class handles reading and writing a debezium offset file. In many cases it is duplicating @@ -32,24 +32,44 @@ import java.util.stream.Collectors * we ever discover that any of the contents of these offset files is not string serializable we * will likely have to drop the human readability support and just base64 encode it. */ -class AirbyteFileOffsetBackingStore(private val offsetFilePath: Path, private val dbName: Optional) { +class AirbyteFileOffsetBackingStore( + private val offsetFilePath: Path, + private val dbName: Optional +) { fun read(): Map { val raw = load() - return raw.entries.stream().collect(Collectors.toMap( - Function { e: Map.Entry -> byteBufferToString(e.key) }, - Function { e: Map.Entry -> byteBufferToString(e.value) })) + return raw.entries + .stream() + .collect( + Collectors.toMap( + Function { e: Map.Entry -> + byteBufferToString(e.key) + }, + Function { e: Map.Entry -> + byteBufferToString(e.value) + } + ) + ) } fun persist(cdcState: JsonNode?) { - val mapAsString: Map = - if (cdcState != null) Jsons.`object`>(cdcState, MutableMap::class.java) else emptyMap() + val mapAsString: Map = + if (cdcState != null) + Jsons.`object`(cdcState, MutableMap::class.java) as Map + else emptyMap() val updatedMap = updateStateForDebezium2_1(mapAsString) - val mappedAsStrings = updatedMap.entries.stream().collect(Collectors.toMap( - Function { e: Map.Entry -> stringToByteBuffer(e.key) }, - Function { e: Map.Entry -> stringToByteBuffer(e.value) })) + val mappedAsStrings = + updatedMap.entries + .stream() + .collect( + Collectors.toMap( + Function { e: Map.Entry -> stringToByteBuffer(e.key) }, + Function { e: Map.Entry -> stringToByteBuffer(e.value) } + ) + ) FileUtils.deleteQuietly(offsetFilePath.toFile()) save(mappedAsStrings) @@ -68,7 +88,10 @@ class AirbyteFileOffsetBackingStore(private val offsetFilePath: Path, private va } LOGGER.info("Mutating sate to make it Debezium 2.1 compatible") - val newKey = if (dbName.isPresent) SQL_SERVER_STATE_MUTATION.apply(key.substring(i, i1 + 1), dbName.get()) else key.substring(i, i1 + 1) + val newKey = + if (dbName.isPresent) + SQL_SERVER_STATE_MUTATION.apply(key.substring(i, i1 + 1), dbName.get()) + else key.substring(i, i1 + 1) val value = mapAsString[key] updatedMap[newKey] = value } @@ -77,16 +100,19 @@ class AirbyteFileOffsetBackingStore(private val offsetFilePath: Path, private va /** * See FileOffsetBackingStore#load - logic is mostly borrowed from here. duplicated because this - * method is not public. Reduced the try catch block to only the read operation from original code - * to reduce errors when reading the file. + * method is not public. Reduced the try catch block to only the read operation from original + * code to reduce errors when reading the file. */ private fun load(): Map { var obj: Any try { SafeObjectInputStream(Files.newInputStream(offsetFilePath)).use { `is` -> - // todo (cgardens) - we currently suppress a security warning for this line. use of readObject from - // untrusted sources is considered unsafe. Since the source is controlled by us in this case it - // should be safe. That said, changing this implementation to not use readObject would remove some + // todo (cgardens) - we currently suppress a security warning for this line. use of + // readObject from + // untrusted sources is considered unsafe. Since the source is controlled by us in + // this case it + // should be safe. That said, changing this implementation to not use readObject + // would remove some // headache. obj = `is`.readObject() } @@ -102,7 +128,8 @@ class AirbyteFileOffsetBackingStore(private val offsetFilePath: Path, private va throw ConnectException(e) } - if (obj !is HashMap<*, *>) throw ConnectException("Expected HashMap but found " + obj.javaClass) + if (obj !is HashMap<*, *>) + throw ConnectException("Expected HashMap but found " + obj.javaClass) val raw = obj as Map val data: MutableMap = HashMap() for ((key1, value1) in raw) { @@ -137,16 +164,23 @@ class AirbyteFileOffsetBackingStore(private val offsetFilePath: Path, private va fun setDebeziumProperties(props: Properties) { // debezium engine configuration // https://debezium.io/documentation/reference/2.2/development/engine.html#engine-properties - props.setProperty("offset.storage", "org.apache.kafka.connect.storage.FileOffsetBackingStore") + props.setProperty( + "offset.storage", + "org.apache.kafka.connect.storage.FileOffsetBackingStore" + ) props.setProperty("offset.storage.file.filename", offsetFilePath.toString()) props.setProperty("offset.flush.interval.ms", "1000") // todo: make this longer } companion object { - private val LOGGER: Logger = LoggerFactory.getLogger(AirbyteFileOffsetBackingStore::class.java) + private val LOGGER: Logger = + LoggerFactory.getLogger(AirbyteFileOffsetBackingStore::class.java) private val SQL_SERVER_STATE_MUTATION = BiFunction { key: String, databaseName: String -> - (key.substring(0, key.length - 2) - + ",\"database\":\"" + databaseName + "\"" + key.substring(key.length - 2)) + (key.substring(0, key.length - 2) + + ",\"database\":\"" + + databaseName + + "\"" + + key.substring(key.length - 2)) } private fun byteBufferToString(byteBuffer: ByteBuffer?): String { @@ -159,7 +193,10 @@ class AirbyteFileOffsetBackingStore(private val offsetFilePath: Path, private va return ByteBuffer.wrap(s!!.toByteArray(StandardCharsets.UTF_8)) } - fun initializeState(cdcState: JsonNode?, dbName: Optional): AirbyteFileOffsetBackingStore { + fun initializeState( + cdcState: JsonNode?, + dbName: Optional + ): AirbyteFileOffsetBackingStore { val cdcWorkingDir: Path try { cdcWorkingDir = Files.createTempDirectory(Path.of("/tmp"), "cdc-state-offset") diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/AirbyteSchemaHistoryStorage.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/AirbyteSchemaHistoryStorage.kt index 537a4c2c837f7..d6829b1d3d6f7 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/AirbyteSchemaHistoryStorage.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/AirbyteSchemaHistoryStorage.kt @@ -8,9 +8,6 @@ import com.google.common.annotations.VisibleForTesting import io.airbyte.commons.json.Jsons import io.debezium.document.DocumentReader import io.debezium.document.DocumentWriter -import org.apache.commons.io.FileUtils -import org.slf4j.Logger -import org.slf4j.LoggerFactory import java.io.* import java.nio.charset.Charset import java.nio.charset.StandardCharsets @@ -21,14 +18,20 @@ import java.nio.file.StandardOpenOption import java.util.* import java.util.zip.GZIPInputStream import java.util.zip.GZIPOutputStream +import org.apache.commons.io.FileUtils +import org.slf4j.Logger +import org.slf4j.LoggerFactory /** - * The purpose of this class is : to , 1. Read the contents of the file [.path] which contains - * the schema history at the end of the sync so that it can be saved in state for future syncs. - * Check [.read] 2. Write the saved content back to the file [.path] at the beginning - * of the sync so that debezium can function smoothly. Check persist(Optional<JsonNode>). + * The purpose of this class is : to , 1. Read the contents of the file [.path] which contains the + * schema history at the end of the sync so that it can be saved in state for future syncs. Check + * [.read] 2. Write the saved content back to the file [.path] at the beginning of the sync so that + * debezium can function smoothly. Check persist(Optional<JsonNode>). */ -class AirbyteSchemaHistoryStorage(private val path: Path, private val compressSchemaHistoryForState: Boolean) { +class AirbyteSchemaHistoryStorage( + private val path: Path, + private val compressSchemaHistoryForState: Boolean +) { private val reader: DocumentReader = DocumentReader.defaultReader() private val writer: DocumentWriter = DocumentWriter.defaultWriter() @@ -37,17 +40,6 @@ class AirbyteSchemaHistoryStorage(private val path: Path, private val compressSc val isCompressed: Boolean init { - this.streamName = streamName - this.primaryKey = primaryKey - this.keySequence = keySequence - this.syncCheckpointRecords = syncCheckpointRecords - this.syncCheckpointDuration = syncCheckpointDuration - this.tableName = tableName - this.cursorColumnName = cursorColumnName - this.cursorSqlType = cursorSqlType - this.cause = cause - this.tableSize = tableSize - this.avgRowLength = avgRowLength this.schema = schema this.isCompressed = isCompressed } @@ -56,21 +48,31 @@ class AirbyteSchemaHistoryStorage(private val path: Path, private val compressSc fun read(): SchemaHistory { val fileSizeMB = path.toFile().length().toDouble() / (ONE_MB) if ((fileSizeMB > SIZE_LIMIT_TO_COMPRESS_MB) && compressSchemaHistoryForState) { - LOGGER.info("File Size {} MB is greater than the size limit of {} MB, compressing the content of the file.", fileSizeMB, - SIZE_LIMIT_TO_COMPRESS_MB) + LOGGER.info( + "File Size {} MB is greater than the size limit of {} MB, compressing the content of the file.", + fileSizeMB, + SIZE_LIMIT_TO_COMPRESS_MB + ) val schemaHistory = readCompressed() val compressedSizeMB = calculateSizeOfStringInMB(schemaHistory) if (fileSizeMB > compressedSizeMB) { LOGGER.info("Content Size post compression is {} MB ", compressedSizeMB) } else { - throw RuntimeException("Compressing increased the size of the content. Size before compression " + fileSizeMB + ", after compression " - + compressedSizeMB) + throw RuntimeException( + "Compressing increased the size of the content. Size before compression " + + fileSizeMB + + ", after compression " + + compressedSizeMB + ) } return SchemaHistory(schemaHistory, true) } if (compressSchemaHistoryForState) { - LOGGER.info("File Size {} MB is less than the size limit of {} MB, reading the content of the file without compression.", fileSizeMB, - SIZE_LIMIT_TO_COMPRESS_MB) + LOGGER.info( + "File Size {} MB is less than the size limit of {} MB, reading the content of the file without compression.", + fileSizeMB, + SIZE_LIMIT_TO_COMPRESS_MB + ) } else { LOGGER.info("File Size {} MB.", fileSizeMB) } @@ -108,8 +110,12 @@ class AirbyteSchemaHistoryStorage(private val path: Path, private val compressSc if (!line.isEmpty()) { val record = reader.read(line) val recordAsString = writer.write(record) - gzipOutputStream.write(recordAsString.toByteArray(StandardCharsets.UTF_8)) - gzipOutputStream.write(lineSeparator.toByteArray(StandardCharsets.UTF_8)) + gzipOutputStream.write( + recordAsString.toByteArray(StandardCharsets.UTF_8) + ) + gzipOutputStream.write( + lineSeparator.toByteArray(StandardCharsets.UTF_8) + ) } } } @@ -136,11 +142,13 @@ class AirbyteSchemaHistoryStorage(private val path: Path, private val compressSc } } catch (e: IOException) { throw IllegalStateException( - "Unable to check or create history file at " + path + ": " + e.message, e) + "Unable to check or create history file at " + path + ": " + e.message, + e + ) } } - private fun persist(schemaHistory: SchemaHistory?>?) { + private fun persist(schemaHistory: SchemaHistory>?) { if (schemaHistory!!.schema!!.isEmpty) { return } @@ -164,20 +172,23 @@ class AirbyteSchemaHistoryStorage(private val path: Path, private val compressSc */ private fun writeToFile(fileAsString: String) { try { - val split = fileAsString.split(System.lineSeparator().toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray() + val split = + fileAsString + .split(System.lineSeparator().toRegex()) + .dropLastWhile { it.isEmpty() } + .toTypedArray() for (element in split) { val read = reader.read(element) val line = writer.write(read) - Files - .newBufferedWriter(path, StandardOpenOption.APPEND).use { historyWriter -> - try { - historyWriter.append(line) - historyWriter.newLine() - } catch (e: IOException) { - throw RuntimeException(e) - } - } + Files.newBufferedWriter(path, StandardOpenOption.APPEND).use { historyWriter -> + try { + historyWriter.append(line) + historyWriter.newLine() + } catch (e: IOException) { + throw RuntimeException(e) + } + } } } catch (e: IOException) { throw RuntimeException(e) @@ -186,7 +197,8 @@ class AirbyteSchemaHistoryStorage(private val path: Path, private val compressSc private fun writeCompressedStringToFile(compressedString: String) { try { - ByteArrayInputStream(Jsons.deserialize(compressedString, ByteArray::class.java)).use { inputStream -> + ByteArrayInputStream(Jsons.deserialize(compressedString, ByteArray::class.java)).use { + inputStream -> GZIPInputStream(inputStream).use { gzipInputStream -> FileOutputStream(path.toFile()).use { fileOutputStream -> val buffer = ByteArray(1024) @@ -205,15 +217,20 @@ class AirbyteSchemaHistoryStorage(private val path: Path, private val compressSc fun setDebeziumProperties(props: Properties) { // https://debezium.io/documentation/reference/2.2/operations/debezium-server.html#debezium-source-database-history-class // https://debezium.io/documentation/reference/development/engine.html#_in_the_code - // As mentioned in the documents above, debezium connector for MySQL needs to track the schema + // As mentioned in the documents above, debezium connector for MySQL needs to track the + // schema // changes. If we don't do this, we can't fetch records for the table. - props.setProperty("schema.history.internal", "io.debezium.storage.file.history.FileSchemaHistory") + props.setProperty( + "schema.history.internal", + "io.debezium.storage.file.history.FileSchemaHistory" + ) props.setProperty("schema.history.internal.file.filename", path.toString()) props.setProperty("schema.history.internal.store.only.captured.databases.ddl", "true") } companion object { - private val LOGGER: Logger = LoggerFactory.getLogger(AirbyteSchemaHistoryStorage::class.java) + private val LOGGER: Logger = + LoggerFactory.getLogger(AirbyteSchemaHistoryStorage::class.java) private const val SIZE_LIMIT_TO_COMPRESS_MB: Long = 1 const val ONE_MB: Int = 1024 * 1024 private val UTF8: Charset = StandardCharsets.UTF_8 @@ -223,8 +240,10 @@ class AirbyteSchemaHistoryStorage(private val path: Path, private val compressSc return string.toByteArray(StandardCharsets.UTF_8).size.toDouble() / (ONE_MB) } - fun initializeDBHistory(schemaHistory: SchemaHistory?>?, - compressSchemaHistoryForState: Boolean): AirbyteSchemaHistoryStorage { + fun initializeDBHistory( + schemaHistory: SchemaHistory>?, + compressSchemaHistoryForState: Boolean + ): AirbyteSchemaHistoryStorage { val dbHistoryWorkingDir: Path try { dbHistoryWorkingDir = Files.createTempDirectory(Path.of("/tmp"), "cdc-db-history") @@ -234,7 +253,7 @@ class AirbyteSchemaHistoryStorage(private val path: Path, private val compressSc val dbHistoryFilePath = dbHistoryWorkingDir.resolve("dbhistory.dat") val schemaHistoryManager = - AirbyteSchemaHistoryStorage(dbHistoryFilePath, compressSchemaHistoryForState) + AirbyteSchemaHistoryStorage(dbHistoryFilePath, compressSchemaHistoryForState) schemaHistoryManager.persist(schemaHistory) return schemaHistoryManager } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/ChangeEventWithMetadata.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/ChangeEventWithMetadata.kt index e5de64be11a0c..8e0a8985e2ffb 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/ChangeEventWithMetadata.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/ChangeEventWithMetadata.kt @@ -10,7 +10,8 @@ import io.debezium.engine.ChangeEvent class ChangeEventWithMetadata(private val event: ChangeEvent) { private val eventKeyAsJson: JsonNode = Jsons.deserialize(event.key()) private val eventValueAsJson: JsonNode = Jsons.deserialize(event.value()) - private val snapshotMetadata: SnapshotMetadata? = SnapshotMetadata.Companion.fromString(eventValueAsJson["source"]["snapshot"].asText()) + private val snapshotMetadata: SnapshotMetadata? = + SnapshotMetadata.Companion.fromString(eventValueAsJson["source"]["snapshot"].asText()) fun event(): ChangeEvent { return event diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumConverterUtils.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumConverterUtils.kt index 8ef146111392d..6c87fb88a35a4 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumConverterUtils.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumConverterUtils.kt @@ -6,14 +6,14 @@ package io.airbyte.cdk.integrations.debezium.internals import io.airbyte.cdk.db.DataTypeUtils.toISO8601String import io.airbyte.cdk.db.DataTypeUtils.toISO8601StringWithMicroseconds import io.debezium.spi.converter.RelationalColumn -import org.slf4j.Logger -import org.slf4j.LoggerFactory import java.sql.Date import java.sql.Timestamp import java.time.Duration import java.time.LocalDate import java.time.LocalDateTime import java.time.format.DateTimeParseException +import org.slf4j.Logger +import org.slf4j.LoggerFactory class DebeziumConverterUtils private constructor() { init { @@ -23,17 +23,16 @@ class DebeziumConverterUtils private constructor() { companion object { private val LOGGER: Logger = LoggerFactory.getLogger(DebeziumConverterUtils::class.java) - /** - * TODO : Replace usage of this method with [io.airbyte.cdk.db.jdbc.DateTimeConverter] - */ + /** TODO : Replace usage of this method with [io.airbyte.cdk.db.jdbc.DateTimeConverter] */ fun convertDate(input: Any): String { /** - * While building this custom converter we were not sure what type debezium could return cause there - * is no mention of it in the documentation. Secondly if you take a look at + * While building this custom converter we were not sure what type debezium could return + * cause there is no mention of it in the documentation. Secondly if you take a look at * [io.debezium.connector.mysql.converters.TinyIntOneToBooleanConverter.converterFor] - * method, even it is handling multiple data types but its not clear under what circumstances which - * data type would be returned. I just went ahead and handled the data types that made sense. - * Secondly, we use LocalDateTime to handle this cause it represents DATETIME datatype in JAVA + * method, even it is handling multiple data types but its not clear under what + * circumstances which data type would be returned. I just went ahead and handled the + * data types that made sense. Secondly, we use LocalDateTime to handle this cause it + * represents DATETIME datatype in JAVA */ if (input is LocalDateTime) { return toISO8601String(input) @@ -44,8 +43,7 @@ class DebeziumConverterUtils private constructor() { } else if (input is Timestamp) { return toISO8601StringWithMicroseconds((input.toInstant())) } else if (input is Number) { - return toISO8601String( - Timestamp(input.toLong()).toLocalDateTime()) + return toISO8601String(Timestamp(input.toLong()).toLocalDateTime()) } else if (input is Date) { return toISO8601String(input) } else if (input is String) { @@ -56,7 +54,10 @@ class DebeziumConverterUtils private constructor() { return input.toString() } } - LOGGER.warn("Uncovered date class type '{}'. Use default converter", input.javaClass.name) + LOGGER.warn( + "Uncovered date class type '{}'. Use default converter", + input.javaClass.name + ) return input.toString() } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumEventConverter.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumEventConverter.kt index 4a4d52ede172b..806371b69f749 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumEventConverter.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumEventConverter.kt @@ -15,29 +15,32 @@ interface DebeziumEventConverter { companion object { fun buildAirbyteMessage( - source: JsonNode?, - cdcMetadataInjector: CdcMetadataInjector<*>, - emittedAt: Instant, - data: JsonNode?): AirbyteMessage { + source: JsonNode?, + cdcMetadataInjector: CdcMetadataInjector<*>, + emittedAt: Instant, + data: JsonNode? + ): AirbyteMessage { val streamNamespace = cdcMetadataInjector.namespace(source) val streamName = cdcMetadataInjector.name(source) - val airbyteRecordMessage = AirbyteRecordMessage() + val airbyteRecordMessage = + AirbyteRecordMessage() .withStream(streamName) .withNamespace(streamNamespace) .withEmittedAt(emittedAt.toEpochMilli()) .withData(data) return AirbyteMessage() - .withType(AirbyteMessage.Type.RECORD) - .withRecord(airbyteRecordMessage) + .withType(AirbyteMessage.Type.RECORD) + .withRecord(airbyteRecordMessage) } fun addCdcMetadata( - baseNode: ObjectNode, - source: JsonNode, - cdcMetadataInjector: CdcMetadataInjector<*>, - isDelete: Boolean): JsonNode { + baseNode: ObjectNode, + source: JsonNode, + cdcMetadataInjector: CdcMetadataInjector<*>, + isDelete: Boolean + ): JsonNode { val transactionMillis = source["ts_ms"].asLong() val transactionTimestamp = Instant.ofEpochMilli(transactionMillis).toString() diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumPropertiesManager.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumPropertiesManager.kt index 7045e7020f497..d4787d615bc0a 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumPropertiesManager.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumPropertiesManager.kt @@ -8,16 +8,19 @@ import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog import io.debezium.spi.common.ReplacementFunction import java.util.* -abstract class DebeziumPropertiesManager(private val properties: Properties, - private val config: JsonNode, - private val catalog: ConfiguredAirbyteCatalog) { +abstract class DebeziumPropertiesManager( + private val properties: Properties, + private val config: JsonNode, + private val catalog: ConfiguredAirbyteCatalog +) { fun getDebeziumProperties(offsetManager: AirbyteFileOffsetBackingStore): Properties { return getDebeziumProperties(offsetManager, Optional.empty()) } fun getDebeziumProperties( - offsetManager: AirbyteFileOffsetBackingStore, - schemaHistoryManager: Optional): Properties { + offsetManager: AirbyteFileOffsetBackingStore, + schemaHistoryManager: Optional + ): Properties { val props = Properties() props.putAll(properties) @@ -33,7 +36,9 @@ abstract class DebeziumPropertiesManager(private val properties: Properties, props.setProperty("errors.retry.delay.initial.ms", "299") props.setProperty("errors.retry.delay.max.ms", "300") - schemaHistoryManager.ifPresent { m: AirbyteSchemaHistoryStorage? -> m!!.setDebeziumProperties(props) } + schemaHistoryManager.ifPresent { m: AirbyteSchemaHistoryStorage -> + m.setDebeziumProperties(props) + } // https://debezium.io/documentation/reference/2.2/configuration/avro.html props.setProperty("key.converter.schemas.enable", "false") @@ -45,8 +50,10 @@ abstract class DebeziumPropertiesManager(private val properties: Properties, // connection configuration props.putAll(getConnectionConfiguration(config)) - // By default "decimal.handing.mode=precise" which's caused returning this value as a binary. - // The "double" type may cause a loss of precision, so set Debezium's config to store it as a String + // By default "decimal.handing.mode=precise" which's caused returning this value as a + // binary. + // The "double" type may cause a loss of precision, so set Debezium's config to store it as + // a String // explicitly in its Kafka messages for more details see: // https://debezium.io/documentation/reference/2.2/connectors/postgresql.html#postgresql-decimal-types // https://debezium.io/documentation/faq/#how_to_retrieve_decimal_field_from_binary_representation @@ -55,12 +62,14 @@ abstract class DebeziumPropertiesManager(private val properties: Properties, // https://debezium.io/documentation/reference/2.2/connectors/postgresql.html#postgresql-property-max-queue-size-in-bytes props.setProperty("max.queue.size.in.bytes", BYTE_VALUE_256_MB) - // WARNING : Never change the value of this otherwise all the connectors would start syncing from + // WARNING : Never change the value of this otherwise all the connectors would start syncing + // from // scratch. props.setProperty(TOPIC_PREFIX_KEY, sanitizeTopicPrefix(getName(config))) // https://issues.redhat.com/browse/DBZ-7635 // https://cwiki.apache.org/confluence/display/KAFKA/KIP-581%3A+Value+of+optional+null+field+which+has+default+value - // A null value in a column with default value won't be generated correctly in CDC unless we set the + // A null value in a column with default value won't be generated correctly in CDC unless we + // set the // following props.setProperty("value.converter.replace.null.with.default", "false") // includes @@ -73,7 +82,10 @@ abstract class DebeziumPropertiesManager(private val properties: Properties, protected abstract fun getName(config: JsonNode): String - protected abstract fun getIncludeConfiguration(catalog: ConfiguredAirbyteCatalog, config: JsonNode?): Properties + protected abstract fun getIncludeConfiguration( + catalog: ConfiguredAirbyteCatalog, + config: JsonNode? + ): Properties companion object { private const val BYTE_VALUE_256_MB = (256 * 1024 * 1024).toString() @@ -90,7 +102,9 @@ abstract class DebeziumPropertiesManager(private val properties: Properties, if (isValidCharacter(c)) { sanitizedNameBuilder.append(c) } else { - sanitizedNameBuilder.append(ReplacementFunction.UNDERSCORE_REPLACEMENT.replace(c)) + sanitizedNameBuilder.append( + ReplacementFunction.UNDERSCORE_REPLACEMENT.replace(c) + ) changed = true } } @@ -105,7 +119,12 @@ abstract class DebeziumPropertiesManager(private val properties: Properties, // We need to keep the validation rule the same as debezium engine, which is defined here: // https://github.com/debezium/debezium/blob/c51ef3099a688efb41204702d3aa6d4722bb4825/debezium-core/src/main/java/io/debezium/schema/AbstractTopicNamingStrategy.java#L178 private fun isValidCharacter(c: Char): Boolean { - return c == '.' || c == '_' || c == '-' || c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || c >= '0' && c <= '9' + return c == '.' || + c == '_' || + c == '-' || + c >= 'A' && c <= 'Z' || + c >= 'a' && c <= 'z' || + c >= '0' && c <= '9' } } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumRecordIterator.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumRecordIterator.kt index 08614d0747f41..a0b0253e4d684 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumRecordIterator.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumRecordIterator.kt @@ -9,15 +9,15 @@ import io.airbyte.cdk.integrations.debezium.CdcTargetPosition import io.airbyte.commons.lang.MoreBooleans import io.airbyte.commons.util.AutoCloseableIterator import io.debezium.engine.ChangeEvent -import org.apache.kafka.connect.source.SourceRecord -import org.slf4j.Logger -import org.slf4j.LoggerFactory import java.lang.reflect.Field import java.time.Duration import java.time.LocalDateTime import java.util.* import java.util.concurrent.* import java.util.function.Supplier +import org.apache.kafka.connect.source.SourceRecord +import org.slf4j.Logger +import org.slf4j.LoggerFactory /** * The record iterator is the consumer (in the producer / consumer relationship with debezium) @@ -30,19 +30,22 @@ import java.util.function.Supplier * publisher is not closed. Even after the publisher is closed, the consumer will finish processing * any produced records before closing. */ -class DebeziumRecordIterator(private val queue: LinkedBlockingQueue>, - private val targetPosition: CdcTargetPosition, - private val publisherStatusSupplier: Supplier, - private val debeziumShutdownProcedure: DebeziumShutdownProcedure>, - private val firstRecordWaitTime: Duration, - subsequentRecordWaitTime: Duration?) : AbstractIterator(), AutoCloseableIterator { - private val heartbeatEventSourceField: MutableMap?>, Field?> = HashMap(1) +class DebeziumRecordIterator( + private val queue: LinkedBlockingQueue>, + private val targetPosition: CdcTargetPosition, + private val publisherStatusSupplier: Supplier, + private val debeziumShutdownProcedure: DebeziumShutdownProcedure>, + private val firstRecordWaitTime: Duration, + subsequentRecordWaitTime: Duration? +) : AbstractIterator(), AutoCloseableIterator { + private val heartbeatEventSourceField: MutableMap?>, Field?> = + HashMap(1) private val subsequentRecordWaitTime: Duration = firstRecordWaitTime.dividedBy(2) private var receivedFirstRecord = false private var hasSnapshotFinished = true private var tsLastHeartbeat: LocalDateTime? = null - private var lastHeartbeatPosition: T = null + private var lastHeartbeatPosition: T? = null private var maxInstanceOfNoRecordsFound = 0 private var signalledDebeziumEngineShutdown = false @@ -55,24 +58,33 @@ class DebeziumRecordIterator(private val queue: LinkedBlockingQueue? - val waitTime = if (receivedFirstRecord) this.subsequentRecordWaitTime else this.firstRecordWaitTime + val waitTime = + if (receivedFirstRecord) this.subsequentRecordWaitTime else this.firstRecordWaitTime try { next = queue.poll(waitTime.seconds, TimeUnit.SECONDS) } catch (e: InterruptedException) { throw RuntimeException(e) } - // if within the timeout, the consumer could not get a record, it is time to tell the producer to + // if within the timeout, the consumer could not get a record, it is time to tell the + // producer to // shutdown. if (next == null) { - if (!receivedFirstRecord || hasSnapshotFinished || maxInstanceOfNoRecordsFound >= 10) { - requestClose(String.format("No records were returned by Debezium in the timeout seconds %s, closing the engine and iterator", - waitTime.seconds)) + if ( + !receivedFirstRecord || hasSnapshotFinished || maxInstanceOfNoRecordsFound >= 10 + ) { + requestClose( + String.format( + "No records were returned by Debezium in the timeout seconds %s, closing the engine and iterator", + waitTime.seconds + ) + ) } LOGGER.info("no record found. polling again.") maxInstanceOfNoRecordsFound++ @@ -85,11 +97,16 @@ class DebeziumRecordIterator(private val queue: LinkedBlockingQueue(private val queue: LinkedBlockingQueue(private val queue: LinkedBlockingQueue? try { - event = debeziumShutdownProcedure.recordsRemainingAfterShutdown.poll(100, TimeUnit.MILLISECONDS) + event = + debeziumShutdownProcedure.recordsRemainingAfterShutdown.poll( + 100, + TimeUnit.MILLISECONDS + ) } catch (e: InterruptedException) { throw RuntimeException(e) } @@ -138,20 +159,20 @@ class DebeziumRecordIterator(private val queue: LinkedBlockingQueue(private val queue: LinkedBlockingQueue): Boolean { - return targetPosition.isHeartbeatSupported && Objects.nonNull(event) && !event.value()!!.contains("source") + return targetPosition.isHeartbeatSupported && + Objects.nonNull(event) && + !event.value()!!.contains("source") } private fun heartbeatPosNotChanging(): Boolean { if (this.tsLastHeartbeat == null) { return false } - val timeElapsedSinceLastHeartbeatTs = Duration.between(this.tsLastHeartbeat, LocalDateTime.now()) - LOGGER.info("Time since last hb_pos change {}s", timeElapsedSinceLastHeartbeatTs.toSeconds()) + val timeElapsedSinceLastHeartbeatTs = + Duration.between(this.tsLastHeartbeat, LocalDateTime.now()) + LOGGER.info( + "Time since last hb_pos change {}s", + timeElapsedSinceLastHeartbeatTs.toSeconds() + ) // wait time for no change in heartbeat position is half of initial waitTime return timeElapsedSinceLastHeartbeatTs.compareTo(firstRecordWaitTime.dividedBy(2)) > 0 } @@ -192,7 +219,7 @@ class DebeziumRecordIterator(private val queue: LinkedBlockingQueue): T? { + internal fun getHeartbeatPosition(heartbeatEvent: ChangeEvent): T { try { val eventClass: Class?> = heartbeatEvent.javaClass val f: Field? @@ -204,7 +231,10 @@ class DebeziumRecordIterator(private val queue: LinkedBlockingQueue 1) { - LOGGER.warn("Field Cache size growing beyond expected size of 1, size is " + heartbeatEventSourceField.size) + LOGGER.warn( + "Field Cache size growing beyond expected size of 1, size is " + + heartbeatEventSourceField.size + ) } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumRecordPublisher.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumRecordPublisher.kt index d382bec64a326..4e0bfc1e14e8f 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumRecordPublisher.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumRecordPublisher.kt @@ -7,18 +7,19 @@ import io.debezium.engine.ChangeEvent import io.debezium.engine.DebeziumEngine import io.debezium.engine.format.Json import io.debezium.engine.spi.OffsetCommitPolicy -import org.slf4j.Logger -import org.slf4j.LoggerFactory import java.util.* import java.util.concurrent.* import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicReference +import org.slf4j.Logger +import org.slf4j.LoggerFactory /** * The purpose of this class is to initialize and spawn the debezium engine with the right * properties to fetch records */ -class DebeziumRecordPublisher(private val debeziumPropertiesManager: DebeziumPropertiesManager) : AutoCloseable { +class DebeziumRecordPublisher(private val debeziumPropertiesManager: DebeziumPropertiesManager) : + AutoCloseable { private val executor: ExecutorService = Executors.newSingleThreadExecutor() private var engine: DebeziumEngine>? = null private val hasClosed = AtomicBoolean(false) @@ -26,14 +27,23 @@ class DebeziumRecordPublisher(private val debeziumPropertiesManager: DebeziumPro private val thrownError = AtomicReference() private val engineLatch = CountDownLatch(1) - fun start(queue: BlockingQueue>, - offsetManager: AirbyteFileOffsetBackingStore, - schemaHistoryManager: Optional) { - engine = DebeziumEngine.create(Json::class.java) - .using(debeziumPropertiesManager.getDebeziumProperties(offsetManager, schemaHistoryManager)) + fun start( + queue: BlockingQueue>, + offsetManager: AirbyteFileOffsetBackingStore, + schemaHistoryManager: Optional + ) { + engine = + DebeziumEngine.create(Json::class.java) + .using( + debeziumPropertiesManager.getDebeziumProperties( + offsetManager, + schemaHistoryManager + ) + ) .using(OffsetCommitPolicy.AlwaysCommitOffsetPolicy()) .notifying { e: ChangeEvent -> - // debezium outputs a tombstone event that has a value of null. this is an artifact of how it + // debezium outputs a tombstone event that has a value of null. this is an + // artifact of how it // interacts with kafka. we want to ignore it. // more on the tombstone: // https://debezium.io/documentation/reference/2.2/transformations/event-flattening.html @@ -47,13 +57,17 @@ class DebeziumRecordPublisher(private val debeziumPropertiesManager: DebeziumPro } } .using { success: Boolean, message: String?, error: Throwable? -> - LOGGER.info("Debezium engine shutdown. Engine terminated successfully : {}", success) + LOGGER.info( + "Debezium engine shutdown. Engine terminated successfully : {}", + success + ) LOGGER.info(message) if (!success) { if (error != null) { thrownError.set(error) } else { - // There are cases where Debezium doesn't succeed but only fills the message field. + // There are cases where Debezium doesn't succeed but only fills the + // message field. // In that case, we still want to fail loud and clear thrownError.set(RuntimeException(message)) } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumShutdownProcedure.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumShutdownProcedure.kt index da5c13996f6a3..939303c1cc738 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumShutdownProcedure.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumShutdownProcedure.kt @@ -5,18 +5,20 @@ package io.airbyte.cdk.integrations.debezium.internals import io.airbyte.commons.concurrency.VoidCallable import io.airbyte.commons.lang.MoreBooleans -import org.slf4j.Logger -import org.slf4j.LoggerFactory import java.util.concurrent.* import java.util.function.Supplier +import org.slf4j.Logger +import org.slf4j.LoggerFactory /** * This class has the logic for shutting down Debezium Engine in graceful manner. We made it Generic * to allow us to write tests easily. */ -class DebeziumShutdownProcedure(private val sourceQueue: LinkedBlockingQueue, - private val debeziumThreadRequestClose: VoidCallable, - private val publisherStatusSupplier: Supplier) { +class DebeziumShutdownProcedure( + private val sourceQueue: LinkedBlockingQueue, + private val debeziumThreadRequestClose: VoidCallable, + private val publisherStatusSupplier: Supplier +) { private val targetQueue = LinkedBlockingQueue() private val executorService: ExecutorService private var exception: Throwable? = null @@ -24,13 +26,13 @@ class DebeziumShutdownProcedure(private val sourceQueue: LinkedBlockingQueue< init { this.hasTransferThreadShutdown = false - this.executorService = Executors.newSingleThreadExecutor { r: Runnable? -> - val thread = Thread(r, "queue-data-transfer-thread") - thread.uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { t: Thread?, e: Throwable? -> - exception = e + this.executorService = + Executors.newSingleThreadExecutor { r: Runnable? -> + val thread = Thread(r, "queue-data-transfer-thread") + thread.uncaughtExceptionHandler = + Thread.UncaughtExceptionHandler { t: Thread?, e: Throwable? -> exception = e } + thread } - thread - } } private fun transfer(): Runnable { @@ -60,19 +62,21 @@ class DebeziumShutdownProcedure(private val sourceQueue: LinkedBlockingQueue< val recordsRemainingAfterShutdown: LinkedBlockingQueue get() { if (!hasTransferThreadShutdown) { - LOGGER.warn("Queue transfer thread has not shut down, some records might be missing.") + LOGGER.warn( + "Queue transfer thread has not shut down, some records might be missing." + ) } return targetQueue } /** - * This method triggers the shutdown of Debezium Engine. When we trigger Debezium shutdown, the main - * thread pauses, as a result we stop reading data from the [sourceQueue] and since the queue - * is of fixed size, if it's already at capacity, Debezium won't be able to put remaining records in - * the queue. So before we trigger Debezium shutdown, we initiate a transfer of the records from the - * [sourceQueue] to a new queue i.e. [targetQueue]. This allows Debezium to continue to - * put records in the [sourceQueue] and once done, gracefully shutdown. After the shutdown is - * complete we just have to read the remaining records from the [targetQueue] + * This method triggers the shutdown of Debezium Engine. When we trigger Debezium shutdown, the + * main thread pauses, as a result we stop reading data from the [sourceQueue] and since the + * queue is of fixed size, if it's already at capacity, Debezium won't be able to put remaining + * records in the queue. So before we trigger Debezium shutdown, we initiate a transfer of the + * records from the [sourceQueue] to a new queue i.e. [targetQueue]. This allows Debezium to + * continue to put records in the [sourceQueue] and once done, gracefully shutdown. After the + * shutdown is complete we just have to read the remaining records from the [targetQueue] */ fun initiateShutdownProcedure() { if (hasEngineShutDown()) { diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumStateDecoratingIterator.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumStateDecoratingIterator.kt index 4711a6320c0aa..2c1313f99f326 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumStateDecoratingIterator.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumStateDecoratingIterator.kt @@ -8,42 +8,40 @@ import io.airbyte.cdk.integrations.debezium.CdcStateHandler import io.airbyte.cdk.integrations.debezium.CdcTargetPosition import io.airbyte.protocol.models.v0.AirbyteMessage import io.airbyte.protocol.models.v0.AirbyteStateStats +import java.time.Duration +import java.time.OffsetDateTime import org.apache.kafka.connect.errors.ConnectException import org.slf4j.Logger import org.slf4j.LoggerFactory -import java.time.Duration -import java.time.OffsetDateTime /** * This class encapsulates CDC change events and adds the required functionality to create * checkpoints for CDC replications. That way, if the process fails in the middle of a long sync, it * will be able to recover for any acknowledged checkpoint in the next syncs. */ -class DebeziumStateDecoratingIterator(private val changeEventIterator: Iterator, - private val cdcStateHandler: CdcStateHandler, - private val targetPosition: CdcTargetPosition, - private val eventConverter: DebeziumEventConverter, - offsetManager: AirbyteFileOffsetBackingStore, - private val trackSchemaHistory: Boolean, - private val schemaHistoryManager: AirbyteSchemaHistoryStorage?, - checkpointDuration: Duration, - checkpointRecords: Long) : AbstractIterator(), MutableIterator { +class DebeziumStateDecoratingIterator( + private val changeEventIterator: Iterator, + private val cdcStateHandler: CdcStateHandler, + private val targetPosition: CdcTargetPosition, + private val eventConverter: DebeziumEventConverter, + offsetManager: AirbyteFileOffsetBackingStore, + private val trackSchemaHistory: Boolean, + private val schemaHistoryManager: AirbyteSchemaHistoryStorage?, + checkpointDuration: Duration, + checkpointRecords: Long +) : AbstractIterator(), MutableIterator { private val offsetManager: AirbyteFileOffsetBackingStore? = offsetManager private var isSyncFinished = false /** - * These parameters control when a checkpoint message has to be sent in a CDC integration. We can - * emit a checkpoint when any of the following two conditions are met. - * + * These parameters control when a checkpoint message has to be sent in a CDC integration. We + * can emit a checkpoint when any of the following two conditions are met. * * 1. The amount of records in the current loop (`SYNC_CHECKPOINT_RECORDS`) is higher than a * threshold defined by `SYNC_CHECKPOINT_RECORDS`. * - * - * 2. Time between checkpoints (`dateTimeLastSync`) is higher than a `Duration` defined - * at `SYNC_CHECKPOINT_SECONDS`. - * - * + * 2. Time between checkpoints (`dateTimeLastSync`) is higher than a `Duration` defined at + * `SYNC_CHECKPOINT_SECONDS`. */ private val syncCheckpointDuration = checkpointDuration private val syncCheckpointRecords = checkpointRecords @@ -55,21 +53,22 @@ class DebeziumStateDecoratingIterator(private val changeEventIterator: Iterat /** * `checkpointOffsetToSend` is used as temporal storage for the offset that we want to send as * message. As Debezium is reading records faster that we process them, if we try to send - * `offsetManger.read()` offset, it is possible that the state is behind the record we are currently - * propagating. To avoid that, we store the offset as soon as we reach the checkpoint threshold - * (time or records) and we wait to send it until we are sure that the record we are processing is - * behind the offset to be sent. + * `offsetManger.read()` offset, it is possible that the state is behind the record we are + * currently propagating. To avoid that, we store the offset as soon as we reach the checkpoint + * threshold (time or records) and we wait to send it until we are sure that the record we are + * processing is behind the offset to be sent. */ private val checkpointOffsetToSend = HashMap() /** * `previousCheckpointOffset` is used to make sure we don't send duplicated states with the same - * offset. Is it possible that the offset Debezium report doesn't move for a period of time, and if - * we just rely on the `offsetManger.read()`, there is a chance to sent duplicate states, generating - * an unneeded usage of networking and processing. + * offset. Is it possible that the offset Debezium report doesn't move for a period of time, and + * if we just rely on the `offsetManger.read()`, there is a chance to sent duplicate states, + * generating an unneeded usage of networking and processing. */ private val initialOffset: HashMap - private val previousCheckpointOffset: HashMap? = offsetManager.read() as HashMap + private val previousCheckpointOffset: HashMap? = + offsetManager.read() as HashMap /** * @param changeEventIterator Base iterator that we want to enrich with checkpoint messages @@ -88,15 +87,12 @@ class DebeziumStateDecoratingIterator(private val changeEventIterator: Iterat } /** - * Computes the next record retrieved from Source stream. Emits state messages as checkpoints based - * on number of records or time lapsed. - * - * - * - * If this method throws an exception, it will propagate outward to the `hasNext` or - * `next` invocation that invoked this method. Any further attempts to use the iterator will - * result in an [IllegalStateException]. + * Computes the next record retrieved from Source stream. Emits state messages as checkpoints + * based on number of records or time lapsed. * + * If this method throws an exception, it will propagate outward to the `hasNext` or `next` + * invocation that invoked this method. Any further attempts to use the iterator will result in + * an [IllegalStateException]. * * @return [AirbyteStateMessage] containing CDC data or state checkpoint message. */ @@ -118,27 +114,41 @@ class DebeziumStateDecoratingIterator(private val changeEventIterator: Iterat val event = changeEventIterator.next() if (cdcStateHandler.isCdcCheckpointEnabled) { - if (checkpointOffsetToSend.isEmpty() && + if ( + checkpointOffsetToSend.isEmpty() && (recordsLastSync >= syncCheckpointRecords || - Duration.between(dateTimeLastSync, OffsetDateTime.now()).compareTo(syncCheckpointDuration) > 0)) { - // Using temporal variable to avoid reading teh offset twice, one in the condition and another in + Duration.between(dateTimeLastSync, OffsetDateTime.now()) + .compareTo(syncCheckpointDuration) > 0) + ) { + // Using temporal variable to avoid reading teh offset twice, one in the + // condition and another in // the assignation try { val temporalOffset = offsetManager!!.read() as HashMap - if (!targetPosition.isSameOffset(previousCheckpointOffset, temporalOffset)) { + if ( + !targetPosition.isSameOffset(previousCheckpointOffset, temporalOffset) + ) { checkpointOffsetToSend.putAll(temporalOffset) } } catch (e: ConnectException) { - LOGGER.warn("Offset file is being written by Debezium. Skipping CDC checkpoint in this loop.") + LOGGER.warn( + "Offset file is being written by Debezium. Skipping CDC checkpoint in this loop." + ) } } - if (checkpointOffsetToSend.size == 1 && changeEventIterator.hasNext() - && !event.isSnapshotEvent) { + if ( + checkpointOffsetToSend.size == 1 && + changeEventIterator.hasNext() && + !event.isSnapshotEvent + ) { if (targetPosition.isEventAheadOffset(checkpointOffsetToSend, event)) { sendCheckpointMessage = true } else { - LOGGER.info("Encountered {} records with the same event offset", recordsLastSync) + LOGGER.info( + "Encountered {} records with the same event offset", + recordsLastSync + ) } } } @@ -149,7 +159,9 @@ class DebeziumStateDecoratingIterator(private val changeEventIterator: Iterat isSyncFinished = true val syncFinishedOffset = offsetManager!!.read() as HashMap - if (recordsAllSyncs == 0L && targetPosition.isSameOffset(initialOffset, syncFinishedOffset)) { + if ( + recordsAllSyncs == 0L && targetPosition.isSameOffset(initialOffset, syncFinishedOffset) + ) { // Edge case where no progress has been made: wrap up the // sync by returning the initial offset instead of the // current offset. We do this because we found that @@ -162,9 +174,7 @@ class DebeziumStateDecoratingIterator(private val changeEventIterator: Iterat return createStateMessage(syncFinishedOffset, recordsLastSync) } - /** - * Initialize or reset the checkpoint variables. - */ + /** Initialize or reset the checkpoint variables. */ private fun resetCheckpointValues() { sendCheckpointMessage = false checkpointOffsetToSend.clear() @@ -178,7 +188,10 @@ class DebeziumStateDecoratingIterator(private val changeEventIterator: Iterat * * @return [AirbyteStateMessage] which includes offset and schema history if used. */ - private fun createStateMessage(offset: Map?, recordCount: Long): AirbyteMessage? { + private fun createStateMessage( + offset: Map?, + recordCount: Long + ): AirbyteMessage? { if (trackSchemaHistory && schemaHistoryManager == null) { throw RuntimeException("Schema History Tracking is true but manager is not initialised") } @@ -192,6 +205,7 @@ class DebeziumStateDecoratingIterator(private val changeEventIterator: Iterat } companion object { - private val LOGGER: Logger = LoggerFactory.getLogger(DebeziumStateDecoratingIterator::class.java) + private val LOGGER: Logger = + LoggerFactory.getLogger(DebeziumStateDecoratingIterator::class.java) } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumStateUtil.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumStateUtil.kt index 1c5a99e64660b..fbc6534eb0915 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumStateUtil.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumStateUtil.kt @@ -5,19 +5,17 @@ package io.airbyte.cdk.integrations.debezium.internals import io.debezium.config.Configuration import io.debezium.embedded.KafkaConnectUtil +import java.lang.Boolean +import java.util.* +import kotlin.String import org.apache.kafka.connect.json.JsonConverter import org.apache.kafka.connect.json.JsonConverterConfig import org.apache.kafka.connect.runtime.WorkerConfig import org.apache.kafka.connect.runtime.standalone.StandaloneConfig import org.apache.kafka.connect.storage.FileOffsetBackingStore import org.apache.kafka.connect.storage.OffsetStorageReaderImpl -import java.lang.Boolean -import java.util.* -import kotlin.String -/** - * Represents a utility class that assists with the parsing of Debezium offset state. - */ +/** Represents a utility class that assists with the parsing of Debezium offset state. */ interface DebeziumStateUtil { /** * Creates and starts a [FileOffsetBackingStore] that is used to store the tracked Debezium @@ -38,8 +36,8 @@ interface DebeziumStateUtil { val keyConverter: JsonConverter? /** - * Creates and returns a [JsonConverter] that can be used to parse keys in the Debezium offset - * state storage. + * Creates and returns a [JsonConverter] that can be used to parse keys in the Debezium + * offset state storage. * * @return A [JsonConverter] for key conversion. */ @@ -56,12 +54,19 @@ interface DebeziumStateUtil { * @param fileOffsetBackingStore The [FileOffsetBackingStore] that contains the offset state * saved to disk. * @param properties The Debezium configuration properties for the selected Debezium connector. - * @return An [OffsetStorageReaderImpl] instance that can be used to load the offset state - * from the offset file storage. + * @return An [OffsetStorageReaderImpl] instance that can be used to load the offset state from + * the offset file storage. */ - fun getOffsetStorageReader(fileOffsetBackingStore: FileOffsetBackingStore?, properties: Properties): OffsetStorageReaderImpl? { - return OffsetStorageReaderImpl(fileOffsetBackingStore, properties.getProperty(CONNECTOR_NAME_PROPERTY), keyConverter, - valueConverter) + fun getOffsetStorageReader( + fileOffsetBackingStore: FileOffsetBackingStore?, + properties: Properties + ): OffsetStorageReaderImpl? { + return OffsetStorageReaderImpl( + fileOffsetBackingStore, + properties.getProperty(CONNECTOR_NAME_PROPERTY), + keyConverter, + valueConverter + ) } val valueConverter: JsonConverter? @@ -79,13 +84,13 @@ interface DebeziumStateUtil { companion object { /** - * The name of the Debezium property that contains the unique name for the Debezium connector. + * The name of the Debezium property that contains the unique name for the Debezium + * connector. */ const val CONNECTOR_NAME_PROPERTY: String = "name" - /** - * Configuration for offset state key/value converters. - */ - val INTERNAL_CONVERTER_CONFIG: Map = java.util.Map.of(JsonConverterConfig.SCHEMAS_ENABLE_CONFIG, Boolean.FALSE.toString()) + /** Configuration for offset state key/value converters. */ + val INTERNAL_CONVERTER_CONFIG: Map = + java.util.Map.of(JsonConverterConfig.SCHEMAS_ENABLE_CONFIG, Boolean.FALSE.toString()) } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/RecordWaitTimeUtil.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/RecordWaitTimeUtil.kt index 37c4e13fbdd52..17bf9e0e1512e 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/RecordWaitTimeUtil.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/RecordWaitTimeUtil.kt @@ -4,10 +4,10 @@ package io.airbyte.cdk.integrations.debezium.internals import com.fasterxml.jackson.databind.JsonNode -import org.slf4j.Logger -import org.slf4j.LoggerFactory import java.time.Duration import java.util.* +import org.slf4j.Logger +import org.slf4j.LoggerFactory object RecordWaitTimeUtil { private val LOGGER: Logger = LoggerFactory.getLogger(RecordWaitTimeUtil::class.java) @@ -28,9 +28,15 @@ object RecordWaitTimeUtil { val firstRecordWaitSeconds = getFirstRecordWaitSeconds(config) if (firstRecordWaitSeconds.isPresent) { val seconds = firstRecordWaitSeconds.get() - require(!(seconds < MIN_FIRST_RECORD_WAIT_TIME.seconds || seconds > MAX_FIRST_RECORD_WAIT_TIME.seconds)) { - String.format("initial_waiting_seconds must be between %d and %d seconds", - MIN_FIRST_RECORD_WAIT_TIME.seconds, MAX_FIRST_RECORD_WAIT_TIME.seconds) + require( + !(seconds < MIN_FIRST_RECORD_WAIT_TIME.seconds || + seconds > MAX_FIRST_RECORD_WAIT_TIME.seconds) + ) { + String.format( + "initial_waiting_seconds must be between %d and %d seconds", + MIN_FIRST_RECORD_WAIT_TIME.seconds, + MAX_FIRST_RECORD_WAIT_TIME.seconds + ) } } } @@ -43,12 +49,16 @@ object RecordWaitTimeUtil { if (firstRecordWaitSeconds.isPresent) { firstRecordWaitTime = Duration.ofSeconds(firstRecordWaitSeconds.get().toLong()) if (!isTest && firstRecordWaitTime.compareTo(MIN_FIRST_RECORD_WAIT_TIME) < 0) { - LOGGER.warn("First record waiting time is overridden to {} minutes, which is the min time allowed for safety.", - MIN_FIRST_RECORD_WAIT_TIME.toMinutes()) + LOGGER.warn( + "First record waiting time is overridden to {} minutes, which is the min time allowed for safety.", + MIN_FIRST_RECORD_WAIT_TIME.toMinutes() + ) firstRecordWaitTime = MIN_FIRST_RECORD_WAIT_TIME } else if (!isTest && firstRecordWaitTime.compareTo(MAX_FIRST_RECORD_WAIT_TIME) > 0) { - LOGGER.warn("First record waiting time is overridden to {} minutes, which is the max time allowed for safety.", - MAX_FIRST_RECORD_WAIT_TIME.toMinutes()) + LOGGER.warn( + "First record waiting time is overridden to {} minutes, which is the max time allowed for safety.", + MAX_FIRST_RECORD_WAIT_TIME.toMinutes() + ) firstRecordWaitTime = MAX_FIRST_RECORD_WAIT_TIME } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/RelationalDbDebeziumEventConverter.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/RelationalDbDebeziumEventConverter.kt index 59cc5388a158b..b7e09e7c9b9ed 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/RelationalDbDebeziumEventConverter.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/RelationalDbDebeziumEventConverter.kt @@ -9,7 +9,10 @@ import io.airbyte.cdk.integrations.debezium.CdcMetadataInjector import io.airbyte.protocol.models.v0.AirbyteMessage import java.time.Instant -class RelationalDbDebeziumEventConverter(private val cdcMetadataInjector: CdcMetadataInjector<*>, private val emittedAt: Instant) : DebeziumEventConverter { +class RelationalDbDebeziumEventConverter( + private val cdcMetadataInjector: CdcMetadataInjector<*>, + private val emittedAt: Instant +) : DebeziumEventConverter { override fun toAirbyteMessage(event: ChangeEventWithMetadata): AirbyteMessage { val debeziumEvent = event.eventValueAsJson() val before: JsonNode = debeziumEvent!!.get(DebeziumEventConverter.Companion.BEFORE_EVENT) @@ -17,7 +20,18 @@ class RelationalDbDebeziumEventConverter(private val cdcMetadataInjector: CdcMet val source: JsonNode = debeziumEvent.get(DebeziumEventConverter.Companion.SOURCE_EVENT) val baseNode = (if (after.isNull) before else after) as ObjectNode - val data: JsonNode = DebeziumEventConverter.Companion.addCdcMetadata(baseNode, source, cdcMetadataInjector, after.isNull) - return DebeziumEventConverter.Companion.buildAirbyteMessage(source, cdcMetadataInjector, emittedAt, data) + val data: JsonNode = + DebeziumEventConverter.Companion.addCdcMetadata( + baseNode, + source, + cdcMetadataInjector, + after.isNull + ) + return DebeziumEventConverter.Companion.buildAirbyteMessage( + source, + cdcMetadataInjector, + emittedAt, + data + ) } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/RelationalDbDebeziumPropertiesManager.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/RelationalDbDebeziumPropertiesManager.kt index 7bf6843e1322c..c78ead79f77d1 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/RelationalDbDebeziumPropertiesManager.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/RelationalDbDebeziumPropertiesManager.kt @@ -9,15 +9,17 @@ import io.airbyte.protocol.models.v0.AirbyteStream import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream import io.airbyte.protocol.models.v0.SyncMode -import org.codehaus.plexus.util.StringUtils import java.util.* import java.util.regex.Pattern import java.util.stream.Collectors import java.util.stream.StreamSupport +import org.codehaus.plexus.util.StringUtils -class RelationalDbDebeziumPropertiesManager(properties: Properties, - config: JsonNode, - catalog: ConfiguredAirbyteCatalog) : DebeziumPropertiesManager(properties, config, catalog) { +class RelationalDbDebeziumPropertiesManager( + properties: Properties, + config: JsonNode, + catalog: ConfiguredAirbyteCatalog +) : DebeziumPropertiesManager(properties, config, catalog) { override fun getConnectionConfiguration(config: JsonNode): Properties { val properties = Properties() @@ -38,7 +40,10 @@ class RelationalDbDebeziumPropertiesManager(properties: Properties, return config[JdbcUtils.DATABASE_KEY].asText() } - override fun getIncludeConfiguration(catalog: ConfiguredAirbyteCatalog, config: JsonNode?): Properties { + override fun getIncludeConfiguration( + catalog: ConfiguredAirbyteCatalog, + config: JsonNode? + ): Properties { val properties = Properties() // table selection @@ -60,12 +65,15 @@ class RelationalDbDebeziumPropertiesManager(properties: Properties, // "name": "table2 // } -------> info "schema1.table1, schema2.table2" - return catalog.streams.stream() - .filter { s: ConfiguredAirbyteStream -> s.syncMode == SyncMode.INCREMENTAL } - .map { obj: ConfiguredAirbyteStream -> obj.stream } - .map { stream: AirbyteStream -> stream.namespace + "." + stream.name } // debezium needs commas escaped to split properly - .map { x: String? -> StringUtils.escape(Pattern.quote(x), ",".toCharArray(), "\\,") } - .collect(Collectors.joining(",")) + return catalog.streams + .stream() + .filter { s: ConfiguredAirbyteStream -> s.syncMode == SyncMode.INCREMENTAL } + .map { obj: ConfiguredAirbyteStream -> obj.stream } + .map { stream: AirbyteStream -> + stream.namespace + "." + stream.name + } // debezium needs commas escaped to split properly + .map { x: String -> StringUtils.escape(Pattern.quote(x), ",".toCharArray(), "\\,") } + .collect(Collectors.joining(",")) } fun getColumnIncludeList(catalog: ConfiguredAirbyteCatalog): String { @@ -82,15 +90,17 @@ class RelationalDbDebeziumPropertiesManager(properties: Properties, // } // } -------> info "schema1.table1.(column1 | column2)" - return catalog.streams.stream() - .filter { s: ConfiguredAirbyteStream -> s.syncMode == SyncMode.INCREMENTAL } - .map { obj: ConfiguredAirbyteStream -> obj.stream } - .map { s: AirbyteStream -> - val fields = parseFields(s.jsonSchema["properties"].fieldNames()) - Pattern.quote(s.namespace + "." + s.name) + (if (StringUtils.isNotBlank(fields)) "\\.$fields" else "") - } - .map { x: String? -> StringUtils.escape(x, ",".toCharArray(), "\\,") } - .collect(Collectors.joining(",")) + return catalog.streams + .stream() + .filter { s: ConfiguredAirbyteStream -> s.syncMode == SyncMode.INCREMENTAL } + .map { obj: ConfiguredAirbyteStream -> obj.stream } + .map { s: AirbyteStream -> + val fields = parseFields(s.jsonSchema["properties"].fieldNames()) + Pattern.quote(s.namespace + "." + s.name) + + (if (StringUtils.isNotBlank(fields)) "\\.$fields" else "") + } + .map { x: String? -> StringUtils.escape(x, ",".toCharArray(), "\\,") } + .collect(Collectors.joining(",")) } private fun parseFields(fieldNames: Iterator?): String { @@ -99,8 +109,8 @@ class RelationalDbDebeziumPropertiesManager(properties: Properties, } val iter = Iterable { fieldNames } return StreamSupport.stream(iter.spliterator(), false) - .map { f: String? -> Pattern.quote(f) } - .collect(Collectors.joining("|", "(", ")")) + .map { f: String -> Pattern.quote(f) } + .collect(Collectors.joining("|", "(", ")")) } } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/SnapshotMetadata.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/SnapshotMetadata.kt index 7bc17918f4c7a..f34141431ca17 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/SnapshotMetadata.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/debezium/internals/SnapshotMetadata.kt @@ -15,7 +15,8 @@ enum class SnapshotMetadata { NULL; companion object { - private val ENTRIES_OF_SNAPSHOT_EVENTS: Set = ImmutableSet.of(TRUE, FIRST, FIRST_IN_DATA_COLLECTION, LAST_IN_DATA_COLLECTION) + private val ENTRIES_OF_SNAPSHOT_EVENTS: Set = + ImmutableSet.of(TRUE, FIRST, FIRST_IN_DATA_COLLECTION, LAST_IN_DATA_COLLECTION) private val STRING_TO_ENUM: MutableMap = HashMap(12) init { diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/AbstractJdbcSource.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/AbstractJdbcSource.kt index 9ed5df111d046..b595adbe2f19d 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/AbstractJdbcSource.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/AbstractJdbcSource.kt @@ -9,10 +9,29 @@ import com.google.common.collect.ImmutableList import com.google.common.collect.ImmutableMap import com.google.common.collect.Sets import datadog.trace.api.Trace +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings import io.airbyte.cdk.db.JdbcCompatibleSourceOperations import io.airbyte.cdk.db.SqlDatabase import io.airbyte.cdk.db.factory.DataSourceFactory.close import io.airbyte.cdk.db.factory.DataSourceFactory.create +import io.airbyte.cdk.db.jdbc.JdbcConstants.INTERNAL_COLUMN_NAME +import io.airbyte.cdk.db.jdbc.JdbcConstants.INTERNAL_COLUMN_SIZE +import io.airbyte.cdk.db.jdbc.JdbcConstants.INTERNAL_COLUMN_TYPE +import io.airbyte.cdk.db.jdbc.JdbcConstants.INTERNAL_COLUMN_TYPE_NAME +import io.airbyte.cdk.db.jdbc.JdbcConstants.INTERNAL_DECIMAL_DIGITS +import io.airbyte.cdk.db.jdbc.JdbcConstants.INTERNAL_IS_NULLABLE +import io.airbyte.cdk.db.jdbc.JdbcConstants.INTERNAL_SCHEMA_NAME +import io.airbyte.cdk.db.jdbc.JdbcConstants.INTERNAL_TABLE_NAME +import io.airbyte.cdk.db.jdbc.JdbcConstants.JDBC_COLUMN_COLUMN_NAME +import io.airbyte.cdk.db.jdbc.JdbcConstants.JDBC_COLUMN_DATABASE_NAME +import io.airbyte.cdk.db.jdbc.JdbcConstants.JDBC_COLUMN_DATA_TYPE +import io.airbyte.cdk.db.jdbc.JdbcConstants.JDBC_COLUMN_SCHEMA_NAME +import io.airbyte.cdk.db.jdbc.JdbcConstants.JDBC_COLUMN_SIZE +import io.airbyte.cdk.db.jdbc.JdbcConstants.JDBC_COLUMN_TABLE_NAME +import io.airbyte.cdk.db.jdbc.JdbcConstants.JDBC_COLUMN_TYPE_NAME +import io.airbyte.cdk.db.jdbc.JdbcConstants.JDBC_DECIMAL_DIGITS +import io.airbyte.cdk.db.jdbc.JdbcConstants.JDBC_IS_NULLABLE +import io.airbyte.cdk.db.jdbc.JdbcConstants.KEY_SEQ import io.airbyte.cdk.db.jdbc.JdbcDatabase import io.airbyte.cdk.db.jdbc.JdbcUtils import io.airbyte.cdk.db.jdbc.JdbcUtils.getFullyQualifiedTableName @@ -37,9 +56,6 @@ import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream import io.airbyte.protocol.models.v0.SyncMode -import org.apache.commons.lang3.tuple.ImmutablePair -import org.slf4j.Logger -import org.slf4j.LoggerFactory import java.sql.Connection import java.sql.PreparedStatement import java.sql.ResultSet @@ -51,16 +67,24 @@ import java.util.function.Predicate import java.util.function.Supplier import java.util.stream.Collectors import javax.sql.DataSource +import org.apache.commons.lang3.tuple.ImmutablePair +import org.slf4j.Logger +import org.slf4j.LoggerFactory /** * This class contains helper functions and boilerplate for implementing a source connector for a * relational DB source which can be accessed via JDBC driver. If you are implementing a connector * for a relational DB which has a JDBC driver, make an effort to use this class. */ -abstract class AbstractJdbcSource(driverClass: String?, - protected val streamingQueryConfigProvider: Supplier, - sourceOperations: JdbcCompatibleSourceOperations) : AbstractDbSource(driverClass), Source { - protected val sourceOperations: JdbcCompatibleSourceOperations +// This is onoly here because spotbugs complains about aggregatePrimateKeys and I wasn't able to +// figure out what it's complaining about +@SuppressFBWarnings("NP_PARAMETER_MUST_BE_NONNULL_BUT_MARKED_AS_NULLABLE") +abstract class AbstractJdbcSource( + driverClass: String, + protected val streamingQueryConfigProvider: Supplier, + sourceOperations: JdbcCompatibleSourceOperations +) : AbstractDbSource(driverClass), Source { + protected val sourceOperations: JdbcCompatibleSourceOperations override var quoteString: String? = null protected var dataSources: MutableCollection = ArrayList() @@ -69,28 +93,56 @@ abstract class AbstractJdbcSource(driverClass: String?, this.sourceOperations = sourceOperations } - override fun queryTableFullRefresh(database: JdbcDatabase, - columnNames: List, - schemaName: String?, - tableName: String, - syncMode: SyncMode, - cursorField: Optional): AutoCloseableIterator? { + override fun queryTableFullRefresh( + database: JdbcDatabase, + columnNames: List, + schemaName: String?, + tableName: String, + syncMode: SyncMode, + cursorField: Optional + ): AutoCloseableIterator? { LOGGER.info("Queueing query for table: {}", tableName) - // This corresponds to the initial sync for in INCREMENTAL_MODE, where the ordering of the records + val quoteString = this.quoteString!! + // This corresponds to the initial sync for in INCREMENTAL_MODE, where the ordering of the + // records // matters // as intermediate state messages are emitted (if the connector emits intermediate state). if (syncMode == SyncMode.INCREMENTAL && stateEmissionFrequency > 0) { - val quotedCursorField = RelationalDbQueryUtils.enquoteIdentifier(cursorField.get(), quoteString) - return RelationalDbQueryUtils.queryTable(database, String.format("SELECT %s FROM %s ORDER BY %s ASC", + val quotedCursorField = + RelationalDbQueryUtils.enquoteIdentifier(cursorField.get(), quoteString) + return RelationalDbQueryUtils.queryTable( + database, + String.format( + "SELECT %s FROM %s ORDER BY %s ASC", RelationalDbQueryUtils.enquoteIdentifierList(columnNames, quoteString), - RelationalDbQueryUtils.getFullyQualifiedTableNameWithQuoting(schemaName, tableName, quoteString), quotedCursorField), - tableName, schemaName) + RelationalDbQueryUtils.getFullyQualifiedTableNameWithQuoting( + schemaName, + tableName, + quoteString + ), + quotedCursorField + ), + tableName, + schemaName + ) } else { - // If we are in FULL_REFRESH mode, state messages are never emitted, so we don't care about ordering + // If we are in FULL_REFRESH mode, state messages are never emitted, so we don't care + // about ordering // of the records. - return RelationalDbQueryUtils.queryTable(database, String.format("SELECT %s FROM %s", + return RelationalDbQueryUtils.queryTable( + database, + String.format( + "SELECT %s FROM %s", RelationalDbQueryUtils.enquoteIdentifierList(columnNames, quoteString), - RelationalDbQueryUtils.getFullyQualifiedTableNameWithQuoting(schemaName, tableName, quoteString)), tableName, schemaName) + RelationalDbQueryUtils.getFullyQualifiedTableNameWithQuoting( + schemaName, + tableName, + quoteString + ) + ), + tableName, + schemaName + ) } } @@ -101,73 +153,127 @@ abstract class AbstractJdbcSource(driverClass: String?, */ @Trace(operationName = AbstractDbSource.Companion.CHECK_TRACE_OPERATION_NAME) @Throws(Exception::class) - override fun getCheckOperations(config: JsonNode?): List> { - return ImmutableList.of(CheckedConsumer { database: JdbcDatabase -> - LOGGER.info("Attempting to get metadata from the database to see if we can connect.") - database.bufferedResultSetQuery(CheckedFunction { connection: Connection -> connection.metaData.catalogs }, CheckedFunction { queryResult: ResultSet? -> sourceOperations.rowToJson(queryResult!!) }) - }) + override fun getCheckOperations( + config: JsonNode? + ): List> { + return ImmutableList.of( + CheckedConsumer { database: JdbcDatabase -> + LOGGER.info( + "Attempting to get metadata from the database to see if we can connect." + ) + database.bufferedResultSetQuery( + CheckedFunction { connection: Connection -> connection.metaData.catalogs }, + CheckedFunction { queryResult: ResultSet? -> + sourceOperations.rowToJson(queryResult!!) + } + ) + } + ) } private fun getCatalog(database: SqlDatabase): String? { - return (if (database.sourceConfig!!.has(JdbcUtils.DATABASE_KEY)) database.sourceConfig!![JdbcUtils.DATABASE_KEY].asText() else null) + return (if (database.sourceConfig!!.has(JdbcUtils.DATABASE_KEY)) + database.sourceConfig!![JdbcUtils.DATABASE_KEY].asText() + else null) } @Throws(Exception::class) - override fun discoverInternal(database: JdbcDatabase, schema: String?): List>> { + override fun discoverInternal( + database: JdbcDatabase, + schema: String? + ): List>> { val internalSchemas: Set = HashSet(excludedInternalNameSpaces) LOGGER.info("Internal schemas to exclude: {}", internalSchemas) - val tablesWithSelectGrantPrivilege = getPrivilegesTableForCurrentUser(database, schema) - return database.bufferedResultSetQuery( // retrieve column metadata from the database - { connection: Connection -> connection.metaData.getColumns(getCatalog(database), schema, null, null) }, // store essential column metadata to a Json object from the result set about each column - { resultSet: ResultSet -> this.getColumnMetadata(resultSet) }) - .stream() - .filter(excludeNotAccessibleTables(internalSchemas, tablesWithSelectGrantPrivilege)) // group by schema and table name to handle the case where a table with the same name exists in - // multiple schemas. - .collect(Collectors.groupingBy>(Function> { t: JsonNode -> ImmutablePair.of(t.get(INTERNAL_SCHEMA_NAME).asText(), t.get(INTERNAL_TABLE_NAME).asText()) })) - .values - .stream() - .map>> { fields: List -> - TableInfo.builder>() - .nameSpace(fields[0].get(INTERNAL_SCHEMA_NAME).asText()) - .name(fields[0].get(INTERNAL_TABLE_NAME).asText()) - .fields(fields.stream() // read the column metadata Json object, and determine its type - .map { f: JsonNode -> - val datatype = sourceOperations.getDatabaseFieldType(f) - val jsonType = getAirbyteType(datatype) - LOGGER.debug("Table {} column {} (type {}[{}], nullable {}) -> {}", - fields[0].get(INTERNAL_TABLE_NAME).asText(), - f.get(INTERNAL_COLUMN_NAME).asText(), - f.get(INTERNAL_COLUMN_TYPE_NAME).asText(), - f.get(INTERNAL_COLUMN_SIZE).asInt(), - f.get(INTERNAL_IS_NULLABLE).asBoolean(), - jsonType) - object : CommonField(f.get(INTERNAL_COLUMN_NAME).asText(), datatype) {} - } - .collect(Collectors.toList>())) - .cursorFields(extractCursorFields(fields)) - .build() - } - .collect(Collectors.toList>>()) + val tablesWithSelectGrantPrivilege = + getPrivilegesTableForCurrentUser(database, schema) + return database + .bufferedResultSetQuery( // retrieve column metadata from the database + { connection: Connection -> + connection.metaData.getColumns(getCatalog(database), schema, null, null) + }, // store essential column metadata to a Json object from the result set about + // each column + { resultSet: ResultSet -> this.getColumnMetadata(resultSet) } + ) + .stream() + .filter( + excludeNotAccessibleTables(internalSchemas, tablesWithSelectGrantPrivilege) + ) // group by schema and table name to handle the case where a table with the same name + // exists in + // multiple schemas. + .collect( + Collectors.groupingBy>( + Function> { t: JsonNode -> + ImmutablePair.of( + t.get(INTERNAL_SCHEMA_NAME).asText(), + t.get(INTERNAL_TABLE_NAME).asText() + ) + } + ) + ) + .values + .stream() + .map>> { fields: List -> + TableInfo>( + nameSpace = fields[0].get(INTERNAL_SCHEMA_NAME).asText(), + name = fields[0].get(INTERNAL_TABLE_NAME).asText(), + fields = + fields + .stream() // read the column metadata Json object, and determine its + // type + .map { f: JsonNode -> + val datatype = sourceOperations.getDatabaseFieldType(f) + val jsonType = getAirbyteType(datatype) + LOGGER.debug( + "Table {} column {} (type {}[{}], nullable {}) -> {}", + fields[0].get(INTERNAL_TABLE_NAME).asText(), + f.get(INTERNAL_COLUMN_NAME).asText(), + f.get(INTERNAL_COLUMN_TYPE_NAME).asText(), + f.get(INTERNAL_COLUMN_SIZE).asInt(), + f.get(INTERNAL_IS_NULLABLE).asBoolean(), + jsonType + ) + object : + CommonField( + f.get(INTERNAL_COLUMN_NAME).asText(), + datatype + ) {} + } + .collect(Collectors.toList>()), + cursorFields = extractCursorFields(fields) + ) + } + .collect(Collectors.toList>>()) } private fun extractCursorFields(fields: List): List { - return fields.stream() - .filter { field: JsonNode? -> isCursorType(sourceOperations.getDatabaseFieldType(field!!)) } - .map(Function { field: JsonNode -> field.get(INTERNAL_COLUMN_NAME).asText() }) - .collect(Collectors.toList()) + return fields + .stream() + .filter { field: JsonNode -> + isCursorType(sourceOperations.getDatabaseFieldType(field)) + } + .map( + Function { field: JsonNode -> + field.get(INTERNAL_COLUMN_NAME).asText() + } + ) + .collect(Collectors.toList()) } - protected fun excludeNotAccessibleTables(internalSchemas: Set, - tablesWithSelectGrantPrivilege: Set?): Predicate { + protected fun excludeNotAccessibleTables( + internalSchemas: Set, + tablesWithSelectGrantPrivilege: Set? + ): Predicate { return Predicate { jsonNode: JsonNode -> if (tablesWithSelectGrantPrivilege!!.isEmpty()) { return@Predicate isNotInternalSchema(jsonNode, internalSchemas) } - (tablesWithSelectGrantPrivilege.stream() - .anyMatch { e: JdbcPrivilegeDto? -> e.getSchemaName() == jsonNode.get(INTERNAL_SCHEMA_NAME).asText() } - && tablesWithSelectGrantPrivilege.stream() - .anyMatch { e: JdbcPrivilegeDto? -> e.getTableName() == jsonNode.get(INTERNAL_TABLE_NAME).asText() } - && !internalSchemas.contains(jsonNode.get(INTERNAL_SCHEMA_NAME).asText())) + (tablesWithSelectGrantPrivilege.stream().anyMatch { e: JdbcPrivilegeDto -> + e.schemaName == jsonNode.get(INTERNAL_SCHEMA_NAME).asText() + } && + tablesWithSelectGrantPrivilege.stream().anyMatch { e: JdbcPrivilegeDto -> + e.tableName == jsonNode.get(INTERNAL_TABLE_NAME).asText() + } && + !internalSchemas.contains(jsonNode.get(INTERNAL_SCHEMA_NAME).asText())) } } @@ -179,14 +285,21 @@ abstract class AbstractJdbcSource(driverClass: String?, /** * @param resultSet Description of a column available in the table catalog. - * @return Essential information about a column to determine which table it belongs to and its type. + * @return Essential information about a column to determine which table it belongs to and its + * type. */ @Throws(SQLException::class) private fun getColumnMetadata(resultSet: ResultSet): JsonNode { - val fieldMap = ImmutableMap.builder() // we always want a namespace, if we cannot get a schema, use db name. - .put(INTERNAL_SCHEMA_NAME, - if (resultSet.getObject(JDBC_COLUMN_SCHEMA_NAME) != null) resultSet.getString(JDBC_COLUMN_SCHEMA_NAME) - else resultSet.getObject(JDBC_COLUMN_DATABASE_NAME)) + val fieldMap = + ImmutableMap.builder< + String, Any + >() // we always want a namespace, if we cannot get a schema, use db name. + .put( + INTERNAL_SCHEMA_NAME, + if (resultSet.getObject(JDBC_COLUMN_SCHEMA_NAME) != null) + resultSet.getString(JDBC_COLUMN_SCHEMA_NAME) + else resultSet.getObject(JDBC_COLUMN_DATABASE_NAME) + ) .put(INTERNAL_TABLE_NAME, resultSet.getString(JDBC_COLUMN_TABLE_NAME)) .put(INTERNAL_COLUMN_NAME, resultSet.getString(JDBC_COLUMN_COLUMN_NAME)) .put(INTERNAL_COLUMN_TYPE, resultSet.getString(JDBC_COLUMN_DATA_TYPE)) @@ -200,7 +313,9 @@ abstract class AbstractJdbcSource(driverClass: String?, } @Throws(Exception::class) - public override fun discoverInternal(database: JdbcDatabase): List>> { + public override fun discoverInternal( + database: JdbcDatabase + ): List>> { return discoverInternal(database, null) } @@ -210,144 +325,269 @@ abstract class AbstractJdbcSource(driverClass: String?, @VisibleForTesting @JvmRecord - data class PrimaryKeyAttributesFromDb(val streamName: String, - val primaryKey: String, - val keySequence: Int) - - override fun discoverPrimaryKeys(database: JdbcDatabase, - tableInfos: List>>): Map> { - LOGGER.info("Discover primary keys for tables: " + tableInfos.stream().map { obj: TableInfo> -> obj.name }.collect( - Collectors.toSet())) + data class PrimaryKeyAttributesFromDb( + val streamName: String, + val primaryKey: String, + val keySequence: Int + ) + + override fun discoverPrimaryKeys( + database: JdbcDatabase, + tableInfos: List>> + ): Map> { + LOGGER.info( + "Discover primary keys for tables: " + + tableInfos + .stream() + .map { obj: TableInfo> -> obj.name } + .collect(Collectors.toSet()) + ) try { // Get all primary keys without specifying a table name - val tablePrimaryKeys = aggregatePrimateKeys(database.bufferedResultSetQuery( - { connection: Connection -> connection.metaData.getPrimaryKeys(getCatalog(database), null, null) }, - { r: ResultSet -> - val schemaName: String = - if (r.getObject(JDBC_COLUMN_SCHEMA_NAME) != null) r.getString(JDBC_COLUMN_SCHEMA_NAME) else r.getString(JDBC_COLUMN_DATABASE_NAME) - val streamName = getFullyQualifiedTableName(schemaName, r.getString(JDBC_COLUMN_TABLE_NAME)) - val primaryKey: String = r.getString(JDBC_COLUMN_COLUMN_NAME) - val keySeq: Int = r.getInt(KEY_SEQ) - PrimaryKeyAttributesFromDb(streamName, primaryKey, keySeq) - })) + val tablePrimaryKeys = + aggregatePrimateKeys( + database.bufferedResultSetQuery( + { connection: Connection -> + connection.metaData.getPrimaryKeys(getCatalog(database), null, null) + }, + { r: ResultSet -> + val schemaName: String = + if (r.getObject(JDBC_COLUMN_SCHEMA_NAME) != null) + r.getString(JDBC_COLUMN_SCHEMA_NAME) + else r.getString(JDBC_COLUMN_DATABASE_NAME) + val streamName = + getFullyQualifiedTableName( + schemaName, + r.getString(JDBC_COLUMN_TABLE_NAME) + ) + val primaryKey: String = r.getString(JDBC_COLUMN_COLUMN_NAME) + val keySeq: Int = r.getInt(KEY_SEQ) + PrimaryKeyAttributesFromDb(streamName, primaryKey, keySeq) + } + ) + ) if (!tablePrimaryKeys.isEmpty()) { return tablePrimaryKeys } } catch (e: SQLException) { - LOGGER.debug(String.format("Could not retrieve primary keys without a table name (%s), retrying", e)) + LOGGER.debug( + String.format( + "Could not retrieve primary keys without a table name (%s), retrying", + e + ) + ) } // Get primary keys one table at a time - return tableInfos.stream() - .collect(Collectors.toMap>, String, MutableList>( - Function>, String> { tableInfo: TableInfo> -> getFullyQualifiedTableName(tableInfo.nameSpace, tableInfo.name) }, - Function>, MutableList> { tableInfo: TableInfo> -> - val streamName = getFullyQualifiedTableName(tableInfo.nameSpace, tableInfo.name) - try { - val primaryKeys = aggregatePrimateKeys(database.bufferedResultSetQuery( - { connection: Connection -> connection.metaData.getPrimaryKeys(getCatalog(database), tableInfo.nameSpace, tableInfo.name) }, - { r: ResultSet -> PrimaryKeyAttributesFromDb(streamName, r.getString(JDBC_COLUMN_COLUMN_NAME), r.getInt(KEY_SEQ)) })) - return@toMap primaryKeys.getOrDefault(streamName, emptyList()) - } catch (e: SQLException) { - LOGGER.error(String.format("Could not retrieve primary keys for %s: %s", streamName, e)) - return@toMap emptyList() - } - })) + return tableInfos + .stream() + .collect( + Collectors.toMap>, String, MutableList>( + Function>, String> { + tableInfo: TableInfo> -> + getFullyQualifiedTableName(tableInfo.nameSpace, tableInfo.name) + }, + Function>, MutableList> toMap@{ + tableInfo: TableInfo> -> + val streamName = + getFullyQualifiedTableName(tableInfo.nameSpace, tableInfo.name) + try { + val primaryKeys = + aggregatePrimateKeys( + database.bufferedResultSetQuery( + { connection: Connection -> + connection.metaData.getPrimaryKeys( + getCatalog(database), + tableInfo.nameSpace, + tableInfo.name + ) + }, + { r: ResultSet -> + PrimaryKeyAttributesFromDb( + streamName, + r.getString(JDBC_COLUMN_COLUMN_NAME), + r.getInt(KEY_SEQ) + ) + } + ) + ) + return@toMap primaryKeys.getOrDefault( + streamName, + mutableListOf() + ) + } catch (e: SQLException) { + LOGGER.error( + String.format( + "Could not retrieve primary keys for %s: %s", + streamName, + e + ) + ) + return@toMap mutableListOf() + } + } + ) + ) } public override fun isCursorType(type: Datatype): Boolean { return sourceOperations.isCursorType(type) } - public override fun queryTableIncremental(database: JdbcDatabase, - columnNames: List, - schemaName: String?, - tableName: String, - cursorInfo: CursorInfo, - cursorFieldType: Datatype): AutoCloseableIterator? { + public override fun queryTableIncremental( + database: JdbcDatabase, + columnNames: List, + schemaName: String?, + tableName: String, + cursorInfo: CursorInfo, + cursorFieldType: Datatype + ): AutoCloseableIterator? { LOGGER.info("Queueing query for table: {}", tableName) - val airbyteStream = - AirbyteStreamUtils.convertFromNameAndNamespace(tableName, schemaName) - return AutoCloseableIterators.lazyIterator({ - try { - val stream = database.unsafeQuery( - CheckedFunction { connection: Connection -> - LOGGER.info("Preparing query for table: {}", tableName) - val fullTableName = RelationalDbQueryUtils.getFullyQualifiedTableNameWithQuoting(schemaName, tableName, quoteString) - val quotedCursorField = RelationalDbQueryUtils.enquoteIdentifier(cursorInfo.cursorField, quoteString) - - val operator: String - if (cursorInfo.cursorRecordCount <= 0L) { - operator = ">" - } else { - val actualRecordCount = getActualCursorRecordCount( - connection, fullTableName, quotedCursorField, cursorFieldType, cursorInfo.cursor) - LOGGER.info("Table {} cursor count: expected {}, actual {}", tableName, cursorInfo.cursorRecordCount, actualRecordCount) - operator = if (actualRecordCount == cursorInfo.cursorRecordCount) { - ">" + val airbyteStream = AirbyteStreamUtils.convertFromNameAndNamespace(tableName, schemaName) + return AutoCloseableIterators.lazyIterator( + { + val quoteString = this.quoteString!! + try { + val stream = + database.unsafeQuery( + CheckedFunction { + connection: Connection -> + LOGGER.info("Preparing query for table: {}", tableName) + val fullTableName = + RelationalDbQueryUtils.getFullyQualifiedTableNameWithQuoting( + schemaName, + tableName, + quoteString + ) + val quotedCursorField = + RelationalDbQueryUtils.enquoteIdentifier( + cursorInfo.cursorField, + quoteString + ) + + val operator: String + if (cursorInfo.cursorRecordCount <= 0L) { + operator = ">" } else { - ">=" + val actualRecordCount = + getActualCursorRecordCount( + connection, + fullTableName, + quotedCursorField, + cursorFieldType, + cursorInfo.cursor + ) + LOGGER.info( + "Table {} cursor count: expected {}, actual {}", + tableName, + cursorInfo.cursorRecordCount, + actualRecordCount + ) + operator = + if (actualRecordCount == cursorInfo.cursorRecordCount) { + ">" + } else { + ">=" + } } - } - val wrappedColumnNames = getWrappedColumnNames(database, connection, columnNames, schemaName, tableName) - val sql = StringBuilder(String.format("SELECT %s FROM %s WHERE %s %s ?", - wrappedColumnNames, - fullTableName, - quotedCursorField, - operator)) - // if the connector emits intermediate states, the incremental query must be sorted by the cursor - // field - if (stateEmissionFrequency > 0) { - sql.append(String.format(" ORDER BY %s ASC", quotedCursorField)) - } + val wrappedColumnNames = + getWrappedColumnNames( + database, + connection, + columnNames, + schemaName, + tableName + ) + val sql = + StringBuilder( + String.format( + "SELECT %s FROM %s WHERE %s %s ?", + wrappedColumnNames, + fullTableName, + quotedCursorField, + operator + ) + ) + // if the connector emits intermediate states, the incremental query + // must be sorted by the cursor + // field + if (stateEmissionFrequency > 0) { + sql.append(String.format(" ORDER BY %s ASC", quotedCursorField)) + } - val preparedStatement = connection.prepareStatement(sql.toString()) - LOGGER.info("Executing query for table {}: {}", tableName, preparedStatement) - sourceOperations.setCursorField(preparedStatement, 1, cursorFieldType, cursorInfo.cursor) - preparedStatement - }, - CheckedFunction { queryResult: ResultSet? -> sourceOperations.rowToJson(queryResult!!) }) - return@lazyIterator AutoCloseableIterators.fromStream(stream, airbyteStream) - } catch (e: SQLException) { - throw RuntimeException(e) - } - }, airbyteStream) + val preparedStatement = connection.prepareStatement(sql.toString()) + LOGGER.info( + "Executing query for table {}: {}", + tableName, + preparedStatement + ) + sourceOperations.setCursorField( + preparedStatement, + 1, + cursorFieldType, + cursorInfo.cursor + ) + preparedStatement + }, + CheckedFunction { + queryResult: ResultSet? -> + sourceOperations.rowToJson(queryResult!!) + } + ) + return@lazyIterator AutoCloseableIterators.fromStream( + stream, + airbyteStream + ) + } catch (e: SQLException) { + throw RuntimeException(e) + } + }, + airbyteStream + ) } - /** - * Some databases need special column names in the query. - */ + /** Some databases need special column names in the query. */ @Throws(SQLException::class) - protected fun getWrappedColumnNames(database: JdbcDatabase?, - connection: Connection?, - columnNames: List, - schemaName: String?, - tableName: String?): String? { - return RelationalDbQueryUtils.enquoteIdentifierList(columnNames, quoteString) + protected fun getWrappedColumnNames( + database: JdbcDatabase?, + connection: Connection?, + columnNames: List, + schemaName: String?, + tableName: String? + ): String? { + return RelationalDbQueryUtils.enquoteIdentifierList(columnNames, quoteString!!) } protected val countColumnName: String get() = "record_count" @Throws(SQLException::class) - protected fun getActualCursorRecordCount(connection: Connection, - fullTableName: String?, - quotedCursorField: String?, - cursorFieldType: Datatype, - cursor: String?): Long { + protected fun getActualCursorRecordCount( + connection: Connection, + fullTableName: String?, + quotedCursorField: String?, + cursorFieldType: Datatype, + cursor: String? + ): Long { val columnName = countColumnName val cursorRecordStatement: PreparedStatement if (cursor == null) { - val cursorRecordQuery = String.format("SELECT COUNT(*) AS %s FROM %s WHERE %s IS NULL", + val cursorRecordQuery = + String.format( + "SELECT COUNT(*) AS %s FROM %s WHERE %s IS NULL", columnName, fullTableName, - quotedCursorField) + quotedCursorField + ) cursorRecordStatement = connection.prepareStatement(cursorRecordQuery) } else { - val cursorRecordQuery = String.format("SELECT COUNT(*) AS %s FROM %s WHERE %s = ?", + val cursorRecordQuery = + String.format( + "SELECT COUNT(*) AS %s FROM %s WHERE %s = ?", columnName, fullTableName, - quotedCursorField) + quotedCursorField + ) cursorRecordStatement = connection.prepareStatement(cursorRecordQuery) sourceOperations.setCursorField(cursorRecordStatement, 1, cursorFieldType, cursor) @@ -361,31 +601,37 @@ abstract class AbstractJdbcSource(driverClass: String?, } @Throws(SQLException::class) - public override fun createDatabase(sourceConfig: JsonNode?): JdbcDatabase { + public override fun createDatabase(sourceConfig: JsonNode): JdbcDatabase { return createDatabase(sourceConfig, JdbcDataSourceUtils.DEFAULT_JDBC_PARAMETERS_DELIMITER) } @Throws(SQLException::class) - fun createDatabase(sourceConfig: JsonNode?, delimiter: String?): JdbcDatabase { - val jdbcConfig = toDatabaseConfig(sourceConfig!!) - val connectionProperties = JdbcDataSourceUtils.getConnectionProperties(sourceConfig, delimiter) + fun createDatabase(sourceConfig: JsonNode, delimiter: String): JdbcDatabase { + val jdbcConfig = toDatabaseConfig(sourceConfig) + val connectionProperties = + JdbcDataSourceUtils.getConnectionProperties(sourceConfig, delimiter) // Create the data source - val dataSource = create( - if (jdbcConfig!!.has(JdbcUtils.USERNAME_KEY)) jdbcConfig[JdbcUtils.USERNAME_KEY].asText() else null, - if (jdbcConfig.has(JdbcUtils.PASSWORD_KEY)) jdbcConfig[JdbcUtils.PASSWORD_KEY].asText() else null, + val dataSource = + create( + if (jdbcConfig!!.has(JdbcUtils.USERNAME_KEY)) + jdbcConfig[JdbcUtils.USERNAME_KEY].asText() + else null, + if (jdbcConfig.has(JdbcUtils.PASSWORD_KEY)) + jdbcConfig[JdbcUtils.PASSWORD_KEY].asText() + else null, driverClassName, jdbcConfig[JdbcUtils.JDBC_URL_KEY].asText(), connectionProperties, - getConnectionTimeout(connectionProperties!!)) + getConnectionTimeout(connectionProperties!!) + ) // Record the data source so that it can be closed. dataSources.add(dataSource) - val database: JdbcDatabase = StreamingJdbcDatabase( - dataSource, - sourceOperations, - streamingQueryConfigProvider) + val database: JdbcDatabase = + StreamingJdbcDatabase(dataSource, sourceOperations, streamingQueryConfigProvider) - quoteString = (if (quoteString == null) database.metaData.identifierQuoteString else quoteString) + quoteString = + (if (quoteString == null) database.metaData.identifierQuoteString else quoteString) database.sourceConfig = sourceConfig database.databaseConfig = jdbcConfig return database @@ -400,38 +646,53 @@ abstract class AbstractJdbcSource(driverClass: String?, */ @Throws(SQLException::class) override fun logPreSyncDebugData(database: JdbcDatabase, catalog: ConfiguredAirbyteCatalog?) { - LOGGER.info("Data source product recognized as {}:{}", - database.metaData.databaseProductName, - database.metaData.databaseProductVersion) + LOGGER.info( + "Data source product recognized as {}:{}", + database.metaData.databaseProductName, + database.metaData.databaseProductVersion + ) } override fun close() { - dataSources.forEach(Consumer { d: DataSource? -> - try { - close(d) - } catch (e: Exception) { - LOGGER.warn("Unable to close data source.", e) + dataSources.forEach( + Consumer { d: DataSource? -> + try { + close(d) + } catch (e: Exception) { + LOGGER.warn("Unable to close data source.", e) + } } - }) + ) dataSources.clear() } - protected fun identifyStreamsToSnapshot(catalog: ConfiguredAirbyteCatalog, stateManager: StateManager): List { + protected fun identifyStreamsToSnapshot( + catalog: ConfiguredAirbyteCatalog, + stateManager: StateManager + ): List { val alreadySyncedStreams = stateManager.cdcStateManager.initialStreamsSynced - if (alreadySyncedStreams!!.isEmpty() && (stateManager.cdcStateManager.cdcState == null - || stateManager.cdcStateManager.cdcState.state == null)) { + if ( + alreadySyncedStreams!!.isEmpty() && + (stateManager.cdcStateManager.cdcState?.state == null) + ) { return emptyList() } val allStreams = AirbyteStreamNameNamespacePair.fromConfiguredCatalog(catalog) - val newlyAddedStreams: Set = HashSet(Sets.difference(allStreams, alreadySyncedStreams)) + val newlyAddedStreams: Set = + HashSet(Sets.difference(allStreams, alreadySyncedStreams)) - return catalog.streams.stream() - .filter { c: ConfiguredAirbyteStream -> c.syncMode == SyncMode.INCREMENTAL } - .filter { stream: ConfiguredAirbyteStream -> newlyAddedStreams.contains(AirbyteStreamNameNamespacePair.fromAirbyteStream(stream.stream)) } - .map { `object`: ConfiguredAirbyteStream? -> Jsons.clone(`object`) } - .collect(Collectors.toList()) + return catalog.streams + .stream() + .filter { c: ConfiguredAirbyteStream -> c.syncMode == SyncMode.INCREMENTAL } + .filter { stream: ConfiguredAirbyteStream -> + newlyAddedStreams.contains( + AirbyteStreamNameNamespacePair.fromAirbyteStream(stream.stream) + ) + } + .map { `object`: ConfiguredAirbyteStream -> Jsons.clone(`object`) } + .collect(Collectors.toList()) } companion object { @@ -443,14 +704,19 @@ abstract class AbstractJdbcSource(driverClass: String?, * @return a map by StreamName to associated list of primary keys */ @VisibleForTesting - fun aggregatePrimateKeys(entries: List): Map> { + fun aggregatePrimateKeys( + entries: List + ): Map> { val result: MutableMap> = HashMap() - entries.stream().sorted(Comparator.comparingInt(PrimaryKeyAttributesFromDb::keySequence)).forEach { entry: PrimaryKeyAttributesFromDb -> - if (!result.containsKey(entry.streamName)) { - result[entry.streamName] = ArrayList() + entries + .stream() + .sorted(Comparator.comparingInt(PrimaryKeyAttributesFromDb::keySequence)) + .forEach { entry: PrimaryKeyAttributesFromDb -> + if (!result.containsKey(entry.streamName)) { + result[entry.streamName] = ArrayList() + } + result[entry.streamName]!!.add(entry.primaryKey) } - result[entry.streamName]!!.add(entry.primaryKey) - } return result } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcDataSourceUtils.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcDataSourceUtils.kt index 23196cf154409..eda6d797635aa 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcDataSourceUtils.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcDataSourceUtils.kt @@ -18,10 +18,15 @@ object JdbcDataSourceUtils { * @param defaultParameters connection properties map as specified by each Jdbc source * @throws IllegalArgumentException */ - fun assertCustomParametersDontOverwriteDefaultParameters(customParameters: Map, - defaultParameters: Map) { + fun assertCustomParametersDontOverwriteDefaultParameters( + customParameters: Map, + defaultParameters: Map + ) { for (key in defaultParameters.keys) { - require(!(customParameters.containsKey(key) && customParameters[key] != defaultParameters[key])) { "Cannot overwrite default JDBC parameter $key" } + require( + !(customParameters.containsKey(key) && + customParameters[key] != defaultParameters[key]) + ) { "Cannot overwrite default JDBC parameter $key" } } } @@ -32,12 +37,13 @@ object JdbcDataSourceUtils { * @param config A configuration used to check Jdbc connection * @return A mapping of connection properties */ - fun getConnectionProperties(config: JsonNode?): Map { + fun getConnectionProperties(config: JsonNode): Map { return getConnectionProperties(config, DEFAULT_JDBC_PARAMETERS_DELIMITER) } - fun getConnectionProperties(config: JsonNode?, parameterDelimiter: String?): Map { - val customProperties = parseJdbcParameters(config!!, JdbcUtils.JDBC_URL_PARAMS_KEY, parameterDelimiter!!) + fun getConnectionProperties(config: JsonNode, parameterDelimiter: String): Map { + val customProperties = + parseJdbcParameters(config, JdbcUtils.JDBC_URL_PARAMS_KEY, parameterDelimiter) val defaultProperties = getDefaultConnectionProperties(config) assertCustomParametersDontOverwriteDefaultParameters(customProperties, defaultProperties) return MoreMaps.merge(customProperties, defaultProperties) @@ -51,8 +57,12 @@ object JdbcDataSourceUtils { * @param config A configuration used to check Jdbc connection * @return A mapping of the default connection properties */ - fun getDefaultConnectionProperties(config: JsonNode?): Map { + fun getDefaultConnectionProperties(config: JsonNode): Map { // NOTE that Postgres returns an empty map for some reason? - return parseJdbcParameters(config!!, "connection_properties", DEFAULT_JDBC_PARAMETERS_DELIMITER) + return parseJdbcParameters( + config, + "connection_properties", + DEFAULT_JDBC_PARAMETERS_DELIMITER + ) } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcSSLConnectionUtils.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcSSLConnectionUtils.kt index e1e3910c8d4b2..8752397229912 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcSSLConnectionUtils.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcSSLConnectionUtils.kt @@ -7,11 +7,6 @@ import com.fasterxml.jackson.databind.JsonNode import io.airbyte.cdk.db.jdbc.JdbcUtils import io.airbyte.cdk.db.util.SSLCertificateUtils.keyStoreFromCertificate import io.airbyte.cdk.db.util.SSLCertificateUtils.keyStoreFromClientCertificate -import org.apache.commons.lang3.RandomStringUtils -import org.apache.commons.lang3.tuple.ImmutablePair -import org.apache.commons.lang3.tuple.Pair -import org.slf4j.Logger -import org.slf4j.LoggerFactory import java.io.IOException import java.net.MalformedURLException import java.net.URI @@ -22,6 +17,11 @@ import java.security.NoSuchAlgorithmException import java.security.cert.CertificateException import java.security.spec.InvalidKeySpecException import java.util.* +import org.apache.commons.lang3.RandomStringUtils +import org.apache.commons.lang3.tuple.ImmutablePair +import org.apache.commons.lang3.tuple.Pair +import org.slf4j.Logger +import org.slf4j.LoggerFactory class JdbcSSLConnectionUtils { var caCertKeyStorePair: Pair? = null @@ -40,8 +40,8 @@ class JdbcSSLConnectionUtils { companion object { fun bySpec(spec: String): Optional { return Arrays.stream(entries.toTypedArray()) - .filter { sslMode: SslMode -> sslMode.spec.contains(spec) } - .findFirst() + .filter { sslMode: SslMode -> sslMode.spec.contains(spec) } + .findFirst() } } } @@ -57,7 +57,8 @@ class JdbcSSLConnectionUtils { const val TRUST_KEY_STORE_TYPE: String = "trustCertificateKeyStoreType" const val KEY_STORE_TYPE_PKCS12: String = "PKCS12" const val PARAM_MODE: String = "mode" - private val LOGGER: Logger = LoggerFactory.getLogger(JdbcSSLConnectionUtils::class.java.javaClass) + private val LOGGER: Logger = + LoggerFactory.getLogger(JdbcSSLConnectionUtils::class.java.javaClass) const val PARAM_CA_CERTIFICATE: String = "ca_certificate" const val PARAM_CLIENT_CERTIFICATE: String = "client_certificate" const val PARAM_CLIENT_KEY: String = "client_key" @@ -67,7 +68,8 @@ class JdbcSSLConnectionUtils { * Parses SSL related configuration and generates keystores to be used by connector * * @param config configuration - * @return map containing relevant parsed values including location of keystore or an empty map + * @return map containing relevant parsed values including location of keystore or an empty + * map */ fun parseSSLConfig(config: JsonNode): Map { LOGGER.debug("source config: {}", config) @@ -79,18 +81,30 @@ class JdbcSSLConnectionUtils { if (!config.has(JdbcUtils.SSL_KEY) || config[JdbcUtils.SSL_KEY].asBoolean()) { if (config.has(JdbcUtils.SSL_MODE_KEY)) { val specMode = config[JdbcUtils.SSL_MODE_KEY][PARAM_MODE].asText() - additionalParameters[SSL_MODE] = SslMode.bySpec(specMode).orElseThrow { IllegalArgumentException("unexpected ssl mode") }.name + additionalParameters[SSL_MODE] = + SslMode.bySpec(specMode) + .orElseThrow { IllegalArgumentException("unexpected ssl mode") } + .name if (Objects.isNull(caCertKeyStorePair)) { caCertKeyStorePair = prepareCACertificateKeyStore(config) } if (Objects.nonNull(caCertKeyStorePair)) { - LOGGER.debug("uri for ca cert keystore: {}", caCertKeyStorePair!!.left.toString()) + LOGGER.debug( + "uri for ca cert keystore: {}", + caCertKeyStorePair!!.left.toString() + ) try { - additionalParameters.putAll(java.util.Map.of( - TRUST_KEY_STORE_URL, caCertKeyStorePair.left.toURL().toString(), - TRUST_KEY_STORE_PASS, caCertKeyStorePair.right, - TRUST_KEY_STORE_TYPE, KEY_STORE_TYPE_PKCS12)) + additionalParameters.putAll( + java.util.Map.of( + TRUST_KEY_STORE_URL, + caCertKeyStorePair.left.toURL().toString(), + TRUST_KEY_STORE_PASS, + caCertKeyStorePair.right, + TRUST_KEY_STORE_TYPE, + KEY_STORE_TYPE_PKCS12 + ) + ) } catch (e: MalformedURLException) { throw RuntimeException("Unable to get a URL for trust key store") } @@ -101,12 +115,22 @@ class JdbcSSLConnectionUtils { } if (Objects.nonNull(clientCertKeyStorePair)) { - LOGGER.debug("uri for client cert keystore: {} / {}", clientCertKeyStorePair!!.left.toString(), clientCertKeyStorePair.right) + LOGGER.debug( + "uri for client cert keystore: {} / {}", + clientCertKeyStorePair!!.left.toString(), + clientCertKeyStorePair.right + ) try { - additionalParameters.putAll(java.util.Map.of( - CLIENT_KEY_STORE_URL, clientCertKeyStorePair.left.toURL().toString(), - CLIENT_KEY_STORE_PASS, clientCertKeyStorePair.right, - CLIENT_KEY_STORE_TYPE, KEY_STORE_TYPE_PKCS12)) + additionalParameters.putAll( + java.util.Map.of( + CLIENT_KEY_STORE_URL, + clientCertKeyStorePair.left.toURL().toString(), + CLIENT_KEY_STORE_PASS, + clientCertKeyStorePair.right, + CLIENT_KEY_STORE_TYPE, + KEY_STORE_TYPE_PKCS12 + ) + ) } catch (e: MalformedURLException) { throw RuntimeException("Unable to get a URL for client key store") } @@ -129,23 +153,40 @@ class JdbcSSLConnectionUtils { if (Objects.nonNull(config)) { if (!config.has(JdbcUtils.SSL_KEY) || config[JdbcUtils.SSL_KEY].asBoolean()) { val encryption = config[JdbcUtils.SSL_MODE_KEY] - if (encryption.has(PARAM_CA_CERTIFICATE) && !encryption[PARAM_CA_CERTIFICATE].asText().isEmpty()) { + if ( + encryption.has(PARAM_CA_CERTIFICATE) && + !encryption[PARAM_CA_CERTIFICATE].asText().isEmpty() + ) { val clientKeyPassword = getOrGeneratePassword(encryption) try { - val caCertKeyStoreUri = keyStoreFromCertificate( + val caCertKeyStoreUri = + keyStoreFromCertificate( encryption[PARAM_CA_CERTIFICATE].asText(), clientKeyPassword, null, - null) + null + ) caCertKeyStorePair = ImmutablePair(caCertKeyStoreUri, clientKeyPassword) } catch (e: CertificateException) { - throw RuntimeException("Failed to create keystore for CA certificate", e) + throw RuntimeException( + "Failed to create keystore for CA certificate", + e + ) } catch (e: IOException) { - throw RuntimeException("Failed to create keystore for CA certificate", e) + throw RuntimeException( + "Failed to create keystore for CA certificate", + e + ) } catch (e: KeyStoreException) { - throw RuntimeException("Failed to create keystore for CA certificate", e) + throw RuntimeException( + "Failed to create keystore for CA certificate", + e + ) } catch (e: NoSuchAlgorithmException) { - throw RuntimeException("Failed to create keystore for CA certificate", e) + throw RuntimeException( + "Failed to create keystore for CA certificate", + e + ) } } } @@ -154,11 +195,15 @@ class JdbcSSLConnectionUtils { } private fun getOrGeneratePassword(sslModeConfig: JsonNode): String { - val clientKeyPassword = if (sslModeConfig.has(PARAM_CLIENT_KEY_PASSWORD) && !sslModeConfig[PARAM_CLIENT_KEY_PASSWORD].asText().isEmpty()) { - sslModeConfig[PARAM_CLIENT_KEY_PASSWORD].asText() - } else { - RandomStringUtils.randomAlphanumeric(10) - } + val clientKeyPassword = + if ( + sslModeConfig.has(PARAM_CLIENT_KEY_PASSWORD) && + !sslModeConfig[PARAM_CLIENT_KEY_PASSWORD].asText().isEmpty() + ) { + sslModeConfig[PARAM_CLIENT_KEY_PASSWORD].asText() + } else { + RandomStringUtils.randomAlphanumeric(10) + } return clientKeyPassword } @@ -167,26 +212,53 @@ class JdbcSSLConnectionUtils { if (Objects.nonNull(config)) { if (!config.has(JdbcUtils.SSL_KEY) || config[JdbcUtils.SSL_KEY].asBoolean()) { val encryption = config[JdbcUtils.SSL_MODE_KEY] - if (encryption.has(PARAM_CLIENT_CERTIFICATE) && !encryption[PARAM_CLIENT_CERTIFICATE].asText().isEmpty() - && encryption.has(PARAM_CLIENT_KEY) && !encryption[PARAM_CLIENT_KEY].asText().isEmpty()) { + if ( + encryption.has(PARAM_CLIENT_CERTIFICATE) && + !encryption[PARAM_CLIENT_CERTIFICATE].asText().isEmpty() && + encryption.has(PARAM_CLIENT_KEY) && + !encryption[PARAM_CLIENT_KEY].asText().isEmpty() + ) { val clientKeyPassword = getOrGeneratePassword(encryption) try { - val clientCertKeyStoreUri = keyStoreFromClientCertificate(encryption[PARAM_CLIENT_CERTIFICATE].asText(), + val clientCertKeyStoreUri = + keyStoreFromClientCertificate( + encryption[PARAM_CLIENT_CERTIFICATE].asText(), encryption[PARAM_CLIENT_KEY].asText(), - clientKeyPassword, null) - clientCertKeyStorePair = ImmutablePair(clientCertKeyStoreUri, clientKeyPassword) + clientKeyPassword, + null + ) + clientCertKeyStorePair = + ImmutablePair(clientCertKeyStoreUri, clientKeyPassword) } catch (e: CertificateException) { - throw RuntimeException("Failed to create keystore for Client certificate", e) + throw RuntimeException( + "Failed to create keystore for Client certificate", + e + ) } catch (e: IOException) { - throw RuntimeException("Failed to create keystore for Client certificate", e) + throw RuntimeException( + "Failed to create keystore for Client certificate", + e + ) } catch (e: KeyStoreException) { - throw RuntimeException("Failed to create keystore for Client certificate", e) + throw RuntimeException( + "Failed to create keystore for Client certificate", + e + ) } catch (e: NoSuchAlgorithmException) { - throw RuntimeException("Failed to create keystore for Client certificate", e) + throw RuntimeException( + "Failed to create keystore for Client certificate", + e + ) } catch (e: InvalidKeySpecException) { - throw RuntimeException("Failed to create keystore for Client certificate", e) + throw RuntimeException( + "Failed to create keystore for Client certificate", + e + ) } catch (e: InterruptedException) { - throw RuntimeException("Failed to create keystore for Client certificate", e) + throw RuntimeException( + "Failed to create keystore for Client certificate", + e + ) } } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcSource.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcSource.kt index b08c28c8e12cc..7e1f9b3125344 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcSource.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcSource.kt @@ -9,12 +9,18 @@ import io.airbyte.cdk.db.jdbc.JdbcUtils import io.airbyte.cdk.db.jdbc.streaming.AdaptiveStreamingQueryConfig import io.airbyte.cdk.integrations.base.IntegrationRunner import io.airbyte.cdk.integrations.base.Source -import org.slf4j.Logger -import org.slf4j.LoggerFactory import java.sql.JDBCType import java.util.function.Supplier +import org.slf4j.Logger +import org.slf4j.LoggerFactory -class JdbcSource : AbstractJdbcSource(DatabaseDriver.POSTGRESQL.driverClassName, Supplier { AdaptiveStreamingQueryConfig() }, JdbcUtils.defaultSourceOperations), Source { +class JdbcSource : + AbstractJdbcSource( + DatabaseDriver.POSTGRESQL.driverClassName, + Supplier { AdaptiveStreamingQueryConfig() }, + JdbcUtils.defaultSourceOperations + ), + Source { // no-op for JdbcSource since the config it receives is designed to be use for JDBC. override fun toDatabaseConfig(config: JsonNode): JsonNode { return config diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/dto/JdbcPrivilegeDto.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/dto/JdbcPrivilegeDto.kt index bb9e4397fb418..0e689e819b3ae 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/dto/JdbcPrivilegeDto.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/jdbc/dto/JdbcPrivilegeDto.kt @@ -5,10 +5,13 @@ package io.airbyte.cdk.integrations.source.jdbc.dto import com.google.common.base.Objects -/** - * The class to store values from privileges table - */ -class JdbcPrivilegeDto(val grantee: String?, val tableName: String?, val schemaName: String?, val privilege: String?) { +/** The class to store values from privileges table */ +class JdbcPrivilegeDto( + val grantee: String?, + val tableName: String?, + val schemaName: String?, + val privilege: String? +) { class JdbcPrivilegeDtoBuilder { private var grantee: String? = null private var tableName: String? = null @@ -48,8 +51,10 @@ class JdbcPrivilegeDto(val grantee: String?, val tableName: String?, val schemaN return false } val that = o as JdbcPrivilegeDto - return (Objects.equal(grantee, that.grantee) && Objects.equal(tableName, that.tableName) - && Objects.equal(schemaName, that.schemaName) && Objects.equal(privilege, that.privilege)) + return (Objects.equal(grantee, that.grantee) && + Objects.equal(tableName, that.tableName) && + Objects.equal(schemaName, that.schemaName) && + Objects.equal(privilege, that.privilege)) } override fun hashCode(): Int { @@ -58,11 +63,19 @@ class JdbcPrivilegeDto(val grantee: String?, val tableName: String?, val schemaN override fun toString(): String { return "JdbcPrivilegeDto{" + - "grantee='" + grantee + '\'' + - ", columnName='" + tableName + '\'' + - ", schemaName='" + schemaName + '\'' + - ", privilege='" + privilege + '\'' + - '}' + "grantee='" + + grantee + + '\'' + + ", columnName='" + + tableName + + '\'' + + ", schemaName='" + + schemaName + + '\'' + + ", privilege='" + + privilege + + '\'' + + '}' } companion object { diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/AbstractDbSource.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/AbstractDbSource.kt index 7bb9dd92dc129..718b74359fb81 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/AbstractDbSource.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/AbstractDbSource.kt @@ -4,7 +4,6 @@ package io.airbyte.cdk.integrations.source.relationaldb import com.fasterxml.jackson.databind.JsonNode -import com.google.common.annotations.VisibleForTesting import com.google.common.base.Preconditions import datadog.trace.api.Trace import io.airbyte.cdk.db.AbstractDatabase @@ -13,6 +12,7 @@ import io.airbyte.cdk.db.IncrementalUtils.getCursorFieldOptional import io.airbyte.cdk.db.IncrementalUtils.getCursorType import io.airbyte.cdk.db.jdbc.JdbcDatabase import io.airbyte.cdk.integrations.JdbcConnector +import io.airbyte.cdk.integrations.base.AirbyteTraceMessageUtility import io.airbyte.cdk.integrations.base.AirbyteTraceMessageUtility.emitConfigErrorTrace import io.airbyte.cdk.integrations.base.Source import io.airbyte.cdk.integrations.base.errors.messages.ErrorMessage.getErrorMessage @@ -25,7 +25,6 @@ import io.airbyte.commons.features.EnvVariableFeatureFlags import io.airbyte.commons.features.FeatureFlags import io.airbyte.commons.functional.CheckedConsumer import io.airbyte.commons.lang.Exceptions -import io.airbyte.commons.stream.AirbyteStreamStatusHolder import io.airbyte.commons.stream.AirbyteStreamUtils import io.airbyte.commons.util.AutoCloseableIterator import io.airbyte.commons.util.AutoCloseableIterators @@ -33,8 +32,6 @@ import io.airbyte.protocol.models.AirbyteStreamNameNamespacePair import io.airbyte.protocol.models.CommonField import io.airbyte.protocol.models.JsonSchemaType import io.airbyte.protocol.models.v0.* -import org.slf4j.Logger -import org.slf4j.LoggerFactory import java.sql.SQLException import java.time.Duration import java.time.Instant @@ -43,20 +40,19 @@ import java.util.concurrent.atomic.AtomicLong import java.util.function.Function import java.util.stream.Collectors import java.util.stream.Stream +import org.slf4j.Logger +import org.slf4j.LoggerFactory /** * This class contains helper functions and boilerplate for implementing a source connector for a DB * source of both non-relational and relational type */ -abstract class AbstractDbSource protected constructor(driverClassName: String?) : JdbcConnector(driverClassName!!), Source, AutoCloseable { +abstract class AbstractDbSource +protected constructor(driverClassName: String) : + JdbcConnector(driverClassName), Source, AutoCloseable { // TODO: Remove when the flag is not use anymore protected var featureFlags: FeatureFlags = EnvVariableFeatureFlags() - @VisibleForTesting - fun setFeatureFlags(featureFlags: FeatureFlags) { - this.featureFlags = featureFlags - } - @Trace(operationName = CHECK_TRACE_OPERATION_NAME) @Throws(Exception::class) override fun check(config: JsonNode): AirbyteConnectionStatus? { @@ -69,18 +65,22 @@ abstract class AbstractDbSource protecte return AirbyteConnectionStatus().withStatus(AirbyteConnectionStatus.Status.SUCCEEDED) } catch (ex: ConnectionErrorException) { addExceptionToTrace(ex) - val message = getErrorMessage(ex.stateCode, ex.errorCode, - ex.exceptionMessage, ex) + val message = getErrorMessage(ex.stateCode, ex.errorCode, ex.exceptionMessage, ex) emitConfigErrorTrace(ex, message) return AirbyteConnectionStatus() - .withStatus(AirbyteConnectionStatus.Status.FAILED) - .withMessage(message) + .withStatus(AirbyteConnectionStatus.Status.FAILED) + .withMessage(message) } catch (e: Exception) { addExceptionToTrace(e) LOGGER.info("Exception while checking connection: ", e) return AirbyteConnectionStatus() - .withStatus(AirbyteConnectionStatus.Status.FAILED) - .withMessage(String.format(ConnectorExceptionUtil.COMMON_EXCEPTION_MESSAGE_TEMPLATE, e.message)) + .withStatus(AirbyteConnectionStatus.Status.FAILED) + .withMessage( + String.format( + ConnectorExceptionUtil.COMMON_EXCEPTION_MESSAGE_TEMPLATE, + e.message + ) + ) } finally { close() } @@ -92,32 +92,43 @@ abstract class AbstractDbSource protecte try { val database = createDatabase(config) val tableInfos = discoverWithoutSystemTables(database) - val fullyQualifiedTableNameToPrimaryKeys = discoverPrimaryKeys( - database, tableInfos) - return DbSourceDiscoverUtil.convertTableInfosToAirbyteCatalog(tableInfos, fullyQualifiedTableNameToPrimaryKeys) { columnType: DataType -> this.getAirbyteType(columnType) } + val fullyQualifiedTableNameToPrimaryKeys = discoverPrimaryKeys(database, tableInfos) + return DbSourceDiscoverUtil.convertTableInfosToAirbyteCatalog( + tableInfos, + fullyQualifiedTableNameToPrimaryKeys + ) { columnType: DataType -> this.getAirbyteType(columnType) } } finally { close() } } /** - * Creates a list of AirbyteMessageIterators with all the streams selected in a configured catalog + * Creates a list of AirbyteMessageIterators with all the streams selected in a configured + * catalog * - * @param config - integration-specific configuration object as json. e.g. { "username": "airbyte", + * @param config + * - integration-specific configuration object as json. e.g. { "username": "airbyte", * "password": "super secure" } - * @param catalog - schema of the incoming messages. - * @param state - state of the incoming messages. + * @param catalog + * - schema of the incoming messages. + * @param state + * - state of the incoming messages. * @return AirbyteMessageIterator with all the streams that are to be synced * @throws Exception */ @Throws(Exception::class) - override fun read(config: JsonNode, - catalog: ConfiguredAirbyteCatalog?, - state: JsonNode?): AutoCloseableIterator { + override fun read( + config: JsonNode, + catalog: ConfiguredAirbyteCatalog, + state: JsonNode? + ): AutoCloseableIterator { val supportedStateType = getSupportedStateType(config) val stateManager = - StateManagerFactory.createStateManager(supportedStateType, - StateGeneratorUtils.deserializeInitialState(state, supportedStateType), catalog) + StateManagerFactory.createStateManager( + supportedStateType, + StateGeneratorUtils.deserializeInitialState(state, supportedStateType), + catalog + ) val emittedAt = Instant.now() val database = createDatabase(config) @@ -125,49 +136,77 @@ abstract class AbstractDbSource protecte logPreSyncDebugData(database, catalog) val fullyQualifiedTableNameToInfo = - discoverWithoutSystemTables(database) - .stream() - .collect(Collectors.toMap(Function { t: TableInfo> -> String.format("%s.%s", t.nameSpace, t.name) }, - Function - .identity())) + discoverWithoutSystemTables(database) + .stream() + .collect( + Collectors.toMap( + Function { t: TableInfo> -> + String.format("%s.%s", t.nameSpace, t.name) + }, + Function.identity() + ) + ) validateCursorFieldForIncrementalTables(fullyQualifiedTableNameToInfo, catalog, database) - DbSourceDiscoverUtil.logSourceSchemaChange(fullyQualifiedTableNameToInfo, catalog) { columnType: DataType -> this.getAirbyteType(columnType) } + DbSourceDiscoverUtil.logSourceSchemaChange(fullyQualifiedTableNameToInfo, catalog) { + columnType: DataType -> + this.getAirbyteType(columnType) + } val incrementalIterators = - getIncrementalIterators(database, catalog, fullyQualifiedTableNameToInfo, stateManager, - emittedAt) + getIncrementalIterators( + database, + catalog, + fullyQualifiedTableNameToInfo, + stateManager, + emittedAt + ) val fullRefreshIterators = - getFullRefreshIterators(database, catalog, fullyQualifiedTableNameToInfo, stateManager, - emittedAt) - val iteratorList = Stream - .of(incrementalIterators, fullRefreshIterators) - .flatMap(Function>, Stream>> { obj: List> -> obj.stream() }) + getFullRefreshIterators( + database, + catalog, + fullyQualifiedTableNameToInfo, + stateManager, + emittedAt + ) + val iteratorList = + Stream.of(incrementalIterators, fullRefreshIterators) + .flatMap(Collection>::stream) .collect(Collectors.toList()) - return AutoCloseableIterators - .appendOnClose(AutoCloseableIterators.concatWithEagerClose(iteratorList) { obj: AirbyteStreamStatusHolder -> obj.emitStreamStatusTrace() }) { - LOGGER.info("Closing database connection pool.") - Exceptions.toRuntime { this.close() } - LOGGER.info("Closed database connection pool.") - } + return AutoCloseableIterators.appendOnClose( + AutoCloseableIterators.concatWithEagerClose( + iteratorList, + AirbyteTraceMessageUtility::emitStreamStatusTrace + ) + ) { + LOGGER.info("Closing database connection pool.") + Exceptions.toRuntime { this.close() } + LOGGER.info("Closed database connection pool.") + } } @Throws(SQLException::class) protected fun validateCursorFieldForIncrementalTables( - tableNameToTable: Map>>, - catalog: ConfiguredAirbyteCatalog?, - database: Database) { - val tablesWithInvalidCursor: MutableList = ArrayList() - for (airbyteStream in catalog!!.streams) { + tableNameToTable: Map>>, + catalog: ConfiguredAirbyteCatalog, + database: Database + ) { + val tablesWithInvalidCursor: MutableList = + ArrayList() + for (airbyteStream in catalog.streams) { val stream = airbyteStream.stream - val fullyQualifiedTableName = DbSourceDiscoverUtil.getFullyQualifiedTableName(stream.namespace, - stream.name) - val hasSourceDefinedCursor = ( - !Objects.isNull(airbyteStream.stream.sourceDefinedCursor) - && airbyteStream.stream.sourceDefinedCursor) - if (!tableNameToTable.containsKey(fullyQualifiedTableName) || airbyteStream.syncMode != SyncMode.INCREMENTAL || hasSourceDefinedCursor) { + val fullyQualifiedTableName = + DbSourceDiscoverUtil.getFullyQualifiedTableName(stream.namespace, stream.name) + val hasSourceDefinedCursor = + (!Objects.isNull(airbyteStream.stream.sourceDefinedCursor) && + airbyteStream.stream.sourceDefinedCursor) + if ( + !tableNameToTable.containsKey(fullyQualifiedTableName) || + airbyteStream.syncMode != SyncMode.INCREMENTAL || + hasSourceDefinedCursor + ) { continue } @@ -176,7 +215,9 @@ abstract class AbstractDbSource protecte if (cursorField.isEmpty) { continue } - val cursorType = table.fields.stream() + val cursorType = + table.fields!! + .stream() .filter { info: CommonField -> info.name == cursorField.get() } .map { obj: CommonField -> obj.type } .findFirst() @@ -184,21 +225,39 @@ abstract class AbstractDbSource protecte if (!isCursorType(cursorType)) { tablesWithInvalidCursor.add( - InvalidCursorInfoUtil.InvalidCursorInfo(fullyQualifiedTableName, cursorField.get(), - cursorType.toString(), "Unsupported cursor type")) + InvalidCursorInfoUtil.InvalidCursorInfo( + fullyQualifiedTableName, + cursorField.get(), + cursorType.toString(), + "Unsupported cursor type" + ) + ) continue } - if (!verifyCursorColumnValues(database, stream.namespace, stream.name, cursorField.get())) { + if ( + !verifyCursorColumnValues( + database, + stream.namespace, + stream.name, + cursorField.get() + ) + ) { tablesWithInvalidCursor.add( - InvalidCursorInfoUtil.InvalidCursorInfo(fullyQualifiedTableName, cursorField.get(), - cursorType.toString(), "Cursor column contains NULL value")) + InvalidCursorInfoUtil.InvalidCursorInfo( + fullyQualifiedTableName, + cursorField.get(), + cursorType.toString(), + "Cursor column contains NULL value" + ) + ) } } if (!tablesWithInvalidCursor.isEmpty()) { throw ConfigErrorException( - InvalidCursorInfoUtil.getInvalidCursorConfigMessage(tablesWithInvalidCursor)) + InvalidCursorInfoUtil.getInvalidCursorConfigMessage(tablesWithInvalidCursor) + ) } } @@ -210,68 +269,86 @@ abstract class AbstractDbSource protecte * @throws SQLException exception */ @Throws(SQLException::class) - protected fun verifyCursorColumnValues(database: Database, schema: String?, tableName: String?, columnName: String?): Boolean { + protected fun verifyCursorColumnValues( + database: Database, + schema: String?, + tableName: String?, + columnName: String? + ): Boolean { /* no-op */ return true } /** - * Estimates the total volume (rows and bytes) to sync and emits a - * [AirbyteEstimateTraceMessage] associated with the full refresh stream. + * Estimates the total volume (rows and bytes) to sync and emits a [AirbyteEstimateTraceMessage] + * associated with the full refresh stream. * * @param database database */ - protected fun estimateFullRefreshSyncSize(database: Database, - configuredAirbyteStream: ConfiguredAirbyteStream?) { + protected fun estimateFullRefreshSyncSize( + database: Database, + configuredAirbyteStream: ConfiguredAirbyteStream? + ) { /* no-op */ } @Throws(Exception::class) - protected fun discoverWithoutSystemTables(database: Database): List>> { + protected fun discoverWithoutSystemTables( + database: Database + ): List>> { val systemNameSpaces = excludedInternalNameSpaces val systemViews = excludedViews val discoveredTables = discoverInternal(database) return (if (systemNameSpaces == null || systemNameSpaces.isEmpty()) discoveredTables - else discoveredTables.stream() - .filter { table: TableInfo> -> !systemNameSpaces.contains(table.nameSpace) && !systemViews.contains(table.name) }.collect( - Collectors.toList())) + else + discoveredTables + .stream() + .filter { table: TableInfo> -> + !systemNameSpaces.contains(table.nameSpace) && !systemViews.contains(table.name) + } + .collect(Collectors.toList())) } protected fun getFullRefreshIterators( - database: Database, - catalog: ConfiguredAirbyteCatalog?, - tableNameToTable: Map>>, - stateManager: StateManager?, - emittedAt: Instant): List> { + database: Database, + catalog: ConfiguredAirbyteCatalog, + tableNameToTable: Map>>, + stateManager: StateManager?, + emittedAt: Instant + ): List> { return getSelectedIterators( - database, - catalog, - tableNameToTable, - stateManager, - emittedAt, - SyncMode.FULL_REFRESH) + database, + catalog, + tableNameToTable, + stateManager, + emittedAt, + SyncMode.FULL_REFRESH + ) } protected fun getIncrementalIterators( - database: Database, - catalog: ConfiguredAirbyteCatalog?, - tableNameToTable: Map>>, - stateManager: StateManager?, - emittedAt: Instant): List> { + database: Database, + catalog: ConfiguredAirbyteCatalog, + tableNameToTable: Map>>, + stateManager: StateManager?, + emittedAt: Instant + ): List> { return getSelectedIterators( - database, - catalog, - tableNameToTable, - stateManager, - emittedAt, - SyncMode.INCREMENTAL) + database, + catalog, + tableNameToTable, + stateManager, + emittedAt, + SyncMode.INCREMENTAL + ) } /** * Creates a list of read iterators for each stream within an ConfiguredAirbyteCatalog * * @param database Source Database - * @param catalog List of streams (e.g. database tables or API endpoints) with settings on sync mode + * @param catalog List of streams (e.g. database tables or API endpoints) with settings on sync + * mode * @param tableNameToTable Mapping of table name to table * @param stateManager Manager used to track the state of data synced by the connector * @param emittedAt Time when data was emitted from the Source database @@ -279,31 +356,30 @@ abstract class AbstractDbSource protecte * @return List of AirbyteMessageIterators containing all iterators for a catalog */ private fun getSelectedIterators( - database: Database, - catalog: ConfiguredAirbyteCatalog?, - tableNameToTable: Map>>, - stateManager: StateManager?, - emittedAt: Instant, - syncMode: SyncMode): List> { - val iteratorList: MutableList> = ArrayList() + database: Database, + catalog: ConfiguredAirbyteCatalog?, + tableNameToTable: Map>>, + stateManager: StateManager?, + emittedAt: Instant, + syncMode: SyncMode + ): List> { + val iteratorList: MutableList> = ArrayList() for (airbyteStream in catalog!!.streams) { if (airbyteStream.syncMode == syncMode) { val stream = airbyteStream.stream - val fullyQualifiedTableName = DbSourceDiscoverUtil.getFullyQualifiedTableName(stream.namespace, - stream.name) + val fullyQualifiedTableName = + DbSourceDiscoverUtil.getFullyQualifiedTableName(stream.namespace, stream.name) if (!tableNameToTable.containsKey(fullyQualifiedTableName)) { - LOGGER - .info("Skipping stream {} because it is not in the source", fullyQualifiedTableName) + LOGGER.info( + "Skipping stream {} because it is not in the source", + fullyQualifiedTableName + ) continue } val table = tableNameToTable[fullyQualifiedTableName]!! - val tableReadIterator = createReadIterator( - database, - airbyteStream, - table, - stateManager, - emittedAt) + val tableReadIterator = + createReadIterator(database, airbyteStream, table, stateManager, emittedAt) iteratorList.add(tableReadIterator) } } @@ -321,17 +397,20 @@ abstract class AbstractDbSource protecte * @param emittedAt Time when data was emitted from the Source database * @return */ - private fun createReadIterator(database: Database, - airbyteStream: ConfiguredAirbyteStream, - table: TableInfo>, - stateManager: StateManager?, - emittedAt: Instant): AutoCloseableIterator { + private fun createReadIterator( + database: Database, + airbyteStream: ConfiguredAirbyteStream, + table: TableInfo>, + stateManager: StateManager?, + emittedAt: Instant + ): AutoCloseableIterator { val streamName = airbyteStream.stream.name val namespace = airbyteStream.stream.namespace - val pair = io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair(streamName, - namespace) + val pair = + io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair(streamName, namespace) val selectedFieldsInCatalog = CatalogHelpers.getTopLevelFieldNames(airbyteStream) - val selectedDatabaseFields = table.fields + val selectedDatabaseFields = + table.fields .stream() .map { obj: CommonField -> obj.name } .filter { o: String -> selectedFieldsInCatalog.contains(o) } @@ -345,48 +424,82 @@ abstract class AbstractDbSource protecte val cursorInfo = stateManager!!.getCursorInfo(pair) val airbyteMessageIterator: AutoCloseableIterator - if (cursorInfo!!.map { obj: CursorInfo? -> obj.getCursor() }.isPresent) { - airbyteMessageIterator = getIncrementalStream( + if (cursorInfo!!.map { it.cursor }.isPresent) { + airbyteMessageIterator = + getIncrementalStream( database, airbyteStream, selectedDatabaseFields, table, cursorInfo.get(), - emittedAt) + emittedAt + ) } else { - // if no cursor is present then this is the first read for is the same as doing a full refresh read. + // if no cursor is present then this is the first read for is the same as doing a + // full refresh read. estimateFullRefreshSyncSize(database, airbyteStream) - airbyteMessageIterator = getFullRefreshStream(database, streamName, namespace, - selectedDatabaseFields, table, emittedAt, SyncMode.INCREMENTAL, Optional.of(cursorField)) + airbyteMessageIterator = + getFullRefreshStream( + database, + streamName, + namespace, + selectedDatabaseFields, + table, + emittedAt, + SyncMode.INCREMENTAL, + Optional.of(cursorField) + ) } - val cursorType = getCursorType(airbyteStream, - cursorField) + val cursorType = getCursorType(airbyteStream, cursorField) - val messageProducer = CursorStateMessageProducer( - stateManager, - cursorInfo.map { obj: CursorInfo? -> obj.getCursor() }) + val messageProducer = + CursorStateMessageProducer(stateManager, cursorInfo.map { it.cursor }) - iterator = AutoCloseableIterators.transform( - { autoCloseableIterator: AutoCloseableIterator -> - SourceStateIterator(autoCloseableIterator, airbyteStream, messageProducer, - StateEmitFrequency(stateEmissionFrequency.toLong(), - Duration.ZERO)) + iterator = + AutoCloseableIterators.transform( + { autoCloseableIterator: AutoCloseableIterator -> + SourceStateIterator( + autoCloseableIterator, + airbyteStream, + messageProducer, + StateEmitFrequency(stateEmissionFrequency.toLong(), Duration.ZERO) + ) }, airbyteMessageIterator, - AirbyteStreamUtils.convertFromNameAndNamespace(pair.name, pair.namespace)) + AirbyteStreamUtils.convertFromNameAndNamespace(pair.name, pair.namespace) + ) } else if (airbyteStream.syncMode == SyncMode.FULL_REFRESH) { estimateFullRefreshSyncSize(database, airbyteStream) - iterator = getFullRefreshStream(database, streamName, namespace, selectedDatabaseFields, - table, emittedAt, SyncMode.FULL_REFRESH, Optional.empty()) - } else requireNotNull(airbyteStream.syncMode) { String.format("%s requires a source sync mode", this.javaClass) } - throw IllegalArgumentException( - String.format("%s does not support sync mode: %s.", this.javaClass, - airbyteStream.syncMode)) + iterator = + getFullRefreshStream( + database, + streamName, + namespace, + selectedDatabaseFields, + table, + emittedAt, + SyncMode.FULL_REFRESH, + Optional.empty() + ) + } else if (airbyteStream.syncMode == null) { + throw IllegalArgumentException( + String.format("%s requires a source sync mode", this.javaClass) + ) + } else { + throw IllegalArgumentException( + String.format( + "%s does not support sync mode: %s.", + this.javaClass, + airbyteStream.syncMode + ) + ) + } val recordCount = AtomicLong() - return AutoCloseableIterators.transform(iterator, - AirbyteStreamUtils.convertFromNameAndNamespace(pair.name, pair.namespace) + return AutoCloseableIterators.transform( + iterator, + AirbyteStreamUtils.convertFromNameAndNamespace(pair.name, pair.namespace) ) { r: AirbyteMessage? -> val count = recordCount.incrementAndGet() if (count % 10000 == 0L) { @@ -405,32 +518,39 @@ abstract class AbstractDbSource protecte * @param emittedAt Time when data was emitted from the Source database * @return AirbyteMessage Iterator that */ - private fun getIncrementalStream(database: Database, - airbyteStream: ConfiguredAirbyteStream, - selectedDatabaseFields: List, - table: TableInfo>, - cursorInfo: CursorInfo, - emittedAt: Instant): AutoCloseableIterator { + private fun getIncrementalStream( + database: Database, + airbyteStream: ConfiguredAirbyteStream, + selectedDatabaseFields: List, + table: TableInfo>, + cursorInfo: CursorInfo, + emittedAt: Instant + ): AutoCloseableIterator { val streamName = airbyteStream.stream.name val namespace = airbyteStream.stream.namespace val cursorField = getCursorField(airbyteStream) - val cursorType = table.fields.stream() + val cursorType = + table.fields + .stream() .filter { info: CommonField -> info.name == cursorField } .map { obj: CommonField -> obj.type } .findFirst() .orElseThrow() Preconditions.checkState( - table.fields.stream().anyMatch { f: CommonField -> f.name == cursorField }, - String.format("Could not find cursor field %s in table %s", cursorField, table.name)) + table.fields.stream().anyMatch { f: CommonField -> f.name == cursorField }, + String.format("Could not find cursor field %s in table %s", cursorField, table.name) + ) - val queryIterator = queryTableIncremental( + val queryIterator = + queryTableIncremental( database, selectedDatabaseFields, table.nameSpace, table.name, cursorInfo, - cursorType) + cursorType + ) return getMessageIterator(queryIterator, streamName, namespace, emittedAt.toEpochMilli()) } @@ -439,8 +559,8 @@ abstract class AbstractDbSource protecte * Creates a AirbyteMessageIterator that contains all records for a database source connection * * @param database Source Database - * @param streamName name of an individual stream in which a stream represents a source (e.g. API - * endpoint or database table) + * @param streamName name of an individual stream in which a stream represents a source (e.g. + * API endpoint or database table) * @param namespace Namespace of the database (e.g. public) * @param selectedDatabaseFields List of all interested database column names * @param table information in tabular format @@ -448,37 +568,48 @@ abstract class AbstractDbSource protecte * @param syncMode The sync mode that this full refresh stream should be associated with. * @return AirbyteMessageIterator with all records for a database source */ - private fun getFullRefreshStream(database: Database, - streamName: String, - namespace: String, - selectedDatabaseFields: List, - table: TableInfo>, - emittedAt: Instant, - syncMode: SyncMode, - cursorField: Optional): AutoCloseableIterator { + private fun getFullRefreshStream( + database: Database, + streamName: String, + namespace: String, + selectedDatabaseFields: List, + table: TableInfo>, + emittedAt: Instant, + syncMode: SyncMode, + cursorField: Optional + ): AutoCloseableIterator { val queryStream = - queryTableFullRefresh(database, selectedDatabaseFields, table.nameSpace, - table.name, syncMode, cursorField) + queryTableFullRefresh( + database, + selectedDatabaseFields, + table.nameSpace, + table.name, + syncMode, + cursorField + ) return getMessageIterator(queryStream, streamName, namespace, emittedAt.toEpochMilli()) } /** - * @param database - The database where from privileges for tables will be consumed - * @param schema - The schema where from privileges for tables will be consumed - * @return Set with privileges for tables for current DB-session user The method is responsible for - * SELECT-ing the table with privileges. In some cases such SELECT doesn't require (e.g. in - * Oracle DB - the schema is the user, you cannot REVOKE a privilege on a table from its - * owner). + * @param database + * - The database where from privileges for tables will be consumed + * @param schema + * - The schema where from privileges for tables will be consumed + * @return Set with privileges for tables for current DB-session user The method is responsible + * for SELECT-ing the table with privileges. In some cases such SELECT doesn't require (e.g. in + * Oracle DB - the schema is the user, you cannot REVOKE a privilege on a table from its owner). */ @Throws(SQLException::class) - protected fun getPrivilegesTableForCurrentUser(database: JdbcDatabase?, - schema: String?): Set { + protected fun getPrivilegesTableForCurrentUser( + database: JdbcDatabase?, + schema: String? + ): Set { return emptySet() } /** - * Map a database implementation-specific configuration to json object that adheres to the database - * config spec. See resources/spec.json. + * Map a database implementation-specific configuration to json object that adheres to the + * database config spec. See resources/spec.json. * * @param config database implementation-specific configuration. * @return database spec config @@ -495,19 +626,21 @@ abstract class AbstractDbSource protecte */ @Trace(operationName = DISCOVER_TRACE_OPERATION_NAME) @Throws(Exception::class) - protected abstract fun createDatabase(config: JsonNode?): Database + protected abstract fun createDatabase(config: JsonNode): Database /** - * Gets and logs relevant and useful database metadata such as DB product/version, index names and - * definition. Called before syncing data. Any logged information should be scoped to the configured - * catalog and database. + * Gets and logs relevant and useful database metadata such as DB product/version, index names + * and definition. Called before syncing data. Any logged information should be scoped to the + * configured catalog and database. * * @param database given database instance. * @param catalog configured catalog. */ @Throws(Exception::class) - protected open fun logPreSyncDebugData(database: Database, catalog: ConfiguredAirbyteCatalog?) { - } + protected open fun logPreSyncDebugData( + database: Database, + catalog: ConfiguredAirbyteCatalog? + ) {} /** * Configures a list of operations that can be used to check the connection to the source. @@ -515,7 +648,9 @@ abstract class AbstractDbSource protecte * @return list of consumers that run queries for the check command. */ @Throws(Exception::class) - protected abstract fun getCheckOperations(config: JsonNode?): List> + protected abstract fun getCheckOperations( + config: JsonNode? + ): List> /** * Map source types to Airbyte types @@ -527,7 +662,8 @@ abstract class AbstractDbSource protecte protected abstract val excludedInternalNameSpaces: Set /** - * Get list of system namespaces(schemas) in order to exclude them from the `discover` result list. + * Get list of system namespaces(schemas) in order to exclude them from the `discover` + * result list. * * @return set of system namespaces(schemas) to be excluded */ @@ -551,20 +687,25 @@ abstract class AbstractDbSource protecte @Trace(operationName = DISCOVER_TRACE_OPERATION_NAME) @Throws(Exception::class) protected abstract fun discoverInternal( - database: Database): List>> + database: Database + ): List>> /** * Discovers all available tables within a schema in the source database. * - * @param database - source database - * @param schema - source schema + * @param database + * - source database + * @param schema + * - source schema * @return list of source tables - * @throws Exception - access to the database might lead to exceptions. + * @throws Exception + * - access to the database might lead to exceptions. */ @Throws(Exception::class) protected abstract fun discoverInternal( - database: Database, - schema: String?): List>> + database: Database, + schema: String? + ): List>> /** * Discover Primary keys for each table and @return a map of namespace.table name to their @@ -574,8 +715,10 @@ abstract class AbstractDbSource protecte * @param tableInfos list of tables * @return map of namespace.table and primary key fields. */ - protected abstract fun discoverPrimaryKeys(database: Database, - tableInfos: List>>): Map> + protected abstract fun discoverPrimaryKeys( + database: Database, + tableInfos: List>> + ): Map> protected abstract val quoteString: String? /** @@ -595,41 +738,43 @@ abstract class AbstractDbSource protecte * @param syncMode The sync mode that this full refresh stream should be associated with. * @return iterator with read data */ - protected abstract fun queryTableFullRefresh(database: Database, - columnNames: List, - schemaName: String?, - tableName: String, - syncMode: SyncMode, - cursorField: Optional): AutoCloseableIterator? + protected abstract fun queryTableFullRefresh( + database: Database, + columnNames: List, + schemaName: String?, + tableName: String, + syncMode: SyncMode, + cursorField: Optional + ): AutoCloseableIterator? /** * Read incremental data from a table. Incremental read should return only records where cursor - * column value is bigger than cursor. Note that if the connector needs to emit intermediate state - * (i.e. [AbstractDbSource.getStateEmissionFrequency] > 0), the incremental query must be + * column value is bigger than cursor. Note that if the connector needs to emit intermediate + * state (i.e. [AbstractDbSource.getStateEmissionFrequency] > 0), the incremental query must be * sorted by the cursor field. * * @return iterator with read data */ - protected abstract fun queryTableIncremental(database: Database, - columnNames: List, - schemaName: String?, - tableName: String, - cursorInfo: CursorInfo, - cursorFieldType: DataType): AutoCloseableIterator? + protected abstract fun queryTableIncremental( + database: Database, + columnNames: List, + schemaName: String?, + tableName: String, + cursorInfo: CursorInfo, + cursorFieldType: DataType + ): AutoCloseableIterator? protected val stateEmissionFrequency: Int /** - * When larger than 0, the incremental iterator will emit intermediate state for every N records. - * Please note that if intermediate state emission is enabled, the incremental query must be ordered - * by the cursor field. + * When larger than 0, the incremental iterator will emit intermediate state for every N + * records. Please note that if intermediate state emission is enabled, the incremental + * query must be ordered by the cursor field. * * TODO: Return an optional value instead of 0 to make it easier to understand. */ get() = 0 - /** - * @return list of fields that could be used as cursors - */ + /** @return list of fields that could be used as cursors */ protected abstract fun isCursorType(type: DataType): Boolean /** @@ -638,7 +783,9 @@ abstract class AbstractDbSource protecte * @param config The connector configuration. * @return A [AirbyteStateType] representing the state supported by this connector. */ - protected open fun getSupportedStateType(config: JsonNode?): AirbyteStateMessage.AirbyteStateType { + protected open fun getSupportedStateType( + config: JsonNode? + ): AirbyteStateMessage.AirbyteStateType { return AirbyteStateMessage.AirbyteStateType.STREAM } @@ -650,20 +797,24 @@ abstract class AbstractDbSource protecte private val LOGGER: Logger = LoggerFactory.getLogger(AbstractDbSource::class.java) private fun getMessageIterator( - recordIterator: AutoCloseableIterator?, - streamName: String, - namespace: String, - emittedAt: Long): AutoCloseableIterator { - return AutoCloseableIterators.transform(recordIterator, - AirbyteStreamNameNamespacePair(streamName, namespace) + recordIterator: AutoCloseableIterator?, + streamName: String, + namespace: String, + emittedAt: Long + ): AutoCloseableIterator { + return AutoCloseableIterators.transform( + recordIterator, + AirbyteStreamNameNamespacePair(streamName, namespace) ) { r: JsonNode? -> AirbyteMessage() - .withType(AirbyteMessage.Type.RECORD) - .withRecord(AirbyteRecordMessage() - .withStream(streamName) - .withNamespace(namespace) - .withEmittedAt(emittedAt) - .withData(r)) + .withType(AirbyteMessage.Type.RECORD) + .withRecord( + AirbyteRecordMessage() + .withStream(streamName) + .withNamespace(namespace) + .withEmittedAt(emittedAt) + .withData(r) + ) } } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/CdcStateManager.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/CdcStateManager.kt index 2a482c73fde05..7662c0fb9878f 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/CdcStateManager.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/CdcStateManager.kt @@ -7,21 +7,24 @@ import io.airbyte.cdk.integrations.source.relationaldb.models.CdcState import io.airbyte.commons.json.Jsons import io.airbyte.protocol.models.v0.AirbyteStateMessage import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair +import java.util.* import org.slf4j.Logger import org.slf4j.LoggerFactory -import java.util.* -class CdcStateManager(private val initialState: CdcState?, - initialStreamsSynced: Set?, - stateMessage: AirbyteStateMessage?) { - private val initialStreamsSynced: Set? +class CdcStateManager( + private val initialState: CdcState?, + initialStreamsSynced: Set?, + stateMessage: AirbyteStateMessage? +) { + val initialStreamsSynced: Set? val rawStateMessage: AirbyteStateMessage? private var currentState: CdcState? init { this.currentState = initialState - this.initialStreamsSynced = initialStreamsSynced - + this.initialStreamsSynced = + if (initialStreamsSynced != null) Collections.unmodifiableSet(initialStreamsSynced) + else null this.rawStateMessage = stateMessage LOGGER.info("Initialized CDC state") } @@ -32,15 +35,13 @@ class CdcStateManager(private val initialState: CdcState?, this.currentState = state } - fun getInitialStreamsSynced(): Set? { - return if (initialStreamsSynced != null) Collections.unmodifiableSet(initialStreamsSynced) else null - } - override fun toString(): String { return "CdcStateManager{" + - "initialState=" + initialState + - ", currentState=" + currentState + - '}' + "initialState=" + + initialState + + ", currentState=" + + currentState + + '}' } companion object { diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/CursorInfo.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/CursorInfo.kt index 0159966be8617..b4e4721d1bb18 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/CursorInfo.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/CursorInfo.kt @@ -5,16 +5,20 @@ package io.airbyte.cdk.integrations.source.relationaldb import java.util.* -class CursorInfo(val originalCursorField: String?, - val originalCursor: String?, - val originalCursorRecordCount: Long, - val cursorField: String?, - var cursor: String?, - var cursorRecordCount: Long) { - constructor(originalCursorField: String?, - originalCursor: String?, - cursorField: String?, - cursor: String?) : this(originalCursorField, originalCursor, 0L, cursorField, cursor, 0L) +class CursorInfo( + val originalCursorField: String?, + val originalCursor: String?, + val originalCursorRecordCount: Long, + val cursorField: String?, + var cursor: String?, + var cursorRecordCount: Long +) { + constructor( + originalCursorField: String?, + originalCursor: String?, + cursorField: String?, + cursor: String? + ) : this(originalCursorField, originalCursor, 0L, cursorField, cursor, 0L) fun setCursor(cursor: String?): CursorInfo { this.cursor = cursor @@ -34,21 +38,45 @@ class CursorInfo(val originalCursorField: String?, return false } val that = o as CursorInfo - return originalCursorField == that.originalCursorField && originalCursor == that.originalCursor && originalCursorRecordCount == that.originalCursorRecordCount && cursorField == that.cursorField && cursor == that.cursor && cursorRecordCount == that.cursorRecordCount + return originalCursorField == that.originalCursorField && + originalCursor == that.originalCursor && + originalCursorRecordCount == that.originalCursorRecordCount && + cursorField == that.cursorField && + cursor == that.cursor && + cursorRecordCount == that.cursorRecordCount } override fun hashCode(): Int { - return Objects.hash(originalCursorField, originalCursor, originalCursorRecordCount, cursorField, cursor, cursorRecordCount) + return Objects.hash( + originalCursorField, + originalCursor, + originalCursorRecordCount, + cursorField, + cursor, + cursorRecordCount + ) } override fun toString(): String { return "CursorInfo{" + - "originalCursorField='" + originalCursorField + '\'' + - ", originalCursor='" + originalCursor + '\'' + - ", originalCursorRecordCount='" + originalCursorRecordCount + '\'' + - ", cursorField='" + cursorField + '\'' + - ", cursor='" + cursor + '\'' + - ", cursorRecordCount='" + cursorRecordCount + '\'' + - '}' + "originalCursorField='" + + originalCursorField + + '\'' + + ", originalCursor='" + + originalCursor + + '\'' + + ", originalCursorRecordCount='" + + originalCursorRecordCount + + '\'' + + ", cursorField='" + + cursorField + + '\'' + + ", cursor='" + + cursor + + '\'' + + ", cursorRecordCount='" + + cursorRecordCount + + '\'' + + '}' } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/DbSourceDiscoverUtil.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/DbSourceDiscoverUtil.kt index ac542b4eec3d6..4bf46677fd3b5 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/DbSourceDiscoverUtil.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/DbSourceDiscoverUtil.kt @@ -11,41 +11,42 @@ import io.airbyte.protocol.models.v0.AirbyteCatalog import io.airbyte.protocol.models.v0.CatalogHelpers import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog import io.airbyte.protocol.models.v0.SyncMode -import org.slf4j.Logger -import org.slf4j.LoggerFactory import java.util.* import java.util.function.Consumer import java.util.function.Function import java.util.stream.Collectors +import org.slf4j.Logger +import org.slf4j.LoggerFactory -/** - * Contains utilities and helper classes for discovering schemas in database sources. - */ +/** Contains utilities and helper classes for discovering schemas in database sources. */ object DbSourceDiscoverUtil { private val LOGGER: Logger = LoggerFactory.getLogger(DbSourceDiscoverUtil::class.java) - private val AIRBYTE_METADATA: List = mutableListOf("_ab_cdc_lsn", - "_ab_cdc_updated_at", - "_ab_cdc_deleted_at") + private val AIRBYTE_METADATA: List = + mutableListOf("_ab_cdc_lsn", "_ab_cdc_updated_at", "_ab_cdc_deleted_at") /* - * This method logs schema drift between source table and the catalog. This can happen if (i) - * underlying table schema changed between syncs (ii) The source connector's mapping of datatypes to - * Airbyte types changed between runs - */ - fun logSourceSchemaChange(fullyQualifiedTableNameToInfo: Map>>, - catalog: ConfiguredAirbyteCatalog?, - airbyteTypeConverter: Function) { - for (airbyteStream in catalog!!.streams) { + * This method logs schema drift between source table and the catalog. This can happen if (i) + * underlying table schema changed between syncs (ii) The source connector's mapping of datatypes to + * Airbyte types changed between runs + */ + fun logSourceSchemaChange( + fullyQualifiedTableNameToInfo: Map>>, + catalog: ConfiguredAirbyteCatalog, + airbyteTypeConverter: Function + ) { + for (airbyteStream in catalog.streams) { val stream = airbyteStream.stream - val fullyQualifiedTableName = getFullyQualifiedTableName(stream.namespace, - stream.name) + val fullyQualifiedTableName = getFullyQualifiedTableName(stream.namespace, stream.name) if (!fullyQualifiedTableNameToInfo.containsKey(fullyQualifiedTableName)) { continue } val table = fullyQualifiedTableNameToInfo[fullyQualifiedTableName]!! - val fields = table.fields + val fields = + table.fields .stream() - .map { commonField: CommonField -> toField(commonField, airbyteTypeConverter) } + .map { commonField: CommonField -> + toField(commonField, airbyteTypeConverter) + } .distinct() .collect(Collectors.toList()) val currentJsonSchema = CatalogHelpers.fieldsToJsonSchema(fields) @@ -54,67 +55,93 @@ object DbSourceDiscoverUtil { val catalogProperties = catalogSchema["properties"] val mismatchedFields: MutableList = ArrayList() catalogProperties.fieldNames().forEachRemaining { fieldName: String -> - // Ignoring metadata fields since those are automatically added onto the catalog schema by Airbyte + // Ignoring metadata fields since those are automatically added onto the catalog + // schema by Airbyte // and don't exist in the source schema. They should not be considered a change if (AIRBYTE_METADATA.contains(fieldName)) { return@forEachRemaining } - if (!currentSchemaProperties.has(fieldName) || - currentSchemaProperties[fieldName] != catalogProperties[fieldName]) { + if ( + !currentSchemaProperties.has(fieldName) || + currentSchemaProperties[fieldName] != catalogProperties[fieldName] + ) { mismatchedFields.add(fieldName) } } if (!mismatchedFields.isEmpty()) { LOGGER.warn( - "Source schema changed for table {}! Potential mismatches: {}. Actual schema: {}. Catalog schema: {}", - fullyQualifiedTableName, - java.lang.String.join(", ", mismatchedFields.toString()), - currentJsonSchema, - catalogSchema) + "Source schema changed for table {}! Potential mismatches: {}. Actual schema: {}. Catalog schema: {}", + fullyQualifiedTableName, + java.lang.String.join(", ", mismatchedFields.toString()), + currentJsonSchema, + catalogSchema + ) } } } - fun convertTableInfosToAirbyteCatalog(tableInfos: List>>, - fullyQualifiedTableNameToPrimaryKeys: Map>, - airbyteTypeConverter: Function): AirbyteCatalog { - val tableInfoFieldList = tableInfos.stream() + fun convertTableInfosToAirbyteCatalog( + tableInfos: List>>, + fullyQualifiedTableNameToPrimaryKeys: Map>, + airbyteTypeConverter: Function + ): AirbyteCatalog { + val tableInfoFieldList = + tableInfos + .stream() .map { t: TableInfo> -> - // some databases return multiple copies of the same record for a column (e.g. redshift) because - // they have at least once delivery guarantees. we want to dedupe these, but first we check that the - // records are actually the same and provide a good error message if they are not. + // some databases return multiple copies of the same record for a column (e.g. + // redshift) because + // they have at least once delivery guarantees. we want to dedupe these, but + // first we check that the + // records are actually the same and provide a good error message if they are + // not. assertColumnsWithSameNameAreSame(t.nameSpace, t.name, t.fields) - val fields = t.fields + val fields = + t.fields .stream() - .map { commonField: CommonField -> toField(commonField, airbyteTypeConverter) } + .map { commonField: CommonField -> + toField(commonField, airbyteTypeConverter) + } .distinct() .collect(Collectors.toList()) - val fullyQualifiedTableName = getFullyQualifiedTableName(t.nameSpace, - t.name) - val primaryKeys = fullyQualifiedTableNameToPrimaryKeys.getOrDefault( - fullyQualifiedTableName, emptyList()) - TableInfo.builder().nameSpace(t.nameSpace).name(t.name) - .fields(fields).primaryKeys(primaryKeys) - .cursorFields(t.cursorFields) - .build() + val fullyQualifiedTableName = getFullyQualifiedTableName(t.nameSpace, t.name) + val primaryKeys = + fullyQualifiedTableNameToPrimaryKeys.getOrDefault( + fullyQualifiedTableName, + emptyList() + ) + TableInfo( + nameSpace = t.nameSpace, + name = t.name, + fields = fields, + primaryKeys = primaryKeys, + cursorFields = t.cursorFields + ) } .collect(Collectors.toList()) - val streams = tableInfoFieldList.stream() + val streams = + tableInfoFieldList + .stream() .map { tableInfo: TableInfo -> - val primaryKeys = tableInfo.primaryKeys.stream() + val primaryKeys = + tableInfo.primaryKeys + .stream() .filter { obj: String? -> Objects.nonNull(obj) } - .map(Function> { o: String? -> listOf(o) }) - .collect(Collectors.toList()) - CatalogHelpers - .createAirbyteStream(tableInfo.name, tableInfo.nameSpace, - tableInfo.fields) - .withSupportedSyncModes( - if (tableInfo.cursorFields != null && tableInfo.cursorFields.isEmpty() - ) Lists.newArrayList(SyncMode.FULL_REFRESH) - else Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL)) - .withSourceDefinedPrimaryKey(primaryKeys) + .map { listOf(it) } + .toList() + CatalogHelpers.createAirbyteStream( + tableInfo.name, + tableInfo.nameSpace, + tableInfo.fields + ) + .withSupportedSyncModes( + if (tableInfo.cursorFields != null && tableInfo.cursorFields.isEmpty()) + Lists.newArrayList(SyncMode.FULL_REFRESH) + else Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL) + ) + .withSourceDefinedPrimaryKey(primaryKeys) } .collect(Collectors.toList()) return AirbyteCatalog().withStreams(streams) @@ -124,31 +151,60 @@ object DbSourceDiscoverUtil { return if (nameSpace != null) "$nameSpace.$tableName" else tableName } - private fun toField(commonField: CommonField, airbyteTypeConverter: Function): Field { - if (airbyteTypeConverter.apply(commonField.type) === JsonSchemaType.OBJECT && commonField.properties != null && !commonField.properties.isEmpty()) { - val properties = commonField.properties.stream().map { commField: CommonField -> toField(commField, airbyteTypeConverter) }.toList() - return Field.of(commonField.name, airbyteTypeConverter.apply(commonField.type), properties) + private fun toField( + commonField: CommonField, + airbyteTypeConverter: Function + ): Field { + if ( + airbyteTypeConverter.apply(commonField.type) === JsonSchemaType.OBJECT && + commonField.properties != null && + !commonField.properties.isEmpty() + ) { + val properties = + commonField.properties + .stream() + .map { commField: CommonField -> + toField(commField, airbyteTypeConverter) + } + .toList() + return Field.of( + commonField.name, + airbyteTypeConverter.apply(commonField.type), + properties + ) } else { return Field.of(commonField.name, airbyteTypeConverter.apply(commonField.type)) } } - private fun assertColumnsWithSameNameAreSame(nameSpace: String, - tableName: String, - columns: List>) { - columns.stream() - .collect(Collectors.groupingBy(Function { obj: CommonField -> obj.name })) - .values - .forEach(Consumer { columnsWithSameName: List> -> + private fun assertColumnsWithSameNameAreSame( + nameSpace: String, + tableName: String, + columns: List> + ) { + columns + .stream() + .collect(Collectors.groupingBy(Function { obj: CommonField -> obj.name })) + .values + .forEach( + Consumer { columnsWithSameName: List> -> val comparisonColumn = columnsWithSameName[0] - columnsWithSameName.forEach(Consumer { column: CommonField -> - if (column != comparisonColumn) { - throw RuntimeException( + columnsWithSameName.forEach( + Consumer { column: CommonField -> + if (column != comparisonColumn) { + throw RuntimeException( String.format( - "Found multiple columns with same name: %s in table: %s.%s but the columns are not the same. columns: %s", - comparisonColumn.name, nameSpace, tableName, columns)) + "Found multiple columns with same name: %s in table: %s.%s but the columns are not the same. columns: %s", + comparisonColumn.name, + nameSpace, + tableName, + columns + ) + ) + } } - }) - }) + ) + } + ) } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/InvalidCursorInfoUtil.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/InvalidCursorInfoUtil.kt index 5caab6e864719..d2c8e2b5ee016 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/InvalidCursorInfoUtil.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/InvalidCursorInfoUtil.kt @@ -7,19 +7,32 @@ import java.util.stream.Collectors object InvalidCursorInfoUtil { fun getInvalidCursorConfigMessage(tablesWithInvalidCursor: List): String { - return ("The following tables have invalid columns selected as cursor, please select a column with a well-defined ordering with no null values as a cursor. " - + tablesWithInvalidCursor.stream().map { obj: InvalidCursorInfo -> obj.toString() } + return ("The following tables have invalid columns selected as cursor, please select a column with a well-defined ordering with no null values as a cursor. " + + tablesWithInvalidCursor + .stream() + .map { obj: InvalidCursorInfo -> obj.toString() } .collect(Collectors.joining(","))) } - class InvalidCursorInfo(tableName: String?, cursorColumnName: String, cursorSqlType: String, cause: String) { + class InvalidCursorInfo( + tableName: String?, + cursorColumnName: String, + cursorSqlType: String, + cause: String + ) { override fun toString(): String { return "{" + - "tableName='" + tableName + '\'' + - ", cursorColumnName='" + cursorColumnName + '\'' + - ", cursorSqlType=" + cursorSqlType + - ", cause=" + cause + - '}' + "tableName='" + + tableName + + '\'' + + ", cursorColumnName='" + + cursorColumnName + + '\'' + + ", cursorSqlType=" + + cursorSqlType + + ", cause=" + + cause + + '}' } val tableName: String? @@ -28,11 +41,6 @@ object InvalidCursorInfoUtil { val cause: String init { - this.streamName = streamName - this.primaryKey = primaryKey - this.keySequence = keySequence - this.syncCheckpointRecords = syncCheckpointRecords - this.syncCheckpointDuration = syncCheckpointDuration this.tableName = tableName this.cursorColumnName = cursorColumnName this.cursorSqlType = cursorSqlType diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/RelationalDbQueryUtils.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/RelationalDbQueryUtils.kt index 31e46ad3ba57b..bd164a44486a2 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/RelationalDbQueryUtils.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/RelationalDbQueryUtils.kt @@ -9,28 +9,26 @@ import io.airbyte.commons.stream.AirbyteStreamUtils import io.airbyte.commons.util.AutoCloseableIterator import io.airbyte.commons.util.AutoCloseableIterators import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream -import org.slf4j.Logger -import org.slf4j.LoggerFactory import java.util.* import java.util.stream.Collectors +import org.slf4j.Logger +import org.slf4j.LoggerFactory -/** - * Utility class for methods to query a relational db. - */ +/** Utility class for methods to query a relational db. */ object RelationalDbQueryUtils { private val LOGGER: Logger = LoggerFactory.getLogger(RelationalDbQueryUtils::class.java) - fun getIdentifierWithQuoting(identifier: String, quoteString: String?): String { + fun getIdentifierWithQuoting(identifier: String, quoteString: String): String { // double-quoted values within a database name or column name should be wrapped with extra // quoteString - return if (identifier.startsWith(quoteString!!) && identifier.endsWith(quoteString)) { + return if (identifier.startsWith(quoteString) && identifier.endsWith(quoteString)) { quoteString + quoteString + identifier + quoteString + quoteString } else { quoteString + identifier + quoteString } } - fun enquoteIdentifierList(identifiers: List, quoteString: String?): String { + fun enquoteIdentifierList(identifiers: List, quoteString: String): String { val joiner = StringJoiner(",") for (identifier in identifiers) { joiner.add(getIdentifierWithQuoting(identifier, quoteString)) @@ -38,42 +36,53 @@ object RelationalDbQueryUtils { return joiner.toString() } - /** - * @return fully qualified table name with the schema (if a schema exists) in quotes. - */ - fun getFullyQualifiedTableNameWithQuoting(nameSpace: String?, tableName: String, quoteString: String?): String { - return (if (nameSpace == null || nameSpace.isEmpty()) getIdentifierWithQuoting(tableName, quoteString) - else getIdentifierWithQuoting(nameSpace, quoteString) + "." + getIdentifierWithQuoting(tableName, quoteString)) + /** @return fully qualified table name with the schema (if a schema exists) in quotes. */ + fun getFullyQualifiedTableNameWithQuoting( + nameSpace: String?, + tableName: String, + quoteString: String + ): String { + return (if (nameSpace == null || nameSpace.isEmpty()) + getIdentifierWithQuoting(tableName, quoteString) + else + getIdentifierWithQuoting(nameSpace, quoteString) + + "." + + getIdentifierWithQuoting(tableName, quoteString)) } - /** - * @return fully qualified table name with the schema (if a schema exists) without quotes. - */ + /** @return fully qualified table name with the schema (if a schema exists) without quotes. */ fun getFullyQualifiedTableName(schemaName: String?, tableName: String): String { return if (schemaName != null) "$schemaName.$tableName" else tableName } - /** - * @return the input identifier with quotes. - */ + /** @return the input identifier with quotes. */ fun enquoteIdentifier(identifier: String?, quoteString: String?): String { return quoteString + identifier + quoteString } - fun queryTable(database: Database, - sqlQuery: String?, - tableName: String?, - schemaName: String?): AutoCloseableIterator { - val airbyteStreamNameNamespacePair = AirbyteStreamUtils.convertFromNameAndNamespace(tableName, schemaName) - return AutoCloseableIterators.lazyIterator({ - try { - LOGGER.info("Queueing query: {}", sqlQuery) - val stream = database!!.unsafeQuery(sqlQuery) - return@lazyIterator AutoCloseableIterators.fromStream(stream, airbyteStreamNameNamespacePair) - } catch (e: Exception) { - throw RuntimeException(e) - } - }, airbyteStreamNameNamespacePair) + fun queryTable( + database: Database, + sqlQuery: String?, + tableName: String?, + schemaName: String? + ): AutoCloseableIterator { + val airbyteStreamNameNamespacePair = + AirbyteStreamUtils.convertFromNameAndNamespace(tableName, schemaName) + return AutoCloseableIterators.lazyIterator( + { + try { + LOGGER.info("Queueing query: {}", sqlQuery) + val stream = database!!.unsafeQuery(sqlQuery) + return@lazyIterator AutoCloseableIterators.fromStream( + stream, + airbyteStreamNameNamespacePair + ) + } catch (e: Exception) { + throw RuntimeException(e) + } + }, + airbyteStreamNameNamespacePair + ) } fun logStreamSyncStatus(streams: List, syncType: String?) { @@ -86,7 +95,12 @@ object RelationalDbQueryUtils { } fun prettyPrintConfiguredAirbyteStreamList(streamList: List): String { - return streamList.stream().map { s: ConfiguredAirbyteStream -> "%s.%s".formatted(s.stream.namespace, s.stream.name) }.collect(Collectors.joining(", ")) + return streamList + .stream() + .map { s: ConfiguredAirbyteStream -> + "%s.%s".formatted(s.stream.namespace, s.stream.name) + } + .collect(Collectors.joining(", ")) } class TableSizeInfo(tableSize: Long, avgRowLength: Long) { @@ -94,15 +108,6 @@ object RelationalDbQueryUtils { val avgRowLength: Long init { - this.streamName = streamName - this.primaryKey = primaryKey - this.keySequence = keySequence - this.syncCheckpointRecords = syncCheckpointRecords - this.syncCheckpointDuration = syncCheckpointDuration - this.tableName = tableName - this.cursorColumnName = cursorColumnName - this.cursorSqlType = cursorSqlType - this.cause = cause this.tableSize = tableSize this.avgRowLength = avgRowLength } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/RelationalDbReadUtil.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/RelationalDbReadUtil.kt index f5eac21e370ab..492878a63851b 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/RelationalDbReadUtil.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/RelationalDbReadUtil.kt @@ -12,31 +12,54 @@ import io.airbyte.protocol.models.v0.SyncMode import java.util.stream.Collectors object RelationalDbReadUtil { - fun identifyStreamsToSnapshot(catalog: ConfiguredAirbyteCatalog, - alreadySyncedStreams: Set?): List { + fun identifyStreamsToSnapshot( + catalog: ConfiguredAirbyteCatalog, + alreadySyncedStreams: Set + ): List { val allStreams = AirbyteStreamNameNamespacePair.fromConfiguredCatalog(catalog) - val newlyAddedStreams: Set = HashSet(Sets.difference(allStreams, alreadySyncedStreams)) - return catalog.streams.stream() - .filter { c: ConfiguredAirbyteStream -> c.syncMode == SyncMode.INCREMENTAL } - .filter { stream: ConfiguredAirbyteStream -> newlyAddedStreams.contains(AirbyteStreamNameNamespacePair.fromAirbyteStream(stream.stream)) } - .map { `object`: ConfiguredAirbyteStream? -> Jsons.clone(`object`) } - .collect(Collectors.toList()) + val newlyAddedStreams: Set = + HashSet(Sets.difference(allStreams, alreadySyncedStreams)) + return catalog.streams + .stream() + .filter { c: ConfiguredAirbyteStream -> c.syncMode == SyncMode.INCREMENTAL } + .filter { stream: ConfiguredAirbyteStream -> + newlyAddedStreams.contains( + AirbyteStreamNameNamespacePair.fromAirbyteStream(stream.stream) + ) + } + .map { `object`: ConfiguredAirbyteStream -> Jsons.clone(`object`) } + .collect(Collectors.toList()) } - fun identifyStreamsForCursorBased(catalog: ConfiguredAirbyteCatalog, - streamsForInitialLoad: List): List { + fun identifyStreamsForCursorBased( + catalog: ConfiguredAirbyteCatalog, + streamsForInitialLoad: List + ): List { val initialLoadStreamsNamespacePairs = - streamsForInitialLoad.stream().map { stream: ConfiguredAirbyteStream -> AirbyteStreamNameNamespacePair.fromAirbyteStream(stream.stream) } - .collect( - Collectors.toSet()) - return catalog.streams.stream() - .filter { c: ConfiguredAirbyteStream -> c.syncMode == SyncMode.INCREMENTAL } - .filter { stream: ConfiguredAirbyteStream -> !initialLoadStreamsNamespacePairs.contains(AirbyteStreamNameNamespacePair.fromAirbyteStream(stream.stream)) } - .map { `object`: ConfiguredAirbyteStream? -> Jsons.clone(`object`) } - .collect(Collectors.toList()) + streamsForInitialLoad + .stream() + .map { stream: ConfiguredAirbyteStream -> + AirbyteStreamNameNamespacePair.fromAirbyteStream(stream.stream) + } + .collect(Collectors.toSet()) + return catalog.streams + .stream() + .filter { c: ConfiguredAirbyteStream -> c.syncMode == SyncMode.INCREMENTAL } + .filter { stream: ConfiguredAirbyteStream -> + !initialLoadStreamsNamespacePairs.contains( + AirbyteStreamNameNamespacePair.fromAirbyteStream(stream.stream) + ) + } + .map { `object`: ConfiguredAirbyteStream -> Jsons.clone(`object`) } + .collect(Collectors.toList()) } - fun convertNameNamespacePairFromV0(v1NameNamespacePair: io.airbyte.protocol.models.AirbyteStreamNameNamespacePair): AirbyteStreamNameNamespacePair { - return AirbyteStreamNameNamespacePair(v1NameNamespacePair.name, v1NameNamespacePair.namespace) + fun convertNameNamespacePairFromV0( + v1NameNamespacePair: io.airbyte.protocol.models.AirbyteStreamNameNamespacePair + ): AirbyteStreamNameNamespacePair { + return AirbyteStreamNameNamespacePair( + v1NameNamespacePair.name, + v1NameNamespacePair.namespace + ) } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/StateDecoratingIterator.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/StateDecoratingIterator.kt index 7a0db7d92413b..7d7bc4498cded 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/StateDecoratingIterator.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/StateDecoratingIterator.kt @@ -10,42 +10,42 @@ import io.airbyte.protocol.models.JsonSchemaPrimitiveUtil import io.airbyte.protocol.models.v0.AirbyteMessage import io.airbyte.protocol.models.v0.AirbyteStateStats import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair +import java.util.* import org.slf4j.Logger import org.slf4j.LoggerFactory -import java.util.* @Deprecated("") -class StateDecoratingIterator(private val messageIterator: Iterator, - private val stateManager: StateManager, - private val pair: AirbyteStreamNameNamespacePair, - private val cursorField: String, - private val initialCursor: String, - private val cursorType: JsonSchemaPrimitiveUtil.JsonSchemaPrimitive, - stateEmissionFrequency: Int) : AbstractIterator(), MutableIterator { +class StateDecoratingIterator( + private val messageIterator: Iterator, + private val stateManager: StateManager, + private val pair: AirbyteStreamNameNamespacePair, + private val cursorField: String, + private val initialCursor: String, + private val cursorType: JsonSchemaPrimitiveUtil.JsonSchemaPrimitive, + stateEmissionFrequency: Int +) : AbstractIterator(), MutableIterator { private var currentMaxCursor: String? private var currentMaxCursorRecordCount = 0L private var hasEmittedFinalState = false /** - * These parameters are for intermediate state message emission. We can emit an intermediate state - * when the following two conditions are met. - * + * These parameters are for intermediate state message emission. We can emit an intermediate + * state when the following two conditions are met. * * 1. The records are sorted by the cursor field. This is true when `stateEmissionFrequency` > - * 0. This logic is guaranteed in `AbstractJdbcSource#queryTableIncremental`, in which an - * "ORDER BY" clause is appended to the SQL query if `stateEmissionFrequency` > 0. - * - * - * 2. There is a cursor value that is ready for emission. A cursor value is "ready" if there is no - * more record with the same value. We cannot emit a cursor at will, because there may be multiple - * records with the same cursor value. If we emit a cursor ignoring this condition, should the sync - * fail right after the emission, the next sync may skip some records with the same cursor value due - * to "WHERE cursor_field > cursor" in `AbstractJdbcSource#queryTableIncremental`. + * 0. This logic is guaranteed in `AbstractJdbcSource#queryTableIncremental`, in which an "ORDER + * BY" clause is appended to the SQL query if `stateEmissionFrequency` > 0. * + * 2. There is a cursor value that is ready for emission. A cursor value is "ready" if there is + * no more record with the same value. We cannot emit a cursor at will, because there may be + * multiple records with the same cursor value. If we emit a cursor ignoring this condition, + * should the sync fail right after the emission, the next sync may skip some records with the + * same cursor value due to "WHERE cursor_field > cursor" in + * `AbstractJdbcSource#queryTableIncremental`. * - * The `intermediateStateMessage` is set to the latest state message that is ready for - * emission. For every `stateEmissionFrequency` messages, `emitIntermediateState` is set - * to true and the latest "ready" state will be emitted in the next `computeNext` call. + * The `intermediateStateMessage` is set to the latest state message that is ready for emission. + * For every `stateEmissionFrequency` messages, `emitIntermediateState` is set to true and the + * latest "ready" state will be emitted in the next `computeNext` call. */ private val stateEmissionFrequency: Int private var totalRecordCount = 0 @@ -61,12 +61,12 @@ class StateDecoratingIterator(private val messageIterator: Iterator 0. + * @param cursorType ENUM type of primitive values that can be used as a cursor for + * checkpointing + * @param stateEmissionFrequency If larger than 0, the records are sorted by the cursor field, + * and intermediate states will be emitted for every `stateEmissionFrequency` records. The order + * of the records is guaranteed in `AbstractJdbcSource#queryTableIncremental`, in which an + * "ORDER BY" clause is appended to the SQL query if `stateEmissionFrequency` > 0. */ init { this.currentMaxCursor = initialCursor @@ -86,15 +86,12 @@ class StateDecoratingIterator(private val messageIterator: Iterator 0 && currentMaxCursor != initialCursor && messageIterator.hasNext()) { - // Only create an intermediate state when it is not the first or last record message. + // Update the current max cursor only when current max cursor < cursor + // candidate from the message + if ( + stateEmissionFrequency > 0 && + currentMaxCursor != initialCursor && + messageIterator.hasNext() + ) { + // Only create an intermediate state when it is not the first or last + // record message. // The last state message will be processed seperately. - intermediateStateMessage = createStateMessage(false, recordCountInStateMessage) + intermediateStateMessage = + createStateMessage(false, recordCountInStateMessage) } currentMaxCursor = cursorCandidate currentMaxCursorRecordCount = 1L } else if (cursorComparison == 0) { currentMaxCursorRecordCount++ } else if (cursorComparison > 0 && stateEmissionFrequency > 0) { - LOGGER.warn("Intermediate state emission feature requires records to be processed in order according to the cursor value. Otherwise, " - + "data loss can occur.") + LOGGER.warn( + "Intermediate state emission feature requires records to be processed in order according to the cursor value. Otherwise, " + + "data loss can occur." + ) } } @@ -158,24 +167,23 @@ class StateDecoratingIterator(private val messageIterator: Iterator + protected val intermediateMessage: Optional /** - * Returns AirbyteStateMessage when in a ready state, a ready state means that it has satifies the - * conditions of: - * + * Returns AirbyteStateMessage when in a ready state, a ready state means that it has + * satifies the conditions of: * * cursorField has changed (e.g. 08-22-2022 -> 08-23-2022) and there have been at least * stateEmissionFrequency number of records since the last emission * - * - * @return AirbyteStateMessage if one exists, otherwise Optional indicating state was not ready to - * be emitted + * @return AirbyteStateMessage if one exists, otherwise Optional indicating state was not + * ready to be emitted */ get() { - if (emitIntermediateState && intermediateStateMessage != null) { - val message: AirbyteMessage = intermediateStateMessage + val message: AirbyteMessage? = intermediateStateMessage + if (emitIntermediateState && message != null) { if (message.state != null) { - message.state.sourceStats = AirbyteStateStats().withRecordCount(recordCountInStateMessage.toDouble()) + message.state.sourceStats = + AirbyteStateStats().withRecordCount(recordCountInStateMessage.toDouble()) } intermediateStateMessage = null @@ -195,26 +203,32 @@ class StateDecoratingIterator(private val messageIterator: Iterator latest: {} = {} (count {})", - pair, - cursorInfo!!.map { obj: CursorInfo? -> obj.getOriginalCursorField() }.orElse(null), - cursorInfo.map { obj: CursorInfo? -> obj.getOriginalCursor() }.orElse(null), - cursorInfo.map { obj: CursorInfo? -> obj.getOriginalCursorRecordCount() }.orElse(null), - cursorInfo.map { obj: CursorInfo? -> obj.getCursorField() }.orElse(null), - cursorInfo.map { obj: CursorInfo? -> obj.getCursor() }.orElse(null), - cursorInfo.map { obj: CursorInfo? -> obj.getCursorRecordCount() }.orElse(null)) + LOGGER.info( + "State report for stream {} - original: {} = {} (count {}) -> latest: {} = {} (count {})", + pair, + cursorInfo.map { obj: CursorInfo -> obj.originalCursorField }.orElse(null), + cursorInfo.map { obj: CursorInfo -> obj.originalCursor }.orElse(null), + cursorInfo.map { obj: CursorInfo -> obj.originalCursorRecordCount }.orElse(null), + cursorInfo.map { obj: CursorInfo -> obj.cursorField }.orElse(null), + cursorInfo.map { obj: CursorInfo -> obj.cursor }.orElse(null), + cursorInfo.map { obj: CursorInfo -> obj.cursorRecordCount }.orElse(null) + ) } stateMessage?.withSourceStats(AirbyteStateStats().withRecordCount(recordCount.toDouble())) if (isFinalState) { hasEmittedFinalState = true if (stateManager.getCursor(pair).isEmpty) { - LOGGER.warn("Cursor for stream {} was null. This stream will replicate all records on the next run", pair) + LOGGER.warn( + "Cursor for stream {} was null. This stream will replicate all records on the next run", + pair + ) } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/TableInfo.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/TableInfo.kt index 8d900fd477b21..46ebe3bd96d86 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/TableInfo.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/TableInfo.kt @@ -3,18 +3,11 @@ */ package io.airbyte.cdk.integrations.source.relationaldb -import lombok.Builder -import lombok.Getter - -/** - * This class encapsulates all externally relevant Table information. - */ -@Getter -@Builder -class TableInfo { - private val nameSpace: String? = null - private val name: String? = null - private val fields: List? = null - private val primaryKeys: List? = null - private val cursorFields: List? = null -} +/** This class encapsulates all externally relevant Table information. */ +data class TableInfo( + val nameSpace: String, + val name: String, + val fields: List, + val primaryKeys: List = emptyList(), + val cursorFields: List +) diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/AbstractStateManager.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/AbstractStateManager.kt index 627695d2d8f00..935f8c6d008de 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/AbstractStateManager.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/AbstractStateManager.kt @@ -12,28 +12,43 @@ import java.util.function.Function import java.util.function.Supplier /** - * Abstract implementation of the [StateManager] interface that provides common functionality - * for state manager implementations. + * Abstract implementation of the [StateManager] interface that provides common functionality for + * state manager implementations. * * @param The type associated with the state object managed by this manager. * @param The type associated with the state object stored in the state managed by this manager. - */ -abstract class AbstractStateManager @JvmOverloads constructor(catalog: ConfiguredAirbyteCatalog?, - streamSupplier: Supplier>, - cursorFunction: Function?, - cursorFieldFunction: Function?>?, - cursorRecordCountFunction: Function?, - namespacePairFunction: Function?, - onlyIncludeIncrementalStreams: Boolean = false) : StateManager { + * + */ +abstract class AbstractStateManager +@JvmOverloads +constructor( + catalog: ConfiguredAirbyteCatalog, + streamSupplier: Supplier>, + cursorFunction: Function?, + cursorFieldFunction: Function>?, + cursorRecordCountFunction: Function?, + namespacePairFunction: Function?, + onlyIncludeIncrementalStreams: Boolean = false +) : StateManager { /** - * The [CursorManager] responsible for keeping track of the current cursor value for each - * stream managed by this state manager. + * The [CursorManager] responsible for keeping track of the current cursor value for each stream + * managed by this state manager. */ - private val cursorManager: CursorManager<*> = CursorManager(catalog, streamSupplier, cursorFunction, cursorFieldFunction, cursorRecordCountFunction, namespacePairFunction, - onlyIncludeIncrementalStreams) + private val cursorManager: CursorManager<*> = + CursorManager( + catalog, + streamSupplier, + cursorFunction, + cursorFieldFunction, + cursorRecordCountFunction, + namespacePairFunction, + onlyIncludeIncrementalStreams + ) - override val pairToCursorInfoMap: Map? + override val pairToCursorInfoMap: Map get() = cursorManager.pairToCursorInfo - abstract override fun toState(pair: Optional): AirbyteStateMessage? + abstract override fun toState( + pair: Optional + ): AirbyteStateMessage } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/CursorManager.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/CursorManager.kt index c17be70b9f4c8..657e9437c603d 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/CursorManager.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/CursorManager.kt @@ -6,31 +6,33 @@ package io.airbyte.cdk.integrations.source.relationaldb.state import com.google.common.annotations.VisibleForTesting import io.airbyte.cdk.integrations.source.relationaldb.CursorInfo import io.airbyte.protocol.models.v0.* -import org.slf4j.Logger -import org.slf4j.LoggerFactory import java.util.* import java.util.concurrent.* import java.util.function.Function import java.util.function.Supplier import java.util.stream.Collectors +import org.slf4j.Logger +import org.slf4j.LoggerFactory /** * Manages the map of streams to current cursor values for state management. * * @param The type that represents the stream object which holds the current cursor information - * in the state. - */ -class CursorManager(catalog: ConfiguredAirbyteCatalog?, - streamSupplier: Supplier>, - cursorFunction: Function?, - cursorFieldFunction: Function?>?, - cursorRecordCountFunction: Function?, - namespacePairFunction: Function?, - onlyIncludeIncrementalStreams: Boolean) { + * in the state. + */ +class CursorManager( + catalog: ConfiguredAirbyteCatalog, + streamSupplier: Supplier>, + cursorFunction: Function?, + cursorFieldFunction: Function>?, + cursorRecordCountFunction: Function?, + namespacePairFunction: Function?, + onlyIncludeIncrementalStreams: Boolean +) { /** * Map of streams (name/namespace tuple) to the current cursor information stored in the state. */ - private val pairToCursorInfo: Map + val pairToCursorInfo: Map /** * Constructs a new [CursorManager] based on the configured connector and current state @@ -45,19 +47,25 @@ class CursorManager(catalog: ConfiguredAirbyteCatalog?, * stored in the connector's state. * @param cursorRecordCountFunction A [Function] that extracts the cursor record count for a * stream stored in the connector's state. - * @param namespacePairFunction A [Function] that generates a - * [AirbyteStreamNameNamespacePair] that identifies each stream in the connector's - * state. + * @param namespacePairFunction A [Function] that generates a [AirbyteStreamNameNamespacePair] + * that identifies each stream in the connector's state. */ init { - pairToCursorInfo = createCursorInfoMap( - catalog, streamSupplier, cursorFunction, cursorFieldFunction, cursorRecordCountFunction, namespacePairFunction, - onlyIncludeIncrementalStreams) + pairToCursorInfo = + createCursorInfoMap( + catalog, + streamSupplier, + cursorFunction, + cursorFieldFunction, + cursorRecordCountFunction, + namespacePairFunction, + onlyIncludeIncrementalStreams + ) } /** - * Creates the cursor information map that associates stream name/namespace tuples with the current - * cursor information for that stream as stored in the connector's state. + * Creates the cursor information map that associates stream name/namespace tuples with the + * current cursor information for that stream as stored in the connector's state. * * @param catalog The connector's configured catalog. * @param streamSupplier A [Supplier] that provides the cursor manager with the collection of @@ -68,21 +76,22 @@ class CursorManager(catalog: ConfiguredAirbyteCatalog?, * stored in the connector's state. * @param cursorRecordCountFunction A [Function] that extracts the cursor record count for a * stream stored in the connector's state. - * @param namespacePairFunction A [Function] that generates a - * [AirbyteStreamNameNamespacePair] that identifies each stream in the connector's - * state. + * @param namespacePairFunction A [Function] that generates a [AirbyteStreamNameNamespacePair] + * that identifies each stream in the connector's state. * @return A map of streams to current cursor information for the stream. */ @VisibleForTesting protected fun createCursorInfoMap( - catalog: ConfiguredAirbyteCatalog?, - streamSupplier: Supplier>, - cursorFunction: Function?, - cursorFieldFunction: Function?>?, - cursorRecordCountFunction: Function?, - namespacePairFunction: Function?, - onlyIncludeIncrementalStreams: Boolean): Map { - val allStreamNames = catalog!!.streams + catalog: ConfiguredAirbyteCatalog, + streamSupplier: Supplier>, + cursorFunction: Function?, + cursorFieldFunction: Function>?, + cursorRecordCountFunction: Function?, + namespacePairFunction: Function?, + onlyIncludeIncrementalStreams: Boolean + ): Map { + val allStreamNames = + catalog.streams .stream() .filter { c: ConfiguredAirbyteStream -> if (onlyIncludeIncrementalStreams) { @@ -91,24 +100,52 @@ class CursorManager(catalog: ConfiguredAirbyteCatalog?, true } .map { obj: ConfiguredAirbyteStream -> obj.stream } - .map { stream: AirbyteStream? -> AirbyteStreamNameNamespacePair.fromAirbyteStream(stream) } + .map { stream: AirbyteStream? -> + AirbyteStreamNameNamespacePair.fromAirbyteStream(stream) + } + .collect(Collectors.toSet()) + allStreamNames.addAll( + streamSupplier + .get() + .stream() + .map(namespacePairFunction) + .filter { obj: AirbyteStreamNameNamespacePair? -> Objects.nonNull(obj) } .collect(Collectors.toSet()) - allStreamNames.addAll(streamSupplier.get().stream().map(namespacePairFunction).filter { obj: AirbyteStreamNameNamespacePair? -> Objects.nonNull(obj) }.collect(Collectors.toSet())) + ) - val localMap: MutableMap = ConcurrentHashMap() - val pairToState = streamSupplier.get() + val localMap: MutableMap = ConcurrentHashMap() + val pairToState = + streamSupplier + .get() .stream() .collect(Collectors.toMap(namespacePairFunction, Function.identity())) - val pairToConfiguredAirbyteStream = catalog.streams.stream() - .collect(Collectors.toMap(Function { stream: ConfiguredAirbyteStream? -> AirbyteStreamNameNamespacePair.fromConfiguredAirbyteSteam(stream) }, Function.identity())) + val pairToConfiguredAirbyteStream = + catalog.streams + .stream() + .collect( + Collectors.toMap( + Function { stream: ConfiguredAirbyteStream? -> + AirbyteStreamNameNamespacePair.fromConfiguredAirbyteSteam(stream) + }, + Function.identity() + ) + ) for (pair in allStreamNames) { val stateOptional: Optional = Optional.ofNullable(pairToState[pair]) val streamOptional = Optional.ofNullable(pairToConfiguredAirbyteStream[pair]) - localMap[pair] = createCursorInfoForStream(pair, stateOptional, streamOptional, cursorFunction, cursorFieldFunction, cursorRecordCountFunction) + localMap[pair] = + createCursorInfoForStream( + pair, + stateOptional, + streamOptional, + cursorFunction, + cursorFieldFunction, + cursorRecordCountFunction + ) } - return localMap + return localMap.toMap() } /** @@ -118,27 +155,31 @@ class CursorManager(catalog: ConfiguredAirbyteCatalog?, * @param pair A [AirbyteStreamNameNamespacePair] that identifies a specific stream managed by * the connector. * @param stateOptional [Optional] containing the current state associated with the stream. - * @param streamOptional [Optional] containing the [ConfiguredAirbyteStream] associated + * @param streamOptional [Optional] containing the [ConfiguredAirbyteStream] associated with the + * stream. + * @param cursorFunction A [Function] that provides the current cursor from the state associated * with the stream. - * @param cursorFunction A [Function] that provides the current cursor from the state - * associated with the stream. * @param cursorFieldFunction A [Function] that provides the cursor field name for the cursor * stored in the state associated with the stream. * @param cursorRecordCountFunction A [Function] that extracts the cursor record count for a * stream stored in the connector's state. - * @return A [CursorInfo] object based on the data currently stored in the connector's state - * for the given stream. + * @return A [CursorInfo] object based on the data currently stored in the connector's state for + * the given stream. */ - @VisibleForTesting - protected fun createCursorInfoForStream(pair: AirbyteStreamNameNamespacePair?, - stateOptional: Optional, - streamOptional: Optional, - cursorFunction: Function?, - cursorFieldFunction: Function?>?, - cursorRecordCountFunction: Function?): CursorInfo { - val originalCursorField = stateOptional + internal fun createCursorInfoForStream( + pair: AirbyteStreamNameNamespacePair?, + stateOptional: Optional, + streamOptional: Optional, + cursorFunction: Function?, + cursorFieldFunction: Function>?, + cursorRecordCountFunction: Function? + ): CursorInfo { + val originalCursorField = + stateOptional .map(cursorFieldFunction) - .flatMap { f: List? -> if (f!!.size > 0) Optional.of(f[0]!!) else Optional.empty() } + .flatMap { f: List -> + if (f.isNotEmpty()) Optional.of(f[0]) else Optional.empty() + } .orElse(null) val originalCursor = stateOptional.map(cursorFunction).orElse(null) val originalCursorRecordCount = stateOptional.map(cursorRecordCountFunction).orElse(0L) @@ -148,75 +189,101 @@ class CursorManager(catalog: ConfiguredAirbyteCatalog?, val cursorRecordCount: Long // if cursor field is set in catalog. - if (streamOptional.map> { obj: ConfiguredAirbyteStream -> obj.cursorField }.isPresent) { - cursorField = streamOptional + if ( + streamOptional + .map> { obj: ConfiguredAirbyteStream -> obj.cursorField } + .isPresent + ) { + cursorField = + streamOptional .map { obj: ConfiguredAirbyteStream -> obj.cursorField } - .flatMap { f: List -> if (f.size > 0) Optional.of(f[0]) else Optional.empty() } + .flatMap { f: List -> + if (f.size > 0) Optional.of(f[0]) else Optional.empty() + } .orElse(null) // if cursor field is set in state. if (stateOptional.map?>(cursorFieldFunction).isPresent) { // if cursor field in catalog and state are the same. - if (stateOptional.map?>(cursorFieldFunction) == streamOptional.map> { obj: ConfiguredAirbyteStream -> obj.cursorField }) { + if ( + stateOptional.map?>(cursorFieldFunction) == + streamOptional.map> { obj: ConfiguredAirbyteStream -> + obj.cursorField + } + ) { cursor = stateOptional.map(cursorFunction).orElse(null) cursorRecordCount = stateOptional.map(cursorRecordCountFunction).orElse(0L) - // If a matching cursor is found in the state, and it's value is null - this indicates a CDC stream + // If a matching cursor is found in the state, and it's value is null - this + // indicates a CDC stream // and we shouldn't log anything. if (cursor != null) { - LOGGER.info("Found matching cursor in state. Stream: {}. Cursor Field: {} Value: {} Count: {}", - pair, cursorField, cursor, cursorRecordCount) + LOGGER.info( + "Found matching cursor in state. Stream: {}. Cursor Field: {} Value: {} Count: {}", + pair, + cursorField, + cursor, + cursorRecordCount + ) } // if cursor field in catalog and state are different. } else { cursor = null cursorRecordCount = 0L LOGGER.info( - "Found cursor field. Does not match previous cursor field. Stream: {}. Original Cursor Field: {} (count {}). New Cursor Field: {}. Resetting cursor value.", - pair, originalCursorField, originalCursorRecordCount, cursorField) + "Found cursor field. Does not match previous cursor field. Stream: {}. Original Cursor Field: {} (count {}). New Cursor Field: {}. Resetting cursor value.", + pair, + originalCursorField, + originalCursorRecordCount, + cursorField + ) } // if cursor field is not set in state but is set in catalog. } else { - LOGGER.info("No cursor field set in catalog but not present in state. Stream: {}, New Cursor Field: {}. Resetting cursor value", pair, - cursorField) + LOGGER.info( + "No cursor field set in catalog but not present in state. Stream: {}, New Cursor Field: {}. Resetting cursor value", + pair, + cursorField + ) cursor = null cursorRecordCount = 0L } // if cursor field is not set in catalog. } else { LOGGER.info( - "Cursor field set in state but not present in catalog. Stream: {}. Original Cursor Field: {}. Original value: {}. Resetting cursor.", - pair, originalCursorField, originalCursor) + "Cursor field set in state but not present in catalog. Stream: {}. Original Cursor Field: {}. Original value: {}. Resetting cursor.", + pair, + originalCursorField, + originalCursor + ) cursorField = null cursor = null cursorRecordCount = 0L } - return CursorInfo(originalCursorField, originalCursor, originalCursorRecordCount, cursorField, cursor, cursorRecordCount) - } - - /** - * Retrieves a copy of the stream name/namespace tuple to current cursor information map. - * - * @return A copy of the stream name/namespace tuple to current cursor information map. - */ - fun getPairToCursorInfo(): Map { - return java.util.Map.copyOf(pairToCursorInfo) + return CursorInfo( + originalCursorField, + originalCursor, + originalCursorRecordCount, + cursorField, + cursor, + cursorRecordCount + ) } /** - * Retrieves an [Optional] possibly containing the current [CursorInfo] associated with - * the provided stream name/namespace tuple. + * Retrieves an [Optional] possibly containing the current [CursorInfo] associated with the + * provided stream name/namespace tuple. * * @param pair The [AirbyteStreamNameNamespacePair] which identifies a stream. - * @return An [Optional] possibly containing the current [CursorInfo] associated with - * the provided stream name/namespace tuple. + * @return An [Optional] possibly containing the current [CursorInfo] associated with the + * provided stream name/namespace tuple. */ fun getCursorInfo(pair: AirbyteStreamNameNamespacePair?): Optional { return Optional.ofNullable(pairToCursorInfo[pair]) } /** - * Retrieves an [Optional] possibly containing the cursor field name associated with the - * cursor tracked in the state associated with the provided stream name/namespace tuple. + * Retrieves an [Optional] possibly containing the cursor field name associated with the cursor + * tracked in the state associated with the provided stream name/namespace tuple. * * @param pair The [AirbyteStreamNameNamespacePair] which identifies a stream. * @return An [Optional] possibly containing the cursor field name associated with the cursor @@ -227,8 +294,8 @@ class CursorManager(catalog: ConfiguredAirbyteCatalog?, } /** - * Retrieves an [Optional] possibly containing the cursor value tracked in the state - * associated with the provided stream name/namespace tuple. + * Retrieves an [Optional] possibly containing the cursor value tracked in the state associated + * with the provided stream name/namespace tuple. * * @param pair The [AirbyteStreamNameNamespacePair] which identifies a stream. * @return An [Optional] possibly containing the cursor value tracked in the state associated diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/CursorStateMessageProducer.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/CursorStateMessageProducer.kt index 2955110bacf8b..d52305abb105d 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/CursorStateMessageProducer.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/CursorStateMessageProducer.kt @@ -10,13 +10,15 @@ import io.airbyte.protocol.models.v0.AirbyteMessage import io.airbyte.protocol.models.v0.AirbyteStateMessage import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream +import java.util.* import org.slf4j.Logger import org.slf4j.LoggerFactory -import java.util.* -class CursorStateMessageProducer(private val stateManager: StateManager?, - private val initialCursor: Optional) : SourceStateMessageProducer { - private var currentMaxCursor: Optional +class CursorStateMessageProducer( + private val stateManager: StateManager?, + private val initialCursor: Optional +) : SourceStateMessageProducer { + private var currentMaxCursor: Optional // We keep this field to mark `cursor_record_count` and also to control logging frequency. private var currentCursorRecordCount = 0 @@ -28,33 +30,43 @@ class CursorStateMessageProducer(private val stateManager: StateManager?, this.currentMaxCursor = initialCursor } - override fun generateStateMessageAtCheckpoint(stream: ConfiguredAirbyteStream?): AirbyteStateMessage? { - // At this stage intermediate state message should never be null; otherwise it would have been + override fun generateStateMessageAtCheckpoint( + stream: ConfiguredAirbyteStream? + ): AirbyteStateMessage? { + // At this stage intermediate state message should never be null; otherwise it would have + // been // blocked by shouldEmitStateMessage check. val message = intermediateStateMessage intermediateStateMessage = null if (cursorOutOfOrderDetected) { - LOGGER.warn("Intermediate state emission feature requires records to be processed in order according to the cursor value. Otherwise, " - + "data loss can occur.") + LOGGER.warn( + "Intermediate state emission feature requires records to be processed in order according to the cursor value. Otherwise, " + + "data loss can occur." + ) } return message } /** - * Note: We do not try to catch exception here. If error/exception happens, we should fail the sync, - * and since we have saved state message before, we should be able to resume it in next sync if we - * have fixed the underlying issue, of if the issue is transient. + * Note: We do not try to catch exception here. If error/exception happens, we should fail the + * sync, and since we have saved state message before, we should be able to resume it in next + * sync if we have fixed the underlying issue, of if the issue is transient. */ - override fun processRecordMessage(stream: ConfiguredAirbyteStream, message: AirbyteMessage): AirbyteMessage { + override fun processRecordMessage( + stream: ConfiguredAirbyteStream, + message: AirbyteMessage + ): AirbyteMessage { val cursorField = getCursorField(stream) if (message.record.data.hasNonNull(cursorField)) { val cursorCandidate = getCursorCandidate(cursorField, message) - val cursorType = getCursorType(stream, - cursorField) - val cursorComparison = compareCursors(currentMaxCursor.orElse(null), cursorCandidate, cursorType) + val cursorType = getCursorType(stream, cursorField) + val cursorComparison = + compareCursors(currentMaxCursor.orElse(null), cursorCandidate, cursorType) if (cursorComparison < 0) { - // Reset cursor but include current record message. This value will be used to create state message. - // Update the current max cursor only when current max cursor < cursor candidate from the message + // Reset cursor but include current record message. This value will be used to + // create state message. + // Update the current max cursor only when current max cursor < cursor candidate + // from the message if (currentMaxCursor != initialCursor) { // Only create an intermediate state when it is not the first record. intermediateStateMessage = createStateMessage(stream) @@ -75,9 +87,7 @@ class CursorStateMessageProducer(private val stateManager: StateManager?, return createStateMessage(stream) } - /** - * Only sends out state message when there is a state message to be sent out. - */ + /** Only sends out state message when there is a state message to be sent out. */ override fun shouldEmitStateMessage(stream: ConfiguredAirbyteStream?): Boolean { return intermediateStateMessage != null } @@ -90,8 +100,20 @@ class CursorStateMessageProducer(private val stateManager: StateManager?, */ private fun createStateMessage(stream: ConfiguredAirbyteStream): AirbyteStateMessage? { val pair = AirbyteStreamNameNamespacePair(stream.stream.name, stream.stream.namespace) - println("state message creation: " + pair + " " + currentMaxCursor.orElse(null) + " " + currentCursorRecordCount) - val stateMessage = stateManager!!.updateAndEmit(pair, currentMaxCursor.orElse(null), currentCursorRecordCount.toLong()) + println( + "state message creation: " + + pair + + " " + + currentMaxCursor.orElse(null) + + " " + + currentCursorRecordCount + ) + val stateMessage = + stateManager!!.updateAndEmit( + pair, + currentMaxCursor.orElse(null), + currentCursorRecordCount.toLong() + ) val cursorInfo = stateManager.getCursorInfo(pair) // logging once every 100 messages to reduce log verbosity diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/GlobalStateManager.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/GlobalStateManager.kt index aed51f4be4ff9..9329d6d665540 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/GlobalStateManager.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/GlobalStateManager.kt @@ -19,61 +19,76 @@ import java.util.stream.Collectors * This implementation generates a single, global state object for the state tracked by this * manager. */ -class GlobalStateManager(airbyteStateMessage: AirbyteStateMessage?, catalog: ConfiguredAirbyteCatalog?) : AbstractStateManager(catalog, +class GlobalStateManager( + airbyteStateMessage: AirbyteStateMessage, + catalog: ConfiguredAirbyteCatalog +) : + AbstractStateManager( + catalog, getStreamsSupplier(airbyteStateMessage), StateGeneratorUtils.CURSOR_FUNCTION, StateGeneratorUtils.CURSOR_FIELD_FUNCTION, StateGeneratorUtils.CURSOR_RECORD_COUNT_FUNCTION, StateGeneratorUtils.NAME_NAMESPACE_PAIR_FUNCTION, - true) { + true + ) { /** - * Legacy [CdcStateManager] used to manage state for connectors that support Change Data - * Capture (CDC). + * Legacy [CdcStateManager] used to manage state for connectors that support Change Data Capture + * (CDC). */ override val cdcStateManager: CdcStateManager /** - * Constructs a new [GlobalStateManager] that is seeded with the provided - * [AirbyteStateMessage]. + * Constructs a new [GlobalStateManager] that is seeded with the provided [AirbyteStateMessage]. * * @param airbyteStateMessage The initial state represented as an [AirbyteStateMessage]. * @param catalog The [ConfiguredAirbyteCatalog] for the connector associated with this state * manager. */ init { - this.cdcStateManager = CdcStateManager(extractCdcState(airbyteStateMessage), extractStreams(airbyteStateMessage), airbyteStateMessage) + this.cdcStateManager = + CdcStateManager( + extractCdcState(airbyteStateMessage), + extractStreams(airbyteStateMessage), + airbyteStateMessage + ) } override val rawStateMessages: List? get() { - throw UnsupportedOperationException("Raw state retrieval not supported by global state manager.") + throw UnsupportedOperationException( + "Raw state retrieval not supported by global state manager." + ) } - override fun toState(pair: Optional): AirbyteStateMessage? { + override fun toState(pair: Optional): AirbyteStateMessage { // Populate global state val globalState = AirbyteGlobalState() globalState.sharedState = Jsons.jsonNode(cdcStateManager.cdcState) globalState.streamStates = StateGeneratorUtils.generateStreamStateList(pairToCursorInfoMap) // Generate the legacy state for backwards compatibility - val dbState = StateGeneratorUtils.generateDbState(pairToCursorInfoMap) + val dbState = + StateGeneratorUtils.generateDbState(pairToCursorInfoMap) .withCdc(true) .withCdcState(cdcStateManager.cdcState) return AirbyteStateMessage() - .withType(AirbyteStateMessage.AirbyteStateType.GLOBAL) // Temporarily include legacy state for backwards compatibility with the platform - .withData(Jsons.jsonNode(dbState)) - .withGlobal(globalState) + .withType( + AirbyteStateMessage.AirbyteStateType.GLOBAL + ) // Temporarily include legacy state for backwards compatibility with the platform + .withData(Jsons.jsonNode(dbState)) + .withGlobal(globalState) } /** - * Extracts the Change Data Capture (CDC) state stored in the initial state provided to this state - * manager. + * Extracts the Change Data Capture (CDC) state stored in the initial state provided to this + * state manager. * - * @param airbyteStateMessage The [AirbyteStateMessage] that contains the initial state - * provided to the state manager. - * @return The [CdcState] stored in the state, if any. Note that this will not be `null` - * but may be empty. + * @param airbyteStateMessage The [AirbyteStateMessage] that contains the initial state provided + * to the state manager. + * @return The [CdcState] stored in the state, if any. Note that this will not be `null` but may + * be empty. */ private fun extractCdcState(airbyteStateMessage: AirbyteStateMessage?): CdcState? { if (airbyteStateMessage!!.type == AirbyteStateMessage.AirbyteStateType.GLOBAL) { @@ -84,24 +99,38 @@ class GlobalStateManager(airbyteStateMessage: AirbyteStateMessage?, catalog: Con } } - private fun extractStreams(airbyteStateMessage: AirbyteStateMessage?): Set { + private fun extractStreams( + airbyteStateMessage: AirbyteStateMessage? + ): Set { if (airbyteStateMessage!!.type == AirbyteStateMessage.AirbyteStateType.GLOBAL) { - return airbyteStateMessage.global.streamStates.stream() - .map { streamState: AirbyteStreamState -> - val cloned = Jsons.clone(streamState) - AirbyteStreamNameNamespacePair(cloned.streamDescriptor.name, cloned.streamDescriptor.namespace) - }.collect(Collectors.toSet()) + return airbyteStateMessage.global.streamStates + .stream() + .map { streamState: AirbyteStreamState -> + val cloned = Jsons.clone(streamState) + AirbyteStreamNameNamespacePair( + cloned.streamDescriptor.name, + cloned.streamDescriptor.namespace + ) + } + .collect(Collectors.toSet()) } else { val legacyState = Jsons.`object`(airbyteStateMessage.data, DbState::class.java) - return if (legacyState != null) extractNamespacePairsFromDbStreamState(legacyState.streams) else emptySet() + return if (legacyState != null) + extractNamespacePairsFromDbStreamState(legacyState.streams) + else emptySet() } } - private fun extractNamespacePairsFromDbStreamState(streams: List): Set { - return streams.stream().map { stream: DbStreamState -> - val cloned = Jsons.clone(stream) - AirbyteStreamNameNamespacePair(cloned.streamName, cloned.streamNamespace) - }.collect(Collectors.toSet()) + private fun extractNamespacePairsFromDbStreamState( + streams: List + ): Set { + return streams + .stream() + .map { stream: DbStreamState -> + val cloned = Jsons.clone(stream) + AirbyteStreamNameNamespacePair(cloned.streamName, cloned.streamNamespace) + } + .collect(Collectors.toSet()) } companion object { @@ -113,23 +142,34 @@ class GlobalStateManager(airbyteStateMessage: AirbyteStateMessage?, catalog: Con * the initial state. * @return A [Supplier] that will be used to fetch the streams present in the initial state. */ - private fun getStreamsSupplier(airbyteStateMessage: AirbyteStateMessage?): Supplier> { + private fun getStreamsSupplier( + airbyteStateMessage: AirbyteStateMessage? + ): Supplier> { /* - * If the incoming message has the state type set to GLOBAL, it is using the new format. Therefore, - * we can look for streams in the "global" field of the message. Otherwise, the message is still - * storing state in the legacy "data" field. - */ + * If the incoming message has the state type set to GLOBAL, it is using the new format. Therefore, + * we can look for streams in the "global" field of the message. Otherwise, the message is still + * storing state in the legacy "data" field. + */ return Supplier { if (airbyteStateMessage!!.type == AirbyteStateMessage.AirbyteStateType.GLOBAL) { return@Supplier airbyteStateMessage.global.streamStates } else if (airbyteStateMessage.data != null) { - return@Supplier Jsons.`object`(airbyteStateMessage.data, DbState::class.java).streams.stream() - .map { s: DbStreamState -> - AirbyteStreamState().withStreamState(Jsons.jsonNode(s)) - .withStreamDescriptor(StreamDescriptor().withNamespace(s.streamNamespace).withName(s.streamName)) - } - .collect( - Collectors.toList()) + return@Supplier Jsons.`object`( + airbyteStateMessage.data, + DbState::class.java + ) + .streams + .stream() + .map { s: DbStreamState -> + AirbyteStreamState() + .withStreamState(Jsons.jsonNode(s)) + .withStreamDescriptor( + StreamDescriptor() + .withNamespace(s.streamNamespace) + .withName(s.streamName) + ) + } + .collect(Collectors.toList()) } else { return@Supplier listOf() } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/LegacyStateManager.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/LegacyStateManager.kt index 4083cf9dd4d28..c379f25a9d1e2 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/LegacyStateManager.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/LegacyStateManager.kt @@ -10,71 +10,84 @@ import io.airbyte.commons.json.Jsons import io.airbyte.protocol.models.v0.AirbyteStateMessage import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog -import org.slf4j.Logger -import org.slf4j.LoggerFactory import java.util.* import java.util.function.Function import java.util.function.Supplier +import org.slf4j.Logger +import org.slf4j.LoggerFactory /** * Legacy implementation (pre-per-stream state support) of the [StateManager] interface. * - * This implementation assumes that the state matches the [DbState] object and effectively - * tracks state as global across the streams managed by a connector. - * + * This implementation assumes that the state matches the [DbState] object and effectively tracks + * state as global across the streams managed by a connector. */ -@Deprecated("""This manager may be removed in the future if/once all connectors support per-stream - state management.""") -class LegacyStateManager(dbState: DbState, catalog: ConfiguredAirbyteCatalog?) : AbstractStateManager(catalog, +@Deprecated( + """This manager may be removed in the future if/once all connectors support per-stream + state management.""" +) +class LegacyStateManager(dbState: DbState, catalog: ConfiguredAirbyteCatalog) : + AbstractStateManager( + catalog, Supplier { dbState.streams }, CURSOR_FUNCTION, CURSOR_FIELD_FUNCTION, CURSOR_RECORD_COUNT_FUNCTION, - NAME_NAMESPACE_PAIR_FUNCTION) { - /** - * Tracks whether the connector associated with this state manager supports CDC. - */ + NAME_NAMESPACE_PAIR_FUNCTION + ) { + /** Tracks whether the connector associated with this state manager supports CDC. */ private var isCdc: Boolean - /** - * [CdcStateManager] used to manage state for connectors that support CDC. - */ - override val cdcStateManager: CdcStateManager = CdcStateManager(dbState.cdcState, AirbyteStreamNameNamespacePair.fromConfiguredCatalog(catalog), null) + /** [CdcStateManager] used to manage state for connectors that support CDC. */ + override val cdcStateManager: CdcStateManager = + CdcStateManager( + dbState.cdcState, + AirbyteStreamNameNamespacePair.fromConfiguredCatalog(catalog), + null + ) /** - * Constructs a new [LegacyStateManager] that is seeded with the provided [DbState] - * instance. + * Constructs a new [LegacyStateManager] that is seeded with the provided [DbState] instance. * * @param dbState The initial state represented as an [DbState] instance. * @param catalog The [ConfiguredAirbyteCatalog] for the connector associated with this state * manager. */ init { - this.isCdc = dbState.cdc - if (dbState.cdc == null) { - this.isCdc = false - } + this.isCdc = dbState.cdc ?: false } override val rawStateMessages: List? get() { - throw UnsupportedOperationException("Raw state retrieval not supported by global state manager.") + throw UnsupportedOperationException( + "Raw state retrieval not supported by global state manager." + ) } - override fun toState(pair: Optional): AirbyteStateMessage? { - val dbState = StateGeneratorUtils.generateDbState(pairToCursorInfoMap) + override fun toState(pair: Optional): AirbyteStateMessage { + val dbState = + StateGeneratorUtils.generateDbState(pairToCursorInfoMap) .withCdc(isCdc) .withCdcState(cdcStateManager.cdcState) LOGGER.debug("Generated legacy state for {} streams", dbState.streams.size) - return AirbyteStateMessage().withType(AirbyteStateMessage.AirbyteStateType.LEGACY).withData(Jsons.jsonNode(dbState)) + return AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.LEGACY) + .withData(Jsons.jsonNode(dbState)) } - override fun updateAndEmit(pair: AirbyteStreamNameNamespacePair, cursor: String?): AirbyteStateMessage? { + override fun updateAndEmit( + pair: AirbyteStreamNameNamespacePair, + cursor: String? + ): AirbyteStateMessage? { return updateAndEmit(pair, cursor, 0L) } - override fun updateAndEmit(pair: AirbyteStreamNameNamespacePair, cursor: String?, cursorRecordCount: Long): AirbyteStateMessage? { + override fun updateAndEmit( + pair: AirbyteStreamNameNamespacePair, + cursor: String?, + cursorRecordCount: Long + ): AirbyteStateMessage? { // cdc file gets updated by debezium so the "update" part is a no op. if (!isCdc) { return super.updateAndEmit(pair, cursor, cursorRecordCount) @@ -86,21 +99,20 @@ class LegacyStateManager(dbState: DbState, catalog: ConfiguredAirbyteCatalog?) : companion object { private val LOGGER: Logger = LoggerFactory.getLogger(LegacyStateManager::class.java) - /** - * [Function] that extracts the cursor from the stream state. - */ - private val CURSOR_FUNCTION = Function { obj: DbStreamState? -> obj!!.cursor } + /** [Function] that extracts the cursor from the stream state. */ + private val CURSOR_FUNCTION = DbStreamState::getCursor - /** - * [Function] that extracts the cursor field(s) from the stream state. - */ - private val CURSOR_FIELD_FUNCTION = Function { obj: DbStreamState? -> obj!!.cursorField } + /** [Function] that extracts the cursor field(s) from the stream state. */ + private val CURSOR_FIELD_FUNCTION = DbStreamState::getCursorField - private val CURSOR_RECORD_COUNT_FUNCTION = Function { stream: DbStreamState? -> Objects.requireNonNullElse(stream!!.cursorRecordCount, 0L) } + private val CURSOR_RECORD_COUNT_FUNCTION = Function { stream: DbStreamState -> + Objects.requireNonNullElse(stream.cursorRecordCount, 0L) + } - /** - * [Function] that creates an [AirbyteStreamNameNamespacePair] from the stream state. - */ - private val NAME_NAMESPACE_PAIR_FUNCTION = Function { s: DbStreamState? -> AirbyteStreamNameNamespacePair(s!!.streamName, s.streamNamespace) } + /** [Function] that creates an [AirbyteStreamNameNamespacePair] from the stream state. */ + private val NAME_NAMESPACE_PAIR_FUNCTION = + Function { s: DbStreamState -> + AirbyteStreamNameNamespacePair(s!!.streamName, s.streamNamespace) + } } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateIterator.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateIterator.kt index f8240ca153cea..ef5758562854b 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateIterator.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateIterator.kt @@ -7,16 +7,18 @@ import com.google.common.collect.AbstractIterator import io.airbyte.protocol.models.v0.AirbyteMessage import io.airbyte.protocol.models.v0.AirbyteStateStats import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream -import org.slf4j.Logger -import org.slf4j.LoggerFactory import java.time.Duration import java.time.Instant import java.time.OffsetDateTime +import org.slf4j.Logger +import org.slf4j.LoggerFactory -class SourceStateIterator(private val messageIterator: Iterator, - private val stream: ConfiguredAirbyteStream, - private val sourceStateMessageProducer: SourceStateMessageProducer<*>, - private val stateEmitFrequency: StateEmitFrequency) : AbstractIterator(), MutableIterator { +open class SourceStateIterator( + private val messageIterator: Iterator, + private val stream: ConfiguredAirbyteStream, + private val sourceStateMessageProducer: SourceStateMessageProducer, + private val stateEmitFrequency: StateEmitFrequency +) : AbstractIterator(), MutableIterator { private var hasEmittedFinalState = false private var recordCount = 0L private var lastCheckpoint: Instant = Instant.now() @@ -26,26 +28,33 @@ class SourceStateIterator(private val messageIterator: Iterator, try { iteratorHasNextValue = messageIterator.hasNext() } catch (ex: Exception) { - // If the underlying iterator throws an exception, we want to fail the sync, expecting sync/attempt + // If the underlying iterator throws an exception, we want to fail the sync, expecting + // sync/attempt // will be restarted and // sync will resume from the last state message. throw FailedRecordIteratorException(ex) } if (iteratorHasNextValue) { - if (shouldEmitStateMessage() && sourceStateMessageProducer.shouldEmitStateMessage(stream)) { - val stateMessage = sourceStateMessageProducer.generateStateMessageAtCheckpoint(stream) - stateMessage!!.withSourceStats(AirbyteStateStats().withRecordCount(recordCount.toDouble())) + if ( + shouldEmitStateMessage() && + sourceStateMessageProducer.shouldEmitStateMessage(stream) + ) { + val stateMessage = + sourceStateMessageProducer.generateStateMessageAtCheckpoint(stream) + stateMessage!!.withSourceStats( + AirbyteStateStats().withRecordCount(recordCount.toDouble()) + ) recordCount = 0L lastCheckpoint = Instant.now() - return AirbyteMessage() - .withType(AirbyteMessage.Type.STATE) - .withState(stateMessage) + return AirbyteMessage().withType(AirbyteMessage.Type.STATE).withState(stateMessage) } - // Use try-catch to catch Exception that could occur when connection to the database fails + // Use try-catch to catch Exception that could occur when connection to the database + // fails try { val message = messageIterator.next() - val processedMessage = sourceStateMessageProducer.processRecordMessage(stream, message) + val processedMessage = + sourceStateMessageProducer.processRecordMessage(stream, message) recordCount++ return processedMessage } catch (e: Exception) { @@ -53,18 +62,22 @@ class SourceStateIterator(private val messageIterator: Iterator, } } else if (!hasEmittedFinalState) { hasEmittedFinalState = true - val finalStateMessageForStream = sourceStateMessageProducer.createFinalStateMessage(stream) - finalStateMessageForStream!!.withSourceStats(AirbyteStateStats().withRecordCount(recordCount.toDouble())) + val finalStateMessageForStream = + sourceStateMessageProducer.createFinalStateMessage(stream) + finalStateMessageForStream!!.withSourceStats( + AirbyteStateStats().withRecordCount(recordCount.toDouble()) + ) recordCount = 0L return AirbyteMessage() - .withType(AirbyteMessage.Type.STATE) - .withState(finalStateMessageForStream) + .withType(AirbyteMessage.Type.STATE) + .withState(finalStateMessageForStream) } else { return endOfData() } } - // This method is used to check if we should emit a state message. If the record count is set to 0, + // This method is used to check if we should emit a state message. If the record count is set to + // 0, // we should not emit a state message. // If the frequency is set to be zero, we should not use it. private fun shouldEmitStateMessage(): Boolean { @@ -75,7 +88,8 @@ class SourceStateIterator(private val messageIterator: Iterator, return true } if (!stateEmitFrequency.syncCheckpointDuration.isZero) { - return Duration.between(lastCheckpoint, OffsetDateTime.now()).compareTo(stateEmitFrequency.syncCheckpointDuration) > 0 + return Duration.between(lastCheckpoint, OffsetDateTime.now()) + .compareTo(stateEmitFrequency.syncCheckpointDuration) > 0 } return false } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateMessageProducer.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateMessageProducer.kt index a68bb1395b75a..4c70a0b0a2b7b 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateMessageProducer.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateMessageProducer.kt @@ -12,17 +12,13 @@ import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream * generate state messages when needed. This interface defines how would those state messages be * generated, and how the incoming record messages will be processed. * - * @param - */ + * @param + */ interface SourceStateMessageProducer { - /** - * Returns a state message that should be emitted at checkpoint. - */ + /** Returns a state message that should be emitted at checkpoint. */ fun generateStateMessageAtCheckpoint(stream: ConfiguredAirbyteStream?): AirbyteStateMessage? - /** - * For the incoming record message, this method defines how the connector will consume it. - */ + /** For the incoming record message, this method defines how the connector will consume it. */ fun processRecordMessage(stream: ConfiguredAirbyteStream, message: T): AirbyteMessage /** @@ -34,9 +30,9 @@ interface SourceStateMessageProducer { fun createFinalStateMessage(stream: ConfiguredAirbyteStream): AirbyteStateMessage? /** - * Determines if the iterator has reached checkpoint or not per connector's definition. By default - * iterator will check if the number of records processed is greater than the checkpoint interval or - * last state message has already passed syncCheckpointDuration. + * Determines if the iterator has reached checkpoint or not per connector's definition. By + * default iterator will check if the number of records processed is greater than the checkpoint + * interval or last state message has already passed syncCheckpointDuration. */ fun shouldEmitStateMessage(stream: ConfiguredAirbyteStream?): Boolean } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateEmitFrequency.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateEmitFrequency.kt index 242ea0e104f8f..6c2d0120cc6f1 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateEmitFrequency.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateEmitFrequency.kt @@ -10,9 +10,6 @@ class StateEmitFrequency(syncCheckpointRecords: Long, syncCheckpointDuration: Du val syncCheckpointDuration: Duration init { - this.streamName = streamName - this.primaryKey = primaryKey - this.keySequence = keySequence this.syncCheckpointRecords = syncCheckpointRecords this.syncCheckpointDuration = syncCheckpointDuration } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateGeneratorUtils.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateGeneratorUtils.kt index 8528da0bdeeb7..a9b61c9da642b 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateGeneratorUtils.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateGeneratorUtils.kt @@ -14,51 +14,51 @@ import io.airbyte.configoss.StateType import io.airbyte.configoss.StateWrapper import io.airbyte.configoss.helpers.StateMessageHelper import io.airbyte.protocol.models.v0.* -import org.slf4j.Logger -import org.slf4j.LoggerFactory import java.util.* import java.util.function.Function import java.util.stream.Collectors +import org.slf4j.Logger +import org.slf4j.LoggerFactory -/** - * Collection of utilities that facilitate the generation of state objects. - */ +/** Collection of utilities that facilitate the generation of state objects. */ object StateGeneratorUtils { private val LOGGER: Logger = LoggerFactory.getLogger(StateGeneratorUtils::class.java) - /** - * [Function] that extracts the cursor from the stream state. - */ - val CURSOR_FUNCTION: Function = Function { stream: AirbyteStreamState -> - val dbStreamState = extractState(stream) - dbStreamState.map { obj: DbStreamState -> obj.cursor }.orElse(null) - } + /** [Function] that extracts the cursor from the stream state. */ + val CURSOR_FUNCTION: Function = + Function { stream: AirbyteStreamState -> + val dbStreamState = extractState(stream) + dbStreamState.map { obj: DbStreamState -> obj.cursor }.orElse(null) + } - /** - * [Function] that extracts the cursor field(s) from the stream state. - */ - val CURSOR_FIELD_FUNCTION: Function> = Function { stream: AirbyteStreamState -> - val dbStreamState = extractState(stream) - if (dbStreamState.isPresent) { - return@Function dbStreamState.get().cursorField - } else { - return@Function listOf() + /** [Function] that extracts the cursor field(s) from the stream state. */ + val CURSOR_FIELD_FUNCTION: Function> = + Function { stream: AirbyteStreamState -> + val dbStreamState = extractState(stream) + if (dbStreamState.isPresent) { + return@Function dbStreamState.get().cursorField + } else { + return@Function listOf() + } } - } - val CURSOR_RECORD_COUNT_FUNCTION: Function = Function { stream: AirbyteStreamState -> - val dbStreamState = extractState(stream) - dbStreamState.map { obj: DbStreamState -> obj.cursorRecordCount }.orElse(0L) - } + val CURSOR_RECORD_COUNT_FUNCTION: Function = + Function { stream: AirbyteStreamState -> + val dbStreamState = extractState(stream) + dbStreamState.map { obj: DbStreamState -> obj.cursorRecordCount }.orElse(0L) + } - /** - * [Function] that creates an [AirbyteStreamNameNamespacePair] from the stream state. - */ - val NAME_NAMESPACE_PAIR_FUNCTION: Function = Function { s: AirbyteStreamState -> - if (isValidStreamDescriptor(s.streamDescriptor) - ) AirbyteStreamNameNamespacePair(s.streamDescriptor.name, s.streamDescriptor.namespace) - else null - } + /** [Function] that creates an [AirbyteStreamNameNamespacePair] from the stream state. */ + val NAME_NAMESPACE_PAIR_FUNCTION: + Function = + Function { s: AirbyteStreamState -> + if (isValidStreamDescriptor(s.streamDescriptor)) + AirbyteStreamNameNamespacePair( + s.streamDescriptor.name, + s.streamDescriptor.namespace + ) + else null + } /** * Generates the stream state for the given stream and cursor information. @@ -67,30 +67,42 @@ object StateGeneratorUtils { * @param cursorInfo The current cursor. * @return The [AirbyteStreamState] representing the current state of the stream. */ - fun generateStreamState(airbyteStreamNameNamespacePair: AirbyteStreamNameNamespacePair?, - cursorInfo: CursorInfo?): AirbyteStreamState { + fun generateStreamState( + airbyteStreamNameNamespacePair: AirbyteStreamNameNamespacePair, + cursorInfo: CursorInfo + ): AirbyteStreamState { return AirbyteStreamState() - .withStreamDescriptor( - StreamDescriptor().withName(airbyteStreamNameNamespacePair!!.name).withNamespace(airbyteStreamNameNamespacePair.namespace)) - .withStreamState(Jsons.jsonNode(generateDbStreamState(airbyteStreamNameNamespacePair, cursorInfo))) + .withStreamDescriptor( + StreamDescriptor() + .withName(airbyteStreamNameNamespacePair.name) + .withNamespace(airbyteStreamNameNamespacePair.namespace) + ) + .withStreamState( + Jsons.jsonNode(generateDbStreamState(airbyteStreamNameNamespacePair, cursorInfo)) + ) } /** - * Generates a list of valid stream states from the provided stream and cursor information. A stream - * state is considered to be valid if the stream has a valid descriptor (see + * Generates a list of valid stream states from the provided stream and cursor information. A + * stream state is considered to be valid if the stream has a valid descriptor (see * [.isValidStreamDescriptor] for more details). * * @param pairToCursorInfoMap The map of stream name/namespace tuple to the current cursor * information for that stream - * @return The list of stream states derived from the state information extracted from the provided - * map. + * @return The list of stream states derived from the state information extracted from the + * provided map. */ - fun generateStreamStateList(pairToCursorInfoMap: Map?): List { - return pairToCursorInfoMap!!.entries.stream() - .sorted(java.util.Map.Entry.comparingByKey()) - .map { e: Map.Entry -> generateStreamState(e.key, e.value) } - .filter { s: AirbyteStreamState -> isValidStreamDescriptor(s.streamDescriptor) } - .collect(Collectors.toList()) + fun generateStreamStateList( + pairToCursorInfoMap: Map + ): List { + return pairToCursorInfoMap.entries + .stream() + .sorted(java.util.Map.Entry.comparingByKey()) + .map { e: Map.Entry -> + generateStreamState(e.key, e.value) + } + .filter { s: AirbyteStreamState -> isValidStreamDescriptor(s.streamDescriptor) } + .collect(Collectors.toList()) } /** @@ -100,13 +112,22 @@ object StateGeneratorUtils { * information for that stream * @return The legacy [DbState]. */ - fun generateDbState(pairToCursorInfoMap: Map?): DbState { + fun generateDbState( + pairToCursorInfoMap: Map + ): DbState { return DbState() - .withCdc(false) - .withStreams(pairToCursorInfoMap!!.entries.stream() - .sorted(java.util.Map.Entry.comparingByKey()) // sort by stream name then namespace for sanity. - .map { e: Map.Entry -> generateDbStreamState(e.key, e.value) } - .collect(Collectors.toList())) + .withCdc(false) + .withStreams( + pairToCursorInfoMap.entries + .stream() + .sorted( + java.util.Map.Entry.comparingByKey() + ) // sort by stream name then namespace for sanity. + .map { e: Map.Entry -> + generateDbStreamState(e.key, e.value) + } + .collect(Collectors.toList()) + ) } /** @@ -116,15 +137,21 @@ object StateGeneratorUtils { * @param cursorInfo The current cursor. * @return The [DbStreamState]. */ - fun generateDbStreamState(airbyteStreamNameNamespacePair: AirbyteStreamNameNamespacePair?, - cursorInfo: CursorInfo?): DbStreamState { - val state = DbStreamState() - .withStreamName(airbyteStreamNameNamespacePair!!.name) + fun generateDbStreamState( + airbyteStreamNameNamespacePair: AirbyteStreamNameNamespacePair, + cursorInfo: CursorInfo + ): DbStreamState { + val state = + DbStreamState() + .withStreamName(airbyteStreamNameNamespacePair.name) .withStreamNamespace(airbyteStreamNameNamespacePair.namespace) - .withCursorField(if (cursorInfo.getCursorField() == null) emptyList() else Lists.newArrayList(cursorInfo.getCursorField())) - .withCursor(cursorInfo.getCursor()) - if (cursorInfo.getCursorRecordCount() > 0L) { - state.cursorRecordCount = cursorInfo.getCursorRecordCount() + .withCursorField( + if (cursorInfo.cursorField == null) emptyList() + else Lists.newArrayList(cursorInfo.cursorField) + ) + .withCursor(cursorInfo.cursor) + if (cursorInfo.cursorRecordCount > 0L) { + state.cursorRecordCount = cursorInfo.cursorRecordCount } return state } @@ -133,9 +160,8 @@ object StateGeneratorUtils { * Extracts the actual state from the [AirbyteStreamState] object. * * @param state The [AirbyteStreamState] that contains the actual stream state as JSON. - * @return An [Optional] possibly containing the deserialized representation of the stream - * state or an empty [Optional] if the state is not present or could not be - * deserialized. + * @return An [Optional] possibly containing the deserialized representation of the stream state + * or an empty [Optional] if the state is not present or could not be deserialized. */ fun extractState(state: AirbyteStreamState): Optional { try { @@ -147,16 +173,15 @@ object StateGeneratorUtils { } /** - * Tests whether the provided [StreamDescriptor] is valid. A valid descriptor is defined as - * one that has a non-`null` name. + * Tests whether the provided [StreamDescriptor] is valid. A valid descriptor is defined as one + * that has a non-`null` name. * * See * https://github.com/airbytehq/airbyte/blob/e63458fabb067978beb5eaa74d2bc130919b419f/docs/understanding-airbyte/airbyte-protocol.md * for more details * * @param streamDescriptor A [StreamDescriptor] to be validated. - * @return `true` if the provided [StreamDescriptor] is valid or `false` if it is - * invalid. + * @return `true` if the provided [StreamDescriptor] is valid or `false` if it is invalid. */ fun isValidStreamDescriptor(streamDescriptor: StreamDescriptor?): Boolean { return if (streamDescriptor != null) { @@ -167,46 +192,69 @@ object StateGeneratorUtils { } /** - * Converts a [AirbyteStateType.LEGACY] state message into a [AirbyteStateType.GLOBAL] - * message. + * Converts a [AirbyteStateType.LEGACY] state message into a [AirbyteStateType.GLOBAL] message. * * @param airbyteStateMessage A [AirbyteStateType.LEGACY] state message. * @return A [AirbyteStateType.GLOBAL] state message. */ - fun convertLegacyStateToGlobalState(airbyteStateMessage: AirbyteStateMessage?): AirbyteStateMessage { - val dbState = Jsons.`object`(airbyteStateMessage!!.data, DbState::class.java) - val globalState = AirbyteGlobalState() + fun convertLegacyStateToGlobalState( + airbyteStateMessage: AirbyteStateMessage + ): AirbyteStateMessage { + val dbState = Jsons.`object`(airbyteStateMessage.data, DbState::class.java) + val globalState = + AirbyteGlobalState() .withSharedState(Jsons.jsonNode(dbState.cdcState)) - .withStreamStates(dbState.streams.stream() + .withStreamStates( + dbState.streams + .stream() .map { s: DbStreamState -> AirbyteStreamState() - .withStreamDescriptor(StreamDescriptor().withName(s.streamName).withNamespace(s.streamNamespace)) - .withStreamState(Jsons.jsonNode(s)) + .withStreamDescriptor( + StreamDescriptor() + .withName(s.streamName) + .withNamespace(s.streamNamespace) + ) + .withStreamState(Jsons.jsonNode(s)) } - .collect( - Collectors.toList())) - return AirbyteStateMessage().withType(AirbyteStateMessage.AirbyteStateType.GLOBAL).withGlobal(globalState) + .collect(Collectors.toList()) + ) + return AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.GLOBAL) + .withGlobal(globalState) } /** - * Converts a [AirbyteStateType.LEGACY] state message into a list of - * [AirbyteStateType.STREAM] messages. + * Converts a [AirbyteStateType.LEGACY] state message into a list of [AirbyteStateType.STREAM] + * messages. * * @param airbyteStateMessage A [AirbyteStateType.LEGACY] state message. * @return A list [AirbyteStateType.STREAM] state messages. */ - fun convertLegacyStateToStreamState(airbyteStateMessage: AirbyteStateMessage?): List { - return Jsons.`object`(airbyteStateMessage!!.data, DbState::class.java).streams.stream() - .map { s: DbStreamState -> - AirbyteStateMessage().withType(AirbyteStateMessage.AirbyteStateType.STREAM) - .withStream(AirbyteStreamState() - .withStreamDescriptor(StreamDescriptor().withNamespace(s.streamNamespace).withName(s.streamName)) - .withStreamState(Jsons.jsonNode(s))) - } - .collect(Collectors.toList()) + fun convertLegacyStateToStreamState( + airbyteStateMessage: AirbyteStateMessage + ): List { + return Jsons.`object`(airbyteStateMessage.data, DbState::class.java) + .streams + .stream() + .map { s: DbStreamState -> + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor() + .withNamespace(s.streamNamespace) + .withName(s.streamName) + ) + .withStreamState(Jsons.jsonNode(s)) + ) + } + .collect(Collectors.toList()) } - fun convertStateMessage(state: io.airbyte.protocol.models.AirbyteStateMessage): AirbyteStateMessage { + fun convertStateMessage( + state: io.airbyte.protocol.models.AirbyteStateMessage + ): AirbyteStateMessage { return Jsons.`object`(Jsons.jsonNode(state), AirbyteStateMessage::class.java) } @@ -217,22 +265,25 @@ object StateGeneratorUtils { * @Param supportedStateType the [AirbyteStateType] supported by this connector. * @return The deserialized object representation of the state. */ - fun deserializeInitialState(initialStateJson: JsonNode?, - supportedStateType: AirbyteStateMessage.AirbyteStateType): List { + fun deserializeInitialState( + initialStateJson: JsonNode?, + supportedStateType: AirbyteStateMessage.AirbyteStateType + ): List { val typedState = StateMessageHelper.getTypedState(initialStateJson) return typedState - .map { state: StateWrapper -> - when (state.stateType) { - StateType.GLOBAL -> java.util.List.of(convertStateMessage(state.global)) - StateType.STREAM -> state.stateMessages - .stream() - .map { obj: io.airbyte.protocol.models.AirbyteStateMessage? -> convertStateMessage() }.toList() - - else -> java.util.List.of(AirbyteStateMessage().withType(AirbyteStateMessage.AirbyteStateType.LEGACY) - .withData(state.legacyState)) - } + .map { state: StateWrapper -> + when (state.stateType) { + StateType.GLOBAL -> java.util.List.of(convertStateMessage(state.global)) + StateType.STREAM -> state.stateMessages.map { convertStateMessage(it) } + else -> + java.util.List.of( + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.LEGACY) + .withData(state.legacyState) + ) } - .orElse(generateEmptyInitialState(supportedStateType)) + } + .orElse(generateEmptyInitialState(supportedStateType)) } /** @@ -241,21 +292,32 @@ object StateGeneratorUtils { * @Param supportedStateType the [AirbyteStateType] supported by this connector. * @return The empty, initial state. */ - private fun generateEmptyInitialState(supportedStateType: AirbyteStateMessage.AirbyteStateType): List { + private fun generateEmptyInitialState( + supportedStateType: AirbyteStateMessage.AirbyteStateType + ): List { // For backwards compatibility with existing connectors if (supportedStateType == AirbyteStateMessage.AirbyteStateType.LEGACY) { - return java.util.List.of(AirbyteStateMessage() + return java.util.List.of( + AirbyteStateMessage() .withType(AirbyteStateMessage.AirbyteStateType.LEGACY) - .withData(Jsons.jsonNode(DbState()))) + .withData(Jsons.jsonNode(DbState())) + ) } else if (supportedStateType == AirbyteStateMessage.AirbyteStateType.GLOBAL) { - val globalState = AirbyteGlobalState() + val globalState = + AirbyteGlobalState() .withSharedState(Jsons.jsonNode(CdcState())) .withStreamStates(listOf()) - return java.util.List.of(AirbyteStateMessage().withType(AirbyteStateMessage.AirbyteStateType.GLOBAL).withGlobal(globalState)) + return java.util.List.of( + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.GLOBAL) + .withGlobal(globalState) + ) } else { - return java.util.List.of(AirbyteStateMessage() + return java.util.List.of( + AirbyteStateMessage() .withType(AirbyteStateMessage.AirbyteStateType.STREAM) - .withStream(AirbyteStreamState())) + .withStream(AirbyteStreamState()) + ) } } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateManager.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateManager.kt index 41d83e58370ae..9588478c6ac51 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateManager.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateManager.kt @@ -8,24 +8,24 @@ import io.airbyte.cdk.integrations.source.relationaldb.CdcStateManager import io.airbyte.cdk.integrations.source.relationaldb.CursorInfo import io.airbyte.protocol.models.v0.AirbyteStateMessage import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair +import java.util.* import org.slf4j.Logger import org.slf4j.LoggerFactory -import java.util.* /** * Defines a manager that manages connector state. Connector state is used to keep track of the data * synced by the connector. * * @param The type of the state maintained by the manager. - * @param The type of the stream(s) stored within the state maintained by the manager. - */ + * @param The type of the stream(s) stored within the state maintained by the manager. + */ interface StateManager { /** * Retrieves the [CdcStateManager] associated with the state manager. * * @return The [CdcStateManager] - * @throws UnsupportedOperationException if the state manager does not support tracking change data - * capture (CDC) state. + * @throws UnsupportedOperationException if the state manager does not support tracking change + * data capture (CDC) state. */ val cdcStateManager: CdcStateManager @@ -34,7 +34,8 @@ interface StateManager { * database-specific sync modes (e.g. Xmin) that would want to handle and parse their own state * * @return the list of airbyte state messages - * @throws UnsupportedOperationException if the state manager does not support retrieving raw state. + * @throws UnsupportedOperationException if the state manager does not support retrieving raw + * state. */ val rawStateMessages: List? @@ -42,10 +43,10 @@ interface StateManager { * Retrieves the map of stream name/namespace tuple to the current cursor information for that * stream. * - * @return The map of stream name/namespace tuple to the current cursor information for that stream - * as maintained by this state manager. + * @return The map of stream name/namespace tuple to the current cursor information for that + * stream as maintained by this state manager. */ - val pairToCursorInfoMap: Map? + val pairToCursorInfoMap: Map /** * Generates an [AirbyteStateMessage] that represents the current state contained in the state @@ -56,11 +57,11 @@ interface StateManager { * @return The [AirbyteStateMessage] that represents the current state contained in the state * manager. */ - fun toState(pair: Optional): AirbyteStateMessage? + fun toState(pair: Optional): AirbyteStateMessage /** - * Retrieves an [Optional] possibly containing the cursor value tracked in the state - * associated with the provided stream name/namespace tuple. + * Retrieves an [Optional] possibly containing the cursor value tracked in the state associated + * with the provided stream name/namespace tuple. * * @param pair The [AirbyteStreamNameNamespacePair] which identifies a stream. * @return An [Optional] possibly containing the cursor value tracked in the state associated @@ -71,8 +72,8 @@ interface StateManager { } /** - * Retrieves an [Optional] possibly containing the cursor field name associated with the - * cursor tracked in the state associated with the provided stream name/namespace tuple. + * Retrieves an [Optional] possibly containing the cursor field name associated with the cursor + * tracked in the state associated with the provided stream name/namespace tuple. * * @param pair The [AirbyteStreamNameNamespacePair] which identifies a stream. * @return An [Optional] possibly containing the cursor field name associated with the cursor @@ -99,16 +100,16 @@ interface StateManager { * the cursor tracked in the state associated with the provided stream name/namespace tuple. * * @param pair The [AirbyteStreamNameNamespacePair] which identifies a stream. - * @return An [Optional] possibly containing the original cursor field name associated with - * the cursor tracked in the state associated with the provided stream name/namespace tuple. + * @return An [Optional] possibly containing the original cursor field name associated with the + * cursor tracked in the state associated with the provided stream name/namespace tuple. */ fun getOriginalCursorField(pair: AirbyteStreamNameNamespacePair?): Optional? { return getCursorInfo(pair).map { obj: CursorInfo -> obj.originalCursorField } } /** - * Retrieves the current cursor information stored in the state manager for the steam name/namespace - * tuple. + * Retrieves the current cursor information stored in the state manager for the steam + * name/namespace tuple. * * @param pair The [AirbyteStreamNameNamespacePair] that represents a stream managed by the * state manager. @@ -127,19 +128,18 @@ interface StateManager { * @return An [AirbyteStateMessage] that represents the current state maintained by the state * manager. */ - fun emit(pair: Optional): AirbyteStateMessage? { + fun emit(pair: Optional): AirbyteStateMessage? { return toState(pair) } /** - * Updates the cursor associated with the provided stream name/namespace pair and emits the current - * state maintained by the state manager. + * Updates the cursor associated with the provided stream name/namespace pair and emits the + * current state maintained by the state manager. * * @param pair The [AirbyteStreamNameNamespacePair] that represents a stream managed by the * state manager. * @param cursor The new value for the cursor associated with the - * [AirbyteStreamNameNamespacePair] that represents a stream managed by the state - * manager. + * [AirbyteStreamNameNamespacePair] that represents a stream managed by the state manager. * @return An [AirbyteStateMessage] that represents the current state maintained by the state * manager. */ @@ -147,14 +147,26 @@ interface StateManager { return updateAndEmit(pair, cursor, 0L) } - fun updateAndEmit(pair: AirbyteStreamNameNamespacePair, cursor: String?, cursorRecordCount: Long): AirbyteStateMessage? { + fun updateAndEmit( + pair: AirbyteStreamNameNamespacePair, + cursor: String?, + cursorRecordCount: Long + ): AirbyteStateMessage? { val cursorInfo = getCursorInfo(pair) - Preconditions.checkState(cursorInfo.isPresent, "Could not find cursor information for stream: $pair") + Preconditions.checkState( + cursorInfo.isPresent, + "Could not find cursor information for stream: $pair" + ) cursorInfo.get().setCursor(cursor) if (cursorRecordCount > 0L) { cursorInfo.get().setCursorRecordCount(cursorRecordCount) } - LOGGER.debug("Updating cursor value for {} to {} (count {})...", pair, cursor, cursorRecordCount) + LOGGER.debug( + "Updating cursor value for {} to {} (count {})...", + pair, + cursor, + cursorRecordCount + ) return emit(Optional.ofNullable(pair)) } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateManagerFactory.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateManagerFactory.kt index ecdd997801c50..2d34be63c3b87 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateManagerFactory.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateManagerFactory.kt @@ -10,9 +10,7 @@ import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog import org.slf4j.Logger import org.slf4j.LoggerFactory -/** - * Factory class that creates [StateManager] instances based on the provided state. - */ +/** Factory class that creates [StateManager] instances based on the provided state. */ object StateManagerFactory { private val LOGGER: Logger = LoggerFactory.getLogger(StateManagerFactory::class.java) @@ -28,61 +26,82 @@ object StateManagerFactory { * manager. * @return A newly created [StateManager] implementation based on the provided state. */ - fun createStateManager(supportedStateType: AirbyteStateMessage.AirbyteStateType?, - initialState: List?, - catalog: ConfiguredAirbyteCatalog?): StateManager { + fun createStateManager( + supportedStateType: AirbyteStateMessage.AirbyteStateType?, + initialState: List?, + catalog: ConfiguredAirbyteCatalog + ): StateManager { if (initialState != null && !initialState.isEmpty()) { val airbyteStateMessage = initialState[0] when (supportedStateType) { AirbyteStateMessage.AirbyteStateType.LEGACY -> { - LOGGER.info("Legacy state manager selected to manage state object with type {}.", airbyteStateMessage!!.type) - @Suppress("deprecation") val retVal: StateManager = LegacyStateManager(Jsons.`object`(airbyteStateMessage.data, DbState::class.java), catalog) + LOGGER.info( + "Legacy state manager selected to manage state object with type {}.", + airbyteStateMessage!!.type + ) + @Suppress("deprecation") + val retVal: StateManager = + LegacyStateManager( + Jsons.`object`(airbyteStateMessage.data, DbState::class.java), + catalog + ) return retVal } - AirbyteStateMessage.AirbyteStateType.GLOBAL -> { - LOGGER.info("Global state manager selected to manage state object with type {}.", airbyteStateMessage!!.type) + LOGGER.info( + "Global state manager selected to manage state object with type {}.", + airbyteStateMessage!!.type + ) return GlobalStateManager(generateGlobalState(airbyteStateMessage), catalog) } - AirbyteStateMessage.AirbyteStateType.STREAM -> { - LOGGER.info("Stream state manager selected to manage state object with type {}.", airbyteStateMessage!!.type) + LOGGER.info( + "Stream state manager selected to manage state object with type {}.", + airbyteStateMessage!!.type + ) return StreamStateManager(generateStreamState(initialState), catalog) } - else -> { - LOGGER.info("Stream state manager selected to manage state object with type {}.", airbyteStateMessage!!.type) + LOGGER.info( + "Stream state manager selected to manage state object with type {}.", + airbyteStateMessage!!.type + ) return StreamStateManager(generateStreamState(initialState), catalog) } } } else { - throw IllegalArgumentException("Failed to create state manager due to empty state list.") + throw IllegalArgumentException( + "Failed to create state manager due to empty state list." + ) } } /** - * Handles the conversion between a different state type and the global state. This method handles - * the following transitions: - * - * * Stream -> Global (not supported, results in [IllegalArgumentException] - * * Legacy -> Global (supported) - * * Global -> Global (supported/no conversion required) + * Handles the conversion between a different state type and the global state. This method + * handles the following transitions: * + * * Stream -> Global (not supported, results in [IllegalArgumentException] + * * Legacy -> Global (supported) + * * Global -> Global (supported/no conversion required) * * @param airbyteStateMessage The current state that is to be converted to global state. * @return The converted state message. - * @throws IllegalArgumentException if unable to convert between the given state type and global. + * @throws IllegalArgumentException if unable to convert between the given state type and + * global. */ - private fun generateGlobalState(airbyteStateMessage: AirbyteStateMessage?): AirbyteStateMessage? { + private fun generateGlobalState(airbyteStateMessage: AirbyteStateMessage): AirbyteStateMessage { var globalStateMessage = airbyteStateMessage when (airbyteStateMessage!!.type) { - AirbyteStateMessage.AirbyteStateType.STREAM -> throw IllegalArgumentException("Unable to convert connector state from stream to global. Please reset the connection to continue.") + AirbyteStateMessage.AirbyteStateType.STREAM -> + throw IllegalArgumentException( + "Unable to convert connector state from stream to global. Please reset the connection to continue." + ) AirbyteStateMessage.AirbyteStateType.LEGACY -> { - globalStateMessage = StateGeneratorUtils.convertLegacyStateToGlobalState(airbyteStateMessage) + globalStateMessage = + StateGeneratorUtils.convertLegacyStateToGlobalState(airbyteStateMessage) LOGGER.info("Legacy state converted to global state.", airbyteStateMessage.type) } - AirbyteStateMessage.AirbyteStateType.GLOBAL -> {} else -> {} } @@ -90,24 +109,30 @@ object StateManagerFactory { } /** - * Handles the conversion between a different state type and the stream state. This method handles - * the following transitions: - * - * * Global -> Stream (not supported, results in [IllegalArgumentException] - * * Legacy -> Stream (supported) - * * Stream -> Stream (supported/no conversion required) + * Handles the conversion between a different state type and the stream state. This method + * handles the following transitions: * + * * Global -> Stream (not supported, results in [IllegalArgumentException] + * * Legacy -> Stream (supported) + * * Stream -> Stream (supported/no conversion required) * * @param states The list of current states. * @return The converted state messages. - * @throws IllegalArgumentException if unable to convert between the given state type and stream. + * @throws IllegalArgumentException if unable to convert between the given state type and + * stream. */ - private fun generateStreamState(states: List): List { + private fun generateStreamState(states: List): List { val airbyteStateMessage = states[0] - val streamStates: MutableList = ArrayList() + val streamStates: MutableList = ArrayList() when (airbyteStateMessage!!.type) { - AirbyteStateMessage.AirbyteStateType.GLOBAL -> throw IllegalArgumentException("Unable to convert connector state from global to stream. Please reset the connection to continue.") - AirbyteStateMessage.AirbyteStateType.LEGACY -> streamStates.addAll(StateGeneratorUtils.convertLegacyStateToStreamState(airbyteStateMessage)) + AirbyteStateMessage.AirbyteStateType.GLOBAL -> + throw IllegalArgumentException( + "Unable to convert connector state from global to stream. Please reset the connection to continue." + ) + AirbyteStateMessage.AirbyteStateType.LEGACY -> + streamStates.addAll( + StateGeneratorUtils.convertLegacyStateToStreamState(airbyteStateMessage) + ) AirbyteStateMessage.AirbyteStateType.STREAM -> streamStates.addAll(states) else -> streamStates.addAll(states) } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StreamStateManager.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StreamStateManager.kt index f599d6cec87b6..e09c7d90d03bf 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StreamStateManager.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/main/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StreamStateManager.kt @@ -9,43 +9,51 @@ import io.airbyte.protocol.models.v0.AirbyteStateMessage import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair import io.airbyte.protocol.models.v0.AirbyteStreamState import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog -import org.slf4j.Logger -import org.slf4j.LoggerFactory import java.util.* import java.util.function.Supplier import java.util.stream.Collectors +import org.slf4j.Logger +import org.slf4j.LoggerFactory /** * Per-stream implementation of the [StateManager] interface. * - * * This implementation generates a state object for each stream detected in catalog/map of known * streams to cursor information stored in this manager. */ class StreamStateManager /** - * Constructs a new [StreamStateManager] that is seeded with the provided - * [AirbyteStateMessage]. + * Constructs a new [StreamStateManager] that is seeded with the provided [AirbyteStateMessage]. * - * @param airbyteStateMessages The initial state represented as a list of - * [AirbyteStateMessage]s. + * @param airbyteStateMessages The initial state represented as a list of [AirbyteStateMessage]s. * @param catalog The [ConfiguredAirbyteCatalog] for the connector associated with this state * manager. - */(private val rawAirbyteStateMessages: List, catalog: ConfiguredAirbyteCatalog?) : AbstractStateManager(catalog, - Supplier { rawAirbyteStateMessages.stream().map { obj: AirbyteStateMessage? -> obj!!.stream }.collect(Collectors.toList()) }, + */ +( + private val rawAirbyteStateMessages: List, + catalog: ConfiguredAirbyteCatalog +) : + AbstractStateManager( + catalog, + Supplier { + rawAirbyteStateMessages.stream().map { it.stream }.collect(Collectors.toList()) + }, StateGeneratorUtils.CURSOR_FUNCTION, StateGeneratorUtils.CURSOR_FIELD_FUNCTION, StateGeneratorUtils.CURSOR_RECORD_COUNT_FUNCTION, - StateGeneratorUtils.NAME_NAMESPACE_PAIR_FUNCTION) { + StateGeneratorUtils.NAME_NAMESPACE_PAIR_FUNCTION + ) { override val cdcStateManager: CdcStateManager get() { - throw UnsupportedOperationException("CDC state management not supported by stream state manager.") + throw UnsupportedOperationException( + "CDC state management not supported by stream state manager." + ) } override val rawStateMessages: List? get() = rawAirbyteStateMessages - override fun toState(pair: Optional): AirbyteStateMessage? { + override fun toState(pair: Optional): AirbyteStateMessage { if (pair.isPresent) { val pairToCursorInfoMap = pairToCursorInfoMap val cursorInfo = Optional.ofNullable(pairToCursorInfoMap!![pair.get()]) @@ -53,16 +61,30 @@ class StreamStateManager if (cursorInfo.isPresent) { LOGGER.debug("Generating state message for {}...", pair) return AirbyteStateMessage() - .withType(AirbyteStateMessage.AirbyteStateType.STREAM) // Temporarily include legacy state for backwards compatibility with the platform - .withData(Jsons.jsonNode(StateGeneratorUtils.generateDbState(pairToCursorInfoMap))) - .withStream(StateGeneratorUtils.generateStreamState(pair.get(), cursorInfo.get())) + .withType( + AirbyteStateMessage.AirbyteStateType.STREAM + ) // Temporarily include legacy state for backwards compatibility with the + // platform + .withData( + Jsons.jsonNode(StateGeneratorUtils.generateDbState(pairToCursorInfoMap)) + ) + .withStream( + StateGeneratorUtils.generateStreamState(pair.get(), cursorInfo.get()) + ) } else { - LOGGER.warn("Cursor information could not be located in state for stream {}. Returning a new, empty state message...", pair) - return AirbyteStateMessage().withType(AirbyteStateMessage.AirbyteStateType.STREAM).withStream(AirbyteStreamState()) + LOGGER.warn( + "Cursor information could not be located in state for stream {}. Returning a new, empty state message...", + pair + ) + return AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream(AirbyteStreamState()) } } else { LOGGER.warn("Stream not provided. Returning a new, empty state message...") - return AirbyteStateMessage().withType(AirbyteStateMessage.AirbyteStateType.STREAM).withStream(AirbyteStreamState()) + return AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream(AirbyteStreamState()) } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/AirbyteDebeziumHandlerTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/AirbyteDebeziumHandlerTest.kt index 2a5ac25296c93..8732a0a6546e7 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/AirbyteDebeziumHandlerTest.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/AirbyteDebeziumHandlerTest.kt @@ -10,45 +10,65 @@ import io.airbyte.protocol.models.v0.AirbyteCatalog import io.airbyte.protocol.models.v0.CatalogHelpers import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream import io.airbyte.protocol.models.v0.SyncMode -import org.junit.jupiter.api.Assertions -import org.junit.jupiter.api.Test import java.util.List import java.util.function.Consumer +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test class AirbyteDebeziumHandlerTest { @Test fun shouldUseCdcTestShouldReturnTrue() { - val catalog = AirbyteCatalog().withStreams(List.of( - CatalogHelpers.createAirbyteStream( - "MODELS_STREAM_NAME", - "MODELS_SCHEMA", - Field.of("COL_ID", JsonSchemaType.NUMBER), - Field.of("COL_MAKE_ID", JsonSchemaType.NUMBER), - Field.of("COL_MODEL", JsonSchemaType.STRING)) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL)) - .withSourceDefinedPrimaryKey(List.of(listOf("COL_ID"))))) - val configuredCatalog = CatalogHelpers - .toDefaultConfiguredCatalog(catalog) + val catalog = + AirbyteCatalog() + .withStreams( + List.of( + CatalogHelpers.createAirbyteStream( + "MODELS_STREAM_NAME", + "MODELS_SCHEMA", + Field.of("COL_ID", JsonSchemaType.NUMBER), + Field.of("COL_MAKE_ID", JsonSchemaType.NUMBER), + Field.of("COL_MODEL", JsonSchemaType.STRING) + ) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL) + ) + .withSourceDefinedPrimaryKey(List.of(listOf("COL_ID"))) + ) + ) + val configuredCatalog = CatalogHelpers.toDefaultConfiguredCatalog(catalog) // set all streams to incremental. - configuredCatalog.streams.forEach(Consumer { s: ConfiguredAirbyteStream -> s.syncMode = SyncMode.INCREMENTAL }) + configuredCatalog.streams.forEach( + Consumer { s: ConfiguredAirbyteStream -> s.syncMode = SyncMode.INCREMENTAL } + ) - Assertions.assertTrue(AirbyteDebeziumHandler.isAnyStreamIncrementalSyncMode(configuredCatalog)) + Assertions.assertTrue( + AirbyteDebeziumHandler.isAnyStreamIncrementalSyncMode(configuredCatalog) + ) } @Test fun shouldUseCdcTestShouldReturnFalse() { - val catalog = AirbyteCatalog().withStreams(List.of( - CatalogHelpers.createAirbyteStream( - "MODELS_STREAM_NAME", - "MODELS_SCHEMA", - Field.of("COL_ID", JsonSchemaType.NUMBER), - Field.of("COL_MAKE_ID", JsonSchemaType.NUMBER), - Field.of("COL_MODEL", JsonSchemaType.STRING)) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL)) - .withSourceDefinedPrimaryKey(List.of(listOf("COL_ID"))))) - val configuredCatalog = CatalogHelpers - .toDefaultConfiguredCatalog(catalog) + val catalog = + AirbyteCatalog() + .withStreams( + List.of( + CatalogHelpers.createAirbyteStream( + "MODELS_STREAM_NAME", + "MODELS_SCHEMA", + Field.of("COL_ID", JsonSchemaType.NUMBER), + Field.of("COL_MAKE_ID", JsonSchemaType.NUMBER), + Field.of("COL_MODEL", JsonSchemaType.STRING) + ) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL) + ) + .withSourceDefinedPrimaryKey(List.of(listOf("COL_ID"))) + ) + ) + val configuredCatalog = CatalogHelpers.toDefaultConfiguredCatalog(catalog) - Assertions.assertFalse(AirbyteDebeziumHandler.isAnyStreamIncrementalSyncMode(configuredCatalog)) + Assertions.assertFalse( + AirbyteDebeziumHandler.isAnyStreamIncrementalSyncMode(configuredCatalog) + ) } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/AirbyteFileOffsetBackingStoreTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/AirbyteFileOffsetBackingStoreTest.kt index 89668f3bed40d..aeba71586adb6 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/AirbyteFileOffsetBackingStoreTest.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/AirbyteFileOffsetBackingStoreTest.kt @@ -7,12 +7,12 @@ import io.airbyte.cdk.integrations.debezium.internals.AirbyteFileOffsetBackingSt import io.airbyte.commons.io.IOs import io.airbyte.commons.json.Jsons import io.airbyte.commons.resources.MoreResources -import org.junit.jupiter.api.Assertions -import org.junit.jupiter.api.Test import java.io.IOException import java.nio.file.Files import java.nio.file.Path import java.util.* +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test internal class AirbyteFileOffsetBackingStoreTest { @Test @@ -41,7 +41,9 @@ internal class AirbyteFileOffsetBackingStoreTest { // verify that, after a round trip through the offset store, we get back the same data. Assertions.assertEquals(stateFromOffsetStore2, stateFromOffsetStore3) // verify that the file written by the offset store is identical to the template file. - Assertions.assertTrue(com.google.common.io.Files.equal(secondWriteFilePath.toFile(), writeFilePath.toFile())) + Assertions.assertTrue( + com.google.common.io.Files.equal(secondWriteFilePath.toFile(), writeFilePath.toFile()) + ) } @Test @@ -70,6 +72,8 @@ internal class AirbyteFileOffsetBackingStoreTest { // verify that, after a round trip through the offset store, we get back the same data. Assertions.assertEquals(stateFromOffsetStore2, stateFromOffsetStore3) // verify that the file written by the offset store is identical to the template file. - Assertions.assertTrue(com.google.common.io.Files.equal(secondWriteFilePath.toFile(), writeFilePath.toFile())) + Assertions.assertTrue( + com.google.common.io.Files.equal(secondWriteFilePath.toFile(), writeFilePath.toFile()) + ) } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/DebeziumRecordPublisherTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/DebeziumRecordPublisherTest.kt index 5c697e3b5f3b3..8a23f58e748bf 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/DebeziumRecordPublisherTest.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/DebeziumRecordPublisherTest.kt @@ -10,19 +10,28 @@ import io.airbyte.protocol.models.JsonSchemaType import io.airbyte.protocol.models.v0.CatalogHelpers import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog import io.airbyte.protocol.models.v0.SyncMode +import java.util.regex.Pattern import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.Test -import java.util.regex.Pattern internal class DebeziumRecordPublisherTest { @Test fun testTableIncludelistCreation() { - val catalog = ConfiguredAirbyteCatalog().withStreams(ImmutableList.of( - CatalogHelpers.createConfiguredAirbyteStream("id_and_name", "public").withSyncMode(SyncMode.INCREMENTAL), - CatalogHelpers.createConfiguredAirbyteStream("id_,something", "public").withSyncMode(SyncMode.INCREMENTAL), - CatalogHelpers.createConfiguredAirbyteStream("n\"aMéS", "public").withSyncMode(SyncMode.INCREMENTAL))) + val catalog = + ConfiguredAirbyteCatalog() + .withStreams( + ImmutableList.of( + CatalogHelpers.createConfiguredAirbyteStream("id_and_name", "public") + .withSyncMode(SyncMode.INCREMENTAL), + CatalogHelpers.createConfiguredAirbyteStream("id_,something", "public") + .withSyncMode(SyncMode.INCREMENTAL), + CatalogHelpers.createConfiguredAirbyteStream("n\"aMéS", "public") + .withSyncMode(SyncMode.INCREMENTAL) + ) + ) - val expectedWhitelist = "\\Qpublic.id_and_name\\E,\\Qpublic.id_\\,something\\E,\\Qpublic.n\"aMéS\\E" + val expectedWhitelist = + "\\Qpublic.id_and_name\\E,\\Qpublic.id_\\,something\\E,\\Qpublic.n\"aMéS\\E" val actualWhitelist = RelationalDbDebeziumPropertiesManager.getTableIncludelist(catalog) Assertions.assertEquals(expectedWhitelist, actualWhitelist) @@ -30,9 +39,16 @@ internal class DebeziumRecordPublisherTest { @Test fun testTableIncludelistFiltersFullRefresh() { - val catalog = ConfiguredAirbyteCatalog().withStreams(ImmutableList.of( - CatalogHelpers.createConfiguredAirbyteStream("id_and_name", "public").withSyncMode(SyncMode.INCREMENTAL), - CatalogHelpers.createConfiguredAirbyteStream("id_and_name2", "public").withSyncMode(SyncMode.FULL_REFRESH))) + val catalog = + ConfiguredAirbyteCatalog() + .withStreams( + ImmutableList.of( + CatalogHelpers.createConfiguredAirbyteStream("id_and_name", "public") + .withSyncMode(SyncMode.INCREMENTAL), + CatalogHelpers.createConfiguredAirbyteStream("id_and_name2", "public") + .withSyncMode(SyncMode.FULL_REFRESH) + ) + ) val expectedWhitelist = "\\Qpublic.id_and_name\\E" val actualWhitelist = RelationalDbDebeziumPropertiesManager.getTableIncludelist(catalog) @@ -42,16 +58,28 @@ internal class DebeziumRecordPublisherTest { @Test fun testColumnIncludelistFiltersFullRefresh() { - val catalog = ConfiguredAirbyteCatalog().withStreams(ImmutableList.of( - CatalogHelpers.createConfiguredAirbyteStream( - "id_and_name", - "public", - Field.of("fld1", JsonSchemaType.NUMBER), Field.of("fld2", JsonSchemaType.STRING)).withSyncMode(SyncMode.INCREMENTAL), - CatalogHelpers.createConfiguredAirbyteStream("id_,something", "public").withSyncMode(SyncMode.INCREMENTAL), - CatalogHelpers.createConfiguredAirbyteStream("id_and_name2", "public").withSyncMode(SyncMode.FULL_REFRESH), - CatalogHelpers.createConfiguredAirbyteStream("n\"aMéS", "public").withSyncMode(SyncMode.INCREMENTAL))) + val catalog = + ConfiguredAirbyteCatalog() + .withStreams( + ImmutableList.of( + CatalogHelpers.createConfiguredAirbyteStream( + "id_and_name", + "public", + Field.of("fld1", JsonSchemaType.NUMBER), + Field.of("fld2", JsonSchemaType.STRING) + ) + .withSyncMode(SyncMode.INCREMENTAL), + CatalogHelpers.createConfiguredAirbyteStream("id_,something", "public") + .withSyncMode(SyncMode.INCREMENTAL), + CatalogHelpers.createConfiguredAirbyteStream("id_and_name2", "public") + .withSyncMode(SyncMode.FULL_REFRESH), + CatalogHelpers.createConfiguredAirbyteStream("n\"aMéS", "public") + .withSyncMode(SyncMode.INCREMENTAL) + ) + ) - val expectedWhitelist = "\\Qpublic.id_and_name\\E\\.(\\Qfld2\\E|\\Qfld1\\E),\\Qpublic.id_\\,something\\E,\\Qpublic.n\"aMéS\\E" + val expectedWhitelist = + "\\Qpublic.id_and_name\\E\\.(\\Qfld2\\E|\\Qfld1\\E),\\Qpublic.id_\\,something\\E,\\Qpublic.n\"aMéS\\E" val actualWhitelist = RelationalDbDebeziumPropertiesManager.getColumnIncludeList(catalog) Assertions.assertEquals(expectedWhitelist, actualWhitelist) @@ -65,13 +93,22 @@ internal class DebeziumRecordPublisherTest { // assertTrue(p.matcher(b).find()); // assertTrue(Pattern.compile(Pattern.quote(b)).matcher(b).find()); - val catalog = ConfiguredAirbyteCatalog().withStreams(ImmutableList.of( - CatalogHelpers.createConfiguredAirbyteStream( - "id_and_name", - "public", - Field.of("fld1", JsonSchemaType.NUMBER), Field.of("fld2", JsonSchemaType.STRING)).withSyncMode(SyncMode.INCREMENTAL))) + val catalog = + ConfiguredAirbyteCatalog() + .withStreams( + ImmutableList.of( + CatalogHelpers.createConfiguredAirbyteStream( + "id_and_name", + "public", + Field.of("fld1", JsonSchemaType.NUMBER), + Field.of("fld2", JsonSchemaType.STRING) + ) + .withSyncMode(SyncMode.INCREMENTAL) + ) + ) - val anchored = "^" + RelationalDbDebeziumPropertiesManager.getColumnIncludeList(catalog) + "$" + val anchored = + "^" + RelationalDbDebeziumPropertiesManager.getColumnIncludeList(catalog) + "$" val pattern = Pattern.compile(anchored) Assertions.assertTrue(pattern.matcher("public.id_and_name.fld1").find()) diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/AirbyteSchemaHistoryStorageTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/AirbyteSchemaHistoryStorageTest.kt index e9d1ac18cb194..217f4d0dffcab 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/AirbyteSchemaHistoryStorageTest.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/AirbyteSchemaHistoryStorageTest.kt @@ -5,45 +5,69 @@ package io.airbyte.cdk.integrations.debezium.internals import io.airbyte.commons.json.Jsons import io.airbyte.commons.resources.MoreResources -import org.junit.jupiter.api.Assertions -import org.junit.jupiter.api.Test import java.io.IOException import java.util.* +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test class AirbyteSchemaHistoryStorageTest { @Test @Throws(IOException::class) fun testForContentBiggerThan1MBLimit() { - val contentReadDirectlyFromFile = MoreResources.readResource("dbhistory_greater_than_1_mb.dat") + val contentReadDirectlyFromFile = + MoreResources.readResource("dbhistory_greater_than_1_mb.dat") - val schemaHistoryStorageFromUncompressedContent = AirbyteSchemaHistoryStorage.initializeDBHistory( - AirbyteSchemaHistoryStorage.SchemaHistory(Optional.of(Jsons.jsonNode(contentReadDirectlyFromFile)), - false), - true) - val schemaHistoryFromUncompressedContent = schemaHistoryStorageFromUncompressedContent.read() + val schemaHistoryStorageFromUncompressedContent = + AirbyteSchemaHistoryStorage.initializeDBHistory( + AirbyteSchemaHistoryStorage.SchemaHistory( + Optional.of(Jsons.jsonNode(contentReadDirectlyFromFile)), + false + ), + true + ) + val schemaHistoryFromUncompressedContent = + schemaHistoryStorageFromUncompressedContent.read() Assertions.assertTrue(schemaHistoryFromUncompressedContent.isCompressed) Assertions.assertNotNull(schemaHistoryFromUncompressedContent.schema) - Assertions.assertEquals(contentReadDirectlyFromFile, schemaHistoryStorageFromUncompressedContent.readUncompressed()) + Assertions.assertEquals( + contentReadDirectlyFromFile, + schemaHistoryStorageFromUncompressedContent.readUncompressed() + ) - val schemaHistoryStorageFromCompressedContent = AirbyteSchemaHistoryStorage.initializeDBHistory( - AirbyteSchemaHistoryStorage.SchemaHistory(Optional.of(Jsons.jsonNode(schemaHistoryFromUncompressedContent.schema)), - true), - true) + val schemaHistoryStorageFromCompressedContent = + AirbyteSchemaHistoryStorage.initializeDBHistory( + AirbyteSchemaHistoryStorage.SchemaHistory( + Optional.of(Jsons.jsonNode(schemaHistoryFromUncompressedContent.schema)), + true + ), + true + ) val schemaHistoryFromCompressedContent = schemaHistoryStorageFromCompressedContent.read() Assertions.assertTrue(schemaHistoryFromCompressedContent.isCompressed) Assertions.assertNotNull(schemaHistoryFromCompressedContent.schema) - Assertions.assertEquals(schemaHistoryFromUncompressedContent.schema, schemaHistoryFromCompressedContent.schema) + Assertions.assertEquals( + schemaHistoryFromUncompressedContent.schema, + schemaHistoryFromCompressedContent.schema + ) } @Test @Throws(IOException::class) fun sizeTest() { - Assertions.assertEquals(5.881045341491699, - AirbyteSchemaHistoryStorage.calculateSizeOfStringInMB(MoreResources.readResource("dbhistory_greater_than_1_mb.dat"))) - Assertions.assertEquals(0.0038671493530273438, - AirbyteSchemaHistoryStorage.calculateSizeOfStringInMB(MoreResources.readResource("dbhistory_less_than_1_mb.dat"))) + Assertions.assertEquals( + 5.881045341491699, + AirbyteSchemaHistoryStorage.calculateSizeOfStringInMB( + MoreResources.readResource("dbhistory_greater_than_1_mb.dat") + ) + ) + Assertions.assertEquals( + 0.0038671493530273438, + AirbyteSchemaHistoryStorage.calculateSizeOfStringInMB( + MoreResources.readResource("dbhistory_less_than_1_mb.dat") + ) + ) } @Test @@ -51,24 +75,39 @@ class AirbyteSchemaHistoryStorageTest { fun testForContentLessThan1MBLimit() { val contentReadDirectlyFromFile = MoreResources.readResource("dbhistory_less_than_1_mb.dat") - val schemaHistoryStorageFromUncompressedContent = AirbyteSchemaHistoryStorage.initializeDBHistory( - AirbyteSchemaHistoryStorage.SchemaHistory(Optional.of(Jsons.jsonNode(contentReadDirectlyFromFile)), - false), - true) - val schemaHistoryFromUncompressedContent = schemaHistoryStorageFromUncompressedContent.read() + val schemaHistoryStorageFromUncompressedContent = + AirbyteSchemaHistoryStorage.initializeDBHistory( + AirbyteSchemaHistoryStorage.SchemaHistory( + Optional.of(Jsons.jsonNode(contentReadDirectlyFromFile)), + false + ), + true + ) + val schemaHistoryFromUncompressedContent = + schemaHistoryStorageFromUncompressedContent.read() Assertions.assertFalse(schemaHistoryFromUncompressedContent.isCompressed) Assertions.assertNotNull(schemaHistoryFromUncompressedContent.schema) - Assertions.assertEquals(contentReadDirectlyFromFile, schemaHistoryFromUncompressedContent.schema) + Assertions.assertEquals( + contentReadDirectlyFromFile, + schemaHistoryFromUncompressedContent.schema + ) - val schemaHistoryStorageFromCompressedContent = AirbyteSchemaHistoryStorage.initializeDBHistory( - AirbyteSchemaHistoryStorage.SchemaHistory(Optional.of(Jsons.jsonNode(schemaHistoryFromUncompressedContent.schema)), - false), - true) + val schemaHistoryStorageFromCompressedContent = + AirbyteSchemaHistoryStorage.initializeDBHistory( + AirbyteSchemaHistoryStorage.SchemaHistory( + Optional.of(Jsons.jsonNode(schemaHistoryFromUncompressedContent.schema)), + false + ), + true + ) val schemaHistoryFromCompressedContent = schemaHistoryStorageFromCompressedContent.read() Assertions.assertFalse(schemaHistoryFromCompressedContent.isCompressed) Assertions.assertNotNull(schemaHistoryFromCompressedContent.schema) - Assertions.assertEquals(schemaHistoryFromUncompressedContent.schema, schemaHistoryFromCompressedContent.schema) + Assertions.assertEquals( + schemaHistoryFromUncompressedContent.schema, + schemaHistoryFromCompressedContent.schema + ) } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumConverterUtilsTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumConverterUtilsTest.kt index ead457f6a1fb0..0b288c96d8f5b 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumConverterUtilsTest.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumConverterUtilsTest.kt @@ -4,15 +4,15 @@ package io.airbyte.cdk.integrations.debezium.internals import io.debezium.spi.converter.RelationalColumn -import org.junit.jupiter.api.Assertions -import org.junit.jupiter.api.Disabled -import org.junit.jupiter.api.Test -import org.mockito.Mockito import java.sql.Timestamp import java.time.Duration import java.time.LocalDate import java.time.LocalDateTime import java.time.LocalTime +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Disabled +import org.junit.jupiter.api.Test +import org.mockito.Mockito internal class DebeziumConverterUtilsTest { @Test @@ -21,7 +21,10 @@ internal class DebeziumConverterUtilsTest { Mockito.`when`(relationalColumn.isOptional).thenReturn(true) var actualColumnDefaultValue = DebeziumConverterUtils.convertDefaultValue(relationalColumn) - Assertions.assertNull(actualColumnDefaultValue, "Default value for optional relational column should be null") + Assertions.assertNull( + actualColumnDefaultValue, + "Default value for optional relational column should be null" + ) Mockito.`when`(relationalColumn.isOptional).thenReturn(false) Mockito.`when`(relationalColumn.hasDefaultValue()).thenReturn(false) diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumRecordIteratorTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumRecordIteratorTest.kt index c946f582d13f4..7e55c7241b179 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumRecordIteratorTest.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumRecordIteratorTest.kt @@ -5,53 +5,67 @@ package io.airbyte.cdk.integrations.debezium.internals import io.airbyte.cdk.integrations.debezium.CdcTargetPosition import io.debezium.engine.ChangeEvent +import java.time.Duration +import java.util.* import org.apache.kafka.connect.source.SourceRecord import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.Test -import org.mockito.Mockito -import java.time.Duration -import java.util.* -import java.util.concurrent.LinkedBlockingQueue -import java.util.function.Supplier +import org.mockito.Mockito.mock class DebeziumRecordIteratorTest { - @get:Test - val heartbeatPositionTest: Unit - get() { - val debeziumRecordIterator = DebeziumRecordIterator(Mockito.mock(LinkedBlockingQueue::class.java), - object : CdcTargetPosition { - override fun reachedTargetPosition(changeEventWithMetadata: ChangeEventWithMetadata): Boolean { - return false - } - - override fun extractPositionFromHeartbeatOffset(sourceOffset: Map): Long { - return sourceOffset["lsn"] as Long - } - }, - Supplier { false }, - Mockito.mock(DebeziumShutdownProcedure::class.java), - Duration.ZERO, - Duration.ZERO) - val lsn = debeziumRecordIterator.getHeartbeatPosition(object : ChangeEvent { - private val sourceRecord = SourceRecord(null, Collections.singletonMap("lsn", 358824993496L), null, null, null) - - override fun key(): String? { - return null - } + @Test + fun getHeartbeatPositionTest() { + val debeziumRecordIterator = + DebeziumRecordIterator( + mock(), + object : CdcTargetPosition { + override fun reachedTargetPosition( + changeEventWithMetadata: ChangeEventWithMetadata? + ): Boolean { + return false + } - override fun value(): String { - return "{\"ts_ms\":1667616934701}" - } + override fun extractPositionFromHeartbeatOffset( + sourceOffset: Map? + ): Long { + return sourceOffset!!["lsn"] as Long + } + }, + { false }, + mock(), + Duration.ZERO, + Duration.ZERO + ) + val lsn = + debeziumRecordIterator.getHeartbeatPosition( + object : ChangeEvent { + private val sourceRecord = + SourceRecord( + null, + Collections.singletonMap("lsn", 358824993496L), + null, + null, + null + ) - override fun destination(): String { - return null - } + override fun key(): String? { + return null + } + + override fun value(): String { + return "{\"ts_ms\":1667616934701}" + } + + override fun destination(): String? { + return null + } - fun sourceRecord(): SourceRecord { - return sourceRecord + fun sourceRecord(): SourceRecord { + return sourceRecord + } } - }) + ) - Assertions.assertEquals(lsn, 358824993496L) - } + Assertions.assertEquals(lsn, 358824993496L) + } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumShutdownProcedureTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumShutdownProcedureTest.kt index 25d104467d82a..df7eb675bcc8a 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumShutdownProcedureTest.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/DebeziumShutdownProcedureTest.kt @@ -3,11 +3,11 @@ */ package io.airbyte.cdk.integrations.debezium.internals -import org.junit.jupiter.api.Assertions -import org.junit.jupiter.api.Test import java.util.concurrent.Executors import java.util.concurrent.LinkedBlockingQueue import java.util.concurrent.atomic.AtomicInteger +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test class DebeziumShutdownProcedureTest { @Test @@ -16,8 +16,12 @@ class DebeziumShutdownProcedureTest { val sourceQueue = LinkedBlockingQueue(10) val recordsInserted = AtomicInteger() val executorService = Executors.newSingleThreadExecutor() - val debeziumShutdownProcedure = DebeziumShutdownProcedure(sourceQueue, - { executorService.shutdown() }, { recordsInserted.get() >= 99 }) + val debeziumShutdownProcedure = + DebeziumShutdownProcedure( + sourceQueue, + { executorService.shutdown() }, + { recordsInserted.get() >= 99 } + ) executorService.execute { for (i in 0..99) { try { @@ -37,7 +41,10 @@ class DebeziumShutdownProcedureTest { Assertions.assertEquals(100, debeziumShutdownProcedure.recordsRemainingAfterShutdown.size) for (i in 0..99) { - Assertions.assertEquals(i, debeziumShutdownProcedure.recordsRemainingAfterShutdown.poll()) + Assertions.assertEquals( + i, + debeziumShutdownProcedure.recordsRemainingAfterShutdown.poll() + ) } } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/RecordWaitTimeUtilTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/RecordWaitTimeUtilTest.kt index 6213deca349c4..19aa9ece08af6 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/RecordWaitTimeUtilTest.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/debezium/internals/RecordWaitTimeUtilTest.kt @@ -4,38 +4,81 @@ package io.airbyte.cdk.integrations.debezium.internals import io.airbyte.commons.json.Jsons -import org.junit.jupiter.api.Assertions -import org.junit.jupiter.api.Test import java.time.Duration import java.util.* import java.util.Map +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test class RecordWaitTimeUtilTest { @Test fun testGetFirstRecordWaitTime() { val emptyConfig = Jsons.jsonNode(emptyMap()) Assertions.assertDoesNotThrow { RecordWaitTimeUtil.checkFirstRecordWaitTime(emptyConfig) } - Assertions.assertEquals(Optional.empty(), RecordWaitTimeUtil.getFirstRecordWaitSeconds(emptyConfig)) - Assertions.assertEquals(RecordWaitTimeUtil.DEFAULT_FIRST_RECORD_WAIT_TIME, RecordWaitTimeUtil.getFirstRecordWaitTime(emptyConfig)) + Assertions.assertEquals( + Optional.empty(), + RecordWaitTimeUtil.getFirstRecordWaitSeconds(emptyConfig) + ) + Assertions.assertEquals( + RecordWaitTimeUtil.DEFAULT_FIRST_RECORD_WAIT_TIME, + RecordWaitTimeUtil.getFirstRecordWaitTime(emptyConfig) + ) - val normalConfig = Jsons.jsonNode(Map.of("replication_method", - Map.of("method", "CDC", "initial_waiting_seconds", 500))) + val normalConfig = + Jsons.jsonNode( + Map.of( + "replication_method", + Map.of("method", "CDC", "initial_waiting_seconds", 500) + ) + ) Assertions.assertDoesNotThrow { RecordWaitTimeUtil.checkFirstRecordWaitTime(normalConfig) } - Assertions.assertEquals(Optional.of(500), RecordWaitTimeUtil.getFirstRecordWaitSeconds(normalConfig)) - Assertions.assertEquals(Duration.ofSeconds(500), RecordWaitTimeUtil.getFirstRecordWaitTime(normalConfig)) + Assertions.assertEquals( + Optional.of(500), + RecordWaitTimeUtil.getFirstRecordWaitSeconds(normalConfig) + ) + Assertions.assertEquals( + Duration.ofSeconds(500), + RecordWaitTimeUtil.getFirstRecordWaitTime(normalConfig) + ) val tooShortTimeout = RecordWaitTimeUtil.MIN_FIRST_RECORD_WAIT_TIME.seconds.toInt() - 1 - val tooShortConfig = Jsons.jsonNode(Map.of("replication_method", - Map.of("method", "CDC", "initial_waiting_seconds", tooShortTimeout))) - Assertions.assertThrows(IllegalArgumentException::class.java) { RecordWaitTimeUtil.checkFirstRecordWaitTime(tooShortConfig) } - Assertions.assertEquals(Optional.of(tooShortTimeout), RecordWaitTimeUtil.getFirstRecordWaitSeconds(tooShortConfig)) - Assertions.assertEquals(RecordWaitTimeUtil.MIN_FIRST_RECORD_WAIT_TIME, RecordWaitTimeUtil.getFirstRecordWaitTime(tooShortConfig)) + val tooShortConfig = + Jsons.jsonNode( + Map.of( + "replication_method", + Map.of("method", "CDC", "initial_waiting_seconds", tooShortTimeout) + ) + ) + Assertions.assertThrows(IllegalArgumentException::class.java) { + RecordWaitTimeUtil.checkFirstRecordWaitTime(tooShortConfig) + } + Assertions.assertEquals( + Optional.of(tooShortTimeout), + RecordWaitTimeUtil.getFirstRecordWaitSeconds(tooShortConfig) + ) + Assertions.assertEquals( + RecordWaitTimeUtil.MIN_FIRST_RECORD_WAIT_TIME, + RecordWaitTimeUtil.getFirstRecordWaitTime(tooShortConfig) + ) val tooLongTimeout = RecordWaitTimeUtil.MAX_FIRST_RECORD_WAIT_TIME.seconds.toInt() + 1 - val tooLongConfig = Jsons.jsonNode(Map.of("replication_method", - Map.of("method", "CDC", "initial_waiting_seconds", tooLongTimeout))) - Assertions.assertThrows(IllegalArgumentException::class.java) { RecordWaitTimeUtil.checkFirstRecordWaitTime(tooLongConfig) } - Assertions.assertEquals(Optional.of(tooLongTimeout), RecordWaitTimeUtil.getFirstRecordWaitSeconds(tooLongConfig)) - Assertions.assertEquals(RecordWaitTimeUtil.MAX_FIRST_RECORD_WAIT_TIME, RecordWaitTimeUtil.getFirstRecordWaitTime(tooLongConfig)) + val tooLongConfig = + Jsons.jsonNode( + Map.of( + "replication_method", + Map.of("method", "CDC", "initial_waiting_seconds", tooLongTimeout) + ) + ) + Assertions.assertThrows(IllegalArgumentException::class.java) { + RecordWaitTimeUtil.checkFirstRecordWaitTime(tooLongConfig) + } + Assertions.assertEquals( + Optional.of(tooLongTimeout), + RecordWaitTimeUtil.getFirstRecordWaitSeconds(tooLongConfig) + ) + Assertions.assertEquals( + RecordWaitTimeUtil.MAX_FIRST_RECORD_WAIT_TIME, + RecordWaitTimeUtil.getFirstRecordWaitTime(tooLongConfig) + ) } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/AbstractDbSourceForTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/AbstractDbSourceForTest.kt new file mode 100644 index 0000000000000..4a049fd570c33 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/AbstractDbSourceForTest.kt @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.integrations.source.jdbc + +import com.fasterxml.jackson.databind.JsonNode +import io.airbyte.cdk.db.AbstractDatabase +import io.airbyte.cdk.integrations.source.relationaldb.AbstractDbSource +import io.airbyte.protocol.models.v0.AirbyteStateMessage + +abstract class AbstractDbSourceForTest( + driverClassName: String +) : AbstractDbSource(driverClassName) { + public override fun getSupportedStateType( + config: JsonNode? + ): AirbyteStateMessage.AirbyteStateType { + return super.getSupportedStateType(config) + } +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/DefaultJdbcSourceAcceptanceTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/DefaultJdbcSourceAcceptanceTest.kt index fb08fbe5caeb6..dd104bc6c1107 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/DefaultJdbcSourceAcceptanceTest.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/DefaultJdbcSourceAcceptanceTest.kt @@ -19,6 +19,11 @@ import io.airbyte.cdk.integrations.util.HostPortResolver.resolvePort import io.airbyte.cdk.testutils.TestDatabase import io.airbyte.commons.json.Jsons import io.airbyte.protocol.models.v0.AirbyteStateMessage +import java.sql.JDBCType +import java.util.List +import java.util.Map +import java.util.function.Supplier +import java.util.stream.Stream import org.jooq.SQLDialect import org.junit.jupiter.api.AfterAll import org.junit.jupiter.api.Assertions @@ -27,20 +32,15 @@ import org.junit.jupiter.api.Test import org.slf4j.Logger import org.slf4j.LoggerFactory import org.testcontainers.containers.PostgreSQLContainer -import java.sql.JDBCType -import java.util.List -import java.util.Map -import java.util.function.Supplier -import java.util.stream.Stream /** * Runs the acceptance tests in the source-jdbc test module. We want this module to run these tests * itself as a sanity check. The trade off here is that this class is duplicated from the one used * in source-postgres. */ -internal class DefaultJdbcSourceAcceptanceTest - - : JdbcSourceAcceptanceTest() { +internal class DefaultJdbcSourceAcceptanceTest : + JdbcSourceAcceptanceTest< + DefaultJdbcSourceAcceptanceTest.PostgresTestSource, BareBonesTestDatabase>() { override fun config(): JsonNode { return testdb!!.testConfigBuilder()!!.build() } @@ -57,8 +57,13 @@ internal class DefaultJdbcSourceAcceptanceTest return true } - fun getConfigWithConnectionProperties(psqlDb: PostgreSQLContainer<*>?, dbName: String?, additionalParameters: String?): JsonNode { - return Jsons.jsonNode(ImmutableMap.builder() + fun getConfigWithConnectionProperties( + psqlDb: PostgreSQLContainer<*>?, + dbName: String?, + additionalParameters: String? + ): JsonNode { + return Jsons.jsonNode( + ImmutableMap.builder() .put(JdbcUtils.HOST_KEY, resolveHost(psqlDb)) .put(JdbcUtils.PORT_KEY, resolvePort(psqlDb)) .put(JdbcUtils.DATABASE_KEY, dbName) @@ -66,17 +71,30 @@ internal class DefaultJdbcSourceAcceptanceTest .put(JdbcUtils.USERNAME_KEY, psqlDb!!.username) .put(JdbcUtils.PASSWORD_KEY, psqlDb.password) .put(JdbcUtils.CONNECTION_PROPERTIES_KEY, additionalParameters) - .build()) + .build() + ) } - class PostgresTestSource : AbstractJdbcSource(DRIVER_CLASS, Supplier { AdaptiveStreamingQueryConfig() }, JdbcUtils.defaultSourceOperations), Source { + class PostgresTestSource : + AbstractJdbcSource( + DRIVER_CLASS, + Supplier { AdaptiveStreamingQueryConfig() }, + JdbcUtils.defaultSourceOperations + ), + Source { override fun toDatabaseConfig(config: JsonNode): JsonNode { - val configBuilder = ImmutableMap.builder() + val configBuilder = + ImmutableMap.builder() .put(JdbcUtils.USERNAME_KEY, config[JdbcUtils.USERNAME_KEY].asText()) - .put(JdbcUtils.JDBC_URL_KEY, String.format(DatabaseDriver.POSTGRESQL.urlFormatString, + .put( + JdbcUtils.JDBC_URL_KEY, + String.format( + DatabaseDriver.POSTGRESQL.urlFormatString, config[JdbcUtils.HOST_KEY].asText(), config[JdbcUtils.PORT_KEY].asInt(), - config[JdbcUtils.DATABASE_KEY].asText())) + config[JdbcUtils.DATABASE_KEY].asText() + ) + ) if (config.has(JdbcUtils.PASSWORD_KEY)) { configBuilder.put(JdbcUtils.PASSWORD_KEY, config[JdbcUtils.PASSWORD_KEY].asText()) @@ -85,11 +103,12 @@ internal class DefaultJdbcSourceAcceptanceTest return Jsons.jsonNode(configBuilder.build()) } - public override fun getExcludedInternalNameSpaces(): Set { - return setOf("information_schema", "pg_catalog", "pg_internal", "catalog_history") - } + override val excludedInternalNameSpaces = + setOf("information_schema", "pg_catalog", "pg_internal", "catalog_history") - override fun getSupportedStateType(config: JsonNode): AirbyteStateMessage.AirbyteStateType { + override fun getSupportedStateType( + config: JsonNode? + ): AirbyteStateMessage.AirbyteStateType { return AirbyteStateMessage.AirbyteStateType.STREAM } @@ -109,21 +128,37 @@ internal class DefaultJdbcSourceAcceptanceTest } } - class BareBonesTestDatabase - (container: PostgreSQLContainer<*>?) : TestDatabase?, BareBonesTestDatabase?, BareBonesConfigBuilder?>(container) { + class BareBonesTestDatabase(container: PostgreSQLContainer<*>) : + TestDatabase, BareBonesTestDatabase, BareBonesConfigBuilder>( + container + ) { override fun inContainerBootstrapCmd(): Stream?>? { - val sql = Stream.of( + val sql = + Stream.of( String.format("CREATE DATABASE %s", databaseName), String.format("CREATE USER %s PASSWORD '%s'", userName, password), - String.format("GRANT ALL PRIVILEGES ON DATABASE %s TO %s", databaseName, userName), - String.format("ALTER USER %s WITH SUPERUSER", userName)) - return Stream.of(Stream.concat( - Stream.of("psql", - "-d", container!!.databaseName, - "-U", container.username, - "-v", "ON_ERROR_STOP=1", - "-a"), - sql.flatMap { stmt: String? -> Stream.of("-c", stmt) })) + String.format( + "GRANT ALL PRIVILEGES ON DATABASE %s TO %s", + databaseName, + userName + ), + String.format("ALTER USER %s WITH SUPERUSER", userName) + ) + return Stream.of( + Stream.concat( + Stream.of( + "psql", + "-d", + container!!.databaseName, + "-U", + container.username, + "-v", + "ON_ERROR_STOP=1", + "-a" + ), + sql.flatMap { stmt: String? -> Stream.of("-c", stmt) } + ) + ) } override fun inContainerUndoBootstrapCmd(): Stream? { @@ -140,35 +175,44 @@ internal class DefaultJdbcSourceAcceptanceTest return BareBonesConfigBuilder(this) } - class BareBonesConfigBuilder(testDatabase: BareBonesTestDatabase) : ConfigBuilder(testDatabase) + class BareBonesConfigBuilder(testDatabase: BareBonesTestDatabase) : + ConfigBuilder(testDatabase) } @Test fun testCustomParametersOverwriteDefaultParametersExpectException() { val connectionPropertiesUrl = "ssl=false" - val config = getConfigWithConnectionProperties(PSQL_CONTAINER, testdb!!.databaseName, connectionPropertiesUrl) + val config = + getConfigWithConnectionProperties( + PSQL_CONTAINER, + testdb!!.databaseName, + connectionPropertiesUrl + ) val customParameters = parseJdbcParameters(config, JdbcUtils.CONNECTION_PROPERTIES_KEY, "&") - val defaultParameters = Map.of( - "ssl", "true", - "sslmode", "require") + val defaultParameters = Map.of("ssl", "true", "sslmode", "require") Assertions.assertThrows(IllegalArgumentException::class.java) { - JdbcDataSourceUtils.assertCustomParametersDontOverwriteDefaultParameters(customParameters, defaultParameters) + JdbcDataSourceUtils.assertCustomParametersDontOverwriteDefaultParameters( + customParameters, + defaultParameters + ) } } companion object { - private var PSQL_CONTAINER: PostgreSQLContainer<*>? = null + private lateinit var PSQL_CONTAINER: PostgreSQLContainer<*> + @JvmStatic @BeforeAll - fun init() { - PSQL_CONTAINER = PostgreSQLContainer("postgres:13-alpine") + fun init(): Unit { + PSQL_CONTAINER = PostgreSQLContainer("postgres:13-alpine") PSQL_CONTAINER!!.start() CREATE_TABLE_WITHOUT_CURSOR_TYPE_QUERY = "CREATE TABLE %s (%s BIT(3) NOT NULL);" INSERT_TABLE_WITHOUT_CURSOR_TYPE_QUERY = "INSERT INTO %s VALUES(B'101');" } + @JvmStatic @AfterAll - fun cleanUp() { + fun cleanUp(): Unit { PSQL_CONTAINER!!.close() } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/DefaultJdbcStressTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/DefaultJdbcStressTest.kt index bbedc8d651c4c..4f90caf8489d1 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/DefaultJdbcStressTest.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/DefaultJdbcStressTest.kt @@ -15,6 +15,9 @@ import io.airbyte.cdk.testutils.PostgreSQLContainerHelper.runSqlScript import io.airbyte.commons.io.IOs import io.airbyte.commons.json.Jsons import io.airbyte.commons.string.Strings +import java.sql.JDBCType +import java.util.* +import java.util.function.Supplier import org.junit.jupiter.api.AfterAll import org.junit.jupiter.api.BeforeAll import org.junit.jupiter.api.BeforeEach @@ -23,9 +26,6 @@ import org.slf4j.Logger import org.slf4j.LoggerFactory import org.testcontainers.containers.PostgreSQLContainer import org.testcontainers.utility.MountableFile -import java.sql.JDBCType -import java.util.* -import java.util.function.Supplier /** * Runs the stress tests in the source-jdbc test module. We want this module to run these tests @@ -41,19 +41,32 @@ internal class DefaultJdbcStressTest : JdbcStressTest() { override fun setup() { val dbName = Strings.addRandomSuffix("db", "_", 10) - config = Jsons.jsonNode(ImmutableMap.of(JdbcUtils.HOST_KEY, "localhost", - JdbcUtils.PORT_KEY, 5432, - JdbcUtils.DATABASE_KEY, "charles", - JdbcUtils.USERNAME_KEY, "postgres", - JdbcUtils.PASSWORD_KEY, "")) - - config = Jsons.jsonNode(ImmutableMap.builder() - .put(JdbcUtils.HOST_KEY, PSQL_DB!!.host) - .put(JdbcUtils.PORT_KEY, PSQL_DB!!.firstMappedPort) - .put(JdbcUtils.DATABASE_KEY, dbName) - .put(JdbcUtils.USERNAME_KEY, PSQL_DB!!.username) - .put(JdbcUtils.PASSWORD_KEY, PSQL_DB!!.password) - .build()) + config = + Jsons.jsonNode( + ImmutableMap.of( + JdbcUtils.HOST_KEY, + "localhost", + JdbcUtils.PORT_KEY, + 5432, + JdbcUtils.DATABASE_KEY, + "charles", + JdbcUtils.USERNAME_KEY, + "postgres", + JdbcUtils.PASSWORD_KEY, + "" + ) + ) + + config = + Jsons.jsonNode( + ImmutableMap.builder() + .put(JdbcUtils.HOST_KEY, PSQL_DB!!.host) + .put(JdbcUtils.PORT_KEY, PSQL_DB!!.firstMappedPort) + .put(JdbcUtils.DATABASE_KEY, dbName) + .put(JdbcUtils.USERNAME_KEY, PSQL_DB!!.username) + .put(JdbcUtils.PASSWORD_KEY, PSQL_DB!!.password) + .build() + ) val initScriptName = "init_$dbName.sql" val tmpFilePath = IOs.writeFileToRandomTmpDir(initScriptName, "CREATE DATABASE $dbName;") @@ -62,9 +75,7 @@ internal class DefaultJdbcStressTest : JdbcStressTest() { super.setup() } - override fun getDefaultSchemaName(): Optional { - return Optional.of("public") - } + override val defaultSchemaName = Optional.of("public") override fun getSource(): AbstractJdbcSource { return PostgresTestSource() @@ -74,18 +85,28 @@ internal class DefaultJdbcStressTest : JdbcStressTest() { return config!! } - override fun getDriverClass(): String { - return PostgresTestSource.DRIVER_CLASS - } + override val driverClass = PostgresTestSource.DRIVER_CLASS - private class PostgresTestSource : AbstractJdbcSource(DRIVER_CLASS, Supplier { AdaptiveStreamingQueryConfig() }, JdbcUtils.defaultSourceOperations), Source { + private class PostgresTestSource : + AbstractJdbcSource( + DRIVER_CLASS, + Supplier { AdaptiveStreamingQueryConfig() }, + JdbcUtils.defaultSourceOperations + ), + Source { override fun toDatabaseConfig(config: JsonNode): JsonNode { - val configBuilder = ImmutableMap.builder() + val configBuilder = + ImmutableMap.builder() .put(JdbcUtils.USERNAME_KEY, config[JdbcUtils.USERNAME_KEY].asText()) - .put(JdbcUtils.JDBC_URL_KEY, String.format(DatabaseDriver.POSTGRESQL.urlFormatString, + .put( + JdbcUtils.JDBC_URL_KEY, + String.format( + DatabaseDriver.POSTGRESQL.urlFormatString, config[JdbcUtils.HOST_KEY].asText(), config[JdbcUtils.PORT_KEY].asInt(), - config[JdbcUtils.DATABASE_KEY].asText())) + config[JdbcUtils.DATABASE_KEY].asText() + ) + ) if (config.has(JdbcUtils.PASSWORD_KEY)) { configBuilder.put(JdbcUtils.PASSWORD_KEY, config[JdbcUtils.PASSWORD_KEY].asText()) @@ -94,9 +115,8 @@ internal class DefaultJdbcStressTest : JdbcStressTest() { return Jsons.jsonNode(configBuilder.build()) } - public override fun getExcludedInternalNameSpaces(): Set { - return setOf("information_schema", "pg_catalog", "pg_internal", "catalog_history") - } + public override val excludedInternalNameSpaces = + setOf("information_schema", "pg_catalog", "pg_internal", "catalog_history") companion object { private val LOGGER: Logger = LoggerFactory.getLogger(PostgresTestSource::class.java) @@ -118,12 +138,14 @@ internal class DefaultJdbcStressTest : JdbcStressTest() { private var PSQL_DB: PostgreSQLContainer<*>? = null @BeforeAll + @JvmStatic fun init() { - PSQL_DB = PostgreSQLContainer("postgres:13-alpine") + PSQL_DB = PostgreSQLContainer("postgres:13-alpine") PSQL_DB!!.start() } @AfterAll + @JvmStatic fun cleanUp() { PSQL_DB!!.close() } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcDataSourceUtilsTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcDataSourceUtilsTest.kt index 4c142d1722aac..6a8dc1ab3d8b4 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcDataSourceUtilsTest.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcDataSourceUtilsTest.kt @@ -4,28 +4,34 @@ package io.airbyte.cdk.integrations.source.jdbc import io.airbyte.commons.json.Jsons +import java.util.function.Consumer import org.assertj.core.api.AssertionsForClassTypes import org.junit.Assert import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.Test -import java.util.function.Consumer class JdbcDataSourceUtilsTest { @Test fun test() { - val validConfigString = "{\"jdbc_url_params\":\"key1=val1&key3=key3\",\"connection_properties\":\"key1=val1&key2=val2\"}" + val validConfigString = + "{\"jdbc_url_params\":\"key1=val1&key3=key3\",\"connection_properties\":\"key1=val1&key2=val2\"}" val validConfig = Jsons.deserialize(validConfigString) val connectionProperties = JdbcDataSourceUtils.getConnectionProperties(validConfig) val validKeys = listOf("key1", "key2", "key3") - validKeys.forEach(Consumer { key: String -> Assert.assertTrue(connectionProperties.containsKey(key)) }) + validKeys.forEach( + Consumer { key: String -> Assert.assertTrue(connectionProperties.containsKey(key)) } + ) - // For an invalid config, there is a conflict betweeen the values of keys in jdbc_url_params and + // For an invalid config, there is a conflict betweeen the values of keys in jdbc_url_params + // and // connection_properties - val invalidConfigString = "{\"jdbc_url_params\":\"key1=val2&key3=key3\",\"connection_properties\":\"key1=val1&key2=val2\"}" + val invalidConfigString = + "{\"jdbc_url_params\":\"key1=val2&key3=key3\",\"connection_properties\":\"key1=val1&key2=val2\"}" val invalidConfig = Jsons.deserialize(invalidConfigString) - val exception: Exception = Assertions.assertThrows(IllegalArgumentException::class.java) { - JdbcDataSourceUtils.getConnectionProperties(invalidConfig) - } + val exception: Exception = + Assertions.assertThrows(IllegalArgumentException::class.java) { + JdbcDataSourceUtils.getConnectionProperties(invalidConfig) + } val expectedMessage = "Cannot overwrite default JDBC parameter key1" AssertionsForClassTypes.assertThat(expectedMessage == exception.message) diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcSourceStressTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcSourceStressTest.kt index 0ccf8ff0d84d5..5d6041c459913 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcSourceStressTest.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/jdbc/JdbcSourceStressTest.kt @@ -15,6 +15,9 @@ import io.airbyte.cdk.testutils.PostgreSQLContainerHelper.runSqlScript import io.airbyte.commons.io.IOs import io.airbyte.commons.json.Jsons import io.airbyte.commons.string.Strings +import java.sql.JDBCType +import java.util.* +import java.util.function.Supplier import org.junit.jupiter.api.AfterAll import org.junit.jupiter.api.BeforeAll import org.junit.jupiter.api.BeforeEach @@ -23,9 +26,6 @@ import org.slf4j.Logger import org.slf4j.LoggerFactory import org.testcontainers.containers.PostgreSQLContainer import org.testcontainers.utility.MountableFile -import java.sql.JDBCType -import java.util.* -import java.util.function.Supplier /** * Runs the stress tests in the source-jdbc test module. We want this module to run these tests @@ -41,26 +41,26 @@ internal class JdbcSourceStressTest : JdbcStressTest() { override fun setup() { val schemaName = Strings.addRandomSuffix("db", "_", 10) - - - config = Jsons.jsonNode(ImmutableMap.builder() - .put(JdbcUtils.HOST_KEY, PSQL_DB!!.host) - .put(JdbcUtils.PORT_KEY, PSQL_DB!!.firstMappedPort) - .put(JdbcUtils.DATABASE_KEY, schemaName) - .put(JdbcUtils.USERNAME_KEY, PSQL_DB!!.username) - .put(JdbcUtils.PASSWORD_KEY, PSQL_DB!!.password) - .build()) + config = + Jsons.jsonNode( + ImmutableMap.builder() + .put(JdbcUtils.HOST_KEY, PSQL_DB!!.host) + .put(JdbcUtils.PORT_KEY, PSQL_DB!!.firstMappedPort) + .put(JdbcUtils.DATABASE_KEY, schemaName) + .put(JdbcUtils.USERNAME_KEY, PSQL_DB!!.username) + .put(JdbcUtils.PASSWORD_KEY, PSQL_DB!!.password) + .build() + ) val initScriptName = "init_$schemaName.sql" - val tmpFilePath = IOs.writeFileToRandomTmpDir(initScriptName, "CREATE DATABASE $schemaName;") + val tmpFilePath = + IOs.writeFileToRandomTmpDir(initScriptName, "CREATE DATABASE $schemaName;") runSqlScript(MountableFile.forHostPath(tmpFilePath), PSQL_DB!!) super.setup() } - override fun getDefaultSchemaName(): Optional { - return Optional.of("public") - } + override val defaultSchemaName = Optional.of("public") override fun getSource(): AbstractJdbcSource { return PostgresTestSource() @@ -70,18 +70,28 @@ internal class JdbcSourceStressTest : JdbcStressTest() { return config!! } - override fun getDriverClass(): String { - return PostgresTestSource.DRIVER_CLASS - } + override val driverClass = PostgresTestSource.DRIVER_CLASS - private class PostgresTestSource : AbstractJdbcSource(DRIVER_CLASS, Supplier { AdaptiveStreamingQueryConfig() }, JdbcUtils.defaultSourceOperations), Source { + private class PostgresTestSource : + AbstractJdbcSource( + DRIVER_CLASS, + Supplier { AdaptiveStreamingQueryConfig() }, + JdbcUtils.defaultSourceOperations + ), + Source { override fun toDatabaseConfig(config: JsonNode): JsonNode { - val configBuilder = ImmutableMap.builder() + val configBuilder = + ImmutableMap.builder() .put(JdbcUtils.USERNAME_KEY, config[JdbcUtils.USERNAME_KEY].asText()) - .put(JdbcUtils.JDBC_URL_KEY, String.format(DatabaseDriver.POSTGRESQL.urlFormatString, + .put( + JdbcUtils.JDBC_URL_KEY, + String.format( + DatabaseDriver.POSTGRESQL.urlFormatString, config[JdbcUtils.HOST_KEY].asText(), config[JdbcUtils.PORT_KEY].asInt(), - config[JdbcUtils.DATABASE_KEY].asText())) + config[JdbcUtils.DATABASE_KEY].asText() + ) + ) if (config.has(JdbcUtils.PASSWORD_KEY)) { configBuilder.put(JdbcUtils.PASSWORD_KEY, config[JdbcUtils.PASSWORD_KEY].asText()) @@ -90,9 +100,8 @@ internal class JdbcSourceStressTest : JdbcStressTest() { return Jsons.jsonNode(configBuilder.build()) } - public override fun getExcludedInternalNameSpaces(): Set { - return setOf("information_schema", "pg_catalog", "pg_internal", "catalog_history") - } + override val excludedInternalNameSpaces = + setOf("information_schema", "pg_catalog", "pg_internal", "catalog_history") companion object { private val LOGGER: Logger = LoggerFactory.getLogger(PostgresTestSource::class.java) @@ -114,12 +123,14 @@ internal class JdbcSourceStressTest : JdbcStressTest() { private var PSQL_DB: PostgreSQLContainer<*>? = null @BeforeAll + @JvmStatic fun init() { - PSQL_DB = PostgreSQLContainer("postgres:13-alpine") + PSQL_DB = PostgreSQLContainer("postgres:13-alpine") PSQL_DB!!.start() } @AfterAll + @JvmStatic fun cleanUp() { PSQL_DB!!.close() } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/AbstractDbSourceTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/AbstractDbSourceTest.kt index 4e66391a6b685..a292255725f07 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/AbstractDbSourceTest.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/AbstractDbSourceTest.kt @@ -4,10 +4,12 @@ package io.airbyte.cdk.integrations.source.relationaldb import com.fasterxml.jackson.databind.JsonNode +import io.airbyte.cdk.integrations.source.jdbc.AbstractDbSourceForTest import io.airbyte.cdk.integrations.source.relationaldb.state.* import io.airbyte.commons.json.Jsons import io.airbyte.commons.resources.MoreResources import io.airbyte.protocol.models.v0.AirbyteStateMessage +import java.io.IOException import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.Test import org.junit.jupiter.api.extension.ExtendWith @@ -15,27 +17,30 @@ import org.mockito.Mockito import uk.org.webcompere.systemstubs.environment.EnvironmentVariables import uk.org.webcompere.systemstubs.jupiter.SystemStub import uk.org.webcompere.systemstubs.jupiter.SystemStubsExtension -import java.io.IOException -/** - * Test suite for the [AbstractDbSource] class. - */ +/** Test suite for the [AbstractDbSource] class. */ @ExtendWith(SystemStubsExtension::class) class AbstractDbSourceTest { - @SystemStub - private val environmentVariables: EnvironmentVariables? = null + @SystemStub private val environmentVariables: EnvironmentVariables? = null @Test @Throws(IOException::class) fun testDeserializationOfLegacyState() { - val dbSource = Mockito.mock(AbstractDbSource::class.java, Mockito.withSettings().useConstructor("").defaultAnswer(Mockito.CALLS_REAL_METHODS)) + val dbSource = + Mockito.mock( + AbstractDbSourceForTest::class.java, + Mockito.withSettings().useConstructor("").defaultAnswer(Mockito.CALLS_REAL_METHODS) + ) val config = Mockito.mock(JsonNode::class.java) val legacyStateJson = MoreResources.readResource("states/legacy.json") val legacyState = Jsons.deserialize(legacyStateJson) - val result = StateGeneratorUtils.deserializeInitialState(legacyState, - dbSource.getSupportedStateType(config)) + val result = + StateGeneratorUtils.deserializeInitialState( + legacyState, + dbSource.getSupportedStateType(config) + ) Assertions.assertEquals(1, result.size) Assertions.assertEquals(AirbyteStateMessage.AirbyteStateType.LEGACY, result[0].type) } @@ -43,14 +48,21 @@ class AbstractDbSourceTest { @Test @Throws(IOException::class) fun testDeserializationOfGlobalState() { - val dbSource = Mockito.mock(AbstractDbSource::class.java, Mockito.withSettings().useConstructor("").defaultAnswer(Mockito.CALLS_REAL_METHODS)) + val dbSource = + Mockito.mock( + AbstractDbSourceForTest::class.java, + Mockito.withSettings().useConstructor("").defaultAnswer(Mockito.CALLS_REAL_METHODS) + ) val config = Mockito.mock(JsonNode::class.java) val globalStateJson = MoreResources.readResource("states/global.json") val globalState = Jsons.deserialize(globalStateJson) val result = - StateGeneratorUtils.deserializeInitialState(globalState, dbSource.getSupportedStateType(config)) + StateGeneratorUtils.deserializeInitialState( + globalState, + dbSource.getSupportedStateType(config) + ) Assertions.assertEquals(1, result.size) Assertions.assertEquals(AirbyteStateMessage.AirbyteStateType.GLOBAL, result[0].type) } @@ -58,14 +70,21 @@ class AbstractDbSourceTest { @Test @Throws(IOException::class) fun testDeserializationOfStreamState() { - val dbSource = Mockito.mock(AbstractDbSource::class.java, Mockito.withSettings().useConstructor("").defaultAnswer(Mockito.CALLS_REAL_METHODS)) + val dbSource = + Mockito.mock( + AbstractDbSourceForTest::class.java, + Mockito.withSettings().useConstructor("").defaultAnswer(Mockito.CALLS_REAL_METHODS) + ) val config = Mockito.mock(JsonNode::class.java) val streamStateJson = MoreResources.readResource("states/per_stream.json") val streamState = Jsons.deserialize(streamStateJson) val result = - StateGeneratorUtils.deserializeInitialState(streamState, dbSource.getSupportedStateType(config)) + StateGeneratorUtils.deserializeInitialState( + streamState, + dbSource.getSupportedStateType(config) + ) Assertions.assertEquals(2, result.size) Assertions.assertEquals(AirbyteStateMessage.AirbyteStateType.STREAM, result[0].type) } @@ -73,10 +92,18 @@ class AbstractDbSourceTest { @Test @Throws(IOException::class) fun testDeserializationOfNullState() { - val dbSource = Mockito.mock(AbstractDbSource::class.java, Mockito.withSettings().useConstructor("").defaultAnswer(Mockito.CALLS_REAL_METHODS)) + val dbSource = + Mockito.mock( + AbstractDbSourceForTest::class.java, + Mockito.withSettings().useConstructor("").defaultAnswer(Mockito.CALLS_REAL_METHODS) + ) val config = Mockito.mock(JsonNode::class.java) - val result = StateGeneratorUtils.deserializeInitialState(null, dbSource.getSupportedStateType(config)) + val result = + StateGeneratorUtils.deserializeInitialState( + null, + dbSource.getSupportedStateType(config) + ) Assertions.assertEquals(1, result.size) Assertions.assertEquals(dbSource.getSupportedStateType(config), result[0].type) } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/CursorManagerTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/CursorManagerTest.kt index f1a8e293eff35..c3905e5043ea1 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/CursorManagerTest.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/CursorManagerTest.kt @@ -6,137 +6,259 @@ package io.airbyte.cdk.integrations.source.relationaldb.state import io.airbyte.cdk.integrations.source.relationaldb.CursorInfo import io.airbyte.cdk.integrations.source.relationaldb.models.DbStreamState import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair -import org.junit.jupiter.api.Assertions -import org.junit.jupiter.api.Test import java.util.* import java.util.function.Function +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test -/** - * Test suite for the [CursorManager] class. - */ +/** Test suite for the [CursorManager] class. */ class CursorManagerTest { @Test fun testCreateCursorInfoCatalogAndStateSameCursorField() { - val cursorManager = createCursorManager(StateTestConstants.CURSOR_FIELD1, StateTestConstants.CURSOR, StateTestConstants.NAME_NAMESPACE_PAIR1) - val actual = cursorManager.createCursorInfoForStream( + val cursorManager = + createCursorManager( + StateTestConstants.CURSOR_FIELD1, + StateTestConstants.CURSOR, + StateTestConstants.NAME_NAMESPACE_PAIR1 + ) + val actual = + cursorManager.createCursorInfoForStream( StateTestConstants.NAME_NAMESPACE_PAIR1, - StateTestConstants.getState(StateTestConstants.CURSOR_FIELD1, StateTestConstants.CURSOR, StateTestConstants.CURSOR_RECORD_COUNT), + StateTestConstants.getState( + StateTestConstants.CURSOR_FIELD1, + StateTestConstants.CURSOR, + StateTestConstants.CURSOR_RECORD_COUNT + ), StateTestConstants.getStream(StateTestConstants.CURSOR_FIELD1), { obj: DbStreamState? -> obj!!.cursor }, { obj: DbStreamState? -> obj!!.cursorField }, - CURSOR_RECORD_COUNT_FUNCTION) - Assertions.assertEquals(CursorInfo(StateTestConstants.CURSOR_FIELD1, StateTestConstants.CURSOR, StateTestConstants.CURSOR_RECORD_COUNT, StateTestConstants.CURSOR_FIELD1, StateTestConstants.CURSOR, StateTestConstants.CURSOR_RECORD_COUNT), actual) + CURSOR_RECORD_COUNT_FUNCTION + ) + Assertions.assertEquals( + CursorInfo( + StateTestConstants.CURSOR_FIELD1, + StateTestConstants.CURSOR, + StateTestConstants.CURSOR_RECORD_COUNT, + StateTestConstants.CURSOR_FIELD1, + StateTestConstants.CURSOR, + StateTestConstants.CURSOR_RECORD_COUNT + ), + actual + ) } @Test fun testCreateCursorInfoCatalogAndStateSameCursorFieldButNoCursor() { - val cursorManager = createCursorManager(StateTestConstants.CURSOR_FIELD1, null, StateTestConstants.NAME_NAMESPACE_PAIR1) - val actual = cursorManager.createCursorInfoForStream( + val cursorManager = + createCursorManager( + StateTestConstants.CURSOR_FIELD1, + null, + StateTestConstants.NAME_NAMESPACE_PAIR1 + ) + val actual = + cursorManager.createCursorInfoForStream( StateTestConstants.NAME_NAMESPACE_PAIR1, StateTestConstants.getState(StateTestConstants.CURSOR_FIELD1, null), StateTestConstants.getStream(StateTestConstants.CURSOR_FIELD1), { obj: DbStreamState? -> obj!!.cursor }, { obj: DbStreamState? -> obj!!.cursorField }, - CURSOR_RECORD_COUNT_FUNCTION) - Assertions.assertEquals(CursorInfo(StateTestConstants.CURSOR_FIELD1, null, StateTestConstants.CURSOR_FIELD1, null), actual) + CURSOR_RECORD_COUNT_FUNCTION + ) + Assertions.assertEquals( + CursorInfo( + StateTestConstants.CURSOR_FIELD1, + null, + StateTestConstants.CURSOR_FIELD1, + null + ), + actual + ) } @Test fun testCreateCursorInfoCatalogAndStateChangeInCursorFieldName() { - val cursorManager = createCursorManager(StateTestConstants.CURSOR_FIELD1, StateTestConstants.CURSOR, StateTestConstants.NAME_NAMESPACE_PAIR1) - val actual = cursorManager.createCursorInfoForStream( + val cursorManager = + createCursorManager( + StateTestConstants.CURSOR_FIELD1, + StateTestConstants.CURSOR, + StateTestConstants.NAME_NAMESPACE_PAIR1 + ) + val actual = + cursorManager.createCursorInfoForStream( StateTestConstants.NAME_NAMESPACE_PAIR1, - StateTestConstants.getState(StateTestConstants.CURSOR_FIELD1, StateTestConstants.CURSOR), + StateTestConstants.getState( + StateTestConstants.CURSOR_FIELD1, + StateTestConstants.CURSOR + ), StateTestConstants.getStream(StateTestConstants.CURSOR_FIELD2), { obj: DbStreamState? -> obj!!.cursor }, { obj: DbStreamState? -> obj!!.cursorField }, - CURSOR_RECORD_COUNT_FUNCTION) - Assertions.assertEquals(CursorInfo(StateTestConstants.CURSOR_FIELD1, StateTestConstants.CURSOR, StateTestConstants.CURSOR_FIELD2, null), actual) + CURSOR_RECORD_COUNT_FUNCTION + ) + Assertions.assertEquals( + CursorInfo( + StateTestConstants.CURSOR_FIELD1, + StateTestConstants.CURSOR, + StateTestConstants.CURSOR_FIELD2, + null + ), + actual + ) } @Test fun testCreateCursorInfoCatalogAndNoState() { - val cursorManager = createCursorManager(StateTestConstants.CURSOR_FIELD1, StateTestConstants.CURSOR, StateTestConstants.NAME_NAMESPACE_PAIR1) - val actual = cursorManager.createCursorInfoForStream( + val cursorManager = + createCursorManager( + StateTestConstants.CURSOR_FIELD1, + StateTestConstants.CURSOR, + StateTestConstants.NAME_NAMESPACE_PAIR1 + ) + val actual = + cursorManager.createCursorInfoForStream( StateTestConstants.NAME_NAMESPACE_PAIR1, Optional.empty(), StateTestConstants.getStream(StateTestConstants.CURSOR_FIELD1), Function { obj: DbStreamState? -> obj!!.cursor }, Function { obj: DbStreamState? -> obj!!.cursorField }, - CURSOR_RECORD_COUNT_FUNCTION) - Assertions.assertEquals(CursorInfo(null, null, StateTestConstants.CURSOR_FIELD1, null), actual) + CURSOR_RECORD_COUNT_FUNCTION + ) + Assertions.assertEquals( + CursorInfo(null, null, StateTestConstants.CURSOR_FIELD1, null), + actual + ) } @Test fun testCreateCursorInfoStateAndNoCatalog() { - val cursorManager = createCursorManager(StateTestConstants.CURSOR_FIELD1, StateTestConstants.CURSOR, StateTestConstants.NAME_NAMESPACE_PAIR1) - val actual = cursorManager.createCursorInfoForStream( + val cursorManager = + createCursorManager( + StateTestConstants.CURSOR_FIELD1, + StateTestConstants.CURSOR, + StateTestConstants.NAME_NAMESPACE_PAIR1 + ) + val actual = + cursorManager.createCursorInfoForStream( StateTestConstants.NAME_NAMESPACE_PAIR1, - StateTestConstants.getState(StateTestConstants.CURSOR_FIELD1, StateTestConstants.CURSOR), + StateTestConstants.getState( + StateTestConstants.CURSOR_FIELD1, + StateTestConstants.CURSOR + ), Optional.empty(), { obj: DbStreamState? -> obj!!.cursor }, { obj: DbStreamState? -> obj!!.cursorField }, - CURSOR_RECORD_COUNT_FUNCTION) - Assertions.assertEquals(CursorInfo(StateTestConstants.CURSOR_FIELD1, StateTestConstants.CURSOR, null, null), actual) + CURSOR_RECORD_COUNT_FUNCTION + ) + Assertions.assertEquals( + CursorInfo(StateTestConstants.CURSOR_FIELD1, StateTestConstants.CURSOR, null, null), + actual + ) } // this is what full refresh looks like. @Test fun testCreateCursorInfoNoCatalogAndNoState() { - val cursorManager = createCursorManager(StateTestConstants.CURSOR_FIELD1, StateTestConstants.CURSOR, StateTestConstants.NAME_NAMESPACE_PAIR1) - val actual = cursorManager.createCursorInfoForStream( + val cursorManager = + createCursorManager( + StateTestConstants.CURSOR_FIELD1, + StateTestConstants.CURSOR, + StateTestConstants.NAME_NAMESPACE_PAIR1 + ) + val actual = + cursorManager.createCursorInfoForStream( StateTestConstants.NAME_NAMESPACE_PAIR1, Optional.empty(), Optional.empty(), Function { obj: DbStreamState? -> obj!!.cursor }, Function { obj: DbStreamState? -> obj!!.cursorField }, - CURSOR_RECORD_COUNT_FUNCTION) + CURSOR_RECORD_COUNT_FUNCTION + ) Assertions.assertEquals(CursorInfo(null, null, null, null), actual) } @Test fun testCreateCursorInfoStateAndCatalogButNoCursorField() { - val cursorManager = createCursorManager(StateTestConstants.CURSOR_FIELD1, StateTestConstants.CURSOR, StateTestConstants.NAME_NAMESPACE_PAIR1) - val actual = cursorManager.createCursorInfoForStream( + val cursorManager = + createCursorManager( + StateTestConstants.CURSOR_FIELD1, + StateTestConstants.CURSOR, + StateTestConstants.NAME_NAMESPACE_PAIR1 + ) + val actual = + cursorManager.createCursorInfoForStream( StateTestConstants.NAME_NAMESPACE_PAIR1, - StateTestConstants.getState(StateTestConstants.CURSOR_FIELD1, StateTestConstants.CURSOR), + StateTestConstants.getState( + StateTestConstants.CURSOR_FIELD1, + StateTestConstants.CURSOR + ), StateTestConstants.getStream(null), { obj: DbStreamState? -> obj!!.cursor }, { obj: DbStreamState? -> obj!!.cursorField }, - CURSOR_RECORD_COUNT_FUNCTION) - Assertions.assertEquals(CursorInfo(StateTestConstants.CURSOR_FIELD1, StateTestConstants.CURSOR, null, null), actual) + CURSOR_RECORD_COUNT_FUNCTION + ) + Assertions.assertEquals( + CursorInfo(StateTestConstants.CURSOR_FIELD1, StateTestConstants.CURSOR, null, null), + actual + ) } @Test fun testGetters() { - val cursorManager: CursorManager<*> = createCursorManager(StateTestConstants.CURSOR_FIELD1, StateTestConstants.CURSOR, StateTestConstants.NAME_NAMESPACE_PAIR1) - val actualCursorInfo = CursorInfo(StateTestConstants.CURSOR_FIELD1, StateTestConstants.CURSOR, null, null) + val cursorManager: CursorManager<*> = + createCursorManager( + StateTestConstants.CURSOR_FIELD1, + StateTestConstants.CURSOR, + StateTestConstants.NAME_NAMESPACE_PAIR1 + ) + val actualCursorInfo = + CursorInfo(StateTestConstants.CURSOR_FIELD1, StateTestConstants.CURSOR, null, null) - Assertions.assertEquals(Optional.of(actualCursorInfo), cursorManager.getCursorInfo(StateTestConstants.NAME_NAMESPACE_PAIR1)) - Assertions.assertEquals(Optional.empty(), cursorManager.getCursorField(StateTestConstants.NAME_NAMESPACE_PAIR1)) - Assertions.assertEquals(Optional.empty(), cursorManager.getCursor(StateTestConstants.NAME_NAMESPACE_PAIR1)) + Assertions.assertEquals( + Optional.of(actualCursorInfo), + cursorManager.getCursorInfo(StateTestConstants.NAME_NAMESPACE_PAIR1) + ) + Assertions.assertEquals( + Optional.empty(), + cursorManager.getCursorField(StateTestConstants.NAME_NAMESPACE_PAIR1) + ) + Assertions.assertEquals( + Optional.empty(), + cursorManager.getCursor(StateTestConstants.NAME_NAMESPACE_PAIR1) + ) - Assertions.assertEquals(Optional.empty(), cursorManager.getCursorInfo(StateTestConstants.NAME_NAMESPACE_PAIR2)) - Assertions.assertEquals(Optional.empty(), cursorManager.getCursorField(StateTestConstants.NAME_NAMESPACE_PAIR2)) - Assertions.assertEquals(Optional.empty(), cursorManager.getCursor(StateTestConstants.NAME_NAMESPACE_PAIR2)) + Assertions.assertEquals( + Optional.empty(), + cursorManager.getCursorInfo(StateTestConstants.NAME_NAMESPACE_PAIR2) + ) + Assertions.assertEquals( + Optional.empty(), + cursorManager.getCursorField(StateTestConstants.NAME_NAMESPACE_PAIR2) + ) + Assertions.assertEquals( + Optional.empty(), + cursorManager.getCursor(StateTestConstants.NAME_NAMESPACE_PAIR2) + ) } - private fun createCursorManager(cursorField: String?, - cursor: String?, - nameNamespacePair: AirbyteStreamNameNamespacePair?): CursorManager { + private fun createCursorManager( + cursorField: String?, + cursor: String?, + nameNamespacePair: AirbyteStreamNameNamespacePair? + ): CursorManager { val dbStreamState = StateTestConstants.getState(cursorField, cursor).get() return CursorManager( - StateTestConstants.getCatalog(cursorField).orElse(null), - { setOf(dbStreamState) }, - { obj: DbStreamState? -> obj!!.cursor }, - { obj: DbStreamState? -> obj!!.cursorField }, - CURSOR_RECORD_COUNT_FUNCTION, - { s: DbStreamState? -> nameNamespacePair }, - false) + StateTestConstants.getCatalog(cursorField).orElse(null), + { setOf(dbStreamState) }, + { obj: DbStreamState? -> obj!!.cursor }, + { obj: DbStreamState? -> obj!!.cursorField }, + CURSOR_RECORD_COUNT_FUNCTION, + { s: DbStreamState? -> nameNamespacePair }, + false + ) } companion object { - private val CURSOR_RECORD_COUNT_FUNCTION = Function { stream: DbStreamState? -> + private val CURSOR_RECORD_COUNT_FUNCTION = Function { stream: DbStreamState -> if (stream!!.cursorRecordCount != null) { return@Function stream.cursorRecordCount } else { diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/CursorStateMessageProducerTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/CursorStateMessageProducerTest.kt index 84fb1130ee998..996b5e02c5196 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/CursorStateMessageProducerTest.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/CursorStateMessageProducerTest.kt @@ -11,20 +11,25 @@ import io.airbyte.commons.util.MoreIterators import io.airbyte.protocol.models.Field import io.airbyte.protocol.models.JsonSchemaType import io.airbyte.protocol.models.v0.* -import org.junit.jupiter.api.Assertions -import org.junit.jupiter.api.BeforeEach -import org.junit.jupiter.api.Test -import org.testcontainers.shaded.com.google.common.collect.ImmutableMap import java.sql.SQLException import java.time.Duration import java.util.* import java.util.List +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.testcontainers.shaded.com.google.common.collect.ImmutableMap internal class CursorStateMessageProducerTest { private fun createExceptionIterator(): Iterator { - return object : MutableIterator { - val internalMessageIterator: Iterator = MoreIterators.of(RECORD_MESSAGE_1, RECORD_MESSAGE_2, - RECORD_MESSAGE_2, RECORD_MESSAGE_3) + return object : Iterator { + val internalMessageIterator: Iterator = + MoreIterators.of( + RECORD_MESSAGE_1, + RECORD_MESSAGE_2, + RECORD_MESSAGE_2, + RECORD_MESSAGE_3 + ) override fun hasNext(): Boolean { return true @@ -34,10 +39,16 @@ internal class CursorStateMessageProducerTest { if (internalMessageIterator.hasNext()) { return internalMessageIterator.next() } else { - // this line throws a RunTimeException wrapped around a SQLException to mimic the flow of when a + // this line throws a RunTimeException wrapped around a SQLException to mimic + // the flow of when a // SQLException is thrown and wrapped in // StreamingJdbcDatabase#tryAdvance - throw RuntimeException(SQLException("Connection marked broken because of SQLSTATE(080006)", "08006")) + throw RuntimeException( + SQLException( + "Connection marked broken because of SQLSTATE(080006)", + "08006" + ) + ) } } } @@ -48,23 +59,31 @@ internal class CursorStateMessageProducerTest { @BeforeEach fun setup() { val airbyteStream = AirbyteStream().withNamespace(NAMESPACE).withName(STREAM_NAME) - val configuredAirbyteStream = ConfiguredAirbyteStream() + val configuredAirbyteStream = + ConfiguredAirbyteStream() .withStream(airbyteStream) .withCursorField(listOf(UUID_FIELD_NAME)) - stateManager = StreamStateManager(emptyList(), - ConfiguredAirbyteCatalog().withStreams(listOf(configuredAirbyteStream))) + stateManager = + StreamStateManager( + emptyList(), + ConfiguredAirbyteCatalog().withStreams(listOf(configuredAirbyteStream)) + ) } @Test fun testWithoutInitialCursor() { messageIterator = MoreIterators.of(RECORD_MESSAGE_1, RECORD_MESSAGE_2) - val producer = CursorStateMessageProducer( - stateManager, - Optional.empty()) + val producer = CursorStateMessageProducer(stateManager, Optional.empty()) - val iterator: SourceStateIterator<*> = SourceStateIterator(messageIterator, STREAM, producer, StateEmitFrequency(0, Duration.ZERO)) + val iterator: SourceStateIterator<*> = + SourceStateIterator( + messageIterator, + STREAM, + producer, + StateEmitFrequency(0, Duration.ZERO) + ) Assertions.assertEquals(RECORD_MESSAGE_1, iterator.next()) Assertions.assertEquals(RECORD_MESSAGE_2, iterator.next()) @@ -74,16 +93,21 @@ internal class CursorStateMessageProducerTest { @Test fun testWithInitialCursor() { - // record 1 and 2 has smaller cursor value, so at the end, the initial cursor is emitted with 0 + // record 1 and 2 has smaller cursor value, so at the end, the initial cursor is emitted + // with 0 // record count messageIterator = MoreIterators.of(RECORD_MESSAGE_1, RECORD_MESSAGE_2) - val producer = CursorStateMessageProducer( - stateManager, - Optional.of(RECORD_VALUE_5)) + val producer = CursorStateMessageProducer(stateManager, Optional.of(RECORD_VALUE_5)) - val iterator: SourceStateIterator<*> = SourceStateIterator(messageIterator, STREAM, producer, StateEmitFrequency(0, Duration.ZERO)) + val iterator: SourceStateIterator<*> = + SourceStateIterator( + messageIterator, + STREAM, + producer, + StateEmitFrequency(0, Duration.ZERO) + ) Assertions.assertEquals(RECORD_MESSAGE_1, iterator.next()) Assertions.assertEquals(RECORD_MESSAGE_2, iterator.next()) @@ -97,11 +121,15 @@ internal class CursorStateMessageProducerTest { (recordMessage.record.data as ObjectNode).remove(UUID_FIELD_NAME) val messageStream = MoreIterators.of(recordMessage) - val producer = CursorStateMessageProducer( - stateManager, - Optional.empty()) + val producer = CursorStateMessageProducer(stateManager, Optional.empty()) - val iterator: SourceStateIterator<*> = SourceStateIterator(messageStream, STREAM, producer, StateEmitFrequency(0, Duration.ZERO)) + val iterator: SourceStateIterator<*> = + SourceStateIterator( + messageStream, + STREAM, + producer, + StateEmitFrequency(0, Duration.ZERO) + ) Assertions.assertEquals(recordMessage, iterator.next()) // null because no records with a cursor field were replicated for the stream. @@ -113,22 +141,29 @@ internal class CursorStateMessageProducerTest { fun testIteratorCatchesExceptionWhenEmissionFrequencyNonZero() { val exceptionIterator = createExceptionIterator() - val producer = CursorStateMessageProducer( - stateManager, - Optional.of(RECORD_VALUE_1)) + val producer = CursorStateMessageProducer(stateManager, Optional.of(RECORD_VALUE_1)) - val iterator: SourceStateIterator<*> = SourceStateIterator(exceptionIterator, STREAM, producer, StateEmitFrequency(1, Duration.ZERO)) + val iterator: SourceStateIterator<*> = + SourceStateIterator( + exceptionIterator, + STREAM, + producer, + StateEmitFrequency(1, Duration.ZERO) + ) Assertions.assertEquals(RECORD_MESSAGE_1, iterator.next()) Assertions.assertEquals(RECORD_MESSAGE_2, iterator.next()) - // continues to emit RECORD_MESSAGE_2 since cursorField has not changed thus not satisfying the + // continues to emit RECORD_MESSAGE_2 since cursorField has not changed thus not satisfying + // the // condition of "ready" Assertions.assertEquals(RECORD_MESSAGE_2, iterator.next()) Assertions.assertEquals(RECORD_MESSAGE_3, iterator.next()) - // emits the first state message since the iterator has changed cursorFields (2 -> 3) and met the + // emits the first state message since the iterator has changed cursorFields (2 -> 3) and + // met the // frequency minimum of 1 record Assertions.assertEquals(createStateMessage(RECORD_VALUE_2, 2, 4.0), iterator.next()) - // no further records to read since Exception was caught above and marked iterator as endOfData() + // no further records to read since Exception was caught above and marked iterator as + // endOfData() Assertions.assertThrows(FailedRecordIteratorException::class.java) { iterator.hasNext() } } @@ -136,11 +171,15 @@ internal class CursorStateMessageProducerTest { fun testIteratorCatchesExceptionWhenEmissionFrequencyZero() { val exceptionIterator = createExceptionIterator() - val producer = CursorStateMessageProducer( - stateManager, - Optional.of(RECORD_VALUE_1)) + val producer = CursorStateMessageProducer(stateManager, Optional.of(RECORD_VALUE_1)) - val iterator: SourceStateIterator<*> = SourceStateIterator(exceptionIterator, STREAM, producer, StateEmitFrequency(0, Duration.ZERO)) + val iterator: SourceStateIterator<*> = + SourceStateIterator( + exceptionIterator, + STREAM, + producer, + StateEmitFrequency(0, Duration.ZERO) + ) Assertions.assertEquals(RECORD_MESSAGE_1, iterator.next()) Assertions.assertEquals(RECORD_MESSAGE_2, iterator.next()) @@ -152,12 +191,15 @@ internal class CursorStateMessageProducerTest { @Test fun testEmptyStream() { - val producer = CursorStateMessageProducer( - stateManager, - Optional.empty()) + val producer = CursorStateMessageProducer(stateManager, Optional.empty()) val iterator: SourceStateIterator<*> = - SourceStateIterator(Collections.emptyIterator(), STREAM, producer, StateEmitFrequency(1, Duration.ZERO)) + SourceStateIterator( + Collections.emptyIterator(), + STREAM, + producer, + StateEmitFrequency(1, Duration.ZERO) + ) Assertions.assertEquals(EMPTY_STATE_MESSAGE, iterator.next()) Assertions.assertFalse(iterator.hasNext()) @@ -171,11 +213,15 @@ internal class CursorStateMessageProducerTest { // UTF8 null \u0000 is removed from the cursor value in the state message messageIterator = MoreIterators.of(recordMessageWithNull) - val producer = CursorStateMessageProducer( - stateManager, - Optional.empty()) + val producer = CursorStateMessageProducer(stateManager, Optional.empty()) - val iterator: SourceStateIterator<*> = SourceStateIterator(messageIterator, STREAM, producer, StateEmitFrequency(0, Duration.ZERO)) + val iterator: SourceStateIterator<*> = + SourceStateIterator( + messageIterator, + STREAM, + producer, + StateEmitFrequency(0, Duration.ZERO) + ) Assertions.assertEquals(recordMessageWithNull, iterator.next()) Assertions.assertEquals(createStateMessage(RECORD_VALUE_1, 1, 1.0), iterator.next()) @@ -184,13 +230,24 @@ internal class CursorStateMessageProducerTest { @Test fun testStateEmissionFrequency1() { - messageIterator = MoreIterators.of(RECORD_MESSAGE_1, RECORD_MESSAGE_2, RECORD_MESSAGE_3, RECORD_MESSAGE_4, RECORD_MESSAGE_5) + messageIterator = + MoreIterators.of( + RECORD_MESSAGE_1, + RECORD_MESSAGE_2, + RECORD_MESSAGE_3, + RECORD_MESSAGE_4, + RECORD_MESSAGE_5 + ) - val producer = CursorStateMessageProducer( - stateManager, - Optional.empty()) + val producer = CursorStateMessageProducer(stateManager, Optional.empty()) - val iterator: SourceStateIterator<*> = SourceStateIterator(messageIterator, STREAM, producer, StateEmitFrequency(1, Duration.ZERO)) + val iterator: SourceStateIterator<*> = + SourceStateIterator( + messageIterator, + STREAM, + producer, + StateEmitFrequency(1, Duration.ZERO) + ) Assertions.assertEquals(RECORD_MESSAGE_1, iterator.next()) // should emit state 1, but it is unclear whether there will be more @@ -212,13 +269,24 @@ internal class CursorStateMessageProducerTest { @Test fun testStateEmissionFrequency2() { - messageIterator = MoreIterators.of(RECORD_MESSAGE_1, RECORD_MESSAGE_2, RECORD_MESSAGE_3, RECORD_MESSAGE_4, RECORD_MESSAGE_5) + messageIterator = + MoreIterators.of( + RECORD_MESSAGE_1, + RECORD_MESSAGE_2, + RECORD_MESSAGE_3, + RECORD_MESSAGE_4, + RECORD_MESSAGE_5 + ) - val producer = CursorStateMessageProducer( - stateManager, - Optional.empty()) + val producer = CursorStateMessageProducer(stateManager, Optional.empty()) - val iterator: SourceStateIterator<*> = SourceStateIterator(messageIterator, STREAM, producer, StateEmitFrequency(2, Duration.ZERO)) + val iterator: SourceStateIterator<*> = + SourceStateIterator( + messageIterator, + STREAM, + producer, + StateEmitFrequency(2, Duration.ZERO) + ) Assertions.assertEquals(RECORD_MESSAGE_1, iterator.next()) Assertions.assertEquals(RECORD_MESSAGE_2, iterator.next()) @@ -235,13 +303,18 @@ internal class CursorStateMessageProducerTest { @Test fun testStateEmissionWhenInitialCursorIsNotNull() { - messageIterator = MoreIterators.of(RECORD_MESSAGE_2, RECORD_MESSAGE_3, RECORD_MESSAGE_4, RECORD_MESSAGE_5) + messageIterator = + MoreIterators.of(RECORD_MESSAGE_2, RECORD_MESSAGE_3, RECORD_MESSAGE_4, RECORD_MESSAGE_5) - val producer = CursorStateMessageProducer( - stateManager, - Optional.of(RECORD_VALUE_1)) + val producer = CursorStateMessageProducer(stateManager, Optional.of(RECORD_VALUE_1)) - val iterator: SourceStateIterator<*> = SourceStateIterator(messageIterator, STREAM, producer, StateEmitFrequency(1, Duration.ZERO)) + val iterator: SourceStateIterator<*> = + SourceStateIterator( + messageIterator, + STREAM, + producer, + StateEmitFrequency(1, Duration.ZERO) + ) Assertions.assertEquals(RECORD_MESSAGE_2, iterator.next()) Assertions.assertEquals(RECORD_MESSAGE_3, iterator.next()) @@ -254,47 +327,51 @@ internal class CursorStateMessageProducerTest { } /** - * Incremental syncs will sort the table with the cursor field, and emit the max cursor for every N - * records. The purpose is to emit the states frequently, so that if any transient failure occurs - * during a long sync, the next run does not need to start from the beginning, but can resume from - * the last successful intermediate state committed on the destination. The next run will start with - * `cursorField > cursor`. However, it is possible that there are multiple records with the same - * cursor value. If the intermediate state is emitted before all these records have been synced to - * the destination, some of these records may be lost. - * + * Incremental syncs will sort the table with the cursor field, and emit the max cursor for + * every N records. The purpose is to emit the states frequently, so that if any transient + * failure occurs during a long sync, the next run does not need to start from the beginning, + * but can resume from the last successful intermediate state committed on the destination. The + * next run will start with `cursorField > cursor`. However, it is possible that there are + * multiple records with the same cursor value. If the intermediate state is emitted before all + * these records have been synced to the destination, some of these records may be lost. * * Here is an example: * - *
-     * | Record ID | Cursor Field | Other Field | Note                          |
-     * | --------- | ------------ | ----------- | ----------------------------- |
-     * | 1         | F1=16        | F2="abc"    |                               |
-     * | 2         | F1=16        | F2="def"    | <- state emission and failure |
-     * | 3         | F1=16        | F2="ghi"    |                               |
-    
* - * - * If the intermediate state is emitted for record 2 and the sync fails immediately such that the - * cursor value `16` is committed, but only record 1 and 2 are actually synced, the next run will - * start with `F1 > 16` and skip record 3. + *
 | Record ID | Cursor Field | Other Field | Note | | --------- | ------------ |
+     * ----------- | ----------------------------- | | 1 | F1=16 | F2="abc" | | | 2 | F1=16 |
+     * F2="def" | <- state emission and failure | | 3 | F1=16 | F2="ghi" | | 
* * + * If the intermediate state is emitted for record 2 and the sync fails immediately such that + * the cursor value `16` is committed, but only record 1 and 2 are actually synced, the next run + * will start with `F1 > 16` and skip record 3. * - * So intermediate state emission should only happen when all records with the same cursor value has - * been synced to destination. Reference: + * So intermediate state emission should only happen when all records with the same cursor value + * has been synced to destination. Reference: * [link](https://github.com/airbytehq/airbyte/issues/15427) */ @Test fun testStateEmissionForRecordsSharingSameCursorValue() { - messageIterator = MoreIterators.of( - RECORD_MESSAGE_2, RECORD_MESSAGE_2, - RECORD_MESSAGE_3, RECORD_MESSAGE_3, RECORD_MESSAGE_3, + messageIterator = + MoreIterators.of( + RECORD_MESSAGE_2, + RECORD_MESSAGE_2, + RECORD_MESSAGE_3, + RECORD_MESSAGE_3, + RECORD_MESSAGE_3, RECORD_MESSAGE_4, - RECORD_MESSAGE_5, RECORD_MESSAGE_5) + RECORD_MESSAGE_5, + RECORD_MESSAGE_5 + ) - val producer = CursorStateMessageProducer( - stateManager, - Optional.of(RECORD_VALUE_1)) + val producer = CursorStateMessageProducer(stateManager, Optional.of(RECORD_VALUE_1)) - val iterator: SourceStateIterator<*> = SourceStateIterator(messageIterator, STREAM, producer, StateEmitFrequency(1, Duration.ZERO)) + val iterator: SourceStateIterator<*> = + SourceStateIterator( + messageIterator, + STREAM, + producer, + StateEmitFrequency(1, Duration.ZERO) + ) Assertions.assertEquals(RECORD_MESSAGE_2, iterator.next()) Assertions.assertEquals(RECORD_MESSAGE_2, iterator.next()) @@ -315,18 +392,30 @@ internal class CursorStateMessageProducerTest { @Test fun testStateEmissionForRecordsSharingSameCursorValueButDifferentStatsCount() { - messageIterator = MoreIterators.of( - RECORD_MESSAGE_2, RECORD_MESSAGE_2, - RECORD_MESSAGE_2, RECORD_MESSAGE_2, - RECORD_MESSAGE_3, RECORD_MESSAGE_3, RECORD_MESSAGE_3, + messageIterator = + MoreIterators.of( + RECORD_MESSAGE_2, + RECORD_MESSAGE_2, + RECORD_MESSAGE_2, + RECORD_MESSAGE_2, + RECORD_MESSAGE_3, + RECORD_MESSAGE_3, RECORD_MESSAGE_3, - RECORD_MESSAGE_3, RECORD_MESSAGE_3, RECORD_MESSAGE_3) + RECORD_MESSAGE_3, + RECORD_MESSAGE_3, + RECORD_MESSAGE_3, + RECORD_MESSAGE_3 + ) - val producer = CursorStateMessageProducer( - stateManager, - Optional.of(RECORD_VALUE_1)) + val producer = CursorStateMessageProducer(stateManager, Optional.of(RECORD_VALUE_1)) - val iterator: SourceStateIterator<*> = SourceStateIterator(messageIterator, STREAM, producer, StateEmitFrequency(10, Duration.ZERO)) + val iterator: SourceStateIterator<*> = + SourceStateIterator( + messageIterator, + STREAM, + producer, + StateEmitFrequency(10, Duration.ZERO) + ) Assertions.assertEquals(RECORD_MESSAGE_2, iterator.next()) Assertions.assertEquals(RECORD_MESSAGE_2, iterator.next()) @@ -351,10 +440,12 @@ internal class CursorStateMessageProducerTest { private const val STREAM_NAME = "shoes" private const val UUID_FIELD_NAME = "ascending_inventory_uuid" - private val STREAM: ConfiguredAirbyteStream = CatalogHelpers.createConfiguredAirbyteStream( - STREAM_NAME, - NAMESPACE, - Field.of(UUID_FIELD_NAME, JsonSchemaType.STRING)) + private val STREAM: ConfiguredAirbyteStream = + CatalogHelpers.createConfiguredAirbyteStream( + STREAM_NAME, + NAMESPACE, + Field.of(UUID_FIELD_NAME, JsonSchemaType.STRING) + ) .withCursorField(List.of(UUID_FIELD_NAME)) private val EMPTY_STATE_MESSAGE = createEmptyStateMessage(0.0) @@ -376,13 +467,20 @@ internal class CursorStateMessageProducerTest { private fun createRecordMessage(recordValue: String): AirbyteMessage { return AirbyteMessage() - .withType(AirbyteMessage.Type.RECORD) - .withRecord(AirbyteRecordMessage() - .withData(Jsons.jsonNode(ImmutableMap.of(UUID_FIELD_NAME, recordValue)))) + .withType(AirbyteMessage.Type.RECORD) + .withRecord( + AirbyteRecordMessage() + .withData(Jsons.jsonNode(ImmutableMap.of(UUID_FIELD_NAME, recordValue))) + ) } - private fun createStateMessage(recordValue: String, cursorRecordCount: Long, statsRecordCount: Double): AirbyteMessage { - val dbStreamState = DbStreamState() + private fun createStateMessage( + recordValue: String, + cursorRecordCount: Long, + statsRecordCount: Double + ): AirbyteMessage { + val dbStreamState = + DbStreamState() .withCursorField(listOf(UUID_FIELD_NAME)) .withCursor(recordValue) .withStreamName(STREAM_NAME) @@ -392,34 +490,51 @@ internal class CursorStateMessageProducerTest { } val dbState = DbState().withCdc(false).withStreams(listOf(dbStreamState)) return AirbyteMessage() - .withType(AirbyteMessage.Type.STATE) - .withState(AirbyteStateMessage() - .withType(AirbyteStateMessage.AirbyteStateType.STREAM) - .withStream(AirbyteStreamState() - .withStreamDescriptor(StreamDescriptor().withName(STREAM_NAME).withNamespace(NAMESPACE)) - .withStreamState(Jsons.jsonNode(dbStreamState))) - .withData(Jsons.jsonNode(dbState)) - .withSourceStats(AirbyteStateStats().withRecordCount(statsRecordCount))) + .withType(AirbyteMessage.Type.STATE) + .withState( + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor() + .withName(STREAM_NAME) + .withNamespace(NAMESPACE) + ) + .withStreamState(Jsons.jsonNode(dbStreamState)) + ) + .withData(Jsons.jsonNode(dbState)) + .withSourceStats(AirbyteStateStats().withRecordCount(statsRecordCount)) + ) } private fun createEmptyStateMessage(statsRecordCount: Double): AirbyteMessage { - val dbStreamState = DbStreamState() + val dbStreamState = + DbStreamState() .withCursorField(listOf(UUID_FIELD_NAME)) .withStreamName(STREAM_NAME) .withStreamNamespace(NAMESPACE) val dbState = DbState().withCdc(false).withStreams(listOf(dbStreamState)) return AirbyteMessage() - .withType(AirbyteMessage.Type.STATE) - .withState(AirbyteStateMessage() - .withType(AirbyteStateMessage.AirbyteStateType.STREAM) - .withStream(AirbyteStreamState() - .withStreamDescriptor(StreamDescriptor().withName(STREAM_NAME).withNamespace(NAMESPACE)) - .withStreamState(Jsons.jsonNode(dbStreamState))) - .withData(Jsons.jsonNode(dbState)) - .withSourceStats(AirbyteStateStats().withRecordCount(statsRecordCount))) + .withType(AirbyteMessage.Type.STATE) + .withState( + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor() + .withName(STREAM_NAME) + .withNamespace(NAMESPACE) + ) + .withStreamState(Jsons.jsonNode(dbStreamState)) + ) + .withData(Jsons.jsonNode(dbState)) + .withSourceStats(AirbyteStateStats().withRecordCount(statsRecordCount)) + ) } - private var messageIterator: Iterator? = null + private lateinit var messageIterator: Iterator } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/GlobalStateManagerTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/GlobalStateManagerTest.kt index 374f11da880de..ec7521360f37d 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/GlobalStateManagerTest.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/GlobalStateManagerTest.kt @@ -8,116 +8,205 @@ import io.airbyte.cdk.integrations.source.relationaldb.models.DbState import io.airbyte.cdk.integrations.source.relationaldb.models.DbStreamState import io.airbyte.commons.json.Jsons import io.airbyte.protocol.models.v0.* -import org.junit.jupiter.api.Assertions -import org.junit.jupiter.api.Disabled -import org.junit.jupiter.api.Test -import org.mockito.Mockito import java.util.* import java.util.List import java.util.Map import java.util.stream.Collectors +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Disabled +import org.junit.jupiter.api.Test +import org.mockito.Mockito -/** - * Test suite for the [GlobalStateManager] class. - */ +/** Test suite for the [GlobalStateManager] class. */ class GlobalStateManagerTest { @Test fun testCdcStateManager() { val catalog = Mockito.mock(ConfiguredAirbyteCatalog::class.java) val cdcState = CdcState().withState(Jsons.jsonNode(Map.of("foo", "bar", "baz", 5))) - val globalState = AirbyteGlobalState().withSharedState(Jsons.jsonNode(cdcState)) - .withStreamStates(List.of(AirbyteStreamState().withStreamDescriptor(StreamDescriptor().withNamespace("namespace").withName("name")) - .withStreamState(Jsons.jsonNode(DbStreamState())))) + val globalState = + AirbyteGlobalState() + .withSharedState(Jsons.jsonNode(cdcState)) + .withStreamStates( + List.of( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor().withNamespace("namespace").withName("name") + ) + .withStreamState(Jsons.jsonNode(DbStreamState())) + ) + ) val stateManager: StateManager = - GlobalStateManager(AirbyteStateMessage().withType(AirbyteStateMessage.AirbyteStateType.GLOBAL).withGlobal(globalState), catalog) + GlobalStateManager( + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.GLOBAL) + .withGlobal(globalState), + catalog + ) Assertions.assertNotNull(stateManager.cdcStateManager) Assertions.assertEquals(cdcState, stateManager.cdcStateManager.cdcState) - Assertions.assertEquals(1, stateManager.cdcStateManager.initialStreamsSynced.size) - Assertions.assertTrue(stateManager.cdcStateManager.initialStreamsSynced.contains(AirbyteStreamNameNamespacePair("name", "namespace"))) + Assertions.assertEquals(1, stateManager.cdcStateManager.initialStreamsSynced!!.size) + Assertions.assertTrue( + stateManager.cdcStateManager.initialStreamsSynced!!.contains( + AirbyteStreamNameNamespacePair("name", "namespace") + ) + ) } @Test fun testToStateFromLegacyState() { - val catalog = ConfiguredAirbyteCatalog() - .withStreams(List.of( + val catalog = + ConfiguredAirbyteCatalog() + .withStreams( + List.of( ConfiguredAirbyteStream() - .withStream(AirbyteStream().withName(StateTestConstants.STREAM_NAME1).withNamespace(StateTestConstants.NAMESPACE)) - .withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)), + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME1) + .withNamespace(StateTestConstants.NAMESPACE) + ) + .withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)), ConfiguredAirbyteStream() - .withStream(AirbyteStream().withName(StateTestConstants.STREAM_NAME2).withNamespace(StateTestConstants.NAMESPACE)) - .withCursorField(List.of(StateTestConstants.CURSOR_FIELD2)), + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME2) + .withNamespace(StateTestConstants.NAMESPACE) + ) + .withCursorField(List.of(StateTestConstants.CURSOR_FIELD2)), ConfiguredAirbyteStream() - .withStream(AirbyteStream().withName(StateTestConstants.STREAM_NAME3).withNamespace(StateTestConstants.NAMESPACE)))) + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME3) + .withNamespace(StateTestConstants.NAMESPACE) + ) + ) + ) val cdcState = CdcState().withState(Jsons.jsonNode(Map.of("foo", "bar", "baz", 5))) - val dbState = DbState() + val dbState = + DbState() .withCdc(true) .withCdcState(cdcState) - .withStreams(List.of( - DbStreamState() + .withStreams( + List.of( + DbStreamState() .withStreamName(StateTestConstants.STREAM_NAME1) .withStreamNamespace(StateTestConstants.NAMESPACE) .withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)) .withCursor("a"), - DbStreamState() + DbStreamState() .withStreamName(StateTestConstants.STREAM_NAME2) .withStreamNamespace(StateTestConstants.NAMESPACE) .withCursorField(List.of(StateTestConstants.CURSOR_FIELD2)), - DbStreamState() + DbStreamState() .withStreamName(StateTestConstants.STREAM_NAME3) - .withStreamNamespace(StateTestConstants.NAMESPACE)) - .stream().sorted(Comparator.comparing { obj: DbStreamState -> obj.streamName }).collect(Collectors.toList())) - val stateManager: StateManager = GlobalStateManager(AirbyteStateMessage().withData(Jsons.jsonNode(dbState)), catalog) + .withStreamNamespace(StateTestConstants.NAMESPACE) + ) + .stream() + .sorted(Comparator.comparing { obj: DbStreamState -> obj.streamName }) + .collect(Collectors.toList()) + ) + val stateManager: StateManager = + GlobalStateManager(AirbyteStateMessage().withData(Jsons.jsonNode(dbState)), catalog) val expectedRecordCount = 19L - val expectedDbState = DbState() + val expectedDbState = + DbState() .withCdc(true) .withCdcState(cdcState) - .withStreams(List.of( - DbStreamState() + .withStreams( + List.of( + DbStreamState() .withStreamName(StateTestConstants.STREAM_NAME1) .withStreamNamespace(StateTestConstants.NAMESPACE) .withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)) .withCursor("a") .withCursorRecordCount(expectedRecordCount), - DbStreamState() + DbStreamState() .withStreamName(StateTestConstants.STREAM_NAME2) .withStreamNamespace(StateTestConstants.NAMESPACE) .withCursorField(List.of(StateTestConstants.CURSOR_FIELD2)), - DbStreamState() + DbStreamState() .withStreamName(StateTestConstants.STREAM_NAME3) - .withStreamNamespace(StateTestConstants.NAMESPACE)) - .stream().sorted(Comparator.comparing { obj: DbStreamState -> obj.streamName }).collect(Collectors.toList())) + .withStreamNamespace(StateTestConstants.NAMESPACE) + ) + .stream() + .sorted(Comparator.comparing { obj: DbStreamState -> obj.streamName }) + .collect(Collectors.toList()) + ) - val expectedGlobalState = AirbyteGlobalState() + val expectedGlobalState = + AirbyteGlobalState() .withSharedState(Jsons.jsonNode(cdcState)) - .withStreamStates(List.of( - AirbyteStreamState() - .withStreamDescriptor(StreamDescriptor().withName(StateTestConstants.STREAM_NAME1).withNamespace(StateTestConstants.NAMESPACE)) - .withStreamState(Jsons.jsonNode(DbStreamState() - .withStreamName(StateTestConstants.STREAM_NAME1) - .withStreamNamespace(StateTestConstants.NAMESPACE) - .withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)) - .withCursor("a") - .withCursorRecordCount(expectedRecordCount))), - AirbyteStreamState() - .withStreamDescriptor(StreamDescriptor().withName(StateTestConstants.STREAM_NAME2).withNamespace(StateTestConstants.NAMESPACE)) - .withStreamState(Jsons.jsonNode(DbStreamState() - .withStreamName(StateTestConstants.STREAM_NAME2) - .withStreamNamespace(StateTestConstants.NAMESPACE) - .withCursorField(List.of(StateTestConstants.CURSOR_FIELD2)))), - AirbyteStreamState() - .withStreamDescriptor(StreamDescriptor().withName(StateTestConstants.STREAM_NAME3).withNamespace(StateTestConstants.NAMESPACE)) - .withStreamState(Jsons.jsonNode(DbStreamState() - .withStreamName(StateTestConstants.STREAM_NAME3) - .withStreamNamespace(StateTestConstants.NAMESPACE)))) - .stream().sorted(Comparator.comparing { o: AirbyteStreamState -> o.streamDescriptor.name }).collect(Collectors.toList())) - val expected = AirbyteStateMessage() + .withStreamStates( + List.of( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor() + .withName(StateTestConstants.STREAM_NAME1) + .withNamespace(StateTestConstants.NAMESPACE) + ) + .withStreamState( + Jsons.jsonNode( + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME1) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField( + List.of(StateTestConstants.CURSOR_FIELD1) + ) + .withCursor("a") + .withCursorRecordCount(expectedRecordCount) + ) + ), + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor() + .withName(StateTestConstants.STREAM_NAME2) + .withNamespace(StateTestConstants.NAMESPACE) + ) + .withStreamState( + Jsons.jsonNode( + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME2) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField( + List.of(StateTestConstants.CURSOR_FIELD2) + ) + ) + ), + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor() + .withName(StateTestConstants.STREAM_NAME3) + .withNamespace(StateTestConstants.NAMESPACE) + ) + .withStreamState( + Jsons.jsonNode( + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME3) + .withStreamNamespace(StateTestConstants.NAMESPACE) + ) + ) + ) + .stream() + .sorted( + Comparator.comparing { o: AirbyteStreamState -> + o.streamDescriptor.name + } + ) + .collect(Collectors.toList()) + ) + val expected = + AirbyteStateMessage() .withData(Jsons.jsonNode(expectedDbState)) .withGlobal(expectedGlobalState) .withType(AirbyteStateMessage.AirbyteStateType.GLOBAL) - val actualFirstEmission = stateManager.updateAndEmit(StateTestConstants.NAME_NAMESPACE_PAIR1, "a", expectedRecordCount) + val actualFirstEmission = + stateManager.updateAndEmit( + StateTestConstants.NAME_NAMESPACE_PAIR1, + "a", + expectedRecordCount + ) Assertions.assertEquals(expected, actualFirstEmission) } @@ -126,84 +215,161 @@ class GlobalStateManagerTest { @Disabled("Failing test.") @Test fun testToState() { - val catalog = ConfiguredAirbyteCatalog() - .withStreams(List.of( + val catalog = + ConfiguredAirbyteCatalog() + .withStreams( + List.of( ConfiguredAirbyteStream() - .withStream(AirbyteStream().withName(StateTestConstants.STREAM_NAME1).withNamespace(StateTestConstants.NAMESPACE)) - .withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)), + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME1) + .withNamespace(StateTestConstants.NAMESPACE) + ) + .withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)), ConfiguredAirbyteStream() - .withStream(AirbyteStream().withName(StateTestConstants.STREAM_NAME2).withNamespace(StateTestConstants.NAMESPACE)) - .withCursorField(List.of(StateTestConstants.CURSOR_FIELD2)), + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME2) + .withNamespace(StateTestConstants.NAMESPACE) + ) + .withCursorField(List.of(StateTestConstants.CURSOR_FIELD2)), ConfiguredAirbyteStream() - .withStream(AirbyteStream().withName(StateTestConstants.STREAM_NAME3).withNamespace(StateTestConstants.NAMESPACE)))) + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME3) + .withNamespace(StateTestConstants.NAMESPACE) + ) + ) + ) val cdcState = CdcState().withState(Jsons.jsonNode(Map.of("foo", "bar", "baz", 5))) - val globalState = AirbyteGlobalState().withSharedState(Jsons.jsonNode(DbState())).withStreamStates( - List.of(AirbyteStreamState().withStreamDescriptor(StreamDescriptor()).withStreamState(Jsons.jsonNode(DbStreamState())))) + val globalState = + AirbyteGlobalState() + .withSharedState(Jsons.jsonNode(DbState())) + .withStreamStates( + List.of( + AirbyteStreamState() + .withStreamDescriptor(StreamDescriptor()) + .withStreamState(Jsons.jsonNode(DbStreamState())) + ) + ) val stateManager: StateManager = - GlobalStateManager(AirbyteStateMessage().withType(AirbyteStateMessage.AirbyteStateType.GLOBAL).withGlobal(globalState), catalog) + GlobalStateManager( + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.GLOBAL) + .withGlobal(globalState), + catalog + ) stateManager.cdcStateManager.cdcState = cdcState - val expectedDbState = DbState() + val expectedDbState = + DbState() .withCdc(true) .withCdcState(cdcState) - .withStreams(List.of( - DbStreamState() + .withStreams( + List.of( + DbStreamState() .withStreamName(StateTestConstants.STREAM_NAME1) .withStreamNamespace(StateTestConstants.NAMESPACE) .withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)) .withCursor("a") .withCursorRecordCount(1L), - DbStreamState() + DbStreamState() .withStreamName(StateTestConstants.STREAM_NAME2) .withStreamNamespace(StateTestConstants.NAMESPACE) .withCursorField(List.of(StateTestConstants.CURSOR_FIELD2)), - DbStreamState() + DbStreamState() .withStreamName(StateTestConstants.STREAM_NAME3) - .withStreamNamespace(StateTestConstants.NAMESPACE)) - .stream().sorted(Comparator.comparing { obj: DbStreamState -> obj.streamName }).collect(Collectors.toList())) + .withStreamNamespace(StateTestConstants.NAMESPACE) + ) + .stream() + .sorted(Comparator.comparing { obj: DbStreamState -> obj.streamName }) + .collect(Collectors.toList()) + ) - val expectedGlobalState = AirbyteGlobalState() + val expectedGlobalState = + AirbyteGlobalState() .withSharedState(Jsons.jsonNode(cdcState)) - .withStreamStates(List.of( - AirbyteStreamState() - .withStreamDescriptor(StreamDescriptor().withName(StateTestConstants.STREAM_NAME1).withNamespace(StateTestConstants.NAMESPACE)) - .withStreamState(Jsons.jsonNode(DbStreamState() - .withStreamName(StateTestConstants.STREAM_NAME1) - .withStreamNamespace(StateTestConstants.NAMESPACE) - .withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)) - .withCursor("a") - .withCursorRecordCount(1L))), - AirbyteStreamState() - .withStreamDescriptor(StreamDescriptor().withName(StateTestConstants.STREAM_NAME2).withNamespace(StateTestConstants.NAMESPACE)) - .withStreamState(Jsons.jsonNode(DbStreamState() - .withStreamName(StateTestConstants.STREAM_NAME2) - .withStreamNamespace(StateTestConstants.NAMESPACE) - .withCursorField(List.of(StateTestConstants.CURSOR_FIELD2)))), - AirbyteStreamState() - .withStreamDescriptor(StreamDescriptor().withName(StateTestConstants.STREAM_NAME3).withNamespace(StateTestConstants.NAMESPACE)) - .withStreamState(Jsons.jsonNode(DbStreamState() - .withStreamName(StateTestConstants.STREAM_NAME3) - .withStreamNamespace(StateTestConstants.NAMESPACE)))) - .stream().sorted(Comparator.comparing { o: AirbyteStreamState -> o.streamDescriptor.name }).collect(Collectors.toList())) - val expected = AirbyteStateMessage() + .withStreamStates( + List.of( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor() + .withName(StateTestConstants.STREAM_NAME1) + .withNamespace(StateTestConstants.NAMESPACE) + ) + .withStreamState( + Jsons.jsonNode( + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME1) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField( + List.of(StateTestConstants.CURSOR_FIELD1) + ) + .withCursor("a") + .withCursorRecordCount(1L) + ) + ), + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor() + .withName(StateTestConstants.STREAM_NAME2) + .withNamespace(StateTestConstants.NAMESPACE) + ) + .withStreamState( + Jsons.jsonNode( + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME2) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField( + List.of(StateTestConstants.CURSOR_FIELD2) + ) + ) + ), + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor() + .withName(StateTestConstants.STREAM_NAME3) + .withNamespace(StateTestConstants.NAMESPACE) + ) + .withStreamState( + Jsons.jsonNode( + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME3) + .withStreamNamespace(StateTestConstants.NAMESPACE) + ) + ) + ) + .stream() + .sorted( + Comparator.comparing { o: AirbyteStreamState -> + o.streamDescriptor.name + } + ) + .collect(Collectors.toList()) + ) + val expected = + AirbyteStateMessage() .withData(Jsons.jsonNode(expectedDbState)) .withGlobal(expectedGlobalState) .withType(AirbyteStateMessage.AirbyteStateType.GLOBAL) - val actualFirstEmission = stateManager.updateAndEmit(StateTestConstants.NAME_NAMESPACE_PAIR1, "a", 1L) + val actualFirstEmission = + stateManager.updateAndEmit(StateTestConstants.NAME_NAMESPACE_PAIR1, "a", 1L) Assertions.assertEquals(expected, actualFirstEmission) } @Test fun testToStateWithNoState() { val catalog = ConfiguredAirbyteCatalog() - val stateManager: StateManager = - GlobalStateManager(AirbyteStateMessage(), catalog) + val stateManager: StateManager = GlobalStateManager(AirbyteStateMessage(), catalog) val airbyteStateMessage = stateManager.toState(Optional.empty()) Assertions.assertNotNull(airbyteStateMessage) - Assertions.assertEquals(AirbyteStateMessage.AirbyteStateType.GLOBAL, airbyteStateMessage.type) + Assertions.assertEquals( + AirbyteStateMessage.AirbyteStateType.GLOBAL, + airbyteStateMessage!!.type + ) Assertions.assertEquals(0, airbyteStateMessage.global.streamStates.size) } @@ -211,14 +377,32 @@ class GlobalStateManagerTest { fun testCdcStateManagerLegacyState() { val catalog = Mockito.mock(ConfiguredAirbyteCatalog::class.java) val cdcState = CdcState().withState(Jsons.jsonNode(Map.of("foo", "bar", "baz", 5))) - val dbState = DbState().withCdcState(CdcState().withState(Jsons.jsonNode(cdcState))) - .withStreams(List - .of(DbStreamState().withStreamName("name").withStreamNamespace("namespace").withCursor("").withCursorField(emptyList()))) + val dbState = + DbState() + .withCdcState(CdcState().withState(Jsons.jsonNode(cdcState))) + .withStreams( + List.of( + DbStreamState() + .withStreamName("name") + .withStreamNamespace("namespace") + .withCursor("") + .withCursorField(emptyList()) + ) + ) .withCdc(true) val stateManager: StateManager = - GlobalStateManager(AirbyteStateMessage().withType(AirbyteStateMessage.AirbyteStateType.LEGACY).withData(Jsons.jsonNode(dbState)), catalog) + GlobalStateManager( + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.LEGACY) + .withData(Jsons.jsonNode(dbState)), + catalog + ) Assertions.assertNotNull(stateManager.cdcStateManager) - Assertions.assertEquals(1, stateManager.cdcStateManager.initialStreamsSynced.size) - Assertions.assertTrue(stateManager.cdcStateManager.initialStreamsSynced.contains(AirbyteStreamNameNamespacePair("name", "namespace"))) + Assertions.assertEquals(1, stateManager.cdcStateManager.initialStreamsSynced!!.size) + Assertions.assertTrue( + stateManager.cdcStateManager.initialStreamsSynced!!.contains( + AirbyteStreamNameNamespacePair("name", "namespace") + ) + ) } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/LegacyStateManagerTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/LegacyStateManagerTest.kt index e4aac20651381..b6a585713b956 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/LegacyStateManagerTest.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/LegacyStateManagerTest.kt @@ -11,144 +11,355 @@ import io.airbyte.protocol.models.v0.AirbyteStateMessage import io.airbyte.protocol.models.v0.AirbyteStream import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream -import org.junit.jupiter.api.Assertions -import org.junit.jupiter.api.Test -import org.mockito.Mockito import java.util.* import java.util.List import java.util.Map import java.util.stream.Collectors +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test +import org.mockito.Mockito -/** - * Test suite for the [LegacyStateManager] class. - */ +/** Test suite for the [LegacyStateManager] class. */ class LegacyStateManagerTest { @Test fun testGetters() { - val state = DbState().withStreams(List.of( - DbStreamState().withStreamName(StateTestConstants.STREAM_NAME1).withStreamNamespace(StateTestConstants.NAMESPACE).withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)) - .withCursor(StateTestConstants.CURSOR), - DbStreamState().withStreamName(StateTestConstants.STREAM_NAME2).withStreamNamespace(StateTestConstants.NAMESPACE))) + val state = + DbState() + .withStreams( + List.of( + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME1) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)) + .withCursor(StateTestConstants.CURSOR), + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME2) + .withStreamNamespace(StateTestConstants.NAMESPACE) + ) + ) - val catalog = ConfiguredAirbyteCatalog() - .withStreams(List.of( + val catalog = + ConfiguredAirbyteCatalog() + .withStreams( + List.of( ConfiguredAirbyteStream() - .withStream(AirbyteStream().withName(StateTestConstants.STREAM_NAME1).withNamespace(StateTestConstants.NAMESPACE)) - .withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)), + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME1) + .withNamespace(StateTestConstants.NAMESPACE) + ) + .withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)), ConfiguredAirbyteStream() - .withStream(AirbyteStream().withName(StateTestConstants.STREAM_NAME2).withNamespace(StateTestConstants.NAMESPACE)))) + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME2) + .withNamespace(StateTestConstants.NAMESPACE) + ) + ) + ) val stateManager: StateManager = LegacyStateManager(state, catalog) - Assertions.assertEquals(Optional.of(StateTestConstants.CURSOR_FIELD1), stateManager.getOriginalCursorField(StateTestConstants.NAME_NAMESPACE_PAIR1)) - Assertions.assertEquals(Optional.of(StateTestConstants.CURSOR), stateManager.getOriginalCursor(StateTestConstants.NAME_NAMESPACE_PAIR1)) - Assertions.assertEquals(Optional.of(StateTestConstants.CURSOR_FIELD1), stateManager.getCursorField(StateTestConstants.NAME_NAMESPACE_PAIR1)) - Assertions.assertEquals(Optional.of(StateTestConstants.CURSOR), stateManager.getCursor(StateTestConstants.NAME_NAMESPACE_PAIR1)) + Assertions.assertEquals( + Optional.of(StateTestConstants.CURSOR_FIELD1), + stateManager.getOriginalCursorField(StateTestConstants.NAME_NAMESPACE_PAIR1) + ) + Assertions.assertEquals( + Optional.of(StateTestConstants.CURSOR), + stateManager.getOriginalCursor(StateTestConstants.NAME_NAMESPACE_PAIR1) + ) + Assertions.assertEquals( + Optional.of(StateTestConstants.CURSOR_FIELD1), + stateManager.getCursorField(StateTestConstants.NAME_NAMESPACE_PAIR1) + ) + Assertions.assertEquals( + Optional.of(StateTestConstants.CURSOR), + stateManager.getCursor(StateTestConstants.NAME_NAMESPACE_PAIR1) + ) - Assertions.assertEquals(Optional.empty(), stateManager.getOriginalCursorField(StateTestConstants.NAME_NAMESPACE_PAIR2)) - Assertions.assertEquals(Optional.empty(), stateManager.getOriginalCursor(StateTestConstants.NAME_NAMESPACE_PAIR2)) - Assertions.assertEquals(Optional.empty(), stateManager.getCursorField(StateTestConstants.NAME_NAMESPACE_PAIR2)) - Assertions.assertEquals(Optional.empty(), stateManager.getCursor(StateTestConstants.NAME_NAMESPACE_PAIR2)) + Assertions.assertEquals( + Optional.empty(), + stateManager.getOriginalCursorField(StateTestConstants.NAME_NAMESPACE_PAIR2) + ) + Assertions.assertEquals( + Optional.empty(), + stateManager.getOriginalCursor(StateTestConstants.NAME_NAMESPACE_PAIR2) + ) + Assertions.assertEquals( + Optional.empty(), + stateManager.getCursorField(StateTestConstants.NAME_NAMESPACE_PAIR2) + ) + Assertions.assertEquals( + Optional.empty(), + stateManager.getCursor(StateTestConstants.NAME_NAMESPACE_PAIR2) + ) } @Test fun testToState() { - val catalog = ConfiguredAirbyteCatalog() - .withStreams(List.of( + val catalog = + ConfiguredAirbyteCatalog() + .withStreams( + List.of( ConfiguredAirbyteStream() - .withStream(AirbyteStream().withName(StateTestConstants.STREAM_NAME1).withNamespace(StateTestConstants.NAMESPACE)) - .withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)), + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME1) + .withNamespace(StateTestConstants.NAMESPACE) + ) + .withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)), ConfiguredAirbyteStream() - .withStream(AirbyteStream().withName(StateTestConstants.STREAM_NAME2).withNamespace(StateTestConstants.NAMESPACE)) - .withCursorField(List.of(StateTestConstants.CURSOR_FIELD2)), + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME2) + .withNamespace(StateTestConstants.NAMESPACE) + ) + .withCursorField(List.of(StateTestConstants.CURSOR_FIELD2)), ConfiguredAirbyteStream() - .withStream(AirbyteStream().withName(StateTestConstants.STREAM_NAME3).withNamespace(StateTestConstants.NAMESPACE)))) + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME3) + .withNamespace(StateTestConstants.NAMESPACE) + ) + ) + ) val stateManager: StateManager = LegacyStateManager(DbState(), catalog) - val expectedFirstEmission = AirbyteStateMessage() + val expectedFirstEmission = + AirbyteStateMessage() .withType(AirbyteStateMessage.AirbyteStateType.LEGACY) - .withData(Jsons.jsonNode(DbState().withStreams(List.of( - DbStreamState().withStreamName(StateTestConstants.STREAM_NAME1).withStreamNamespace(StateTestConstants.NAMESPACE).withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)) - .withCursor("a"), - DbStreamState().withStreamName(StateTestConstants.STREAM_NAME2).withStreamNamespace(StateTestConstants.NAMESPACE).withCursorField(List.of(StateTestConstants.CURSOR_FIELD2)), - DbStreamState().withStreamName(StateTestConstants.STREAM_NAME3).withStreamNamespace(StateTestConstants.NAMESPACE)) - .stream().sorted(Comparator.comparing { obj: DbStreamState -> obj.streamName }).collect(Collectors.toList())) - .withCdc(false))) - val actualFirstEmission = stateManager.updateAndEmit(StateTestConstants.NAME_NAMESPACE_PAIR1, "a") + .withData( + Jsons.jsonNode( + DbState() + .withStreams( + List.of( + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME1) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField( + List.of(StateTestConstants.CURSOR_FIELD1) + ) + .withCursor("a"), + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME2) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField( + List.of(StateTestConstants.CURSOR_FIELD2) + ), + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME3) + .withStreamNamespace(StateTestConstants.NAMESPACE) + ) + .stream() + .sorted( + Comparator.comparing { obj: DbStreamState -> + obj.streamName + } + ) + .collect(Collectors.toList()) + ) + .withCdc(false) + ) + ) + val actualFirstEmission = + stateManager.updateAndEmit(StateTestConstants.NAME_NAMESPACE_PAIR1, "a") Assertions.assertEquals(expectedFirstEmission, actualFirstEmission) - val expectedSecondEmission = AirbyteStateMessage() + val expectedSecondEmission = + AirbyteStateMessage() .withType(AirbyteStateMessage.AirbyteStateType.LEGACY) - .withData(Jsons.jsonNode(DbState().withStreams(List.of( - DbStreamState().withStreamName(StateTestConstants.STREAM_NAME1).withStreamNamespace(StateTestConstants.NAMESPACE).withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)) - .withCursor("a"), - DbStreamState().withStreamName(StateTestConstants.STREAM_NAME2).withStreamNamespace(StateTestConstants.NAMESPACE).withCursorField(List.of(StateTestConstants.CURSOR_FIELD2)) - .withCursor("b"), - DbStreamState().withStreamName(StateTestConstants.STREAM_NAME3).withStreamNamespace(StateTestConstants.NAMESPACE)) - .stream().sorted(Comparator.comparing { obj: DbStreamState -> obj.streamName }).collect(Collectors.toList())) - .withCdc(false))) - val actualSecondEmission = stateManager.updateAndEmit(StateTestConstants.NAME_NAMESPACE_PAIR2, "b") + .withData( + Jsons.jsonNode( + DbState() + .withStreams( + List.of( + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME1) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField( + List.of(StateTestConstants.CURSOR_FIELD1) + ) + .withCursor("a"), + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME2) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField( + List.of(StateTestConstants.CURSOR_FIELD2) + ) + .withCursor("b"), + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME3) + .withStreamNamespace(StateTestConstants.NAMESPACE) + ) + .stream() + .sorted( + Comparator.comparing { obj: DbStreamState -> + obj.streamName + } + ) + .collect(Collectors.toList()) + ) + .withCdc(false) + ) + ) + val actualSecondEmission = + stateManager.updateAndEmit(StateTestConstants.NAME_NAMESPACE_PAIR2, "b") Assertions.assertEquals(expectedSecondEmission, actualSecondEmission) } @Test fun testToStateNullCursorField() { - val catalog = ConfiguredAirbyteCatalog() - .withStreams(List.of( + val catalog = + ConfiguredAirbyteCatalog() + .withStreams( + List.of( ConfiguredAirbyteStream() - .withStream(AirbyteStream().withName(StateTestConstants.STREAM_NAME1).withNamespace(StateTestConstants.NAMESPACE)) - .withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)), + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME1) + .withNamespace(StateTestConstants.NAMESPACE) + ) + .withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)), ConfiguredAirbyteStream() - .withStream(AirbyteStream().withName(StateTestConstants.STREAM_NAME2).withNamespace(StateTestConstants.NAMESPACE)))) + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME2) + .withNamespace(StateTestConstants.NAMESPACE) + ) + ) + ) val stateManager: StateManager = LegacyStateManager(DbState(), catalog) - val expectedFirstEmission = AirbyteStateMessage() + val expectedFirstEmission = + AirbyteStateMessage() .withType(AirbyteStateMessage.AirbyteStateType.LEGACY) - .withData(Jsons.jsonNode(DbState().withStreams(List.of( - DbStreamState().withStreamName(StateTestConstants.STREAM_NAME1).withStreamNamespace(StateTestConstants.NAMESPACE).withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)) - .withCursor("a"), - DbStreamState().withStreamName(StateTestConstants.STREAM_NAME2).withStreamNamespace(StateTestConstants.NAMESPACE)) - .stream().sorted(Comparator.comparing { obj: DbStreamState -> obj.streamName }).collect(Collectors.toList())) - .withCdc(false))) - - val actualFirstEmission = stateManager.updateAndEmit(StateTestConstants.NAME_NAMESPACE_PAIR1, "a") + .withData( + Jsons.jsonNode( + DbState() + .withStreams( + List.of( + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME1) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField( + List.of(StateTestConstants.CURSOR_FIELD1) + ) + .withCursor("a"), + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME2) + .withStreamNamespace(StateTestConstants.NAMESPACE) + ) + .stream() + .sorted( + Comparator.comparing { obj: DbStreamState -> + obj.streamName + } + ) + .collect(Collectors.toList()) + ) + .withCdc(false) + ) + ) + + val actualFirstEmission = + stateManager.updateAndEmit(StateTestConstants.NAME_NAMESPACE_PAIR1, "a") Assertions.assertEquals(expectedFirstEmission, actualFirstEmission) } @Test fun testCursorNotUpdatedForCdc() { - val catalog = ConfiguredAirbyteCatalog() - .withStreams(List.of( + val catalog = + ConfiguredAirbyteCatalog() + .withStreams( + List.of( ConfiguredAirbyteStream() - .withStream(AirbyteStream().withName(StateTestConstants.STREAM_NAME1).withNamespace(StateTestConstants.NAMESPACE)) - .withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)), + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME1) + .withNamespace(StateTestConstants.NAMESPACE) + ) + .withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)), ConfiguredAirbyteStream() - .withStream(AirbyteStream().withName(StateTestConstants.STREAM_NAME2).withNamespace(StateTestConstants.NAMESPACE)))) + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME2) + .withNamespace(StateTestConstants.NAMESPACE) + ) + ) + ) val state = DbState() state.cdc = true val stateManager: StateManager = LegacyStateManager(state, catalog) - val expectedFirstEmission = AirbyteStateMessage() + val expectedFirstEmission = + AirbyteStateMessage() .withType(AirbyteStateMessage.AirbyteStateType.LEGACY) - .withData(Jsons.jsonNode(DbState().withStreams(List.of( - DbStreamState().withStreamName(StateTestConstants.STREAM_NAME1).withStreamNamespace(StateTestConstants.NAMESPACE).withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)) - .withCursor(null), - DbStreamState().withStreamName(StateTestConstants.STREAM_NAME2).withStreamNamespace(StateTestConstants.NAMESPACE).withCursorField(listOf())) - .stream().sorted(Comparator.comparing { obj: DbStreamState -> obj.streamName }).collect(Collectors.toList())) - .withCdc(true))) - val actualFirstEmission = stateManager.updateAndEmit(StateTestConstants.NAME_NAMESPACE_PAIR1, "a") + .withData( + Jsons.jsonNode( + DbState() + .withStreams( + List.of( + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME1) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField( + List.of(StateTestConstants.CURSOR_FIELD1) + ) + .withCursor(null), + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME2) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField(listOf()) + ) + .stream() + .sorted( + Comparator.comparing { obj: DbStreamState -> + obj.streamName + } + ) + .collect(Collectors.toList()) + ) + .withCdc(true) + ) + ) + val actualFirstEmission = + stateManager.updateAndEmit(StateTestConstants.NAME_NAMESPACE_PAIR1, "a") Assertions.assertEquals(expectedFirstEmission, actualFirstEmission) - val expectedSecondEmission = AirbyteStateMessage() + val expectedSecondEmission = + AirbyteStateMessage() .withType(AirbyteStateMessage.AirbyteStateType.LEGACY) - .withData(Jsons.jsonNode(DbState().withStreams(List.of( - DbStreamState().withStreamName(StateTestConstants.STREAM_NAME1).withStreamNamespace(StateTestConstants.NAMESPACE).withCursorField(List.of(StateTestConstants.CURSOR_FIELD1)) - .withCursor(null), - DbStreamState().withStreamName(StateTestConstants.STREAM_NAME2).withStreamNamespace(StateTestConstants.NAMESPACE).withCursorField(listOf()) - .withCursor(null)) - .stream().sorted(Comparator.comparing { obj: DbStreamState -> obj.streamName }).collect(Collectors.toList())) - .withCdc(true))) - val actualSecondEmission = stateManager.updateAndEmit(StateTestConstants.NAME_NAMESPACE_PAIR2, "b") + .withData( + Jsons.jsonNode( + DbState() + .withStreams( + List.of( + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME1) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField( + List.of(StateTestConstants.CURSOR_FIELD1) + ) + .withCursor(null), + DbStreamState() + .withStreamName(StateTestConstants.STREAM_NAME2) + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withCursorField(listOf()) + .withCursor(null) + ) + .stream() + .sorted( + Comparator.comparing { obj: DbStreamState -> + obj.streamName + } + ) + .collect(Collectors.toList()) + ) + .withCdc(true) + ) + ) + val actualSecondEmission = + stateManager.updateAndEmit(StateTestConstants.NAME_NAMESPACE_PAIR2, "b") Assertions.assertEquals(expectedSecondEmission, actualSecondEmission) } @@ -156,8 +367,16 @@ class LegacyStateManagerTest { fun testCdcStateManager() { val catalog = Mockito.mock(ConfiguredAirbyteCatalog::class.java) val cdcState = CdcState().withState(Jsons.jsonNode(Map.of("foo", "bar", "baz", 5))) - val dbState = DbState().withCdcState(cdcState).withStreams(List.of( - DbStreamState().withStreamNamespace(StateTestConstants.NAMESPACE).withStreamName(StateTestConstants.STREAM_NAME1))) + val dbState = + DbState() + .withCdcState(cdcState) + .withStreams( + List.of( + DbStreamState() + .withStreamNamespace(StateTestConstants.NAMESPACE) + .withStreamName(StateTestConstants.STREAM_NAME1) + ) + ) val stateManager: StateManager = LegacyStateManager(dbState, catalog) Assertions.assertNotNull(stateManager.cdcStateManager) Assertions.assertEquals(cdcState, stateManager.cdcStateManager.cdcState) diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateIteratorForTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateIteratorForTest.kt new file mode 100644 index 0000000000000..2097439769900 --- /dev/null +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateIteratorForTest.kt @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2024 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.cdk.integrations.source.relationaldb.state + +import io.airbyte.protocol.models.v0.AirbyteMessage +import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream + +class SourceStateIteratorForTest( + messageIterator: Iterator, + stream: ConfiguredAirbyteStream, + sourceStateMessageProducer: SourceStateMessageProducer, + stateEmitFrequency: StateEmitFrequency +) : + SourceStateIterator( + messageIterator, + stream, + sourceStateMessageProducer, + stateEmitFrequency + ) { + public override fun computeNext(): AirbyteMessage? = super.computeNext() +} diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateIteratorTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateIteratorTest.kt index 8607359ec4c84..fb34ea35822b1 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateIteratorTest.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/SourceStateIteratorTest.kt @@ -4,40 +4,49 @@ package io.airbyte.cdk.integrations.source.relationaldb.state import io.airbyte.protocol.models.v0.* +import java.time.Duration import org.junit.Assert import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test import org.mockito.ArgumentMatchers import org.mockito.Mockito -import java.time.Duration +import org.mockito.Mockito.mock +import org.mockito.kotlin.any +import org.mockito.kotlin.eq class SourceStateIteratorTest { - var mockProducer: SourceStateMessageProducer<*>? = null - var messageIterator: Iterator? = null - var stream: ConfiguredAirbyteStream? = null + lateinit var mockProducer: SourceStateMessageProducer + lateinit var messageIterator: Iterator + lateinit var stream: ConfiguredAirbyteStream - var sourceStateIterator: SourceStateIterator<*>? = null + var sourceStateIterator: SourceStateIteratorForTest<*>? = null @BeforeEach fun setup() { - mockProducer = Mockito.mock(SourceStateMessageProducer::class.java) - stream = Mockito.mock(ConfiguredAirbyteStream::class.java) - messageIterator = Mockito.mock>(MutableIterator::class.java) + mockProducer = mock() + stream = mock() + messageIterator = mock() val stateEmitFrequency = StateEmitFrequency(1L, Duration.ofSeconds(100L)) - sourceStateIterator = SourceStateIterator(messageIterator, stream, mockProducer, stateEmitFrequency) + sourceStateIterator = + SourceStateIteratorForTest(messageIterator, stream, mockProducer, stateEmitFrequency) } - // Provides a way to generate a record message and will verify corresponding spied functions have + // Provides a way to generate a record message and will verify corresponding spied functions + // have // been called. fun processRecordMessage() { Mockito.doReturn(true).`when`(messageIterator).hasNext() - Mockito.doReturn(false).`when`(mockProducer).shouldEmitStateMessage(ArgumentMatchers.eq(stream)) - val message = AirbyteMessage().withType(AirbyteMessage.Type.RECORD).withRecord(AirbyteRecordMessage()) - Mockito.doReturn(message).`when`(mockProducer).processRecordMessage(ArgumentMatchers.eq(stream), ArgumentMatchers.any()) + Mockito.doReturn(false) + .`when`(mockProducer) + .shouldEmitStateMessage(ArgumentMatchers.eq(stream)) + val message = + AirbyteMessage().withType(AirbyteMessage.Type.RECORD).withRecord(AirbyteRecordMessage()) + Mockito.doReturn(message).`when`(mockProducer).processRecordMessage(eq(stream), any()) Mockito.doReturn(message).`when`(messageIterator).next() Assert.assertEquals(message, sourceStateIterator!!.computeNext()) - Mockito.verify(mockProducer, Mockito.atLeastOnce()).processRecordMessage(ArgumentMatchers.eq(stream), ArgumentMatchers.eq(message)) + Mockito.verify(mockProducer, Mockito.atLeastOnce()) + .processRecordMessage(eq(stream), eq(message)) } @Test @@ -48,10 +57,13 @@ class SourceStateIteratorTest { @Test fun testShouldEmitStateMessage() { processRecordMessage() - Mockito.doReturn(true).`when`(mockProducer).shouldEmitStateMessage(ArgumentMatchers.eq(stream)) + Mockito.doReturn(true) + .`when`(mockProducer) + .shouldEmitStateMessage(ArgumentMatchers.eq(stream)) val stateMessage = AirbyteStateMessage() Mockito.doReturn(stateMessage).`when`(mockProducer).generateStateMessageAtCheckpoint(stream) - val expectedMessage = AirbyteMessage().withType(AirbyteMessage.Type.STATE).withState(stateMessage) + val expectedMessage = + AirbyteMessage().withType(AirbyteMessage.Type.STATE).withState(stateMessage) expectedMessage.state.withSourceStats(AirbyteStateStats().withRecordCount(1.0)) Assert.assertEquals(expectedMessage, sourceStateIterator!!.computeNext()) } @@ -63,7 +75,8 @@ class SourceStateIteratorTest { Mockito.doReturn(false).`when`(messageIterator).hasNext() val stateMessage = AirbyteStateMessage() Mockito.doReturn(stateMessage).`when`(mockProducer).createFinalStateMessage(stream) - val expectedMessage = AirbyteMessage().withType(AirbyteMessage.Type.STATE).withState(stateMessage) + val expectedMessage = + AirbyteMessage().withType(AirbyteMessage.Type.STATE).withState(stateMessage) expectedMessage.state.withSourceStats(AirbyteStateStats().withRecordCount(2.0)) Assert.assertEquals(expectedMessage, sourceStateIterator!!.computeNext()) } @@ -82,7 +95,9 @@ class SourceStateIteratorTest { @Test fun testShouldRethrowExceptions() { processRecordMessage() - Mockito.doThrow(ArrayIndexOutOfBoundsException("unexpected error")).`when`(messageIterator).hasNext() + Mockito.doThrow(ArrayIndexOutOfBoundsException("unexpected error")) + .`when`(messageIterator) + .hasNext() Assert.assertThrows(RuntimeException::class.java) { sourceStateIterator!!.computeNext() } } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateGeneratorUtilsTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateGeneratorUtilsTest.kt index 8a4cc603eabca..e9334ff081f34 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateGeneratorUtilsTest.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateGeneratorUtilsTest.kt @@ -7,9 +7,7 @@ import io.airbyte.protocol.models.v0.StreamDescriptor import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.Test -/** - * Test suite for the [StateGeneratorUtils] class. - */ +/** Test suite for the [StateGeneratorUtils] class. */ class StateGeneratorUtilsTest { @Test fun testValidStreamDescriptor() { diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateManagerFactoryTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateManagerFactoryTest.kt index 74b6b45e96e5b..ca8c76753b0c2 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateManagerFactoryTest.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateManagerFactoryTest.kt @@ -8,41 +8,63 @@ import io.airbyte.cdk.integrations.source.relationaldb.models.DbState import io.airbyte.cdk.integrations.source.relationaldb.models.DbStreamState import io.airbyte.commons.json.Jsons import io.airbyte.protocol.models.v0.* +import java.util.List import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.Test import org.mockito.Mockito -import java.util.List -/** - * Test suite for the [StateManagerFactory] class. - */ +/** Test suite for the [StateManagerFactory] class. */ class StateManagerFactoryTest { @Test fun testNullOrEmptyState() { val catalog = Mockito.mock(ConfiguredAirbyteCatalog::class.java) Assertions.assertThrows(IllegalArgumentException::class.java) { - StateManagerFactory.createStateManager(AirbyteStateMessage.AirbyteStateType.GLOBAL, null, catalog) + StateManagerFactory.createStateManager( + AirbyteStateMessage.AirbyteStateType.GLOBAL, + null, + catalog + ) } Assertions.assertThrows(IllegalArgumentException::class.java) { - StateManagerFactory.createStateManager(AirbyteStateMessage.AirbyteStateType.GLOBAL, listOf(), catalog) + StateManagerFactory.createStateManager( + AirbyteStateMessage.AirbyteStateType.GLOBAL, + listOf(), + catalog + ) } Assertions.assertThrows(IllegalArgumentException::class.java) { - StateManagerFactory.createStateManager(AirbyteStateMessage.AirbyteStateType.LEGACY, null, catalog) + StateManagerFactory.createStateManager( + AirbyteStateMessage.AirbyteStateType.LEGACY, + null, + catalog + ) } Assertions.assertThrows(IllegalArgumentException::class.java) { - StateManagerFactory.createStateManager(AirbyteStateMessage.AirbyteStateType.LEGACY, listOf(), catalog) + StateManagerFactory.createStateManager( + AirbyteStateMessage.AirbyteStateType.LEGACY, + listOf(), + catalog + ) } Assertions.assertThrows(IllegalArgumentException::class.java) { - StateManagerFactory.createStateManager(AirbyteStateMessage.AirbyteStateType.STREAM, null, catalog) + StateManagerFactory.createStateManager( + AirbyteStateMessage.AirbyteStateType.STREAM, + null, + catalog + ) } Assertions.assertThrows(IllegalArgumentException::class.java) { - StateManagerFactory.createStateManager(AirbyteStateMessage.AirbyteStateType.STREAM, listOf(), catalog) + StateManagerFactory.createStateManager( + AirbyteStateMessage.AirbyteStateType.STREAM, + listOf(), + catalog + ) } } @@ -52,7 +74,12 @@ class StateManagerFactoryTest { val airbyteStateMessage = Mockito.mock(AirbyteStateMessage::class.java) Mockito.`when`(airbyteStateMessage.data).thenReturn(Jsons.jsonNode(DbState())) - val stateManager = StateManagerFactory.createStateManager(AirbyteStateMessage.AirbyteStateType.LEGACY, List.of(airbyteStateMessage), catalog) + val stateManager = + StateManagerFactory.createStateManager( + AirbyteStateMessage.AirbyteStateType.LEGACY, + List.of(airbyteStateMessage), + catalog + ) Assertions.assertNotNull(stateManager) Assertions.assertEquals(LegacyStateManager::class.java, stateManager.javaClass) @@ -62,12 +89,32 @@ class StateManagerFactoryTest { fun testGlobalStateManagerCreation() { val catalog = Mockito.mock(ConfiguredAirbyteCatalog::class.java) val globalState = - AirbyteGlobalState().withSharedState(Jsons.jsonNode(DbState().withCdcState(CdcState().withState(Jsons.jsonNode(DbState()))))) - .withStreamStates(List.of(AirbyteStreamState().withStreamDescriptor(StreamDescriptor().withNamespace(NAMESPACE).withName(NAME)) - .withStreamState(Jsons.jsonNode(DbStreamState())))) - val airbyteStateMessage = AirbyteStateMessage().withType(AirbyteStateMessage.AirbyteStateType.GLOBAL).withGlobal(globalState) + AirbyteGlobalState() + .withSharedState( + Jsons.jsonNode( + DbState().withCdcState(CdcState().withState(Jsons.jsonNode(DbState()))) + ) + ) + .withStreamStates( + List.of( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor().withNamespace(NAMESPACE).withName(NAME) + ) + .withStreamState(Jsons.jsonNode(DbStreamState())) + ) + ) + val airbyteStateMessage = + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.GLOBAL) + .withGlobal(globalState) - val stateManager = StateManagerFactory.createStateManager(AirbyteStateMessage.AirbyteStateType.GLOBAL, List.of(airbyteStateMessage), catalog) + val stateManager = + StateManagerFactory.createStateManager( + AirbyteStateMessage.AirbyteStateType.GLOBAL, + List.of(airbyteStateMessage), + catalog + ) Assertions.assertNotNull(stateManager) Assertions.assertEquals(GlobalStateManager::class.java, stateManager.javaClass) @@ -77,13 +124,23 @@ class StateManagerFactoryTest { fun testGlobalStateManagerCreationFromLegacyState() { val catalog = Mockito.mock(ConfiguredAirbyteCatalog::class.java) val cdcState = CdcState() - val dbState = DbState() + val dbState = + DbState() .withCdcState(cdcState) - .withStreams(List.of(DbStreamState().withStreamName(NAME).withStreamNamespace(NAMESPACE))) + .withStreams( + List.of(DbStreamState().withStreamName(NAME).withStreamNamespace(NAMESPACE)) + ) val airbyteStateMessage = - AirbyteStateMessage().withType(AirbyteStateMessage.AirbyteStateType.LEGACY).withData(Jsons.jsonNode(dbState)) + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.LEGACY) + .withData(Jsons.jsonNode(dbState)) - val stateManager = StateManagerFactory.createStateManager(AirbyteStateMessage.AirbyteStateType.GLOBAL, List.of(airbyteStateMessage), catalog) + val stateManager = + StateManagerFactory.createStateManager( + AirbyteStateMessage.AirbyteStateType.GLOBAL, + List.of(airbyteStateMessage), + catalog + ) Assertions.assertNotNull(stateManager) Assertions.assertEquals(GlobalStateManager::class.java, stateManager.javaClass) @@ -92,25 +149,57 @@ class StateManagerFactoryTest { @Test fun testGlobalStateManagerCreationFromStreamState() { val catalog = Mockito.mock(ConfiguredAirbyteCatalog::class.java) - val airbyteStateMessage = AirbyteStateMessage().withType(AirbyteStateMessage.AirbyteStateType.STREAM) - .withStream(AirbyteStreamState().withStreamDescriptor(StreamDescriptor().withName(NAME).withNamespace( - NAMESPACE)).withStreamState(Jsons.jsonNode(DbStreamState()))) + val airbyteStateMessage = + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor().withName(NAME).withNamespace(NAMESPACE) + ) + .withStreamState(Jsons.jsonNode(DbStreamState())) + ) - Assertions.assertThrows(IllegalArgumentException::class.java - ) { StateManagerFactory.createStateManager(AirbyteStateMessage.AirbyteStateType.GLOBAL, List.of(airbyteStateMessage), catalog) } + Assertions.assertThrows(IllegalArgumentException::class.java) { + StateManagerFactory.createStateManager( + AirbyteStateMessage.AirbyteStateType.GLOBAL, + List.of(airbyteStateMessage), + catalog + ) + } } @Test fun testGlobalStateManagerCreationWithLegacyDataPresent() { val catalog = Mockito.mock(ConfiguredAirbyteCatalog::class.java) val globalState = - AirbyteGlobalState().withSharedState(Jsons.jsonNode(DbState().withCdcState(CdcState().withState(Jsons.jsonNode(DbState()))))) - .withStreamStates(List.of(AirbyteStreamState().withStreamDescriptor(StreamDescriptor().withNamespace(NAMESPACE).withName(NAME)) - .withStreamState(Jsons.jsonNode(DbStreamState())))) + AirbyteGlobalState() + .withSharedState( + Jsons.jsonNode( + DbState().withCdcState(CdcState().withState(Jsons.jsonNode(DbState()))) + ) + ) + .withStreamStates( + List.of( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor().withNamespace(NAMESPACE).withName(NAME) + ) + .withStreamState(Jsons.jsonNode(DbStreamState())) + ) + ) val airbyteStateMessage = - AirbyteStateMessage().withType(AirbyteStateMessage.AirbyteStateType.GLOBAL).withGlobal(globalState).withData(Jsons.jsonNode(DbState())) + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.GLOBAL) + .withGlobal(globalState) + .withData(Jsons.jsonNode(DbState())) - val stateManager = StateManagerFactory.createStateManager(AirbyteStateMessage.AirbyteStateType.GLOBAL, List.of(airbyteStateMessage), catalog) + val stateManager = + StateManagerFactory.createStateManager( + AirbyteStateMessage.AirbyteStateType.GLOBAL, + List.of(airbyteStateMessage), + catalog + ) Assertions.assertNotNull(stateManager) Assertions.assertEquals(GlobalStateManager::class.java, stateManager.javaClass) @@ -119,11 +208,23 @@ class StateManagerFactoryTest { @Test fun testStreamStateManagerCreation() { val catalog = Mockito.mock(ConfiguredAirbyteCatalog::class.java) - val airbyteStateMessage = AirbyteStateMessage().withType(AirbyteStateMessage.AirbyteStateType.STREAM) - .withStream(AirbyteStreamState().withStreamDescriptor(StreamDescriptor().withName(NAME).withNamespace( - NAMESPACE)).withStreamState(Jsons.jsonNode(DbStreamState()))) - - val stateManager = StateManagerFactory.createStateManager(AirbyteStateMessage.AirbyteStateType.STREAM, List.of(airbyteStateMessage), catalog) + val airbyteStateMessage = + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor().withName(NAME).withNamespace(NAMESPACE) + ) + .withStreamState(Jsons.jsonNode(DbStreamState())) + ) + + val stateManager = + StateManagerFactory.createStateManager( + AirbyteStateMessage.AirbyteStateType.STREAM, + List.of(airbyteStateMessage), + catalog + ) Assertions.assertNotNull(stateManager) Assertions.assertEquals(StreamStateManager::class.java, stateManager.javaClass) @@ -133,13 +234,23 @@ class StateManagerFactoryTest { fun testStreamStateManagerCreationFromLegacy() { val catalog = Mockito.mock(ConfiguredAirbyteCatalog::class.java) val cdcState = CdcState() - val dbState = DbState() + val dbState = + DbState() .withCdcState(cdcState) - .withStreams(List.of(DbStreamState().withStreamName(NAME).withStreamNamespace(NAMESPACE))) + .withStreams( + List.of(DbStreamState().withStreamName(NAME).withStreamNamespace(NAMESPACE)) + ) val airbyteStateMessage = - AirbyteStateMessage().withType(AirbyteStateMessage.AirbyteStateType.LEGACY).withData(Jsons.jsonNode(dbState)) + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.LEGACY) + .withData(Jsons.jsonNode(dbState)) - val stateManager = StateManagerFactory.createStateManager(AirbyteStateMessage.AirbyteStateType.STREAM, List.of(airbyteStateMessage), catalog) + val stateManager = + StateManagerFactory.createStateManager( + AirbyteStateMessage.AirbyteStateType.STREAM, + List.of(airbyteStateMessage), + catalog + ) Assertions.assertNotNull(stateManager) Assertions.assertEquals(StreamStateManager::class.java, stateManager.javaClass) @@ -149,24 +260,56 @@ class StateManagerFactoryTest { fun testStreamStateManagerCreationFromGlobal() { val catalog = Mockito.mock(ConfiguredAirbyteCatalog::class.java) val globalState = - AirbyteGlobalState().withSharedState(Jsons.jsonNode(DbState().withCdcState(CdcState().withState(Jsons.jsonNode(DbState()))))) - .withStreamStates(List.of(AirbyteStreamState().withStreamDescriptor(StreamDescriptor().withNamespace(NAMESPACE).withName(NAME)) - .withStreamState(Jsons.jsonNode(DbStreamState())))) - val airbyteStateMessage = AirbyteStateMessage().withType(AirbyteStateMessage.AirbyteStateType.GLOBAL).withGlobal(globalState) + AirbyteGlobalState() + .withSharedState( + Jsons.jsonNode( + DbState().withCdcState(CdcState().withState(Jsons.jsonNode(DbState()))) + ) + ) + .withStreamStates( + List.of( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor().withNamespace(NAMESPACE).withName(NAME) + ) + .withStreamState(Jsons.jsonNode(DbStreamState())) + ) + ) + val airbyteStateMessage = + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.GLOBAL) + .withGlobal(globalState) - Assertions.assertThrows(IllegalArgumentException::class.java - ) { StateManagerFactory.createStateManager(AirbyteStateMessage.AirbyteStateType.STREAM, List.of(airbyteStateMessage), catalog) } + Assertions.assertThrows(IllegalArgumentException::class.java) { + StateManagerFactory.createStateManager( + AirbyteStateMessage.AirbyteStateType.STREAM, + List.of(airbyteStateMessage), + catalog + ) + } } @Test fun testStreamStateManagerCreationWithLegacyDataPresent() { val catalog = Mockito.mock(ConfiguredAirbyteCatalog::class.java) - val airbyteStateMessage = AirbyteStateMessage().withType(AirbyteStateMessage.AirbyteStateType.STREAM) - .withStream(AirbyteStreamState().withStreamDescriptor(StreamDescriptor().withName(NAME).withNamespace( - NAMESPACE)).withStreamState(Jsons.jsonNode(DbStreamState()))) + val airbyteStateMessage = + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor().withName(NAME).withNamespace(NAMESPACE) + ) + .withStreamState(Jsons.jsonNode(DbStreamState())) + ) .withData(Jsons.jsonNode(DbState())) - val stateManager = StateManagerFactory.createStateManager(AirbyteStateMessage.AirbyteStateType.STREAM, List.of(airbyteStateMessage), catalog) + val stateManager = + StateManagerFactory.createStateManager( + AirbyteStateMessage.AirbyteStateType.STREAM, + List.of(airbyteStateMessage), + catalog + ) Assertions.assertNotNull(stateManager) Assertions.assertEquals(StreamStateManager::class.java, stateManager.javaClass) diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateTestConstants.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateTestConstants.kt index 90eb8d487540e..3ffd9781e7607 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateTestConstants.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StateTestConstants.kt @@ -8,48 +8,61 @@ import io.airbyte.protocol.models.v0.AirbyteStream import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream -import org.testcontainers.shaded.com.google.common.collect.Lists import java.util.* import java.util.List +import org.testcontainers.shaded.com.google.common.collect.Lists -/** - * Collection of constants for use in state management-related tests. - */ +/** Collection of constants for use in state management-related tests. */ object StateTestConstants { const val NAMESPACE: String = "public" const val STREAM_NAME1: String = "cars" - val NAME_NAMESPACE_PAIR1: AirbyteStreamNameNamespacePair = AirbyteStreamNameNamespacePair(STREAM_NAME1, NAMESPACE) + val NAME_NAMESPACE_PAIR1: AirbyteStreamNameNamespacePair = + AirbyteStreamNameNamespacePair(STREAM_NAME1, NAMESPACE) const val STREAM_NAME2: String = "bicycles" - val NAME_NAMESPACE_PAIR2: AirbyteStreamNameNamespacePair = AirbyteStreamNameNamespacePair(STREAM_NAME2, NAMESPACE) + val NAME_NAMESPACE_PAIR2: AirbyteStreamNameNamespacePair = + AirbyteStreamNameNamespacePair(STREAM_NAME2, NAMESPACE) const val STREAM_NAME3: String = "stationary_bicycles" const val CURSOR_FIELD1: String = "year" const val CURSOR_FIELD2: String = "generation" const val CURSOR: String = "2000" const val CURSOR_RECORD_COUNT: Long = 19L - fun getState(cursorField: String?, cursor: String?): Optional { - return Optional.of(DbStreamState() + fun getState(cursorField: String?, cursor: String?): Optional { + return Optional.of( + DbStreamState() .withStreamName(STREAM_NAME1) .withCursorField(Lists.newArrayList(cursorField)) - .withCursor(cursor)) + .withCursor(cursor) + ) } - fun getState(cursorField: String?, cursor: String?, cursorRecordCount: Long): Optional { - return Optional.of(DbStreamState() + fun getState( + cursorField: String?, + cursor: String?, + cursorRecordCount: Long + ): Optional { + return Optional.of( + DbStreamState() .withStreamName(STREAM_NAME1) .withCursorField(Lists.newArrayList(cursorField)) .withCursor(cursor) - .withCursorRecordCount(cursorRecordCount)) + .withCursorRecordCount(cursorRecordCount) + ) } - fun getCatalog(cursorField: String?): Optional { - return Optional.of(ConfiguredAirbyteCatalog() - .withStreams(List.of(getStream(cursorField).orElse(null)))) + fun getCatalog(cursorField: String?): Optional { + return Optional.of( + ConfiguredAirbyteCatalog().withStreams(List.of(getStream(cursorField).orElse(null))) + ) } - fun getStream(cursorField: String?): Optional { - return Optional.of(ConfiguredAirbyteStream() + fun getStream(cursorField: String?): Optional { + return Optional.of( + ConfiguredAirbyteStream() .withStream(AirbyteStream().withName(STREAM_NAME1)) - .withCursorField(if (cursorField == null) emptyList() else Lists.newArrayList(cursorField))) + .withCursorField( + if (cursorField == null) emptyList() else Lists.newArrayList(cursorField) + ) + ) } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StreamStateManagerTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StreamStateManagerTest.kt index 11ad5a98ac8f5..6fba4dda3a85e 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StreamStateManagerTest.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/integrations/source/relationaldb/state/StreamStateManagerTest.kt @@ -8,27 +8,33 @@ import io.airbyte.cdk.integrations.source.relationaldb.models.DbState import io.airbyte.cdk.integrations.source.relationaldb.models.DbStreamState import io.airbyte.commons.json.Jsons import io.airbyte.protocol.models.v0.* +import java.util.* +import java.util.stream.Collectors import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.Test import org.mockito.Mockito -import java.util.* -import java.util.stream.Collectors -/** - * Test suite for the [StreamStateManager] class. - */ +/** Test suite for the [StreamStateManager] class. */ class StreamStateManagerTest { @Test fun testCreationFromInvalidState() { - val airbyteStateMessage = AirbyteStateMessage() + val airbyteStateMessage = + AirbyteStateMessage() .withType(AirbyteStateMessage.AirbyteStateType.STREAM) - .withStream(AirbyteStreamState() - .withStreamDescriptor(StreamDescriptor().withName(StateTestConstants.STREAM_NAME1).withNamespace(StateTestConstants.NAMESPACE)) - .withStreamState(Jsons.jsonNode("Not a state object"))) + .withStream( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor() + .withName(StateTestConstants.STREAM_NAME1) + .withNamespace(StateTestConstants.NAMESPACE) + ) + .withStreamState(Jsons.jsonNode("Not a state object")) + ) val catalog = Mockito.mock(ConfiguredAirbyteCatalog::class.java) Assertions.assertDoesNotThrow { - val stateManager: StateManager = StreamStateManager(java.util.List.of(airbyteStateMessage), catalog) + val stateManager: StateManager = + StreamStateManager(java.util.List.of(airbyteStateMessage), catalog) Assertions.assertNotNull(stateManager) } } @@ -36,198 +42,411 @@ class StreamStateManagerTest { @Test fun testGetters() { val state: MutableList = ArrayList() - state.add(createStreamState(StateTestConstants.STREAM_NAME1, StateTestConstants.NAMESPACE, java.util.List.of(StateTestConstants.CURSOR_FIELD1), StateTestConstants.CURSOR, 0L)) - state.add(createStreamState(StateTestConstants.STREAM_NAME2, StateTestConstants.NAMESPACE, listOf(), null, 0L)) + state.add( + createStreamState( + StateTestConstants.STREAM_NAME1, + StateTestConstants.NAMESPACE, + java.util.List.of(StateTestConstants.CURSOR_FIELD1), + StateTestConstants.CURSOR, + 0L + ) + ) + state.add( + createStreamState( + StateTestConstants.STREAM_NAME2, + StateTestConstants.NAMESPACE, + listOf(), + null, + 0L + ) + ) - val catalog = ConfiguredAirbyteCatalog() - .withStreams(java.util.List.of( + val catalog = + ConfiguredAirbyteCatalog() + .withStreams( + java.util.List.of( ConfiguredAirbyteStream() - .withStream(AirbyteStream().withName(StateTestConstants.STREAM_NAME1).withNamespace(StateTestConstants.NAMESPACE) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH))) - .withCursorField(java.util.List.of(StateTestConstants.CURSOR_FIELD1)), + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME1) + .withNamespace(StateTestConstants.NAMESPACE) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH) + ) + ) + .withCursorField(java.util.List.of(StateTestConstants.CURSOR_FIELD1)), ConfiguredAirbyteStream() - .withStream(AirbyteStream().withName(StateTestConstants.STREAM_NAME2).withNamespace(StateTestConstants.NAMESPACE) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH))))) + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME2) + .withNamespace(StateTestConstants.NAMESPACE) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH) + ) + ) + ) + ) val stateManager: StateManager = StreamStateManager(state, catalog) - Assertions.assertEquals(Optional.of(StateTestConstants.CURSOR_FIELD1), stateManager.getOriginalCursorField(StateTestConstants.NAME_NAMESPACE_PAIR1)) - Assertions.assertEquals(Optional.of(StateTestConstants.CURSOR), stateManager.getOriginalCursor(StateTestConstants.NAME_NAMESPACE_PAIR1)) - Assertions.assertEquals(Optional.of(StateTestConstants.CURSOR_FIELD1), stateManager.getCursorField(StateTestConstants.NAME_NAMESPACE_PAIR1)) - Assertions.assertEquals(Optional.of(StateTestConstants.CURSOR), stateManager.getCursor(StateTestConstants.NAME_NAMESPACE_PAIR1)) + Assertions.assertEquals( + Optional.of(StateTestConstants.CURSOR_FIELD1), + stateManager.getOriginalCursorField(StateTestConstants.NAME_NAMESPACE_PAIR1) + ) + Assertions.assertEquals( + Optional.of(StateTestConstants.CURSOR), + stateManager.getOriginalCursor(StateTestConstants.NAME_NAMESPACE_PAIR1) + ) + Assertions.assertEquals( + Optional.of(StateTestConstants.CURSOR_FIELD1), + stateManager.getCursorField(StateTestConstants.NAME_NAMESPACE_PAIR1) + ) + Assertions.assertEquals( + Optional.of(StateTestConstants.CURSOR), + stateManager.getCursor(StateTestConstants.NAME_NAMESPACE_PAIR1) + ) - Assertions.assertEquals(Optional.empty(), stateManager.getOriginalCursorField(StateTestConstants.NAME_NAMESPACE_PAIR2)) - Assertions.assertEquals(Optional.empty(), stateManager.getOriginalCursor(StateTestConstants.NAME_NAMESPACE_PAIR2)) - Assertions.assertEquals(Optional.empty(), stateManager.getCursorField(StateTestConstants.NAME_NAMESPACE_PAIR2)) - Assertions.assertEquals(Optional.empty(), stateManager.getCursor(StateTestConstants.NAME_NAMESPACE_PAIR2)) + Assertions.assertEquals( + Optional.empty(), + stateManager.getOriginalCursorField(StateTestConstants.NAME_NAMESPACE_PAIR2) + ) + Assertions.assertEquals( + Optional.empty(), + stateManager.getOriginalCursor(StateTestConstants.NAME_NAMESPACE_PAIR2) + ) + Assertions.assertEquals( + Optional.empty(), + stateManager.getCursorField(StateTestConstants.NAME_NAMESPACE_PAIR2) + ) + Assertions.assertEquals( + Optional.empty(), + stateManager.getCursor(StateTestConstants.NAME_NAMESPACE_PAIR2) + ) } @Test fun testToState() { - val catalog = ConfiguredAirbyteCatalog() - .withStreams(java.util.List.of( + val catalog = + ConfiguredAirbyteCatalog() + .withStreams( + java.util.List.of( ConfiguredAirbyteStream() - .withStream(AirbyteStream().withName(StateTestConstants.STREAM_NAME1).withNamespace(StateTestConstants.NAMESPACE) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH))) - .withCursorField(java.util.List.of(StateTestConstants.CURSOR_FIELD1)), + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME1) + .withNamespace(StateTestConstants.NAMESPACE) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH) + ) + ) + .withCursorField(java.util.List.of(StateTestConstants.CURSOR_FIELD1)), ConfiguredAirbyteStream() - .withStream(AirbyteStream().withName(StateTestConstants.STREAM_NAME2).withNamespace(StateTestConstants.NAMESPACE) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH))) - .withCursorField(java.util.List.of(StateTestConstants.CURSOR_FIELD2)), + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME2) + .withNamespace(StateTestConstants.NAMESPACE) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH) + ) + ) + .withCursorField(java.util.List.of(StateTestConstants.CURSOR_FIELD2)), ConfiguredAirbyteStream() - .withStream(AirbyteStream().withName(StateTestConstants.STREAM_NAME3).withNamespace(StateTestConstants.NAMESPACE) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH))))) + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME3) + .withNamespace(StateTestConstants.NAMESPACE) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH) + ) + ) + ) + ) val stateManager: StateManager = StreamStateManager(createDefaultState(), catalog) - val expectedFirstDbState = DbState() + val expectedFirstDbState = + DbState() .withCdc(false) - .withStreams(java.util.List.of( - DbStreamState() + .withStreams( + java.util.List.of( + DbStreamState() .withStreamName(StateTestConstants.STREAM_NAME1) .withStreamNamespace(StateTestConstants.NAMESPACE) - .withCursorField(java.util.List.of(StateTestConstants.CURSOR_FIELD1)) + .withCursorField( + java.util.List.of(StateTestConstants.CURSOR_FIELD1) + ) .withCursor("a"), - DbStreamState() + DbStreamState() .withStreamName(StateTestConstants.STREAM_NAME2) .withStreamNamespace(StateTestConstants.NAMESPACE) - .withCursorField(java.util.List.of(StateTestConstants.CURSOR_FIELD2)), - DbStreamState() + .withCursorField( + java.util.List.of(StateTestConstants.CURSOR_FIELD2) + ), + DbStreamState() .withStreamName(StateTestConstants.STREAM_NAME3) - .withStreamNamespace(StateTestConstants.NAMESPACE)) - .stream().sorted(Comparator.comparing { obj: DbStreamState -> obj.streamName }).collect(Collectors.toList())) + .withStreamNamespace(StateTestConstants.NAMESPACE) + ) + .stream() + .sorted(Comparator.comparing { obj: DbStreamState -> obj.streamName }) + .collect(Collectors.toList()) + ) val expectedFirstEmission = - createStreamState(StateTestConstants.STREAM_NAME1, StateTestConstants.NAMESPACE, java.util.List.of(StateTestConstants.CURSOR_FIELD1), "a", 0L).withData(Jsons.jsonNode(expectedFirstDbState)) + createStreamState( + StateTestConstants.STREAM_NAME1, + StateTestConstants.NAMESPACE, + java.util.List.of(StateTestConstants.CURSOR_FIELD1), + "a", + 0L + ) + .withData(Jsons.jsonNode(expectedFirstDbState)) - val actualFirstEmission = stateManager.updateAndEmit(StateTestConstants.NAME_NAMESPACE_PAIR1, "a") + val actualFirstEmission = + stateManager.updateAndEmit(StateTestConstants.NAME_NAMESPACE_PAIR1, "a") Assertions.assertEquals(expectedFirstEmission, actualFirstEmission) val expectedRecordCount = 17L - val expectedSecondDbState = DbState() + val expectedSecondDbState = + DbState() .withCdc(false) - .withStreams(java.util.List.of( - DbStreamState() + .withStreams( + java.util.List.of( + DbStreamState() .withStreamName(StateTestConstants.STREAM_NAME1) .withStreamNamespace(StateTestConstants.NAMESPACE) - .withCursorField(java.util.List.of(StateTestConstants.CURSOR_FIELD1)) + .withCursorField( + java.util.List.of(StateTestConstants.CURSOR_FIELD1) + ) .withCursor("a"), - DbStreamState() + DbStreamState() .withStreamName(StateTestConstants.STREAM_NAME2) .withStreamNamespace(StateTestConstants.NAMESPACE) - .withCursorField(java.util.List.of(StateTestConstants.CURSOR_FIELD2)) + .withCursorField( + java.util.List.of(StateTestConstants.CURSOR_FIELD2) + ) .withCursor("b") .withCursorRecordCount(expectedRecordCount), - DbStreamState() + DbStreamState() .withStreamName(StateTestConstants.STREAM_NAME3) - .withStreamNamespace(StateTestConstants.NAMESPACE)) - .stream().sorted(Comparator.comparing { obj: DbStreamState -> obj.streamName }).collect(Collectors.toList())) + .withStreamNamespace(StateTestConstants.NAMESPACE) + ) + .stream() + .sorted(Comparator.comparing { obj: DbStreamState -> obj.streamName }) + .collect(Collectors.toList()) + ) val expectedSecondEmission = - createStreamState(StateTestConstants.STREAM_NAME2, StateTestConstants.NAMESPACE, java.util.List.of(StateTestConstants.CURSOR_FIELD2), "b", expectedRecordCount).withData(Jsons.jsonNode(expectedSecondDbState)) + createStreamState( + StateTestConstants.STREAM_NAME2, + StateTestConstants.NAMESPACE, + java.util.List.of(StateTestConstants.CURSOR_FIELD2), + "b", + expectedRecordCount + ) + .withData(Jsons.jsonNode(expectedSecondDbState)) - val actualSecondEmission = stateManager.updateAndEmit(StateTestConstants.NAME_NAMESPACE_PAIR2, "b", expectedRecordCount) + val actualSecondEmission = + stateManager.updateAndEmit( + StateTestConstants.NAME_NAMESPACE_PAIR2, + "b", + expectedRecordCount + ) Assertions.assertEquals(expectedSecondEmission, actualSecondEmission) } @Test fun testToStateWithoutCursorInfo() { - val catalog = ConfiguredAirbyteCatalog() - .withStreams(java.util.List.of( + val catalog = + ConfiguredAirbyteCatalog() + .withStreams( + java.util.List.of( ConfiguredAirbyteStream() - .withStream(AirbyteStream().withName(StateTestConstants.STREAM_NAME1).withNamespace(StateTestConstants.NAMESPACE) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH))) - .withCursorField(java.util.List.of(StateTestConstants.CURSOR_FIELD1)), + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME1) + .withNamespace(StateTestConstants.NAMESPACE) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH) + ) + ) + .withCursorField(java.util.List.of(StateTestConstants.CURSOR_FIELD1)), ConfiguredAirbyteStream() - .withStream(AirbyteStream().withName(StateTestConstants.STREAM_NAME2).withNamespace(StateTestConstants.NAMESPACE) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH))) - .withCursorField(java.util.List.of(StateTestConstants.CURSOR_FIELD2)), + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME2) + .withNamespace(StateTestConstants.NAMESPACE) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH) + ) + ) + .withCursorField(java.util.List.of(StateTestConstants.CURSOR_FIELD2)), ConfiguredAirbyteStream() - .withStream(AirbyteStream().withName(StateTestConstants.STREAM_NAME3).withNamespace(StateTestConstants.NAMESPACE) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH))))) + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME3) + .withNamespace(StateTestConstants.NAMESPACE) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH) + ) + ) + ) + ) val airbyteStreamNameNamespacePair = AirbyteStreamNameNamespacePair("other", "other") val stateManager: StateManager = StreamStateManager(createDefaultState(), catalog) val airbyteStateMessage = stateManager.toState(Optional.of(airbyteStreamNameNamespacePair)) Assertions.assertNotNull(airbyteStateMessage) - Assertions.assertEquals(AirbyteStateMessage.AirbyteStateType.STREAM, airbyteStateMessage.type) + Assertions.assertEquals( + AirbyteStateMessage.AirbyteStateType.STREAM, + airbyteStateMessage.type + ) Assertions.assertNotNull(airbyteStateMessage.stream) } @Test fun testToStateWithoutStreamPair() { - val catalog = ConfiguredAirbyteCatalog() - .withStreams(java.util.List.of( + val catalog = + ConfiguredAirbyteCatalog() + .withStreams( + java.util.List.of( ConfiguredAirbyteStream() - .withStream(AirbyteStream().withName(StateTestConstants.STREAM_NAME1).withNamespace(StateTestConstants.NAMESPACE) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH))) - .withCursorField(java.util.List.of(StateTestConstants.CURSOR_FIELD1)), + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME1) + .withNamespace(StateTestConstants.NAMESPACE) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH) + ) + ) + .withCursorField(java.util.List.of(StateTestConstants.CURSOR_FIELD1)), ConfiguredAirbyteStream() - .withStream(AirbyteStream().withName(StateTestConstants.STREAM_NAME2).withNamespace(StateTestConstants.NAMESPACE) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH))) - .withCursorField(java.util.List.of(StateTestConstants.CURSOR_FIELD2)), + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME2) + .withNamespace(StateTestConstants.NAMESPACE) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH) + ) + ) + .withCursorField(java.util.List.of(StateTestConstants.CURSOR_FIELD2)), ConfiguredAirbyteStream() - .withStream(AirbyteStream().withName(StateTestConstants.STREAM_NAME3).withNamespace(StateTestConstants.NAMESPACE) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH))))) + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME3) + .withNamespace(StateTestConstants.NAMESPACE) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH) + ) + ) + ) + ) val stateManager: StateManager = StreamStateManager(createDefaultState(), catalog) val airbyteStateMessage = stateManager.toState(Optional.empty()) Assertions.assertNotNull(airbyteStateMessage) - Assertions.assertEquals(AirbyteStateMessage.AirbyteStateType.STREAM, airbyteStateMessage.type) + Assertions.assertEquals( + AirbyteStateMessage.AirbyteStateType.STREAM, + airbyteStateMessage.type + ) Assertions.assertNotNull(airbyteStateMessage.stream) Assertions.assertNull(airbyteStateMessage.stream.streamState) } @Test fun testToStateNullCursorField() { - val catalog = ConfiguredAirbyteCatalog() - .withStreams(java.util.List.of( + val catalog = + ConfiguredAirbyteCatalog() + .withStreams( + java.util.List.of( ConfiguredAirbyteStream() - .withStream(AirbyteStream().withName(StateTestConstants.STREAM_NAME1).withNamespace(StateTestConstants.NAMESPACE) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH))) - .withCursorField(java.util.List.of(StateTestConstants.CURSOR_FIELD1)), + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME1) + .withNamespace(StateTestConstants.NAMESPACE) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH) + ) + ) + .withCursorField(java.util.List.of(StateTestConstants.CURSOR_FIELD1)), ConfiguredAirbyteStream() - .withStream(AirbyteStream().withName(StateTestConstants.STREAM_NAME2).withNamespace(StateTestConstants.NAMESPACE) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH))))) + .withStream( + AirbyteStream() + .withName(StateTestConstants.STREAM_NAME2) + .withNamespace(StateTestConstants.NAMESPACE) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH) + ) + ) + ) + ) val stateManager: StateManager = StreamStateManager(createDefaultState(), catalog) - val expectedFirstDbState = DbState() + val expectedFirstDbState = + DbState() .withCdc(false) - .withStreams(java.util.List.of( - DbStreamState() + .withStreams( + java.util.List.of( + DbStreamState() .withStreamName(StateTestConstants.STREAM_NAME1) .withStreamNamespace(StateTestConstants.NAMESPACE) - .withCursorField(java.util.List.of(StateTestConstants.CURSOR_FIELD1)) + .withCursorField( + java.util.List.of(StateTestConstants.CURSOR_FIELD1) + ) .withCursor("a"), - DbStreamState() + DbStreamState() .withStreamName(StateTestConstants.STREAM_NAME2) - .withStreamNamespace(StateTestConstants.NAMESPACE)) - .stream().sorted(Comparator.comparing { obj: DbStreamState -> obj.streamName }).collect(Collectors.toList())) + .withStreamNamespace(StateTestConstants.NAMESPACE) + ) + .stream() + .sorted(Comparator.comparing { obj: DbStreamState -> obj.streamName }) + .collect(Collectors.toList()) + ) val expectedFirstEmission = - createStreamState(StateTestConstants.STREAM_NAME1, StateTestConstants.NAMESPACE, java.util.List.of(StateTestConstants.CURSOR_FIELD1), "a", 0L).withData(Jsons.jsonNode(expectedFirstDbState)) - val actualFirstEmission = stateManager.updateAndEmit(StateTestConstants.NAME_NAMESPACE_PAIR1, "a") + createStreamState( + StateTestConstants.STREAM_NAME1, + StateTestConstants.NAMESPACE, + java.util.List.of(StateTestConstants.CURSOR_FIELD1), + "a", + 0L + ) + .withData(Jsons.jsonNode(expectedFirstDbState)) + val actualFirstEmission = + stateManager.updateAndEmit(StateTestConstants.NAME_NAMESPACE_PAIR1, "a") Assertions.assertEquals(expectedFirstEmission, actualFirstEmission) } @Test fun testCdcStateManager() { val catalog = Mockito.mock(ConfiguredAirbyteCatalog::class.java) - val stateManager: StateManager = StreamStateManager( - java.util.List.of(AirbyteStateMessage().withType(AirbyteStateMessage.AirbyteStateType.STREAM).withStream(AirbyteStreamState())), catalog) - Assertions.assertThrows(UnsupportedOperationException::class.java) { stateManager.cdcStateManager } + val stateManager: StateManager = + StreamStateManager( + java.util.List.of( + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream(AirbyteStreamState()) + ), + catalog + ) + Assertions.assertThrows(UnsupportedOperationException::class.java) { + stateManager.cdcStateManager + } } private fun createDefaultState(): List { - return java.util.List.of(AirbyteStateMessage().withType(AirbyteStateMessage.AirbyteStateType.STREAM).withStream(AirbyteStreamState())) + return java.util.List.of( + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream(AirbyteStreamState()) + ) } - private fun createStreamState(name: String?, - namespace: String?, - cursorFields: List?, - cursorValue: String?, - cursorRecordCount: Long): AirbyteStateMessage { - val dbStreamState = DbStreamState() - .withStreamName(name) - .withStreamNamespace(namespace) + private fun createStreamState( + name: String?, + namespace: String?, + cursorFields: List?, + cursorValue: String?, + cursorRecordCount: Long + ): AirbyteStateMessage { + val dbStreamState = DbStreamState().withStreamName(name).withStreamNamespace(namespace) if (cursorFields != null && !cursorFields.isEmpty()) { dbStreamState.withCursorField(cursorFields) @@ -242,9 +461,13 @@ class StreamStateManagerTest { } return AirbyteStateMessage() - .withType(AirbyteStateMessage.AirbyteStateType.STREAM) - .withStream(AirbyteStreamState() - .withStreamDescriptor(StreamDescriptor().withName(name).withNamespace(namespace)) - .withStreamState(Jsons.jsonNode(dbStreamState))) + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor().withName(name).withNamespace(namespace) + ) + .withStreamState(Jsons.jsonNode(dbStreamState)) + ) } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/test/utils/DatabaseConnectionHelperTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/test/utils/DatabaseConnectionHelperTest.kt index f9d93767af0d8..e0759df756098 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/test/utils/DatabaseConnectionHelperTest.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/test/kotlin/io/airbyte/cdk/test/utils/DatabaseConnectionHelperTest.kt @@ -19,7 +19,10 @@ internal class DatabaseConnectionHelperTest { val dataSource = createDataSource(container) Assertions.assertNotNull(dataSource) Assertions.assertEquals(HikariDataSource::class.java, dataSource!!.javaClass) - Assertions.assertEquals(10, (dataSource as HikariDataSource?)!!.hikariConfigMXBean.maximumPoolSize) + Assertions.assertEquals( + 10, + (dataSource as HikariDataSource?)!!.hikariConfigMXBean.maximumPoolSize + ) } @Test @@ -36,8 +39,10 @@ internal class DatabaseConnectionHelperTest { protected var container: PostgreSQLContainer<*>? = null @BeforeAll + @JvmStatic fun dbSetup() { - container = PostgreSQLContainer("postgres:13-alpine") + container = + PostgreSQLContainer("postgres:13-alpine") .withDatabaseName(DATABASE_NAME) .withUsername("docker") .withPassword("docker") @@ -45,6 +50,7 @@ internal class DatabaseConnectionHelperTest { } @AfterAll + @JvmStatic fun dbDown() { container!!.close() } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/debezium/CdcSourceTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/debezium/CdcSourceTest.kt index 539539453cec9..9afef4bfcd9fd 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/debezium/CdcSourceTest.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/debezium/CdcSourceTest.kt @@ -13,19 +13,19 @@ import io.airbyte.commons.util.AutoCloseableIterators import io.airbyte.protocol.models.Field import io.airbyte.protocol.models.JsonSchemaType import io.airbyte.protocol.models.v0.* +import java.util.* +import java.util.function.Consumer +import java.util.stream.Collectors +import java.util.stream.Stream import org.junit.jupiter.api.AfterEach import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test import org.slf4j.Logger import org.slf4j.LoggerFactory -import java.util.* -import java.util.function.Consumer -import java.util.stream.Collectors -import java.util.stream.Stream -abstract class CdcSourceTest?> { - protected var testdb: T? = null +abstract class CdcSourceTest> { + protected lateinit var testdb: T protected fun createTableSqlFmt(): String { return "CREATE TABLE %s.%s(%s);" @@ -39,28 +39,38 @@ abstract class CdcSourceTest?> { return "models_schema" } - /** - * The schema of a random table which is used as a new table in snapshot test - */ + /** The schema of a random table which is used as a new table in snapshot test */ protected fun randomSchema(): String { return "models_schema_random" } protected val catalog: AirbyteCatalog - get() = AirbyteCatalog().withStreams(java.util.List.of( - CatalogHelpers.createAirbyteStream( - MODELS_STREAM_NAME, - modelsSchema(), - Field.of(COL_ID, JsonSchemaType.INTEGER), - Field.of(COL_MAKE_ID, JsonSchemaType.INTEGER), - Field.of(COL_MODEL, JsonSchemaType.STRING)) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL)) - .withSourceDefinedPrimaryKey(java.util.List.of(java.util.List.of(COL_ID))))) + get() = + AirbyteCatalog() + .withStreams( + java.util.List.of( + CatalogHelpers.createAirbyteStream( + MODELS_STREAM_NAME, + modelsSchema(), + Field.of(COL_ID, JsonSchemaType.INTEGER), + Field.of(COL_MAKE_ID, JsonSchemaType.INTEGER), + Field.of(COL_MODEL, JsonSchemaType.STRING) + ) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL) + ) + .withSourceDefinedPrimaryKey( + java.util.List.of(java.util.List.of(COL_ID)) + ) + ) + ) protected val configuredCatalog: ConfiguredAirbyteCatalog get() { val configuredCatalog = CatalogHelpers.toDefaultConfiguredCatalog(catalog) - configuredCatalog.streams.forEach(Consumer { s: ConfiguredAirbyteStream -> s.syncMode = SyncMode.INCREMENTAL }) + configuredCatalog.streams.forEach( + Consumer { s: ConfiguredAirbyteStream -> s.syncMode = SyncMode.INCREMENTAL } + ) return configuredCatalog } @@ -89,7 +99,10 @@ abstract class CdcSourceTest?> { // TODO: this assertion should be added into test cases in this class, we will need to implement // corresponding iterator for other connectors before // doing so. - protected fun assertExpectedStateMessageCountMatches(stateMessages: List?, totalCount: Long) { + protected fun assertExpectedStateMessageCountMatches( + stateMessages: List?, + totalCount: Long + ) { // Do nothing. } @@ -102,25 +115,39 @@ abstract class CdcSourceTest?> { protected fun createTables() { // create and populate actual table - val actualColumns = ImmutableMap.of( - COL_ID, "INTEGER", - COL_MAKE_ID, "INTEGER", - COL_MODEL, "VARCHAR(200)") + val actualColumns = + ImmutableMap.of(COL_ID, "INTEGER", COL_MAKE_ID, "INTEGER", COL_MODEL, "VARCHAR(200)") testdb - .with(createSchemaSqlFmt(), modelsSchema()) - .with(createTableSqlFmt(), modelsSchema(), MODELS_STREAM_NAME, columnClause(actualColumns, Optional.of(COL_ID))) + .with(createSchemaSqlFmt(), modelsSchema()) + .with( + createTableSqlFmt(), + modelsSchema(), + MODELS_STREAM_NAME, + columnClause(actualColumns, Optional.of(COL_ID)) + ) // Create random table. - // This table is not part of Airbyte sync. It is being created just to make sure the schemas not + // This table is not part of Airbyte sync. It is being created just to make sure the schemas + // not // being synced by Airbyte are not causing issues with our debezium logic. - val randomColumns = ImmutableMap.of( - COL_ID + "_random", "INTEGER", - COL_MAKE_ID + "_random", "INTEGER", - COL_MODEL + "_random", "VARCHAR(200)") + val randomColumns = + ImmutableMap.of( + COL_ID + "_random", + "INTEGER", + COL_MAKE_ID + "_random", + "INTEGER", + COL_MODEL + "_random", + "VARCHAR(200)" + ) if (randomSchema() != modelsSchema()) { testdb!!.with(createSchemaSqlFmt(), randomSchema()) } - testdb!!.with(createTableSqlFmt(), randomSchema(), RANDOM_TABLE_NAME, columnClause(randomColumns, Optional.of(COL_ID + "_random"))) + testdb!!.with( + createTableSqlFmt(), + randomSchema(), + RANDOM_TABLE_NAME, + columnClause(randomColumns, Optional.of(COL_ID + "_random")) + ) } protected fun populateTables() { @@ -129,8 +156,14 @@ abstract class CdcSourceTest?> { } for (recordJson in MODEL_RECORDS_RANDOM) { - writeRecords(recordJson, randomSchema(), RANDOM_TABLE_NAME, - COL_ID + "_random", COL_MAKE_ID + "_random", COL_MODEL + "_random") + writeRecords( + recordJson, + randomSchema(), + RANDOM_TABLE_NAME, + COL_ID + "_random", + COL_MAKE_ID + "_random", + COL_MODEL + "_random" + ) } } @@ -143,7 +176,10 @@ abstract class CdcSourceTest?> { } } - protected fun columnClause(columnsWithDataType: Map, primaryKey: Optional): String { + protected fun columnClause( + columnsWithDataType: Map, + primaryKey: Optional + ): String { val columnClause = StringBuilder() var i = 0 for ((key, value) in columnsWithDataType) { @@ -156,7 +192,9 @@ abstract class CdcSourceTest?> { } i++ } - primaryKey.ifPresent { s: String? -> columnClause.append(", PRIMARY KEY (").append(s).append(")") } + primaryKey.ifPresent { s: String? -> + columnClause.append(", PRIMARY KEY (").append(s).append(")") + } return columnClause.toString() } @@ -166,16 +204,24 @@ abstract class CdcSourceTest?> { } protected fun writeRecords( - recordJson: JsonNode, - dbName: String?, - streamName: String?, - idCol: String?, - makeIdCol: String?, - modelCol: String?) { - testdb!!.with("INSERT INTO %s.%s (%s, %s, %s) VALUES (%s, %s, '%s');", dbName, streamName, - idCol, makeIdCol, modelCol, - recordJson[idCol].asInt(), recordJson[makeIdCol].asInt(), - recordJson[modelCol].asText()) + recordJson: JsonNode, + dbName: String?, + streamName: String?, + idCol: String?, + makeIdCol: String?, + modelCol: String? + ) { + testdb!!.with( + "INSERT INTO %s.%s (%s, %s, %s) VALUES (%s, %s, '%s');", + dbName, + streamName, + idCol, + makeIdCol, + modelCol, + recordJson[idCol].asInt(), + recordJson[makeIdCol].asInt(), + recordJson[modelCol].asText() + ) } protected fun deleteMessageOnIdCol(streamName: String?, idCol: String?, idValue: Int) { @@ -186,32 +232,55 @@ abstract class CdcSourceTest?> { testdb!!.with("DELETE FROM %s.%s", modelsSchema(), streamName) } - protected fun updateCommand(streamName: String?, modelCol: String?, modelVal: String?, idCol: String?, idValue: Int) { - testdb!!.with("UPDATE %s.%s SET %s = '%s' WHERE %s = %s", modelsSchema(), streamName, - modelCol, modelVal, COL_ID, 11) + protected fun updateCommand( + streamName: String?, + modelCol: String?, + modelVal: String?, + idCol: String?, + idValue: Int + ) { + testdb!!.with( + "UPDATE %s.%s SET %s = '%s' WHERE %s = %s", + modelsSchema(), + streamName, + modelCol, + modelVal, + COL_ID, + 11 + ) } protected fun extractRecordMessages(messages: List): Set { val recordsPerStream = extractRecordMessagesStreamWise(messages) val consolidatedRecords: MutableSet = HashSet() - recordsPerStream.values.forEach(Consumer { c: Set? -> consolidatedRecords.addAll(c!!) }) + recordsPerStream.values.forEach( + Consumer { c: Set? -> consolidatedRecords.addAll(c!!) } + ) return consolidatedRecords } - protected fun extractRecordMessagesStreamWise(messages: List): Map> { + protected fun extractRecordMessagesStreamWise( + messages: List + ): Map> { val recordsPerStream: MutableMap> = HashMap() for (message in messages) { if (message.type == AirbyteMessage.Type.RECORD) { val recordMessage = message.record - recordsPerStream.computeIfAbsent(recordMessage.stream) { c: String? -> ArrayList() }.add(recordMessage) + recordsPerStream + .computeIfAbsent(recordMessage.stream) { c: String? -> ArrayList() } + .add(recordMessage) } } - val recordsPerStreamWithNoDuplicates: MutableMap> = HashMap() + val recordsPerStreamWithNoDuplicates: MutableMap> = + HashMap() for ((streamName, records) in recordsPerStream) { val recordMessageSet: Set = HashSet(records) - Assertions.assertEquals(records.size, recordMessageSet.size, - "Expected no duplicates in airbyte record message output for a single sync.") + Assertions.assertEquals( + records.size, + recordMessageSet.size, + "Expected no duplicates in airbyte record message output for a single sync." + ) recordsPerStreamWithNoDuplicates[streamName] = recordMessageSet } @@ -219,27 +288,51 @@ abstract class CdcSourceTest?> { } protected fun extractStateMessages(messages: List): List { - return messages.stream().filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.STATE }.map { obj: AirbyteMessage -> obj.state } - .collect(Collectors.toList()) + return messages + .stream() + .filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.STATE } + .map { obj: AirbyteMessage -> obj.state } + .collect(Collectors.toList()) } - protected fun assertExpectedRecords(expectedRecords: Set, actualRecords: Set) { + protected fun assertExpectedRecords( + expectedRecords: Set, + actualRecords: Set + ) { // assume all streams are cdc. - assertExpectedRecords(expectedRecords, actualRecords, actualRecords.stream().map { obj: AirbyteRecordMessage -> obj.stream }.collect(Collectors.toSet())) + assertExpectedRecords( + expectedRecords, + actualRecords, + actualRecords + .stream() + .map { obj: AirbyteRecordMessage -> obj.stream } + .collect(Collectors.toSet()) + ) } - private fun assertExpectedRecords(expectedRecords: Set, - actualRecords: Set, - cdcStreams: Set) { - assertExpectedRecords(expectedRecords, actualRecords, cdcStreams, STREAM_NAMES, modelsSchema()) + private fun assertExpectedRecords( + expectedRecords: Set, + actualRecords: Set, + cdcStreams: Set + ) { + assertExpectedRecords( + expectedRecords, + actualRecords, + cdcStreams, + STREAM_NAMES, + modelsSchema() + ) } - protected fun assertExpectedRecords(expectedRecords: Set?, - actualRecords: Set, - cdcStreams: Set, - streamNames: Set, - namespace: String?) { - val actualData = actualRecords + protected fun assertExpectedRecords( + expectedRecords: Set?, + actualRecords: Set, + cdcStreams: Set, + streamNames: Set, + namespace: String? + ) { + val actualData = + actualRecords .stream() .map { recordMessage: AirbyteRecordMessage -> Assertions.assertTrue(streamNames.contains(recordMessage.stream)) @@ -274,25 +367,31 @@ abstract class CdcSourceTest?> { val stateMessages = extractStateMessages(actualRecords) Assertions.assertNotNull(targetPosition) - recordMessages.forEach(Consumer { record: AirbyteRecordMessage -> - compareTargetPositionFromTheRecordsWithTargetPostionGeneratedBeforeSync(targetPosition, record) - }) + recordMessages.forEach( + Consumer { record: AirbyteRecordMessage -> + compareTargetPositionFromTheRecordsWithTargetPostionGeneratedBeforeSync( + targetPosition, + record + ) + } + ) assertExpectedRecords(HashSet(MODEL_RECORDS), recordMessages) assertExpectedStateMessages(stateMessages) assertExpectedStateMessageCountMatches(stateMessages, MODEL_RECORDS.size.toLong()) } - protected fun compareTargetPositionFromTheRecordsWithTargetPostionGeneratedBeforeSync(targetPosition: CdcTargetPosition<*>?, - record: AirbyteRecordMessage) { + protected fun compareTargetPositionFromTheRecordsWithTargetPostionGeneratedBeforeSync( + targetPosition: CdcTargetPosition<*>?, + record: AirbyteRecordMessage + ) { Assertions.assertEquals(extractPosition(record.data), targetPosition) } @Test // When a record is deleted, produces a deletion record. @Throws(Exception::class) fun testDelete() { - val read1 = source() - .read(config()!!, configuredCatalog, null) + val read1 = source().read(config()!!, configuredCatalog, null) val actualRecords1 = AutoCloseableIterators.toListAndClose(read1) val stateMessages1 = extractStateMessages(actualRecords1) assertExpectedStateMessages(stateMessages1) @@ -301,11 +400,10 @@ abstract class CdcSourceTest?> { waitForCdcRecords(modelsSchema(), MODELS_STREAM_NAME, 1) val state = Jsons.jsonNode(listOf(stateMessages1[stateMessages1.size - 1])) - val read2 = source() - .read(config()!!, configuredCatalog, state) + val read2 = source().read(config()!!, configuredCatalog, state) val actualRecords2 = AutoCloseableIterators.toListAndClose(read2) - val recordMessages2: List = ArrayList( - extractRecordMessages(actualRecords2)) + val recordMessages2: List = + ArrayList(extractRecordMessages(actualRecords2)) val stateMessages2 = extractStateMessages(actualRecords2) assertExpectedStateMessagesFromIncrementalSync(stateMessages2) assertExpectedStateMessageCountMatches(stateMessages2, 1) @@ -314,7 +412,9 @@ abstract class CdcSourceTest?> { assertCdcMetaData(recordMessages2[0].data, false) } - protected fun assertExpectedStateMessagesFromIncrementalSync(stateMessages: List?) { + protected fun assertExpectedStateMessagesFromIncrementalSync( + stateMessages: List? + ) { assertExpectedStateMessages(stateMessages) } @@ -322,8 +422,7 @@ abstract class CdcSourceTest?> { @Throws(Exception::class) fun testUpdate() { val updatedModel = "Explorer" - val read1 = source() - .read(config()!!, configuredCatalog, null) + val read1 = source().read(config()!!, configuredCatalog, null) val actualRecords1 = AutoCloseableIterators.toListAndClose(read1) val stateMessages1 = extractStateMessages(actualRecords1) assertExpectedStateMessages(stateMessages1) @@ -332,11 +431,10 @@ abstract class CdcSourceTest?> { waitForCdcRecords(modelsSchema(), MODELS_STREAM_NAME, 1) val state = Jsons.jsonNode(listOf(stateMessages1[stateMessages1.size - 1])) - val read2 = source() - .read(config()!!, configuredCatalog, state) + val read2 = source().read(config()!!, configuredCatalog, state) val actualRecords2 = AutoCloseableIterators.toListAndClose(read2) - val recordMessages2: List = ArrayList( - extractRecordMessages(actualRecords2)) + val recordMessages2: List = + ArrayList(extractRecordMessages(actualRecords2)) val stateMessages2 = extractStateMessages(actualRecords2) assertExpectedStateMessagesFromIncrementalSync(stateMessages2) Assertions.assertEquals(1, recordMessages2.size) @@ -346,7 +444,8 @@ abstract class CdcSourceTest?> { assertExpectedStateMessageCountMatches(stateMessages2, 1) } - @Test // Verify that when data is inserted into the database while a sync is happening and after the first + @Test // Verify that when data is inserted into the database while a sync is happening and after + // the first // sync, it all gets replicated. @Throws(Exception::class) protected fun testRecordsProducedDuringAndAfterSync() { @@ -357,31 +456,42 @@ abstract class CdcSourceTest?> { // first batch of records. 20 created here and 6 created in setup method. for (recordsCreated in 0 until recordsToCreate) { val record = - Jsons.jsonNode(ImmutableMap - .of(COL_ID, 100 + recordsCreated, COL_MAKE_ID, 1, COL_MODEL, - "F-$recordsCreated")) + Jsons.jsonNode( + ImmutableMap.of( + COL_ID, + 100 + recordsCreated, + COL_MAKE_ID, + 1, + COL_MODEL, + "F-$recordsCreated" + ) + ) writeModelRecord(record) expectedRecords++ expectedRecordsInCdc++ } waitForCdcRecords(modelsSchema(), MODELS_STREAM_NAME, expectedRecordsInCdc) - val firstBatchIterator = source() - .read(config()!!, configuredCatalog, null) - val dataFromFirstBatch = AutoCloseableIterators - .toListAndClose(firstBatchIterator) + val firstBatchIterator = source().read(config()!!, configuredCatalog, null) + val dataFromFirstBatch = AutoCloseableIterators.toListAndClose(firstBatchIterator) val stateAfterFirstBatch = extractStateMessages(dataFromFirstBatch) assertExpectedStateMessagesForRecordsProducedDuringAndAfterSync(stateAfterFirstBatch) - val recordsFromFirstBatch = extractRecordMessages( - dataFromFirstBatch) + val recordsFromFirstBatch = extractRecordMessages(dataFromFirstBatch) Assertions.assertEquals(expectedRecords, recordsFromFirstBatch.size) // second batch of records again 20 being created for (recordsCreated in 0 until recordsToCreate) { val record = - Jsons.jsonNode(ImmutableMap - .of(COL_ID, 200 + recordsCreated, COL_MAKE_ID, 1, COL_MODEL, - "F-$recordsCreated")) + Jsons.jsonNode( + ImmutableMap.of( + COL_ID, + 200 + recordsCreated, + COL_MAKE_ID, + 1, + COL_MODEL, + "F-$recordsCreated" + ) + ) writeModelRecord(record) expectedRecords++ expectedRecordsInCdc++ @@ -389,76 +499,101 @@ abstract class CdcSourceTest?> { waitForCdcRecords(modelsSchema(), MODELS_STREAM_NAME, expectedRecordsInCdc) val state = Jsons.jsonNode(listOf(stateAfterFirstBatch[stateAfterFirstBatch.size - 1])) - val secondBatchIterator = source() - .read(config()!!, configuredCatalog, state) - val dataFromSecondBatch = AutoCloseableIterators - .toListAndClose(secondBatchIterator) + val secondBatchIterator = source().read(config()!!, configuredCatalog, state) + val dataFromSecondBatch = AutoCloseableIterators.toListAndClose(secondBatchIterator) val stateAfterSecondBatch = extractStateMessages(dataFromSecondBatch) assertExpectedStateMessagesFromIncrementalSync(stateAfterSecondBatch) - val recordsFromSecondBatch = extractRecordMessages( - dataFromSecondBatch) - Assertions.assertEquals(recordsToCreate, recordsFromSecondBatch.size, - "Expected 20 records to be replicated in the second sync.") + val recordsFromSecondBatch = extractRecordMessages(dataFromSecondBatch) + Assertions.assertEquals( + recordsToCreate, + recordsFromSecondBatch.size, + "Expected 20 records to be replicated in the second sync." + ) - // sometimes there can be more than one of these at the end of the snapshot and just before the + // sometimes there can be more than one of these at the end of the snapshot and just before + // the // first incremental. - val recordsFromFirstBatchWithoutDuplicates = removeDuplicates( - recordsFromFirstBatch) - val recordsFromSecondBatchWithoutDuplicates = removeDuplicates( - recordsFromSecondBatch) + val recordsFromFirstBatchWithoutDuplicates = removeDuplicates(recordsFromFirstBatch) + val recordsFromSecondBatchWithoutDuplicates = removeDuplicates(recordsFromSecondBatch) - Assertions.assertTrue(recordsCreatedBeforeTestCount < recordsFromFirstBatchWithoutDuplicates.size, - "Expected first sync to include records created while the test was running.") - Assertions.assertEquals(expectedRecords, - recordsFromFirstBatchWithoutDuplicates.size + recordsFromSecondBatchWithoutDuplicates - .size) + Assertions.assertTrue( + recordsCreatedBeforeTestCount < recordsFromFirstBatchWithoutDuplicates.size, + "Expected first sync to include records created while the test was running." + ) + Assertions.assertEquals( + expectedRecords, + recordsFromFirstBatchWithoutDuplicates.size + + recordsFromSecondBatchWithoutDuplicates.size + ) } - protected fun assertExpectedStateMessagesForRecordsProducedDuringAndAfterSync(stateAfterFirstBatch: List?) { + protected fun assertExpectedStateMessagesForRecordsProducedDuringAndAfterSync( + stateAfterFirstBatch: List? + ) { assertExpectedStateMessages(stateAfterFirstBatch) } - @Test // When both incremental CDC and full refresh are configured for different streams in a sync, the + @Test // When both incremental CDC and full refresh are configured for different streams in a + // sync, the // data is replicated as expected. @Throws(Exception::class) fun testCdcAndFullRefreshInSameSync() { val configuredCatalog = Jsons.clone(configuredCatalog) - val MODEL_RECORDS_2: List = ImmutableList.of( + val MODEL_RECORDS_2: List = + ImmutableList.of( Jsons.jsonNode(ImmutableMap.of(COL_ID, 110, COL_MAKE_ID, 1, COL_MODEL, "Fiesta-2")), Jsons.jsonNode(ImmutableMap.of(COL_ID, 120, COL_MAKE_ID, 1, COL_MODEL, "Focus-2")), Jsons.jsonNode(ImmutableMap.of(COL_ID, 130, COL_MAKE_ID, 1, COL_MODEL, "Ranger-2")), Jsons.jsonNode(ImmutableMap.of(COL_ID, 140, COL_MAKE_ID, 2, COL_MODEL, "GLA-2")), Jsons.jsonNode(ImmutableMap.of(COL_ID, 150, COL_MAKE_ID, 2, COL_MODEL, "A 220-2")), - Jsons.jsonNode(ImmutableMap.of(COL_ID, 160, COL_MAKE_ID, 2, COL_MODEL, "E 350-2"))) - - val columns = ImmutableMap.of(COL_ID, "INTEGER", COL_MAKE_ID, "INTEGER", COL_MODEL, "VARCHAR(200)") - testdb!!.with(createTableSqlFmt(), modelsSchema(), MODELS_STREAM_NAME + "_2", columnClause(columns, Optional.of(COL_ID))) + Jsons.jsonNode(ImmutableMap.of(COL_ID, 160, COL_MAKE_ID, 2, COL_MODEL, "E 350-2")) + ) + + val columns = + ImmutableMap.of(COL_ID, "INTEGER", COL_MAKE_ID, "INTEGER", COL_MODEL, "VARCHAR(200)") + testdb!!.with( + createTableSqlFmt(), + modelsSchema(), + MODELS_STREAM_NAME + "_2", + columnClause(columns, Optional.of(COL_ID)) + ) for (recordJson in MODEL_RECORDS_2) { - writeRecords(recordJson, modelsSchema(), MODELS_STREAM_NAME + "_2", COL_ID, COL_MAKE_ID, COL_MODEL) + writeRecords( + recordJson, + modelsSchema(), + MODELS_STREAM_NAME + "_2", + COL_ID, + COL_MAKE_ID, + COL_MODEL + ) } - val airbyteStream = ConfiguredAirbyteStream() - .withStream(CatalogHelpers.createAirbyteStream( - MODELS_STREAM_NAME + "_2", - modelsSchema(), - Field.of(COL_ID, JsonSchemaType.INTEGER), - Field.of(COL_MAKE_ID, JsonSchemaType.INTEGER), - Field.of(COL_MODEL, JsonSchemaType.STRING)) + val airbyteStream = + ConfiguredAirbyteStream() + .withStream( + CatalogHelpers.createAirbyteStream( + MODELS_STREAM_NAME + "_2", + modelsSchema(), + Field.of(COL_ID, JsonSchemaType.INTEGER), + Field.of(COL_MAKE_ID, JsonSchemaType.INTEGER), + Field.of(COL_MODEL, JsonSchemaType.STRING) + ) .withSupportedSyncModes( - Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL)) - .withSourceDefinedPrimaryKey(java.util.List.of(java.util.List.of(COL_ID)))) + Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL) + ) + .withSourceDefinedPrimaryKey(java.util.List.of(java.util.List.of(COL_ID))) + ) airbyteStream.syncMode = SyncMode.FULL_REFRESH val streams = configuredCatalog.streams streams.add(airbyteStream) configuredCatalog.withStreams(streams) - val read1 = source() - .read(config()!!, configuredCatalog, null) + val read1 = source().read(config()!!, configuredCatalog, null) val actualRecords1 = AutoCloseableIterators.toListAndClose(read1) val recordMessages1 = extractRecordMessages(actualRecords1) @@ -468,21 +603,22 @@ abstract class CdcSourceTest?> { assertExpectedStateMessages(stateMessages1) // Full refresh does not get any state messages. assertExpectedStateMessageCountMatches(stateMessages1, MODEL_RECORDS_2.size.toLong()) - assertExpectedRecords(Streams.concat(MODEL_RECORDS_2.stream(), MODEL_RECORDS.stream()) + assertExpectedRecords( + Streams.concat(MODEL_RECORDS_2.stream(), MODEL_RECORDS.stream()) .collect(Collectors.toSet()), - recordMessages1, - setOf(MODELS_STREAM_NAME), - names, - modelsSchema()) - - val puntoRecord = Jsons - .jsonNode(ImmutableMap.of(COL_ID, 100, COL_MAKE_ID, 3, COL_MODEL, "Punto")) + recordMessages1, + setOf(MODELS_STREAM_NAME), + names, + modelsSchema() + ) + + val puntoRecord = + Jsons.jsonNode(ImmutableMap.of(COL_ID, 100, COL_MAKE_ID, 3, COL_MODEL, "Punto")) writeModelRecord(puntoRecord) waitForCdcRecords(modelsSchema(), MODELS_STREAM_NAME, 1) val state = Jsons.jsonNode(listOf(stateMessages1[stateMessages1.size - 1])) - val read2 = source() - .read(config()!!, configuredCatalog, state) + val read2 = source().read(config()!!, configuredCatalog, state) val actualRecords2 = AutoCloseableIterators.toListAndClose(read2) val recordMessages2 = extractRecordMessages(actualRecords2) @@ -490,12 +626,13 @@ abstract class CdcSourceTest?> { assertExpectedStateMessagesFromIncrementalSync(stateMessages2) assertExpectedStateMessageCountMatches(stateMessages2, 1) assertExpectedRecords( - Streams.concat(MODEL_RECORDS_2.stream(), Stream.of(puntoRecord)) - .collect(Collectors.toSet()), - recordMessages2, - setOf(MODELS_STREAM_NAME), - names, - modelsSchema()) + Streams.concat(MODEL_RECORDS_2.stream(), Stream.of(puntoRecord)) + .collect(Collectors.toSet()), + recordMessages2, + setOf(MODELS_STREAM_NAME), + names, + modelsSchema() + ) } @Test // When no records exist, no records are returned. @@ -517,17 +654,17 @@ abstract class CdcSourceTest?> { assertExpectedStateMessages(stateMessages) } - @Test // When no changes have been made to the database since the previous sync, no records are returned. + @Test // When no changes have been made to the database since the previous sync, no records are + // returned. @Throws(Exception::class) fun testNoDataOnSecondSync() { - val read1 = source() - .read(config()!!, configuredCatalog, null) + val read1 = source().read(config()!!, configuredCatalog, null) val actualRecords1 = AutoCloseableIterators.toListAndClose(read1) val stateMessagesFromFirstSync = extractStateMessages(actualRecords1) - val state = Jsons.jsonNode(listOf(stateMessagesFromFirstSync[stateMessagesFromFirstSync.size - 1])) + val state = + Jsons.jsonNode(listOf(stateMessagesFromFirstSync[stateMessagesFromFirstSync.size - 1])) - val read2 = source() - .read(config()!!, configuredCatalog, state) + val read2 = source().read(config()!!, configuredCatalog, state) val actualRecords2 = AutoCloseableIterators.toListAndClose(read2) val recordMessages2 = extractRecordMessages(actualRecords2) @@ -552,34 +689,46 @@ abstract class CdcSourceTest?> { val actualCatalog = source()!!.discover(config()!!) Assertions.assertEquals( - expectedCatalog.streams.stream().sorted(Comparator.comparing { obj: AirbyteStream -> obj.name }) - .collect(Collectors.toList()), - actualCatalog!!.streams.stream().sorted(Comparator.comparing { obj: AirbyteStream -> obj.name }) - .collect(Collectors.toList())) + expectedCatalog.streams + .stream() + .sorted(Comparator.comparing { obj: AirbyteStream -> obj.name }) + .collect(Collectors.toList()), + actualCatalog!! + .streams + .stream() + .sorted(Comparator.comparing { obj: AirbyteStream -> obj.name }) + .collect(Collectors.toList()) + ) } @Test @Throws(Exception::class) fun newTableSnapshotTest() { - val firstBatchIterator = source() - .read(config()!!, configuredCatalog, null) - val dataFromFirstBatch = AutoCloseableIterators - .toListAndClose(firstBatchIterator) - val recordsFromFirstBatch = extractRecordMessages( - dataFromFirstBatch) + val firstBatchIterator = source().read(config()!!, configuredCatalog, null) + val dataFromFirstBatch = AutoCloseableIterators.toListAndClose(firstBatchIterator) + val recordsFromFirstBatch = extractRecordMessages(dataFromFirstBatch) val stateAfterFirstBatch = extractStateMessages(dataFromFirstBatch) assertExpectedStateMessages(stateAfterFirstBatch) assertExpectedStateMessageCountMatches(stateAfterFirstBatch, MODEL_RECORDS.size.toLong()) - val stateMessageEmittedAfterFirstSyncCompletion = stateAfterFirstBatch[stateAfterFirstBatch.size - 1] - Assertions.assertEquals(AirbyteStateMessage.AirbyteStateType.GLOBAL, stateMessageEmittedAfterFirstSyncCompletion.type) + val stateMessageEmittedAfterFirstSyncCompletion = + stateAfterFirstBatch[stateAfterFirstBatch.size - 1] + Assertions.assertEquals( + AirbyteStateMessage.AirbyteStateType.GLOBAL, + stateMessageEmittedAfterFirstSyncCompletion.type + ) Assertions.assertNotNull(stateMessageEmittedAfterFirstSyncCompletion.global.sharedState) - val streamsInStateAfterFirstSyncCompletion = stateMessageEmittedAfterFirstSyncCompletion.global.streamStates + val streamsInStateAfterFirstSyncCompletion = + stateMessageEmittedAfterFirstSyncCompletion.global.streamStates .stream() .map { obj: AirbyteStreamState -> obj.streamDescriptor } .collect(Collectors.toSet()) Assertions.assertEquals(1, streamsInStateAfterFirstSyncCompletion.size) - Assertions.assertTrue(streamsInStateAfterFirstSyncCompletion.contains(StreamDescriptor().withName(MODELS_STREAM_NAME).withNamespace(modelsSchema()))) + Assertions.assertTrue( + streamsInStateAfterFirstSyncCompletion.contains( + StreamDescriptor().withName(MODELS_STREAM_NAME).withNamespace(modelsSchema()) + ) + ) Assertions.assertNotNull(stateMessageEmittedAfterFirstSyncCompletion.data) Assertions.assertEquals((MODEL_RECORDS.size), recordsFromFirstBatch.size) @@ -587,18 +736,31 @@ abstract class CdcSourceTest?> { val state = stateAfterFirstBatch[stateAfterFirstBatch.size - 1].data - val newTables = CatalogHelpers - .toDefaultConfiguredCatalog(AirbyteCatalog().withStreams(java.util.List.of( - CatalogHelpers.createAirbyteStream( - RANDOM_TABLE_NAME, - randomSchema(), - Field.of(COL_ID + "_random", JsonSchemaType.NUMBER), - Field.of(COL_MAKE_ID + "_random", JsonSchemaType.NUMBER), - Field.of(COL_MODEL + "_random", JsonSchemaType.STRING)) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL)) - .withSourceDefinedPrimaryKey(java.util.List.of(java.util.List.of(COL_ID + "_random")))))) - - newTables.streams.forEach(Consumer { s: ConfiguredAirbyteStream -> s.syncMode = SyncMode.INCREMENTAL }) + val newTables = + CatalogHelpers.toDefaultConfiguredCatalog( + AirbyteCatalog() + .withStreams( + java.util.List.of( + CatalogHelpers.createAirbyteStream( + RANDOM_TABLE_NAME, + randomSchema(), + Field.of(COL_ID + "_random", JsonSchemaType.NUMBER), + Field.of(COL_MAKE_ID + "_random", JsonSchemaType.NUMBER), + Field.of(COL_MODEL + "_random", JsonSchemaType.STRING) + ) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL) + ) + .withSourceDefinedPrimaryKey( + java.util.List.of(java.util.List.of(COL_ID + "_random")) + ) + ) + ) + ) + + newTables.streams.forEach( + Consumer { s: ConfiguredAirbyteStream -> s.syncMode = SyncMode.INCREMENTAL } + ) val combinedStreams: MutableList = ArrayList() combinedStreams.addAll(configuredCatalog.streams) combinedStreams.addAll(newTables.streams) @@ -606,25 +768,33 @@ abstract class CdcSourceTest?> { val updatedCatalog = ConfiguredAirbyteCatalog().withStreams(combinedStreams) /* - * Write 20 records to the existing table - */ + * Write 20 records to the existing table + */ val recordsWritten: MutableSet = HashSet() for (recordsCreated in 0..19) { val record = - Jsons.jsonNode(ImmutableMap - .of(COL_ID, 100 + recordsCreated, COL_MAKE_ID, 1, COL_MODEL, - "F-$recordsCreated")) + Jsons.jsonNode( + ImmutableMap.of( + COL_ID, + 100 + recordsCreated, + COL_MAKE_ID, + 1, + COL_MODEL, + "F-$recordsCreated" + ) + ) recordsWritten.add(record) writeModelRecord(record) } - val secondBatchIterator = source() - .read(config()!!, updatedCatalog, state) - val dataFromSecondBatch = AutoCloseableIterators - .toListAndClose(secondBatchIterator) + val secondBatchIterator = source().read(config()!!, updatedCatalog, state) + val dataFromSecondBatch = AutoCloseableIterators.toListAndClose(secondBatchIterator) val stateAfterSecondBatch = extractStateMessages(dataFromSecondBatch) - assertStateMessagesForNewTableSnapshotTest(stateAfterSecondBatch, stateMessageEmittedAfterFirstSyncCompletion) + assertStateMessagesForNewTableSnapshotTest( + stateAfterSecondBatch, + stateMessageEmittedAfterFirstSyncCompletion + ) val recordsStreamWise = extractRecordMessagesStreamWise(dataFromSecondBatch) Assertions.assertTrue(recordsStreamWise.containsKey(MODELS_STREAM_NAME)) @@ -633,118 +803,194 @@ abstract class CdcSourceTest?> { val recordsForModelsStreamFromSecondBatch = recordsStreamWise[MODELS_STREAM_NAME]!! val recordsForModelsRandomStreamFromSecondBatch = recordsStreamWise[RANDOM_TABLE_NAME]!! - Assertions.assertEquals((MODEL_RECORDS_RANDOM.size), recordsForModelsRandomStreamFromSecondBatch.size) + Assertions.assertEquals( + (MODEL_RECORDS_RANDOM.size), + recordsForModelsRandomStreamFromSecondBatch.size + ) Assertions.assertEquals(20, recordsForModelsStreamFromSecondBatch.size) - assertExpectedRecords(HashSet(MODEL_RECORDS_RANDOM), recordsForModelsRandomStreamFromSecondBatch, - recordsForModelsRandomStreamFromSecondBatch.stream().map { obj: AirbyteRecordMessage -> obj.stream }.collect( - Collectors.toSet()), - Sets - .newHashSet(RANDOM_TABLE_NAME), - randomSchema()) + assertExpectedRecords( + HashSet(MODEL_RECORDS_RANDOM), + recordsForModelsRandomStreamFromSecondBatch, + recordsForModelsRandomStreamFromSecondBatch + .stream() + .map { obj: AirbyteRecordMessage -> obj.stream } + .collect(Collectors.toSet()), + Sets.newHashSet(RANDOM_TABLE_NAME), + randomSchema() + ) assertExpectedRecords(recordsWritten, recordsForModelsStreamFromSecondBatch) /* - * Write 20 records to both the tables - */ + * Write 20 records to both the tables + */ val recordsWrittenInRandomTable: MutableSet = HashSet() recordsWritten.clear() for (recordsCreated in 30..49) { val record = - Jsons.jsonNode(ImmutableMap - .of(COL_ID, 100 + recordsCreated, COL_MAKE_ID, 1, COL_MODEL, - "F-$recordsCreated")) + Jsons.jsonNode( + ImmutableMap.of( + COL_ID, + 100 + recordsCreated, + COL_MAKE_ID, + 1, + COL_MODEL, + "F-$recordsCreated" + ) + ) writeModelRecord(record) recordsWritten.add(record) - val record2 = Jsons - .jsonNode(ImmutableMap - .of(COL_ID + "_random", 11000 + recordsCreated, COL_MAKE_ID + "_random", 1 + recordsCreated, COL_MODEL + "_random", - "Fiesta-random$recordsCreated")) - writeRecords(record2, randomSchema(), RANDOM_TABLE_NAME, - COL_ID + "_random", COL_MAKE_ID + "_random", COL_MODEL + "_random") + val record2 = + Jsons.jsonNode( + ImmutableMap.of( + COL_ID + "_random", + 11000 + recordsCreated, + COL_MAKE_ID + "_random", + 1 + recordsCreated, + COL_MODEL + "_random", + "Fiesta-random$recordsCreated" + ) + ) + writeRecords( + record2, + randomSchema(), + RANDOM_TABLE_NAME, + COL_ID + "_random", + COL_MAKE_ID + "_random", + COL_MODEL + "_random" + ) recordsWrittenInRandomTable.add(record2) } val state2 = stateAfterSecondBatch[stateAfterSecondBatch.size - 1].data - val thirdBatchIterator = source() - .read(config()!!, updatedCatalog, state2) - val dataFromThirdBatch = AutoCloseableIterators - .toListAndClose(thirdBatchIterator) + val thirdBatchIterator = source().read(config()!!, updatedCatalog, state2) + val dataFromThirdBatch = AutoCloseableIterators.toListAndClose(thirdBatchIterator) val stateAfterThirdBatch = extractStateMessages(dataFromThirdBatch) Assertions.assertTrue(stateAfterThirdBatch.size >= 1) - val stateMessageEmittedAfterThirdSyncCompletion = stateAfterThirdBatch[stateAfterThirdBatch.size - 1] - Assertions.assertEquals(AirbyteStateMessage.AirbyteStateType.GLOBAL, stateMessageEmittedAfterThirdSyncCompletion.type) - Assertions.assertNotEquals(stateMessageEmittedAfterThirdSyncCompletion.global.sharedState, - stateAfterSecondBatch[stateAfterSecondBatch.size - 1].global.sharedState) - val streamsInSyncCompletionStateAfterThirdSync = stateMessageEmittedAfterThirdSyncCompletion.global.streamStates + val stateMessageEmittedAfterThirdSyncCompletion = + stateAfterThirdBatch[stateAfterThirdBatch.size - 1] + Assertions.assertEquals( + AirbyteStateMessage.AirbyteStateType.GLOBAL, + stateMessageEmittedAfterThirdSyncCompletion.type + ) + Assertions.assertNotEquals( + stateMessageEmittedAfterThirdSyncCompletion.global.sharedState, + stateAfterSecondBatch[stateAfterSecondBatch.size - 1].global.sharedState + ) + val streamsInSyncCompletionStateAfterThirdSync = + stateMessageEmittedAfterThirdSyncCompletion.global.streamStates .stream() .map { obj: AirbyteStreamState -> obj.streamDescriptor } .collect(Collectors.toSet()) Assertions.assertTrue( - streamsInSyncCompletionStateAfterThirdSync.contains( - StreamDescriptor().withName(RANDOM_TABLE_NAME).withNamespace(randomSchema()))) + streamsInSyncCompletionStateAfterThirdSync.contains( + StreamDescriptor().withName(RANDOM_TABLE_NAME).withNamespace(randomSchema()) + ) + ) Assertions.assertTrue( - streamsInSyncCompletionStateAfterThirdSync.contains(StreamDescriptor().withName(MODELS_STREAM_NAME).withNamespace(modelsSchema()))) + streamsInSyncCompletionStateAfterThirdSync.contains( + StreamDescriptor().withName(MODELS_STREAM_NAME).withNamespace(modelsSchema()) + ) + ) Assertions.assertNotNull(stateMessageEmittedAfterThirdSyncCompletion.data) val recordsStreamWiseFromThirdBatch = extractRecordMessagesStreamWise(dataFromThirdBatch) Assertions.assertTrue(recordsStreamWiseFromThirdBatch.containsKey(MODELS_STREAM_NAME)) Assertions.assertTrue(recordsStreamWiseFromThirdBatch.containsKey(RANDOM_TABLE_NAME)) - val recordsForModelsStreamFromThirdBatch = recordsStreamWiseFromThirdBatch[MODELS_STREAM_NAME]!! - val recordsForModelsRandomStreamFromThirdBatch = recordsStreamWiseFromThirdBatch[RANDOM_TABLE_NAME]!! + val recordsForModelsStreamFromThirdBatch = + recordsStreamWiseFromThirdBatch[MODELS_STREAM_NAME]!! + val recordsForModelsRandomStreamFromThirdBatch = + recordsStreamWiseFromThirdBatch[RANDOM_TABLE_NAME]!! Assertions.assertEquals(20, recordsForModelsStreamFromThirdBatch.size) Assertions.assertEquals(20, recordsForModelsRandomStreamFromThirdBatch.size) assertExpectedRecords(recordsWritten, recordsForModelsStreamFromThirdBatch) - assertExpectedRecords(recordsWrittenInRandomTable, recordsForModelsRandomStreamFromThirdBatch, - recordsForModelsRandomStreamFromThirdBatch.stream().map { obj: AirbyteRecordMessage -> obj.stream }.collect( - Collectors.toSet()), - Sets - .newHashSet(RANDOM_TABLE_NAME), - randomSchema()) + assertExpectedRecords( + recordsWrittenInRandomTable, + recordsForModelsRandomStreamFromThirdBatch, + recordsForModelsRandomStreamFromThirdBatch + .stream() + .map { obj: AirbyteRecordMessage -> obj.stream } + .collect(Collectors.toSet()), + Sets.newHashSet(RANDOM_TABLE_NAME), + randomSchema() + ) } - protected fun assertStateMessagesForNewTableSnapshotTest(stateMessages: List, - stateMessageEmittedAfterFirstSyncCompletion: AirbyteStateMessage) { + protected fun assertStateMessagesForNewTableSnapshotTest( + stateMessages: List, + stateMessageEmittedAfterFirstSyncCompletion: AirbyteStateMessage + ) { Assertions.assertEquals(2, stateMessages.size) val stateMessageEmittedAfterSnapshotCompletionInSecondSync = stateMessages[0] - Assertions.assertEquals(AirbyteStateMessage.AirbyteStateType.GLOBAL, stateMessageEmittedAfterSnapshotCompletionInSecondSync.type) - Assertions.assertEquals(stateMessageEmittedAfterFirstSyncCompletion.global.sharedState, - stateMessageEmittedAfterSnapshotCompletionInSecondSync.global.sharedState) - val streamsInSnapshotState = stateMessageEmittedAfterSnapshotCompletionInSecondSync.global.streamStates + Assertions.assertEquals( + AirbyteStateMessage.AirbyteStateType.GLOBAL, + stateMessageEmittedAfterSnapshotCompletionInSecondSync.type + ) + Assertions.assertEquals( + stateMessageEmittedAfterFirstSyncCompletion.global.sharedState, + stateMessageEmittedAfterSnapshotCompletionInSecondSync.global.sharedState + ) + val streamsInSnapshotState = + stateMessageEmittedAfterSnapshotCompletionInSecondSync.global.streamStates .stream() .map { obj: AirbyteStreamState -> obj.streamDescriptor } .collect(Collectors.toSet()) Assertions.assertEquals(2, streamsInSnapshotState.size) Assertions.assertTrue( - streamsInSnapshotState.contains(StreamDescriptor().withName(RANDOM_TABLE_NAME).withNamespace(randomSchema()))) - Assertions.assertTrue(streamsInSnapshotState.contains(StreamDescriptor().withName(MODELS_STREAM_NAME).withNamespace(modelsSchema()))) + streamsInSnapshotState.contains( + StreamDescriptor().withName(RANDOM_TABLE_NAME).withNamespace(randomSchema()) + ) + ) + Assertions.assertTrue( + streamsInSnapshotState.contains( + StreamDescriptor().withName(MODELS_STREAM_NAME).withNamespace(modelsSchema()) + ) + ) Assertions.assertNotNull(stateMessageEmittedAfterSnapshotCompletionInSecondSync.data) val stateMessageEmittedAfterSecondSyncCompletion = stateMessages[1] - Assertions.assertEquals(AirbyteStateMessage.AirbyteStateType.GLOBAL, stateMessageEmittedAfterSecondSyncCompletion.type) - Assertions.assertNotEquals(stateMessageEmittedAfterFirstSyncCompletion.global.sharedState, - stateMessageEmittedAfterSecondSyncCompletion.global.sharedState) - val streamsInSyncCompletionState = stateMessageEmittedAfterSecondSyncCompletion.global.streamStates + Assertions.assertEquals( + AirbyteStateMessage.AirbyteStateType.GLOBAL, + stateMessageEmittedAfterSecondSyncCompletion.type + ) + Assertions.assertNotEquals( + stateMessageEmittedAfterFirstSyncCompletion.global.sharedState, + stateMessageEmittedAfterSecondSyncCompletion.global.sharedState + ) + val streamsInSyncCompletionState = + stateMessageEmittedAfterSecondSyncCompletion.global.streamStates .stream() .map { obj: AirbyteStreamState -> obj.streamDescriptor } .collect(Collectors.toSet()) Assertions.assertEquals(2, streamsInSnapshotState.size) Assertions.assertTrue( - streamsInSyncCompletionState.contains( - StreamDescriptor().withName(RANDOM_TABLE_NAME).withNamespace(randomSchema()))) - Assertions.assertTrue(streamsInSyncCompletionState.contains(StreamDescriptor().withName(MODELS_STREAM_NAME).withNamespace(modelsSchema()))) + streamsInSyncCompletionState.contains( + StreamDescriptor().withName(RANDOM_TABLE_NAME).withNamespace(randomSchema()) + ) + ) + Assertions.assertTrue( + streamsInSyncCompletionState.contains( + StreamDescriptor().withName(MODELS_STREAM_NAME).withNamespace(modelsSchema()) + ) + ) Assertions.assertNotNull(stateMessageEmittedAfterSecondSyncCompletion.data) } protected fun expectedCatalogForDiscover(): AirbyteCatalog { val expectedCatalog = Jsons.clone(catalog) - val columns = ImmutableMap.of(COL_ID, "INTEGER", COL_MAKE_ID, "INTEGER", COL_MODEL, "VARCHAR(200)") - testdb!!.with(createTableSqlFmt(), modelsSchema(), MODELS_STREAM_NAME + "_2", columnClause(columns, Optional.empty())) + val columns = + ImmutableMap.of(COL_ID, "INTEGER", COL_MAKE_ID, "INTEGER", COL_MODEL, "VARCHAR(200)") + testdb!!.with( + createTableSqlFmt(), + modelsSchema(), + MODELS_STREAM_NAME + "_2", + columnClause(columns, Optional.empty()) + ) val streams = expectedCatalog.streams // stream with PK @@ -752,26 +998,34 @@ abstract class CdcSourceTest?> { addCdcMetadataColumns(streams[0]) addCdcDefaultCursorField(streams[0]) - val streamWithoutPK = CatalogHelpers.createAirbyteStream( + val streamWithoutPK = + CatalogHelpers.createAirbyteStream( MODELS_STREAM_NAME + "_2", modelsSchema(), Field.of(COL_ID, JsonSchemaType.INTEGER), Field.of(COL_MAKE_ID, JsonSchemaType.INTEGER), - Field.of(COL_MODEL, JsonSchemaType.STRING)) + Field.of(COL_MODEL, JsonSchemaType.STRING) + ) streamWithoutPK.sourceDefinedPrimaryKey = emptyList() streamWithoutPK.supportedSyncModes = java.util.List.of(SyncMode.FULL_REFRESH) addCdcDefaultCursorField(streamWithoutPK) addCdcMetadataColumns(streamWithoutPK) - val randomStream = CatalogHelpers.createAirbyteStream( - RANDOM_TABLE_NAME, - randomSchema(), - Field.of(COL_ID + "_random", JsonSchemaType.INTEGER), - Field.of(COL_MAKE_ID + "_random", JsonSchemaType.INTEGER), - Field.of(COL_MODEL + "_random", JsonSchemaType.STRING)) + val randomStream = + CatalogHelpers.createAirbyteStream( + RANDOM_TABLE_NAME, + randomSchema(), + Field.of(COL_ID + "_random", JsonSchemaType.INTEGER), + Field.of(COL_MAKE_ID + "_random", JsonSchemaType.INTEGER), + Field.of(COL_MODEL + "_random", JsonSchemaType.STRING) + ) .withSourceDefinedCursor(true) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL)) - .withSourceDefinedPrimaryKey(java.util.List.of(java.util.List.of(COL_ID + "_random"))) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL) + ) + .withSourceDefinedPrimaryKey( + java.util.List.of(java.util.List.of(COL_ID + "_random")) + ) addCdcDefaultCursorField(randomStream) addCdcMetadataColumns(randomStream) @@ -783,8 +1037,7 @@ abstract class CdcSourceTest?> { } @Throws(Exception::class) - protected fun waitForCdcRecords(schemaName: String?, tableName: String?, recordCount: Int) { - } + protected fun waitForCdcRecords(schemaName: String?, tableName: String?, recordCount: Int) {} companion object { private val LOGGER: Logger = LoggerFactory.getLogger(CdcSourceTest::class.java) @@ -795,26 +1048,37 @@ abstract class CdcSourceTest?> { protected const val COL_MAKE_ID: String = "make_id" protected const val COL_MODEL: String = "model" - protected val MODEL_RECORDS: List = ImmutableList.of( + protected val MODEL_RECORDS: List = + ImmutableList.of( Jsons.jsonNode(ImmutableMap.of(COL_ID, 11, COL_MAKE_ID, 1, COL_MODEL, "Fiesta")), Jsons.jsonNode(ImmutableMap.of(COL_ID, 12, COL_MAKE_ID, 1, COL_MODEL, "Focus")), Jsons.jsonNode(ImmutableMap.of(COL_ID, 13, COL_MAKE_ID, 1, COL_MODEL, "Ranger")), Jsons.jsonNode(ImmutableMap.of(COL_ID, 14, COL_MAKE_ID, 2, COL_MODEL, "GLA")), Jsons.jsonNode(ImmutableMap.of(COL_ID, 15, COL_MAKE_ID, 2, COL_MODEL, "A 220")), - Jsons.jsonNode(ImmutableMap.of(COL_ID, 16, COL_MAKE_ID, 2, COL_MODEL, "E 350"))) + Jsons.jsonNode(ImmutableMap.of(COL_ID, 16, COL_MAKE_ID, 2, COL_MODEL, "E 350")) + ) protected const val RANDOM_TABLE_NAME: String = MODELS_STREAM_NAME + "_random" - protected val MODEL_RECORDS_RANDOM: List = MODEL_RECORDS.stream() + protected val MODEL_RECORDS_RANDOM: List = + MODEL_RECORDS.stream() .map { r: JsonNode -> - Jsons.jsonNode(ImmutableMap.of( - COL_ID + "_random", r[COL_ID].asInt() * 1000, - COL_MAKE_ID + "_random", r[COL_MAKE_ID], - COL_MODEL + "_random", r[COL_MODEL].asText() + "-random")) + Jsons.jsonNode( + ImmutableMap.of( + COL_ID + "_random", + r[COL_ID].asInt() * 1000, + COL_MAKE_ID + "_random", + r[COL_MAKE_ID], + COL_MODEL + "_random", + r[COL_MODEL].asText() + "-random" + ) + ) } .toList() - protected fun removeDuplicates(messages: Set): Set { + protected fun removeDuplicates( + messages: Set + ): Set { val existingDataRecordsWithoutUpdated: MutableSet = HashSet() val output: MutableSet = HashSet() diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/debug/DebugUtil.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/debug/DebugUtil.kt index e9d1c955534f3..d04a8ea5e0146 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/debug/DebugUtil.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/debug/DebugUtil.kt @@ -23,11 +23,12 @@ object DebugUtil { fun debug(debugSource: Source) { val debugConfig = config val configuredAirbyteCatalog = catalog - var state = try { - state - } catch (e: Exception) { - null - } + var state = + try { + state + } catch (e: Exception) { + null + } debugSource.check(debugConfig) debugSource.discover(debugConfig) @@ -39,8 +40,10 @@ object DebugUtil { @get:Throws(Exception::class) private val config: JsonNode get() { - val originalConfig = ObjectMapper().readTree(MoreResources.readResource("debug_resources/config.json")) - val debugConfig: JsonNode = (originalConfig.deepCopy() as ObjectNode).put("debug_mode", true) + val originalConfig = + ObjectMapper().readTree(MoreResources.readResource("debug_resources/config.json")) + val debugConfig: JsonNode = + (originalConfig.deepCopy() as ObjectNode).put("debug_mode", true) return debugConfig } @@ -54,7 +57,11 @@ object DebugUtil { @get:Throws(Exception::class) private val state: JsonNode get() { - val message = Jsons.deserialize(MoreResources.readResource("debug_resources/state.json"), AirbyteStateMessage::class.java) + val message = + Jsons.deserialize( + MoreResources.readResource("debug_resources/state.json"), + AirbyteStateMessage::class.java + ) return Jsons.jsonNode(listOf(message)) } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/source/jdbc/test/JdbcSourceAcceptanceTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/source/jdbc/test/JdbcSourceAcceptanceTest.kt index cbaf441c3b9ae..d8444667cfe2e 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/source/jdbc/test/JdbcSourceAcceptanceTest.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/source/jdbc/test/JdbcSourceAcceptanceTest.kt @@ -19,6 +19,11 @@ import io.airbyte.commons.util.MoreIterators import io.airbyte.protocol.models.Field import io.airbyte.protocol.models.JsonSchemaType import io.airbyte.protocol.models.v0.* +import java.math.BigDecimal +import java.sql.SQLException +import java.util.* +import java.util.function.Consumer +import java.util.stream.Collectors import org.hamcrest.MatcherAssert import org.hamcrest.Matchers import org.junit.jupiter.api.AfterEach @@ -26,18 +31,15 @@ import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test import org.mockito.Mockito -import java.math.BigDecimal -import java.sql.SQLException -import java.util.* -import java.util.function.Consumer -import java.util.stream.Collectors -/** - * Tests that should be run on all Sources that extend the AbstractJdbcSource. - */ -@SuppressFBWarnings(value = ["MS_SHOULD_BE_FINAL"], justification = "The static variables are updated in subclasses for convenience, and cannot be final.") -abstract class JdbcSourceAcceptanceTest?> { - protected var testdb: T? = null +/** Tests that should be run on all Sources that extend the AbstractJdbcSource. */ +@SuppressFBWarnings( + value = ["MS_SHOULD_BE_FINAL"], + justification = + "The static variables are updated in subclasses for convenience, and cannot be final." +) +abstract class JdbcSourceAcceptanceTest> { + protected lateinit var testdb: T protected fun streamName(): String { return TABLE_NAME @@ -65,19 +67,28 @@ abstract class JdbcSourceAcceptanceTest?> protected abstract fun createTestDatabase(): T /** - * These tests write records without specifying a namespace (schema name). They will be written into - * whatever the default schema is for the database. When they are discovered they will be namespaced - * by the schema name (e.g. .). Thus the source needs to tell the - * tests what that default schema name is. If the database does not support schemas, then database - * name should used instead. + * These tests write records without specifying a namespace (schema name). They will be written + * into whatever the default schema is for the database. When they are discovered they will be + * namespaced by the schema name (e.g. .). Thus the source + * needs to tell the tests what that default schema name is. If the database does not support + * schemas, then database name should used instead. * - * @return name that will be used to namespace the record. - */ + * @return name that will be used to namespace the record. + */ protected abstract fun supportsSchemas(): Boolean - protected fun createTableQuery(tableName: String?, columnClause: String?, primaryKeyClause: String): String { - return String.format("CREATE TABLE %s(%s %s %s)", - tableName, columnClause, if (primaryKeyClause == "") "" else ",", primaryKeyClause) + protected fun createTableQuery( + tableName: String?, + columnClause: String?, + primaryKeyClause: String + ): String { + return String.format( + "CREATE TABLE %s(%s %s %s)", + tableName, + columnClause, + if (primaryKeyClause == "") "" else ",", + primaryKeyClause + ) } protected fun primaryKeyClause(columns: List): String { @@ -108,26 +119,68 @@ abstract class JdbcSourceAcceptanceTest?> testdb!!.with("ALTER SESSION SET NLS_DATE_FORMAT = 'YYYY-MM-DD'") } testdb - .with(createTableQuery(getFullyQualifiedTableName(TABLE_NAME), COLUMN_CLAUSE_WITH_PK, primaryKeyClause(listOf("id")))) - .with("INSERT INTO %s(id, name, updated_at) VALUES (1, 'picard', '2004-10-19')", getFullyQualifiedTableName(TABLE_NAME)) - .with("INSERT INTO %s(id, name, updated_at) VALUES (2, 'crusher', '2005-10-19')", getFullyQualifiedTableName(TABLE_NAME)) - .with("INSERT INTO %s(id, name, updated_at) VALUES (3, 'vash', '2006-10-19')", getFullyQualifiedTableName(TABLE_NAME)) - .with(createTableQuery(getFullyQualifiedTableName(TABLE_NAME_WITHOUT_PK), COLUMN_CLAUSE_WITHOUT_PK, "")) - .with("INSERT INTO %s(id, name, updated_at) VALUES (1, 'picard', '2004-10-19')", getFullyQualifiedTableName(TABLE_NAME_WITHOUT_PK)) - .with("INSERT INTO %s(id, name, updated_at) VALUES (2, 'crusher', '2005-10-19')", getFullyQualifiedTableName(TABLE_NAME_WITHOUT_PK)) - .with("INSERT INTO %s(id, name, updated_at) VALUES (3, 'vash', '2006-10-19')", getFullyQualifiedTableName(TABLE_NAME_WITHOUT_PK)) - .with(createTableQuery(getFullyQualifiedTableName(TABLE_NAME_COMPOSITE_PK), COLUMN_CLAUSE_WITH_COMPOSITE_PK, - primaryKeyClause(listOf("first_name", "last_name")))) - .with("INSERT INTO %s(first_name, last_name, updated_at) VALUES ('first', 'picard', '2004-10-19')", - getFullyQualifiedTableName(TABLE_NAME_COMPOSITE_PK)) - .with("INSERT INTO %s(first_name, last_name, updated_at) VALUES ('second', 'crusher', '2005-10-19')", - getFullyQualifiedTableName(TABLE_NAME_COMPOSITE_PK)) - .with("INSERT INTO %s(first_name, last_name, updated_at) VALUES ('third', 'vash', '2006-10-19')", - getFullyQualifiedTableName(TABLE_NAME_COMPOSITE_PK)) + .with( + createTableQuery( + getFullyQualifiedTableName(TABLE_NAME), + COLUMN_CLAUSE_WITH_PK, + primaryKeyClause(listOf("id")) + ) + ) + .with( + "INSERT INTO %s(id, name, updated_at) VALUES (1, 'picard', '2004-10-19')", + getFullyQualifiedTableName(TABLE_NAME) + ) + .with( + "INSERT INTO %s(id, name, updated_at) VALUES (2, 'crusher', '2005-10-19')", + getFullyQualifiedTableName(TABLE_NAME) + ) + .with( + "INSERT INTO %s(id, name, updated_at) VALUES (3, 'vash', '2006-10-19')", + getFullyQualifiedTableName(TABLE_NAME) + ) + .with( + createTableQuery( + getFullyQualifiedTableName(TABLE_NAME_WITHOUT_PK), + COLUMN_CLAUSE_WITHOUT_PK, + "" + ) + ) + .with( + "INSERT INTO %s(id, name, updated_at) VALUES (1, 'picard', '2004-10-19')", + getFullyQualifiedTableName(TABLE_NAME_WITHOUT_PK) + ) + .with( + "INSERT INTO %s(id, name, updated_at) VALUES (2, 'crusher', '2005-10-19')", + getFullyQualifiedTableName(TABLE_NAME_WITHOUT_PK) + ) + .with( + "INSERT INTO %s(id, name, updated_at) VALUES (3, 'vash', '2006-10-19')", + getFullyQualifiedTableName(TABLE_NAME_WITHOUT_PK) + ) + .with( + createTableQuery( + getFullyQualifiedTableName(TABLE_NAME_COMPOSITE_PK), + COLUMN_CLAUSE_WITH_COMPOSITE_PK, + primaryKeyClause(listOf("first_name", "last_name")) + ) + ) + .with( + "INSERT INTO %s(first_name, last_name, updated_at) VALUES ('first', 'picard', '2004-10-19')", + getFullyQualifiedTableName(TABLE_NAME_COMPOSITE_PK) + ) + .with( + "INSERT INTO %s(first_name, last_name, updated_at) VALUES ('second', 'crusher', '2005-10-19')", + getFullyQualifiedTableName(TABLE_NAME_COMPOSITE_PK) + ) + .with( + "INSERT INTO %s(first_name, last_name, updated_at) VALUES ('third', 'vash', '2006-10-19')", + getFullyQualifiedTableName(TABLE_NAME_COMPOSITE_PK) + ) } protected fun maybeSetShorterConnectionTimeout(config: JsonNode?) { - // Optionally implement this to speed up test cases which will result in a connection timeout. + // Optionally implement this to speed up test cases which will result in a connection + // timeout. } @AfterEach @@ -149,7 +202,8 @@ abstract class JdbcSourceAcceptanceTest?> @Throws(Exception::class) fun testCheckSuccess() { val actual = source()!!.check(config()) - val expected = AirbyteConnectionStatus().withStatus(AirbyteConnectionStatus.Status.SUCCEEDED) + val expected = + AirbyteConnectionStatus().withStatus(AirbyteConnectionStatus.Status.SUCCEEDED) Assertions.assertEquals(expected, actual) } @@ -169,26 +223,53 @@ abstract class JdbcSourceAcceptanceTest?> val actual = filterOutOtherSchemas(source()!!.discover(config())) val expected = getCatalog(defaultNamespace) Assertions.assertEquals(expected.streams.size, actual!!.streams.size) - actual.streams.forEach(Consumer { actualStream: AirbyteStream -> - val expectedStream = - expected.streams.stream() - .filter { stream: AirbyteStream -> stream.namespace == actualStream.namespace && stream.name == actualStream.name } - .findAny() - Assertions.assertTrue(expectedStream.isPresent, String.format("Unexpected stream %s", actualStream.name)) - Assertions.assertEquals(expectedStream.get(), actualStream) - }) + actual.streams.forEach( + Consumer { actualStream: AirbyteStream -> + val expectedStream = + expected.streams + .stream() + .filter { stream: AirbyteStream -> + stream.namespace == actualStream.namespace && + stream.name == actualStream.name + } + .findAny() + Assertions.assertTrue( + expectedStream.isPresent, + String.format("Unexpected stream %s", actualStream.name) + ) + Assertions.assertEquals(expectedStream.get(), actualStream) + } + ) } @Test @Throws(Exception::class) protected fun testDiscoverWithNonCursorFields() { - testdb!!.with(CREATE_TABLE_WITHOUT_CURSOR_TYPE_QUERY, getFullyQualifiedTableName(TABLE_NAME_WITHOUT_CURSOR_TYPE), COL_CURSOR) - .with(INSERT_TABLE_WITHOUT_CURSOR_TYPE_QUERY, getFullyQualifiedTableName(TABLE_NAME_WITHOUT_CURSOR_TYPE)) + testdb!! + .with( + CREATE_TABLE_WITHOUT_CURSOR_TYPE_QUERY, + getFullyQualifiedTableName(TABLE_NAME_WITHOUT_CURSOR_TYPE), + COL_CURSOR + ) + .with( + INSERT_TABLE_WITHOUT_CURSOR_TYPE_QUERY, + getFullyQualifiedTableName(TABLE_NAME_WITHOUT_CURSOR_TYPE) + ) val actual = filterOutOtherSchemas(source()!!.discover(config())) val stream = - actual!!.streams.stream().filter { s: AirbyteStream -> s.name.equals(TABLE_NAME_WITHOUT_CURSOR_TYPE, ignoreCase = true) }.findFirst().orElse(null) + actual!! + .streams + .stream() + .filter { s: AirbyteStream -> + s.name.equals(TABLE_NAME_WITHOUT_CURSOR_TYPE, ignoreCase = true) + } + .findFirst() + .orElse(null) Assertions.assertNotNull(stream) - Assertions.assertEquals(TABLE_NAME_WITHOUT_CURSOR_TYPE.lowercase(Locale.getDefault()), stream.name.lowercase(Locale.getDefault())) + Assertions.assertEquals( + TABLE_NAME_WITHOUT_CURSOR_TYPE.lowercase(Locale.getDefault()), + stream.name.lowercase(Locale.getDefault()) + ) Assertions.assertEquals(1, stream.supportedSyncModes.size) Assertions.assertEquals(SyncMode.FULL_REFRESH, stream.supportedSyncModes[0]) } @@ -196,13 +277,31 @@ abstract class JdbcSourceAcceptanceTest?> @Test @Throws(Exception::class) protected fun testDiscoverWithNullableCursorFields() { - testdb!!.with(CREATE_TABLE_WITH_NULLABLE_CURSOR_TYPE_QUERY, getFullyQualifiedTableName(TABLE_NAME_WITH_NULLABLE_CURSOR_TYPE), COL_CURSOR) - .with(INSERT_TABLE_WITH_NULLABLE_CURSOR_TYPE_QUERY, getFullyQualifiedTableName(TABLE_NAME_WITH_NULLABLE_CURSOR_TYPE)) + testdb!! + .with( + CREATE_TABLE_WITH_NULLABLE_CURSOR_TYPE_QUERY, + getFullyQualifiedTableName(TABLE_NAME_WITH_NULLABLE_CURSOR_TYPE), + COL_CURSOR + ) + .with( + INSERT_TABLE_WITH_NULLABLE_CURSOR_TYPE_QUERY, + getFullyQualifiedTableName(TABLE_NAME_WITH_NULLABLE_CURSOR_TYPE) + ) val actual = filterOutOtherSchemas(source()!!.discover(config())) val stream = - actual!!.streams.stream().filter { s: AirbyteStream -> s.name.equals(TABLE_NAME_WITH_NULLABLE_CURSOR_TYPE, ignoreCase = true) }.findFirst().orElse(null) + actual!! + .streams + .stream() + .filter { s: AirbyteStream -> + s.name.equals(TABLE_NAME_WITH_NULLABLE_CURSOR_TYPE, ignoreCase = true) + } + .findFirst() + .orElse(null) Assertions.assertNotNull(stream) - Assertions.assertEquals(TABLE_NAME_WITH_NULLABLE_CURSOR_TYPE.lowercase(Locale.getDefault()), stream.name.lowercase(Locale.getDefault())) + Assertions.assertEquals( + TABLE_NAME_WITH_NULLABLE_CURSOR_TYPE.lowercase(Locale.getDefault()), + stream.name.lowercase(Locale.getDefault()) + ) Assertions.assertEquals(2, stream.supportedSyncModes.size) Assertions.assertTrue(stream.supportedSyncModes.contains(SyncMode.FULL_REFRESH)) Assertions.assertTrue(stream.supportedSyncModes.contains(SyncMode.INCREMENTAL)) @@ -211,9 +310,14 @@ abstract class JdbcSourceAcceptanceTest?> protected fun filterOutOtherSchemas(catalog: AirbyteCatalog?): AirbyteCatalog? { if (supportsSchemas()) { val filteredCatalog = Jsons.clone(catalog) - filteredCatalog!!.streams = filteredCatalog.streams + filteredCatalog!!.streams = + filteredCatalog.streams .stream() - .filter { stream: AirbyteStream -> TEST_SCHEMAS.stream().anyMatch { schemaName: String? -> stream.namespace.startsWith(schemaName!!) } } + .filter { stream: AirbyteStream -> + TEST_SCHEMAS.stream().anyMatch { schemaName: String? -> + stream.namespace.startsWith(schemaName!!) + } + } .collect(Collectors.toList()) return filteredCatalog } else { @@ -224,36 +328,55 @@ abstract class JdbcSourceAcceptanceTest?> @Test @Throws(Exception::class) protected fun testDiscoverWithMultipleSchemas() { - // clickhouse and mysql do not have a concept of schemas, so this test does not make sense for them. + // clickhouse and mysql do not have a concept of schemas, so this test does not make sense + // for them. when (testdb!!.databaseDriver) { - DatabaseDriver.MYSQL, DatabaseDriver.CLICKHOUSE, DatabaseDriver.TERADATA -> return + DatabaseDriver.MYSQL, + DatabaseDriver.CLICKHOUSE, + DatabaseDriver.TERADATA -> return + else -> {} } // add table and data to a separate schema. - testdb!!.with("CREATE TABLE %s(id VARCHAR(200) NOT NULL, name VARCHAR(200) NOT NULL)", - RelationalDbQueryUtils.getFullyQualifiedTableName(SCHEMA_NAME2, TABLE_NAME)) - .with("INSERT INTO %s(id, name) VALUES ('1','picard')", - RelationalDbQueryUtils.getFullyQualifiedTableName(SCHEMA_NAME2, TABLE_NAME)) - .with("INSERT INTO %s(id, name) VALUES ('2', 'crusher')", - RelationalDbQueryUtils.getFullyQualifiedTableName(SCHEMA_NAME2, TABLE_NAME)) - .with("INSERT INTO %s(id, name) VALUES ('3', 'vash')", - RelationalDbQueryUtils.getFullyQualifiedTableName(SCHEMA_NAME2, TABLE_NAME)) + testdb!! + .with( + "CREATE TABLE %s(id VARCHAR(200) NOT NULL, name VARCHAR(200) NOT NULL)", + RelationalDbQueryUtils.getFullyQualifiedTableName(SCHEMA_NAME2, TABLE_NAME) + ) + .with( + "INSERT INTO %s(id, name) VALUES ('1','picard')", + RelationalDbQueryUtils.getFullyQualifiedTableName(SCHEMA_NAME2, TABLE_NAME) + ) + .with( + "INSERT INTO %s(id, name) VALUES ('2', 'crusher')", + RelationalDbQueryUtils.getFullyQualifiedTableName(SCHEMA_NAME2, TABLE_NAME) + ) + .with( + "INSERT INTO %s(id, name) VALUES ('3', 'vash')", + RelationalDbQueryUtils.getFullyQualifiedTableName(SCHEMA_NAME2, TABLE_NAME) + ) val actual = source()!!.discover(config()) val expected = getCatalog(defaultNamespace) val catalogStreams: MutableList = ArrayList() catalogStreams.addAll(expected.streams) - catalogStreams.add(CatalogHelpers - .createAirbyteStream(TABLE_NAME, - SCHEMA_NAME2, - Field.of(COL_ID, JsonSchemaType.STRING), - Field.of(COL_NAME, JsonSchemaType.STRING)) - .withSupportedSyncModes(java.util.List.of(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL))) + catalogStreams.add( + CatalogHelpers.createAirbyteStream( + TABLE_NAME, + SCHEMA_NAME2, + Field.of(COL_ID, JsonSchemaType.STRING), + Field.of(COL_NAME, JsonSchemaType.STRING) + ) + .withSupportedSyncModes( + java.util.List.of(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL) + ) + ) expected.streams = catalogStreams // sort streams by name so that we are comparing lists with the same order. - val schemaTableCompare = Comparator.comparing { stream: AirbyteStream -> stream.namespace + "." + stream.name } - expected.streams.sort(schemaTableCompare) - actual!!.streams.sort(schemaTableCompare) + val schemaTableCompare = + Comparator.comparing { stream: AirbyteStream -> stream.namespace + "." + stream.name } + expected.streams.sortWith(schemaTableCompare) + actual!!.streams.sortWith(schemaTableCompare) Assertions.assertEquals(expected, filterOutOtherSchemas(actual)) } @@ -261,22 +384,32 @@ abstract class JdbcSourceAcceptanceTest?> @Throws(Exception::class) fun testReadSuccess() { val actualMessages = - MoreIterators.toList( - source()!!.read(config(), getConfiguredCatalogWithOneStream(defaultNamespace), null)) + MoreIterators.toList( + source()!!.read(config(), getConfiguredCatalogWithOneStream(defaultNamespace), null) + ) setEmittedAtToNull(actualMessages) val expectedMessages = testMessages - MatcherAssert.assertThat(expectedMessages, Matchers.containsInAnyOrder(*actualMessages.toTypedArray())) - MatcherAssert.assertThat(actualMessages, Matchers.containsInAnyOrder(*expectedMessages.toTypedArray())) + MatcherAssert.assertThat( + expectedMessages, + Matchers.containsInAnyOrder(*actualMessages.toTypedArray()) + ) + MatcherAssert.assertThat( + actualMessages, + Matchers.containsInAnyOrder(*expectedMessages.toTypedArray()) + ) } @Test @Throws(Exception::class) protected fun testReadOneColumn() { - val catalog = CatalogHelpers - .createConfiguredAirbyteCatalog(streamName(), defaultNamespace, Field.of(COL_ID, JsonSchemaType.NUMBER)) - val actualMessages = MoreIterators - .toList(source()!!.read(config(), catalog, null)) + val catalog = + CatalogHelpers.createConfiguredAirbyteCatalog( + streamName(), + defaultNamespace, + Field.of(COL_ID, JsonSchemaType.NUMBER) + ) + val actualMessages = MoreIterators.toList(source()!!.read(config(), catalog, null)) setEmittedAtToNull(actualMessages) @@ -288,13 +421,17 @@ abstract class JdbcSourceAcceptanceTest?> protected val airbyteMessagesReadOneColumn: List get() { - val expectedMessages = testMessages.stream() - .map { `object`: AirbyteMessage? -> Jsons.clone(`object`) } + val expectedMessages = + testMessages + .stream() + .map { `object`: AirbyteMessage -> Jsons.clone(`object`) } .peek { m: AirbyteMessage -> (m.record.data as ObjectNode).remove(COL_NAME) (m.record.data as ObjectNode).remove(COL_UPDATED_AT) - (m.record.data as ObjectNode).replace(COL_ID, - convertIdBasedOnDatabase(m.record.data[COL_ID].asInt())) + (m.record.data as ObjectNode).replace( + COL_ID, + convertIdBasedOnDatabase(m.record.data[COL_ID].asInt()) + ) } .collect(Collectors.toList()) return expectedMessages @@ -303,28 +440,30 @@ abstract class JdbcSourceAcceptanceTest?> @Test @Throws(Exception::class) protected fun testReadMultipleTables() { - val catalog = getConfiguredCatalogWithOneStream( - defaultNamespace) + val catalog = getConfiguredCatalogWithOneStream(defaultNamespace) val expectedMessages: MutableList = ArrayList(testMessages) for (i in 2..9) { val streamName2 = streamName() + i val tableName = getFullyQualifiedTableName(TABLE_NAME + i) - testdb!!.with(createTableQuery(tableName, "id INTEGER, name VARCHAR(200)", "")) - .with("INSERT INTO %s(id, name) VALUES (1,'picard')", tableName) - .with("INSERT INTO %s(id, name) VALUES (2, 'crusher')", tableName) - .with("INSERT INTO %s(id, name) VALUES (3, 'vash')", tableName) - catalog.streams.add(CatalogHelpers.createConfiguredAirbyteStream( + testdb!! + .with(createTableQuery(tableName, "id INTEGER, name VARCHAR(200)", "")) + .with("INSERT INTO %s(id, name) VALUES (1,'picard')", tableName) + .with("INSERT INTO %s(id, name) VALUES (2, 'crusher')", tableName) + .with("INSERT INTO %s(id, name) VALUES (3, 'vash')", tableName) + catalog.streams.add( + CatalogHelpers.createConfiguredAirbyteStream( streamName2, defaultNamespace, Field.of(COL_ID, JsonSchemaType.NUMBER), - Field.of(COL_NAME, JsonSchemaType.STRING))) + Field.of(COL_NAME, JsonSchemaType.STRING) + ) + ) expectedMessages.addAll(getAirbyteMessagesSecondSync(streamName2)) } - val actualMessages = MoreIterators - .toList(source()!!.read(config(), catalog, null)) + val actualMessages = MoreIterators.toList(source()!!.read(config(), catalog, null)) setEmittedAtToNull(actualMessages) @@ -335,16 +474,18 @@ abstract class JdbcSourceAcceptanceTest?> protected fun getAirbyteMessagesSecondSync(streamName: String?): List { return testMessages - .stream() - .map { `object`: AirbyteMessage? -> Jsons.clone(`object`) } - .peek { m: AirbyteMessage -> - m.record.stream = streamName - m.record.namespace = defaultNamespace - (m.record.data as ObjectNode).remove(COL_UPDATED_AT) - (m.record.data as ObjectNode).replace(COL_ID, - convertIdBasedOnDatabase(m.record.data[COL_ID].asInt())) - } - .collect(Collectors.toList()) + .stream() + .map { `object`: AirbyteMessage -> Jsons.clone(`object`) } + .peek { m: AirbyteMessage -> + m.record.stream = streamName + m.record.namespace = defaultNamespace + (m.record.data as ObjectNode).remove(COL_UPDATED_AT) + (m.record.data as ObjectNode).replace( + COL_ID, + convertIdBasedOnDatabase(m.record.data[COL_ID].asInt()) + ) + } + .collect(Collectors.toList()) } @Test @@ -352,12 +493,15 @@ abstract class JdbcSourceAcceptanceTest?> protected fun testTablesWithQuoting() { val streamForTableWithSpaces = createTableWithSpaces() - val catalog = ConfiguredAirbyteCatalog() - .withStreams(java.util.List.of( + val catalog = + ConfiguredAirbyteCatalog() + .withStreams( + java.util.List.of( getConfiguredCatalogWithOneStream(defaultNamespace).streams[0], - streamForTableWithSpaces)) - val actualMessages = MoreIterators - .toList(source()!!.read(config(), catalog, null)) + streamForTableWithSpaces + ) + ) + val actualMessages = MoreIterators.toList(source()!!.read(config(), catalog, null)) setEmittedAtToNull(actualMessages) @@ -369,60 +513,60 @@ abstract class JdbcSourceAcceptanceTest?> Assertions.assertTrue(actualMessages.containsAll(expectedMessages)) } - protected fun getAirbyteMessagesForTablesWithQuoting(streamForTableWithSpaces: ConfiguredAirbyteStream): List { + protected fun getAirbyteMessagesForTablesWithQuoting( + streamForTableWithSpaces: ConfiguredAirbyteStream + ): List { return testMessages - .stream() - .map { `object`: AirbyteMessage? -> Jsons.clone(`object`) } - .peek { m: AirbyteMessage -> - m.record.stream = streamForTableWithSpaces.stream.name - (m.record.data as ObjectNode).set(COL_LAST_NAME_WITH_SPACE, - (m.record.data as ObjectNode).remove(COL_NAME)) - (m.record.data as ObjectNode).remove(COL_UPDATED_AT) - (m.record.data as ObjectNode).replace(COL_ID, - convertIdBasedOnDatabase(m.record.data[COL_ID].asInt())) - } - .collect(Collectors.toList()) + .stream() + .map { `object`: AirbyteMessage -> Jsons.clone(`object`) } + .peek { m: AirbyteMessage -> + m.record.stream = streamForTableWithSpaces.stream.name + (m.record.data as ObjectNode).set( + COL_LAST_NAME_WITH_SPACE, + (m.record.data as ObjectNode).remove(COL_NAME) + ) + (m.record.data as ObjectNode).remove(COL_UPDATED_AT) + (m.record.data as ObjectNode).replace( + COL_ID, + convertIdBasedOnDatabase(m.record.data[COL_ID].asInt()) + ) + } + .collect(Collectors.toList()) } @Test fun testReadFailure() { - val spiedAbStream = Mockito.spy( - getConfiguredCatalogWithOneStream(defaultNamespace).streams[0]) - val catalog = ConfiguredAirbyteCatalog() - .withStreams(java.util.List.of(spiedAbStream)) + val spiedAbStream = + Mockito.spy(getConfiguredCatalogWithOneStream(defaultNamespace).streams[0]) + val catalog = ConfiguredAirbyteCatalog().withStreams(java.util.List.of(spiedAbStream)) Mockito.doCallRealMethod().doThrow(RuntimeException()).`when`(spiedAbStream).stream - Assertions.assertThrows(RuntimeException::class.java) { source()!!.read(config(), catalog, null) } + Assertions.assertThrows(RuntimeException::class.java) { + source()!!.read(config(), catalog, null) + } } @Test @Throws(Exception::class) fun testIncrementalNoPreviousState() { - incrementalCursorCheck( - COL_ID, - null, - "3", - testMessages) + incrementalCursorCheck(COL_ID, null, "3", testMessages) } @Test @Throws(Exception::class) fun testIncrementalIntCheckCursor() { - incrementalCursorCheck( - COL_ID, - "2", - "3", - java.util.List.of(testMessages[2])) + incrementalCursorCheck(COL_ID, "2", "3", java.util.List.of(testMessages[2])) } @Test @Throws(Exception::class) fun testIncrementalStringCheckCursor() { incrementalCursorCheck( - COL_NAME, - "patent", - "vash", - java.util.List.of(testMessages[0], testMessages[2])) + COL_NAME, + "patent", + "vash", + java.util.List.of(testMessages[0], testMessages[2]) + ) } @Test @@ -430,28 +574,36 @@ abstract class JdbcSourceAcceptanceTest?> fun testIncrementalStringCheckCursorSpaceInColumnName() { val streamWithSpaces = createTableWithSpaces() - val expectedRecordMessages = getAirbyteMessagesCheckCursorSpaceInColumnName(streamWithSpaces) + val expectedRecordMessages = + getAirbyteMessagesCheckCursorSpaceInColumnName(streamWithSpaces) incrementalCursorCheck( - COL_LAST_NAME_WITH_SPACE, - COL_LAST_NAME_WITH_SPACE, - "patent", - "vash", - expectedRecordMessages, - streamWithSpaces) + COL_LAST_NAME_WITH_SPACE, + COL_LAST_NAME_WITH_SPACE, + "patent", + "vash", + expectedRecordMessages, + streamWithSpaces + ) } - protected fun getAirbyteMessagesCheckCursorSpaceInColumnName(streamWithSpaces: ConfiguredAirbyteStream): List { + protected fun getAirbyteMessagesCheckCursorSpaceInColumnName( + streamWithSpaces: ConfiguredAirbyteStream + ): List { val firstMessage = testMessages[0] firstMessage.record.stream = streamWithSpaces.stream.name (firstMessage.record.data as ObjectNode).remove(COL_UPDATED_AT) - (firstMessage.record.data as ObjectNode).set(COL_LAST_NAME_WITH_SPACE, - (firstMessage.record.data as ObjectNode).remove(COL_NAME)) + (firstMessage.record.data as ObjectNode).set( + COL_LAST_NAME_WITH_SPACE, + (firstMessage.record.data as ObjectNode).remove(COL_NAME) + ) val secondMessage = testMessages[2] secondMessage.record.stream = streamWithSpaces.stream.name (secondMessage.record.data as ObjectNode).remove(COL_UPDATED_AT) - (secondMessage.record.data as ObjectNode).set(COL_LAST_NAME_WITH_SPACE, - (secondMessage.record.data as ObjectNode).remove(COL_NAME)) + (secondMessage.record.data as ObjectNode).set( + COL_LAST_NAME_WITH_SPACE, + (secondMessage.record.data as ObjectNode).remove(COL_NAME) + ) return java.util.List.of(firstMessage, secondMessage) } @@ -465,23 +617,27 @@ abstract class JdbcSourceAcceptanceTest?> @Throws(Exception::class) protected fun incrementalDateCheck() { incrementalCursorCheck( - COL_UPDATED_AT, - "2005-10-18", - "2006-10-19", - java.util.List.of(testMessages[1], testMessages[2])) + COL_UPDATED_AT, + "2005-10-18", + "2006-10-19", + java.util.List.of(testMessages[1], testMessages[2]) + ) } @Test @Throws(Exception::class) fun testIncrementalCursorChanges() { incrementalCursorCheck( - COL_ID, - COL_NAME, // cheesing this value a little bit. in the correct implementation this initial cursor value should - // be ignored because the cursor field changed. setting it to a value that if used, will cause - // records to (incorrectly) be filtered out. - "data", - "vash", - testMessages) + COL_ID, + COL_NAME, // cheesing this value a little bit. in the correct implementation this + // initial cursor value should + // be ignored because the cursor field changed. setting it to a value that if used, will + // cause + // records to (incorrectly) be filtered out. + "data", + "vash", + testMessages + ) } @Test @@ -490,26 +646,49 @@ abstract class JdbcSourceAcceptanceTest?> val config = config() val namespace = defaultNamespace val configuredCatalog = getConfiguredCatalogWithOneStream(namespace) - configuredCatalog.streams.forEach(Consumer { airbyteStream: ConfiguredAirbyteStream -> - airbyteStream.syncMode = SyncMode.INCREMENTAL - airbyteStream.cursorField = java.util.List.of(COL_ID) - airbyteStream.destinationSyncMode = DestinationSyncMode.APPEND - }) - - val actualMessagesFirstSync = MoreIterators - .toList(source()!!.read(config, configuredCatalog, createEmptyState(streamName(), namespace))) - - val stateAfterFirstSyncOptional = actualMessagesFirstSync.stream() - .filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.STATE }.findFirst() + configuredCatalog.streams.forEach( + Consumer { airbyteStream: ConfiguredAirbyteStream -> + airbyteStream.syncMode = SyncMode.INCREMENTAL + airbyteStream.cursorField = java.util.List.of(COL_ID) + airbyteStream.destinationSyncMode = DestinationSyncMode.APPEND + } + ) + + val actualMessagesFirstSync = + MoreIterators.toList( + source()!!.read( + config, + configuredCatalog, + createEmptyState(streamName(), namespace) + ) + ) + + val stateAfterFirstSyncOptional = + actualMessagesFirstSync + .stream() + .filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.STATE } + .findFirst() Assertions.assertTrue(stateAfterFirstSyncOptional.isPresent) executeStatementReadIncrementallyTwice() - val actualMessagesSecondSync = MoreIterators - .toList(source()!!.read(config, configuredCatalog, extractState(stateAfterFirstSyncOptional.get()))) - - Assertions.assertEquals(2, - actualMessagesSecondSync.stream().filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.RECORD }.count().toInt()) + val actualMessagesSecondSync = + MoreIterators.toList( + source()!!.read( + config, + configuredCatalog, + extractState(stateAfterFirstSyncOptional.get()) + ) + ) + + Assertions.assertEquals( + 2, + actualMessagesSecondSync + .stream() + .filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.RECORD } + .count() + .toInt() + ) val expectedMessages = getExpectedAirbyteMessagesSecondSync(namespace) setEmittedAtToNull(actualMessagesSecondSync) @@ -521,25 +700,62 @@ abstract class JdbcSourceAcceptanceTest?> protected fun executeStatementReadIncrementallyTwice() { testdb - .with("INSERT INTO %s (id, name, updated_at) VALUES (4, 'riker', '2006-10-19')", getFullyQualifiedTableName(TABLE_NAME)) - .with("INSERT INTO %s (id, name, updated_at) VALUES (5, 'data', '2006-10-19')", getFullyQualifiedTableName(TABLE_NAME)) + .with( + "INSERT INTO %s (id, name, updated_at) VALUES (4, 'riker', '2006-10-19')", + getFullyQualifiedTableName(TABLE_NAME) + ) + .with( + "INSERT INTO %s (id, name, updated_at) VALUES (5, 'data', '2006-10-19')", + getFullyQualifiedTableName(TABLE_NAME) + ) } protected fun getExpectedAirbyteMessagesSecondSync(namespace: String?): List { val expectedMessages: MutableList = ArrayList() - expectedMessages.add(AirbyteMessage().withType(AirbyteMessage.Type.RECORD) - .withRecord(AirbyteRecordMessage().withStream(streamName()).withNamespace(namespace) - .withData(Jsons.jsonNode(java.util.Map - .of(COL_ID, ID_VALUE_4, - COL_NAME, "riker", - COL_UPDATED_AT, "2006-10-19"))))) - expectedMessages.add(AirbyteMessage().withType(AirbyteMessage.Type.RECORD) - .withRecord(AirbyteRecordMessage().withStream(streamName()).withNamespace(namespace) - .withData(Jsons.jsonNode(java.util.Map - .of(COL_ID, ID_VALUE_5, - COL_NAME, "data", - COL_UPDATED_AT, "2006-10-19"))))) - val state = DbStreamState() + expectedMessages.add( + AirbyteMessage() + .withType(AirbyteMessage.Type.RECORD) + .withRecord( + AirbyteRecordMessage() + .withStream(streamName()) + .withNamespace(namespace) + .withData( + Jsons.jsonNode( + java.util.Map.of( + COL_ID, + ID_VALUE_4, + COL_NAME, + "riker", + COL_UPDATED_AT, + "2006-10-19" + ) + ) + ) + ) + ) + expectedMessages.add( + AirbyteMessage() + .withType(AirbyteMessage.Type.RECORD) + .withRecord( + AirbyteRecordMessage() + .withStream(streamName()) + .withNamespace(namespace) + .withData( + Jsons.jsonNode( + java.util.Map.of( + COL_ID, + ID_VALUE_5, + COL_NAME, + "data", + COL_UPDATED_AT, + "2006-10-19" + ) + ) + ) + ) + ) + val state = + DbStreamState() .withStreamName(streamName()) .withStreamNamespace(namespace) .withCursorField(java.util.List.of(COL_ID)) @@ -555,70 +771,92 @@ abstract class JdbcSourceAcceptanceTest?> val tableName2 = TABLE_NAME + 2 val streamName2 = streamName() + 2 val fqTableName2 = getFullyQualifiedTableName(tableName2) - testdb!!.with(createTableQuery(fqTableName2, "id INTEGER, name VARCHAR(200)", "")) - .with("INSERT INTO %s(id, name) VALUES (1,'picard')", fqTableName2) - .with("INSERT INTO %s(id, name) VALUES (2, 'crusher')", fqTableName2) - .with("INSERT INTO %s(id, name) VALUES (3, 'vash')", fqTableName2) + testdb!! + .with(createTableQuery(fqTableName2, "id INTEGER, name VARCHAR(200)", "")) + .with("INSERT INTO %s(id, name) VALUES (1,'picard')", fqTableName2) + .with("INSERT INTO %s(id, name) VALUES (2, 'crusher')", fqTableName2) + .with("INSERT INTO %s(id, name) VALUES (3, 'vash')", fqTableName2) val namespace = defaultNamespace - val configuredCatalog = getConfiguredCatalogWithOneStream( - namespace) - configuredCatalog.streams.add(CatalogHelpers.createConfiguredAirbyteStream( + val configuredCatalog = getConfiguredCatalogWithOneStream(namespace) + configuredCatalog.streams.add( + CatalogHelpers.createConfiguredAirbyteStream( streamName2, namespace, Field.of(COL_ID, JsonSchemaType.NUMBER), - Field.of(COL_NAME, JsonSchemaType.STRING))) - configuredCatalog.streams.forEach(Consumer { airbyteStream: ConfiguredAirbyteStream -> - airbyteStream.syncMode = SyncMode.INCREMENTAL - airbyteStream.cursorField = java.util.List.of(COL_ID) - airbyteStream.destinationSyncMode = DestinationSyncMode.APPEND - }) + Field.of(COL_NAME, JsonSchemaType.STRING) + ) + ) + configuredCatalog.streams.forEach( + Consumer { airbyteStream: ConfiguredAirbyteStream -> + airbyteStream.syncMode = SyncMode.INCREMENTAL + airbyteStream.cursorField = java.util.List.of(COL_ID) + airbyteStream.destinationSyncMode = DestinationSyncMode.APPEND + } + ) - val actualMessagesFirstSync = MoreIterators - .toList(source()!!.read(config(), configuredCatalog, createEmptyState(streamName(), namespace))) + val actualMessagesFirstSync = + MoreIterators.toList( + source()!!.read( + config(), + configuredCatalog, + createEmptyState(streamName(), namespace) + ) + ) // get last state message. - val stateAfterFirstSyncOptional = actualMessagesFirstSync.stream() + val stateAfterFirstSyncOptional = + actualMessagesFirstSync + .stream() .filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.STATE } .reduce { first: AirbyteMessage?, second: AirbyteMessage -> second } Assertions.assertTrue(stateAfterFirstSyncOptional.isPresent) - // we know the second streams messages are the same as the first minus the updated at column. so we + // we know the second streams messages are the same as the first minus the updated at + // column. so we // cheat and generate the expected messages off of the first expected messages. val secondStreamExpectedMessages = getAirbyteMessagesSecondStreamWithNamespace(streamName2) // Represents the state after the first stream has been updated - val expectedStateStreams1 = java.util.List.of( + val expectedStateStreams1 = + java.util.List.of( DbStreamState() - .withStreamName(streamName()) - .withStreamNamespace(namespace) - .withCursorField(java.util.List.of(COL_ID)) - .withCursor("3") - .withCursorRecordCount(1L), + .withStreamName(streamName()) + .withStreamNamespace(namespace) + .withCursorField(java.util.List.of(COL_ID)) + .withCursor("3") + .withCursorRecordCount(1L), DbStreamState() - .withStreamName(streamName2) - .withStreamNamespace(namespace) - .withCursorField(java.util.List.of(COL_ID))) + .withStreamName(streamName2) + .withStreamNamespace(namespace) + .withCursorField(java.util.List.of(COL_ID)) + ) // Represents the state after both streams have been updated - val expectedStateStreams2 = java.util.List.of( + val expectedStateStreams2 = + java.util.List.of( DbStreamState() - .withStreamName(streamName()) - .withStreamNamespace(namespace) - .withCursorField(java.util.List.of(COL_ID)) - .withCursor("3") - .withCursorRecordCount(1L), + .withStreamName(streamName()) + .withStreamNamespace(namespace) + .withCursorField(java.util.List.of(COL_ID)) + .withCursor("3") + .withCursorRecordCount(1L), DbStreamState() - .withStreamName(streamName2) - .withStreamNamespace(namespace) - .withCursorField(java.util.List.of(COL_ID)) - .withCursor("3") - .withCursorRecordCount(1L)) + .withStreamName(streamName2) + .withStreamNamespace(namespace) + .withCursorField(java.util.List.of(COL_ID)) + .withCursor("3") + .withCursorRecordCount(1L) + ) val expectedMessagesFirstSync: MutableList = ArrayList(testMessages) - expectedMessagesFirstSync.add(createStateMessage(expectedStateStreams1[0], expectedStateStreams1, 3L)) + expectedMessagesFirstSync.add( + createStateMessage(expectedStateStreams1[0], expectedStateStreams1, 3L) + ) expectedMessagesFirstSync.addAll(secondStreamExpectedMessages) - expectedMessagesFirstSync.add(createStateMessage(expectedStateStreams2[1], expectedStateStreams2, 3L)) + expectedMessagesFirstSync.add( + createStateMessage(expectedStateStreams2[1], expectedStateStreams2, 3L) + ) setEmittedAtToNull(actualMessagesFirstSync) @@ -627,28 +865,38 @@ abstract class JdbcSourceAcceptanceTest?> Assertions.assertTrue(actualMessagesFirstSync.containsAll(expectedMessagesFirstSync)) } - protected fun getAirbyteMessagesSecondStreamWithNamespace(streamName2: String?): List { + protected fun getAirbyteMessagesSecondStreamWithNamespace( + streamName2: String? + ): List { return testMessages - .stream() - .map { `object`: AirbyteMessage? -> Jsons.clone(`object`) } - .peek { m: AirbyteMessage -> - m.record.stream = streamName2 - (m.record.data as ObjectNode).remove(COL_UPDATED_AT) - (m.record.data as ObjectNode).replace(COL_ID, - convertIdBasedOnDatabase(m.record.data[COL_ID].asInt())) - } - .collect(Collectors.toList()) + .stream() + .map { `object`: AirbyteMessage -> Jsons.clone(`object`) } + .peek { m: AirbyteMessage -> + m.record.stream = streamName2 + (m.record.data as ObjectNode).remove(COL_UPDATED_AT) + (m.record.data as ObjectNode).replace( + COL_ID, + convertIdBasedOnDatabase(m.record.data[COL_ID].asInt()) + ) + } + .collect(Collectors.toList()) } // when initial and final cursor fields are the same. @Throws(Exception::class) protected fun incrementalCursorCheck( - cursorField: String, - initialCursorValue: String?, - endCursorValue: String, - expectedRecordMessages: List) { - incrementalCursorCheck(cursorField, cursorField, initialCursorValue, endCursorValue, - expectedRecordMessages) + cursorField: String, + initialCursorValue: String?, + endCursorValue: String, + expectedRecordMessages: List + ) { + incrementalCursorCheck( + cursorField, + cursorField, + initialCursorValue, + endCursorValue, + expectedRecordMessages + ) } // See https://github.com/airbytehq/airbyte/issues/14732 for rationale and details. @@ -657,98 +905,201 @@ abstract class JdbcSourceAcceptanceTest?> fun testIncrementalWithConcurrentInsertion() { val namespace = defaultNamespace val fullyQualifiedTableName = getFullyQualifiedTableName(TABLE_NAME_AND_TIMESTAMP) - val columnDefinition = String.format("name VARCHAR(200) NOT NULL, %s %s NOT NULL", COL_TIMESTAMP, COL_TIMESTAMP_TYPE) + val columnDefinition = + String.format( + "name VARCHAR(200) NOT NULL, %s %s NOT NULL", + COL_TIMESTAMP, + COL_TIMESTAMP_TYPE + ) // 1st sync - testdb!!.with(createTableQuery(fullyQualifiedTableName, columnDefinition, "")) - .with(INSERT_TABLE_NAME_AND_TIMESTAMP_QUERY, fullyQualifiedTableName, "a", "2021-01-01 00:00:00") - .with(INSERT_TABLE_NAME_AND_TIMESTAMP_QUERY, fullyQualifiedTableName, "b", "2021-01-01 00:00:00") - - val configuredCatalog = CatalogHelpers.toDefaultConfiguredCatalog( - AirbyteCatalog().withStreams(java.util.List.of( - CatalogHelpers.createAirbyteStream( + testdb!! + .with(createTableQuery(fullyQualifiedTableName, columnDefinition, "")) + .with( + INSERT_TABLE_NAME_AND_TIMESTAMP_QUERY, + fullyQualifiedTableName, + "a", + "2021-01-01 00:00:00" + ) + .with( + INSERT_TABLE_NAME_AND_TIMESTAMP_QUERY, + fullyQualifiedTableName, + "b", + "2021-01-01 00:00:00" + ) + + val configuredCatalog = + CatalogHelpers.toDefaultConfiguredCatalog( + AirbyteCatalog() + .withStreams( + java.util.List.of( + CatalogHelpers.createAirbyteStream( TABLE_NAME_AND_TIMESTAMP, namespace, Field.of(COL_NAME, JsonSchemaType.STRING), - Field.of(COL_TIMESTAMP, JsonSchemaType.STRING_TIMESTAMP_WITHOUT_TIMEZONE))))) - - configuredCatalog.streams.forEach(Consumer { airbyteStream: ConfiguredAirbyteStream -> - airbyteStream.syncMode = SyncMode.INCREMENTAL - airbyteStream.cursorField = java.util.List.of(COL_TIMESTAMP) - airbyteStream.destinationSyncMode = DestinationSyncMode.APPEND - }) + Field.of( + COL_TIMESTAMP, + JsonSchemaType.STRING_TIMESTAMP_WITHOUT_TIMEZONE + ) + ) + ) + ) + ) + + configuredCatalog.streams.forEach( + Consumer { airbyteStream: ConfiguredAirbyteStream -> + airbyteStream.syncMode = SyncMode.INCREMENTAL + airbyteStream.cursorField = java.util.List.of(COL_TIMESTAMP) + airbyteStream.destinationSyncMode = DestinationSyncMode.APPEND + } + ) - val firstSyncActualMessages = MoreIterators.toList( - source()!!.read(config(), configuredCatalog, createEmptyState(TABLE_NAME_AND_TIMESTAMP, namespace))) + val firstSyncActualMessages = + MoreIterators.toList( + source()!!.read( + config(), + configuredCatalog, + createEmptyState(TABLE_NAME_AND_TIMESTAMP, namespace) + ) + ) // cursor after 1st sync: 2021-01-01 00:00:00, count 2 - val firstSyncStateOptional = firstSyncActualMessages.stream().filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.STATE }.findFirst() + val firstSyncStateOptional = + firstSyncActualMessages + .stream() + .filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.STATE } + .findFirst() Assertions.assertTrue(firstSyncStateOptional.isPresent) val firstSyncState = getStateData(firstSyncStateOptional.get(), TABLE_NAME_AND_TIMESTAMP) - Assertions.assertEquals(firstSyncState["cursor_field"].elements().next().asText(), COL_TIMESTAMP) + Assertions.assertEquals( + firstSyncState["cursor_field"].elements().next().asText(), + COL_TIMESTAMP + ) Assertions.assertTrue(firstSyncState["cursor"].asText().contains("2021-01-01")) Assertions.assertTrue(firstSyncState["cursor"].asText().contains("00:00:00")) Assertions.assertEquals(2L, firstSyncState["cursor_record_count"].asLong()) - val firstSyncNames = firstSyncActualMessages.stream() + val firstSyncNames = + firstSyncActualMessages + .stream() .filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.RECORD } .map { r: AirbyteMessage -> r.record.data[COL_NAME].asText() } .toList() // some databases don't make insertion order guarantee when equal ordering value - if (testdb!!.databaseDriver == DatabaseDriver.TERADATA || testdb!!.databaseDriver == DatabaseDriver.ORACLE) { - MatcherAssert.assertThat(listOf("a", "b"), Matchers.containsInAnyOrder(*firstSyncNames.toTypedArray())) + if ( + testdb!!.databaseDriver == DatabaseDriver.TERADATA || + testdb!!.databaseDriver == DatabaseDriver.ORACLE + ) { + MatcherAssert.assertThat( + listOf("a", "b"), + Matchers.containsInAnyOrder(*firstSyncNames.toTypedArray()) + ) } else { Assertions.assertEquals(listOf("a", "b"), firstSyncNames) } // 2nd sync - testdb!!.with(INSERT_TABLE_NAME_AND_TIMESTAMP_QUERY, fullyQualifiedTableName, "c", "2021-01-02 00:00:00") - - val secondSyncActualMessages = MoreIterators.toList( - source()!!.read(config(), configuredCatalog, createState(TABLE_NAME_AND_TIMESTAMP, namespace, firstSyncState))) + testdb!!.with( + INSERT_TABLE_NAME_AND_TIMESTAMP_QUERY, + fullyQualifiedTableName, + "c", + "2021-01-02 00:00:00" + ) + + val secondSyncActualMessages = + MoreIterators.toList( + source()!!.read( + config(), + configuredCatalog, + createState(TABLE_NAME_AND_TIMESTAMP, namespace, firstSyncState) + ) + ) // cursor after 2nd sync: 2021-01-02 00:00:00, count 1 - val secondSyncStateOptional = secondSyncActualMessages.stream().filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.STATE }.findFirst() + val secondSyncStateOptional = + secondSyncActualMessages + .stream() + .filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.STATE } + .findFirst() Assertions.assertTrue(secondSyncStateOptional.isPresent) val secondSyncState = getStateData(secondSyncStateOptional.get(), TABLE_NAME_AND_TIMESTAMP) - Assertions.assertEquals(secondSyncState["cursor_field"].elements().next().asText(), COL_TIMESTAMP) + Assertions.assertEquals( + secondSyncState["cursor_field"].elements().next().asText(), + COL_TIMESTAMP + ) Assertions.assertTrue(secondSyncState["cursor"].asText().contains("2021-01-02")) Assertions.assertTrue(secondSyncState["cursor"].asText().contains("00:00:00")) Assertions.assertEquals(1L, secondSyncState["cursor_record_count"].asLong()) - val secondSyncNames = secondSyncActualMessages.stream() + val secondSyncNames = + secondSyncActualMessages + .stream() .filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.RECORD } .map { r: AirbyteMessage -> r.record.data[COL_NAME].asText() } .toList() Assertions.assertEquals(listOf("c"), secondSyncNames) // 3rd sync has records with duplicated cursors - testdb!!.with(INSERT_TABLE_NAME_AND_TIMESTAMP_QUERY, fullyQualifiedTableName, "d", "2021-01-02 00:00:00") - .with(INSERT_TABLE_NAME_AND_TIMESTAMP_QUERY, fullyQualifiedTableName, "e", "2021-01-02 00:00:00") - .with(INSERT_TABLE_NAME_AND_TIMESTAMP_QUERY, fullyQualifiedTableName, "f", "2021-01-03 00:00:00") - - val thirdSyncActualMessages = MoreIterators.toList( - source()!!.read(config(), configuredCatalog, createState(TABLE_NAME_AND_TIMESTAMP, namespace, secondSyncState))) + testdb!! + .with( + INSERT_TABLE_NAME_AND_TIMESTAMP_QUERY, + fullyQualifiedTableName, + "d", + "2021-01-02 00:00:00" + ) + .with( + INSERT_TABLE_NAME_AND_TIMESTAMP_QUERY, + fullyQualifiedTableName, + "e", + "2021-01-02 00:00:00" + ) + .with( + INSERT_TABLE_NAME_AND_TIMESTAMP_QUERY, + fullyQualifiedTableName, + "f", + "2021-01-03 00:00:00" + ) + + val thirdSyncActualMessages = + MoreIterators.toList( + source()!!.read( + config(), + configuredCatalog, + createState(TABLE_NAME_AND_TIMESTAMP, namespace, secondSyncState) + ) + ) // Cursor after 3rd sync is: 2021-01-03 00:00:00, count 1. - val thirdSyncStateOptional = thirdSyncActualMessages.stream().filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.STATE }.findFirst() + val thirdSyncStateOptional = + thirdSyncActualMessages + .stream() + .filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.STATE } + .findFirst() Assertions.assertTrue(thirdSyncStateOptional.isPresent) val thirdSyncState = getStateData(thirdSyncStateOptional.get(), TABLE_NAME_AND_TIMESTAMP) - Assertions.assertEquals(thirdSyncState["cursor_field"].elements().next().asText(), COL_TIMESTAMP) + Assertions.assertEquals( + thirdSyncState["cursor_field"].elements().next().asText(), + COL_TIMESTAMP + ) Assertions.assertTrue(thirdSyncState["cursor"].asText().contains("2021-01-03")) Assertions.assertTrue(thirdSyncState["cursor"].asText().contains("00:00:00")) Assertions.assertEquals(1L, thirdSyncState["cursor_record_count"].asLong()) // The c, d, e, f are duplicated records from this sync, because the cursor // record count in the database is different from that in the state. - val thirdSyncExpectedNames = thirdSyncActualMessages.stream() + val thirdSyncExpectedNames = + thirdSyncActualMessages + .stream() .filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.RECORD } .map { r: AirbyteMessage -> r.record.data[COL_NAME].asText() } .toList() // teradata doesn't make insertion order guarantee when equal ordering value if (testdb!!.databaseDriver == DatabaseDriver.TERADATA) { - MatcherAssert.assertThat(listOf("c", "d", "e", "f"), Matchers.containsInAnyOrder(*thirdSyncExpectedNames.toTypedArray())) + MatcherAssert.assertThat( + listOf("c", "d", "e", "f"), + Matchers.containsInAnyOrder(*thirdSyncExpectedNames.toTypedArray()) + ) } else { Assertions.assertEquals(listOf("c", "d", "e", "f"), thirdSyncExpectedNames) } @@ -765,146 +1116,247 @@ abstract class JdbcSourceAcceptanceTest?> @Throws(Exception::class) private fun incrementalCursorCheck( - initialCursorField: String, - cursorField: String, - initialCursorValue: String?, - endCursorValue: String, - expectedRecordMessages: List) { - incrementalCursorCheck(initialCursorField, cursorField, initialCursorValue, endCursorValue, - expectedRecordMessages, - getConfiguredCatalogWithOneStream(defaultNamespace).streams[0]) + initialCursorField: String, + cursorField: String, + initialCursorValue: String?, + endCursorValue: String, + expectedRecordMessages: List + ) { + incrementalCursorCheck( + initialCursorField, + cursorField, + initialCursorValue, + endCursorValue, + expectedRecordMessages, + getConfiguredCatalogWithOneStream(defaultNamespace).streams[0] + ) } @Throws(Exception::class) protected fun incrementalCursorCheck( - initialCursorField: String?, - cursorField: String, - initialCursorValue: String?, - endCursorValue: String?, - expectedRecordMessages: List, - airbyteStream: ConfiguredAirbyteStream) { + initialCursorField: String?, + cursorField: String, + initialCursorValue: String?, + endCursorValue: String?, + expectedRecordMessages: List, + airbyteStream: ConfiguredAirbyteStream + ) { airbyteStream.syncMode = SyncMode.INCREMENTAL airbyteStream.cursorField = java.util.List.of(cursorField) airbyteStream.destinationSyncMode = DestinationSyncMode.APPEND - val configuredCatalog = ConfiguredAirbyteCatalog() - .withStreams(java.util.List.of(airbyteStream)) + val configuredCatalog = + ConfiguredAirbyteCatalog().withStreams(java.util.List.of(airbyteStream)) val dbStreamState = buildStreamState(airbyteStream, initialCursorField, initialCursorValue) - val actualMessages = MoreIterators - .toList(source()!!.read(config(), configuredCatalog, Jsons.jsonNode(createState(java.util.List.of(dbStreamState))))) + val actualMessages = + MoreIterators.toList( + source()!!.read( + config(), + configuredCatalog, + Jsons.jsonNode(createState(java.util.List.of(dbStreamState))) + ) + ) setEmittedAtToNull(actualMessages) - val expectedStreams = java.util.List.of(buildStreamState(airbyteStream, cursorField, endCursorValue)) + val expectedStreams = + java.util.List.of(buildStreamState(airbyteStream, cursorField, endCursorValue)) val expectedMessages: MutableList = ArrayList(expectedRecordMessages) - expectedMessages.addAll(createExpectedTestMessages(expectedStreams, expectedRecordMessages.size.toLong())) + expectedMessages.addAll( + createExpectedTestMessages(expectedStreams, expectedRecordMessages.size.toLong()) + ) Assertions.assertEquals(expectedMessages.size, actualMessages.size) Assertions.assertTrue(expectedMessages.containsAll(actualMessages)) Assertions.assertTrue(actualMessages.containsAll(expectedMessages)) } - protected fun buildStreamState(configuredAirbyteStream: ConfiguredAirbyteStream, - cursorField: String?, - cursorValue: String?): DbStreamState { + protected fun buildStreamState( + configuredAirbyteStream: ConfiguredAirbyteStream, + cursorField: String?, + cursorValue: String? + ): DbStreamState { return DbStreamState() - .withStreamName(configuredAirbyteStream.stream.name) - .withStreamNamespace(configuredAirbyteStream.stream.namespace) - .withCursorField(java.util.List.of(cursorField)) - .withCursor(cursorValue) - .withCursorRecordCount(1L) + .withStreamName(configuredAirbyteStream.stream.name) + .withStreamNamespace(configuredAirbyteStream.stream.namespace) + .withCursorField(java.util.List.of(cursorField)) + .withCursor(cursorValue) + .withCursorRecordCount(1L) } // get catalog and perform a defensive copy. - protected fun getConfiguredCatalogWithOneStream(defaultNamespace: String?): ConfiguredAirbyteCatalog { + protected fun getConfiguredCatalogWithOneStream( + defaultNamespace: String? + ): ConfiguredAirbyteCatalog { val catalog = CatalogHelpers.toDefaultConfiguredCatalog(getCatalog(defaultNamespace)) // Filter to only keep the main stream name as configured stream catalog.withStreams( - catalog.streams.stream().filter { s: ConfiguredAirbyteStream -> s.stream.name == streamName() } - .collect(Collectors.toList())) + catalog.streams + .stream() + .filter { s: ConfiguredAirbyteStream -> s.stream.name == streamName() } + .collect(Collectors.toList()) + ) return catalog } protected fun getCatalog(defaultNamespace: String?): AirbyteCatalog { - return AirbyteCatalog().withStreams(java.util.List.of( - CatalogHelpers.createAirbyteStream( - TABLE_NAME, - defaultNamespace, - Field.of(COL_ID, JsonSchemaType.INTEGER), - Field.of(COL_NAME, JsonSchemaType.STRING), - Field.of(COL_UPDATED_AT, JsonSchemaType.STRING)) - .withSupportedSyncModes(java.util.List.of(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL)) + return AirbyteCatalog() + .withStreams( + java.util.List.of( + CatalogHelpers.createAirbyteStream( + TABLE_NAME, + defaultNamespace, + Field.of(COL_ID, JsonSchemaType.INTEGER), + Field.of(COL_NAME, JsonSchemaType.STRING), + Field.of(COL_UPDATED_AT, JsonSchemaType.STRING) + ) + .withSupportedSyncModes( + java.util.List.of(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL) + ) .withSourceDefinedPrimaryKey(java.util.List.of(java.util.List.of(COL_ID))), - CatalogHelpers.createAirbyteStream( - TABLE_NAME_WITHOUT_PK, - defaultNamespace, - Field.of(COL_ID, JsonSchemaType.INTEGER), - Field.of(COL_NAME, JsonSchemaType.STRING), - Field.of(COL_UPDATED_AT, JsonSchemaType.STRING)) - .withSupportedSyncModes(java.util.List.of(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL)) + CatalogHelpers.createAirbyteStream( + TABLE_NAME_WITHOUT_PK, + defaultNamespace, + Field.of(COL_ID, JsonSchemaType.INTEGER), + Field.of(COL_NAME, JsonSchemaType.STRING), + Field.of(COL_UPDATED_AT, JsonSchemaType.STRING) + ) + .withSupportedSyncModes( + java.util.List.of(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL) + ) .withSourceDefinedPrimaryKey(emptyList()), - CatalogHelpers.createAirbyteStream( - TABLE_NAME_COMPOSITE_PK, - defaultNamespace, - Field.of(COL_FIRST_NAME, JsonSchemaType.STRING), - Field.of(COL_LAST_NAME, JsonSchemaType.STRING), - Field.of(COL_UPDATED_AT, JsonSchemaType.STRING)) - .withSupportedSyncModes(java.util.List.of(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL)) + CatalogHelpers.createAirbyteStream( + TABLE_NAME_COMPOSITE_PK, + defaultNamespace, + Field.of(COL_FIRST_NAME, JsonSchemaType.STRING), + Field.of(COL_LAST_NAME, JsonSchemaType.STRING), + Field.of(COL_UPDATED_AT, JsonSchemaType.STRING) + ) + .withSupportedSyncModes( + java.util.List.of(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL) + ) .withSourceDefinedPrimaryKey( - java.util.List.of(java.util.List.of(COL_FIRST_NAME), java.util.List.of(COL_LAST_NAME))))) + java.util.List.of( + java.util.List.of(COL_FIRST_NAME), + java.util.List.of(COL_LAST_NAME) + ) + ) + ) + ) } protected val testMessages: List - get() = java.util.List.of( - AirbyteMessage().withType(AirbyteMessage.Type.RECORD) - .withRecord(AirbyteRecordMessage().withStream(streamName()).withNamespace(defaultNamespace) - .withData(Jsons.jsonNode(java.util.Map - .of(COL_ID, ID_VALUE_1, - COL_NAME, "picard", - COL_UPDATED_AT, "2004-10-19")))), - AirbyteMessage().withType(AirbyteMessage.Type.RECORD) - .withRecord(AirbyteRecordMessage().withStream(streamName()).withNamespace(defaultNamespace) - .withData(Jsons.jsonNode(java.util.Map - .of(COL_ID, ID_VALUE_2, - COL_NAME, "crusher", - COL_UPDATED_AT, - "2005-10-19")))), - AirbyteMessage().withType(AirbyteMessage.Type.RECORD) - .withRecord(AirbyteRecordMessage().withStream(streamName()).withNamespace(defaultNamespace) - .withData(Jsons.jsonNode(java.util.Map - .of(COL_ID, ID_VALUE_3, - COL_NAME, "vash", - COL_UPDATED_AT, "2006-10-19"))))) - - protected fun createExpectedTestMessages(states: List, numRecords: Long): List { - return states.stream() - .map { s: DbStreamState -> - AirbyteMessage().withType(AirbyteMessage.Type.STATE) - .withState( - AirbyteStateMessage().withType(AirbyteStateMessage.AirbyteStateType.STREAM) - .withStream(AirbyteStreamState() - .withStreamDescriptor(StreamDescriptor().withNamespace(s.streamNamespace).withName(s.streamName)) - .withStreamState(Jsons.jsonNode(s))) - .withData(Jsons.jsonNode(DbState().withCdc(false).withStreams(states))) - .withSourceStats(AirbyteStateStats().withRecordCount(numRecords.toDouble()))) - } - .collect( - Collectors.toList()) + get() = + java.util.List.of( + AirbyteMessage() + .withType(AirbyteMessage.Type.RECORD) + .withRecord( + AirbyteRecordMessage() + .withStream(streamName()) + .withNamespace(defaultNamespace) + .withData( + Jsons.jsonNode( + java.util.Map.of( + COL_ID, + ID_VALUE_1, + COL_NAME, + "picard", + COL_UPDATED_AT, + "2004-10-19" + ) + ) + ) + ), + AirbyteMessage() + .withType(AirbyteMessage.Type.RECORD) + .withRecord( + AirbyteRecordMessage() + .withStream(streamName()) + .withNamespace(defaultNamespace) + .withData( + Jsons.jsonNode( + java.util.Map.of( + COL_ID, + ID_VALUE_2, + COL_NAME, + "crusher", + COL_UPDATED_AT, + "2005-10-19" + ) + ) + ) + ), + AirbyteMessage() + .withType(AirbyteMessage.Type.RECORD) + .withRecord( + AirbyteRecordMessage() + .withStream(streamName()) + .withNamespace(defaultNamespace) + .withData( + Jsons.jsonNode( + java.util.Map.of( + COL_ID, + ID_VALUE_3, + COL_NAME, + "vash", + COL_UPDATED_AT, + "2006-10-19" + ) + ) + ) + ) + ) + + protected fun createExpectedTestMessages( + states: List, + numRecords: Long + ): List { + return states + .stream() + .map { s: DbStreamState -> + AirbyteMessage() + .withType(AirbyteMessage.Type.STATE) + .withState( + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor() + .withNamespace(s.streamNamespace) + .withName(s.streamName) + ) + .withStreamState(Jsons.jsonNode(s)) + ) + .withData(Jsons.jsonNode(DbState().withCdc(false).withStreams(states))) + .withSourceStats( + AirbyteStateStats().withRecordCount(numRecords.toDouble()) + ) + ) + } + .collect(Collectors.toList()) } protected fun createState(states: List): List { - return states.stream() - .map { s: DbStreamState -> - AirbyteStateMessage().withType(AirbyteStateMessage.AirbyteStateType.STREAM) - .withStream(AirbyteStreamState() - .withStreamDescriptor(StreamDescriptor().withNamespace(s.streamNamespace).withName(s.streamName)) - .withStreamState(Jsons.jsonNode(s))) - } - .collect( - Collectors.toList()) + return states + .stream() + .map { s: DbStreamState -> + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor() + .withNamespace(s.streamNamespace) + .withName(s.streamName) + ) + .withStreamState(Jsons.jsonNode(s)) + ) + } + .collect(Collectors.toList()) } @Throws(SQLException::class) @@ -914,37 +1366,86 @@ abstract class JdbcSourceAcceptanceTest?> testdb!!.getDataSource()!!.connection.use { connection -> val identifierQuoteString = connection.metaData.identifierQuoteString - connection.createStatement() - .execute( - createTableQuery(getFullyQualifiedTableName( - RelationalDbQueryUtils.enquoteIdentifier(tableNameWithSpaces, identifierQuoteString)), - "id INTEGER, " + RelationalDbQueryUtils.enquoteIdentifier(COL_LAST_NAME_WITH_SPACE, identifierQuoteString) - + " VARCHAR(200)", - "")) - connection.createStatement() - .execute(String.format("INSERT INTO %s(id, %s) VALUES (1,'picard')", - getFullyQualifiedTableName( - RelationalDbQueryUtils.enquoteIdentifier(tableNameWithSpaces, identifierQuoteString)), - RelationalDbQueryUtils.enquoteIdentifier(COL_LAST_NAME_WITH_SPACE, identifierQuoteString))) - connection.createStatement() - .execute(String.format("INSERT INTO %s(id, %s) VALUES (2, 'crusher')", - getFullyQualifiedTableName( - RelationalDbQueryUtils.enquoteIdentifier(tableNameWithSpaces, identifierQuoteString)), - RelationalDbQueryUtils.enquoteIdentifier(COL_LAST_NAME_WITH_SPACE, identifierQuoteString))) - connection.createStatement() - .execute(String.format("INSERT INTO %s(id, %s) VALUES (3, 'vash')", - getFullyQualifiedTableName( - RelationalDbQueryUtils.enquoteIdentifier(tableNameWithSpaces, identifierQuoteString)), - RelationalDbQueryUtils.enquoteIdentifier(COL_LAST_NAME_WITH_SPACE, identifierQuoteString))) + connection + .createStatement() + .execute( + createTableQuery( + getFullyQualifiedTableName( + RelationalDbQueryUtils.enquoteIdentifier( + tableNameWithSpaces, + identifierQuoteString + ) + ), + "id INTEGER, " + + RelationalDbQueryUtils.enquoteIdentifier( + COL_LAST_NAME_WITH_SPACE, + identifierQuoteString + ) + + " VARCHAR(200)", + "" + ) + ) + connection + .createStatement() + .execute( + String.format( + "INSERT INTO %s(id, %s) VALUES (1,'picard')", + getFullyQualifiedTableName( + RelationalDbQueryUtils.enquoteIdentifier( + tableNameWithSpaces, + identifierQuoteString + ) + ), + RelationalDbQueryUtils.enquoteIdentifier( + COL_LAST_NAME_WITH_SPACE, + identifierQuoteString + ) + ) + ) + connection + .createStatement() + .execute( + String.format( + "INSERT INTO %s(id, %s) VALUES (2, 'crusher')", + getFullyQualifiedTableName( + RelationalDbQueryUtils.enquoteIdentifier( + tableNameWithSpaces, + identifierQuoteString + ) + ), + RelationalDbQueryUtils.enquoteIdentifier( + COL_LAST_NAME_WITH_SPACE, + identifierQuoteString + ) + ) + ) + connection + .createStatement() + .execute( + String.format( + "INSERT INTO %s(id, %s) VALUES (3, 'vash')", + getFullyQualifiedTableName( + RelationalDbQueryUtils.enquoteIdentifier( + tableNameWithSpaces, + identifierQuoteString + ) + ), + RelationalDbQueryUtils.enquoteIdentifier( + COL_LAST_NAME_WITH_SPACE, + identifierQuoteString + ) + ) + ) } return CatalogHelpers.createConfiguredAirbyteStream( - streamName2, - defaultNamespace, - Field.of(COL_ID, JsonSchemaType.NUMBER), - Field.of(COL_LAST_NAME_WITH_SPACE, JsonSchemaType.STRING)) + streamName2, + defaultNamespace, + Field.of(COL_ID, JsonSchemaType.NUMBER), + Field.of(COL_LAST_NAME_WITH_SPACE, JsonSchemaType.STRING) + ) } - fun getFullyQualifiedTableName(tableName: String?): String { + fun getFullyQualifiedTableName(tableName: String): String { return RelationalDbQueryUtils.getFullyQualifiedTableName(defaultSchemaName, tableName) } @@ -958,7 +1459,8 @@ abstract class JdbcSourceAcceptanceTest?> private fun convertIdBasedOnDatabase(idValue: Int): JsonNode { return when (testdb!!.databaseDriver) { - DatabaseDriver.ORACLE, DatabaseDriver.SNOWFLAKE -> Jsons.jsonNode(BigDecimal.valueOf(idValue.toLong())) + DatabaseDriver.ORACLE, + DatabaseDriver.SNOWFLAKE -> Jsons.jsonNode(BigDecimal.valueOf(idValue.toLong())) else -> Jsons.jsonNode(idValue) } } @@ -967,10 +1469,13 @@ abstract class JdbcSourceAcceptanceTest?> get() = if (supportsSchemas()) SCHEMA_NAME else null protected val defaultNamespace: String - get() = when (testdb!!.databaseDriver) { - DatabaseDriver.MYSQL, DatabaseDriver.CLICKHOUSE, DatabaseDriver.TERADATA -> testdb!!.databaseName!! - else -> SCHEMA_NAME - } + get() = + when (testdb!!.databaseDriver) { + DatabaseDriver.MYSQL, + DatabaseDriver.CLICKHOUSE, + DatabaseDriver.TERADATA -> testdb!!.databaseName!! + else -> SCHEMA_NAME + } /** * Creates empty state with the provided stream name and namespace. @@ -980,19 +1485,33 @@ abstract class JdbcSourceAcceptanceTest?> * @return [JsonNode] representation of the generated empty state. */ protected fun createEmptyState(streamName: String?, streamNamespace: String?): JsonNode { - val airbyteStateMessage = AirbyteStateMessage() + val airbyteStateMessage = + AirbyteStateMessage() .withType(AirbyteStateMessage.AirbyteStateType.STREAM) - .withStream(AirbyteStreamState().withStreamDescriptor(StreamDescriptor().withName(streamName).withNamespace(streamNamespace))) + .withStream( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor().withName(streamName).withNamespace(streamNamespace) + ) + ) return Jsons.jsonNode(java.util.List.of(airbyteStateMessage)) } - protected fun createState(streamName: String?, streamNamespace: String?, stateData: JsonNode?): JsonNode { - val airbyteStateMessage = AirbyteStateMessage() + protected fun createState( + streamName: String?, + streamNamespace: String?, + stateData: JsonNode? + ): JsonNode { + val airbyteStateMessage = + AirbyteStateMessage() .withType(AirbyteStateMessage.AirbyteStateType.STREAM) .withStream( - AirbyteStreamState() - .withStreamDescriptor(StreamDescriptor().withName(streamName).withNamespace(streamNamespace)) - .withStreamState(stateData)) + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor().withName(streamName).withNamespace(streamNamespace) + ) + .withStreamState(stateData) + ) return Jsons.jsonNode(java.util.List.of(airbyteStateMessage)) } @@ -1000,48 +1519,91 @@ abstract class JdbcSourceAcceptanceTest?> return Jsons.jsonNode(java.util.List.of(airbyteMessage.state)) } - protected fun createStateMessage(dbStreamState: DbStreamState, legacyStates: List?, recordCount: Long): AirbyteMessage { - return AirbyteMessage().withType(AirbyteMessage.Type.STATE) - .withState( - AirbyteStateMessage().withType(AirbyteStateMessage.AirbyteStateType.STREAM) - .withStream(AirbyteStreamState() - .withStreamDescriptor(StreamDescriptor().withNamespace(dbStreamState.streamNamespace) - .withName(dbStreamState.streamName)) - .withStreamState(Jsons.jsonNode(dbStreamState))) - .withData(Jsons.jsonNode(DbState().withCdc(false).withStreams(legacyStates))) - .withSourceStats(AirbyteStateStats().withRecordCount(recordCount.toDouble()))) + protected fun createStateMessage( + dbStreamState: DbStreamState, + legacyStates: List?, + recordCount: Long + ): AirbyteMessage { + return AirbyteMessage() + .withType(AirbyteMessage.Type.STATE) + .withState( + AirbyteStateMessage() + .withType(AirbyteStateMessage.AirbyteStateType.STREAM) + .withStream( + AirbyteStreamState() + .withStreamDescriptor( + StreamDescriptor() + .withNamespace(dbStreamState.streamNamespace) + .withName(dbStreamState.streamName) + ) + .withStreamState(Jsons.jsonNode(dbStreamState)) + ) + .withData(Jsons.jsonNode(DbState().withCdc(false).withStreams(legacyStates))) + .withSourceStats(AirbyteStateStats().withRecordCount(recordCount.toDouble())) + ) } - protected fun extractSpecificFieldFromCombinedMessages(messages: List, - streamName: String, - field: String?): List { - return extractStateMessage(messages).stream() - .filter { s: AirbyteStateMessage -> s.stream.streamDescriptor.name == streamName } - .map { s: AirbyteStateMessage -> if (s.stream.streamState[field] != null) s.stream.streamState[field].asText() else "" }.toList() + protected fun extractSpecificFieldFromCombinedMessages( + messages: List, + streamName: String, + field: String? + ): List { + return extractStateMessage(messages) + .stream() + .filter { s: AirbyteStateMessage -> s.stream.streamDescriptor.name == streamName } + .map { s: AirbyteStateMessage -> + if (s.stream.streamState[field] != null) s.stream.streamState[field].asText() + else "" + } + .toList() } protected fun filterRecords(messages: List): List { - return messages.stream().filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.RECORD } - .collect(Collectors.toList()) + return messages + .stream() + .filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.RECORD } + .collect(Collectors.toList()) } protected fun extractStateMessage(messages: List): List { - return messages.stream().filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.STATE }.map { obj: AirbyteMessage -> obj.state } - .collect(Collectors.toList()) + return messages + .stream() + .filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.STATE } + .map { obj: AirbyteMessage -> obj.state } + .collect(Collectors.toList()) } - protected fun extractStateMessage(messages: List, streamName: String): List { - return messages.stream().filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.STATE && r.state.stream.streamDescriptor.name == streamName }.map { obj: AirbyteMessage -> obj.state } - .collect(Collectors.toList()) + protected fun extractStateMessage( + messages: List, + streamName: String + ): List { + return messages + .stream() + .filter { r: AirbyteMessage -> + r.type == AirbyteMessage.Type.STATE && + r.state.stream.streamDescriptor.name == streamName + } + .map { obj: AirbyteMessage -> obj.state } + .collect(Collectors.toList()) } - protected fun createRecord(stream: String?, namespace: String?, data: Map): AirbyteMessage { - return AirbyteMessage().withType(AirbyteMessage.Type.RECORD) - .withRecord(AirbyteRecordMessage().withData(Jsons.jsonNode(data)).withStream(stream).withNamespace(namespace)) + protected fun createRecord( + stream: String?, + namespace: String?, + data: Map + ): AirbyteMessage { + return AirbyteMessage() + .withType(AirbyteMessage.Type.RECORD) + .withRecord( + AirbyteRecordMessage() + .withData(Jsons.jsonNode(data)) + .withStream(stream) + .withNamespace(namespace) + ) } companion object { - protected var SCHEMA_NAME: String = "jdbc_integration_test1" + @JvmStatic protected var SCHEMA_NAME: String = "jdbc_integration_test1" protected var SCHEMA_NAME2: String = "jdbc_integration_test2" protected var TEST_SCHEMAS: Set = java.util.Set.of(SCHEMA_NAME, SCHEMA_NAME2) @@ -1071,15 +1633,22 @@ abstract class JdbcSourceAcceptanceTest?> protected var ID_VALUE_5: Number = 5 protected var DROP_SCHEMA_QUERY: String = "DROP SCHEMA IF EXISTS %s CASCADE" - protected var COLUMN_CLAUSE_WITH_PK: String = "id INTEGER, name VARCHAR(200) NOT NULL, updated_at DATE NOT NULL" - protected var COLUMN_CLAUSE_WITHOUT_PK: String = "id INTEGER, name VARCHAR(200) NOT NULL, updated_at DATE NOT NULL" - protected var COLUMN_CLAUSE_WITH_COMPOSITE_PK: String = "first_name VARCHAR(200) NOT NULL, last_name VARCHAR(200) NOT NULL, updated_at DATE NOT NULL" - - protected var CREATE_TABLE_WITHOUT_CURSOR_TYPE_QUERY: String = "CREATE TABLE %s (%s bit NOT NULL);" - protected var INSERT_TABLE_WITHOUT_CURSOR_TYPE_QUERY: String = "INSERT INTO %s VALUES(0);" - protected var CREATE_TABLE_WITH_NULLABLE_CURSOR_TYPE_QUERY: String = "CREATE TABLE %s (%s VARCHAR(20));" - protected var INSERT_TABLE_WITH_NULLABLE_CURSOR_TYPE_QUERY: String = "INSERT INTO %s VALUES('Hello world :)');" - protected var INSERT_TABLE_NAME_AND_TIMESTAMP_QUERY: String = "INSERT INTO %s (name, timestamp) VALUES ('%s', '%s')" + protected var COLUMN_CLAUSE_WITH_PK: String = + "id INTEGER, name VARCHAR(200) NOT NULL, updated_at DATE NOT NULL" + protected var COLUMN_CLAUSE_WITHOUT_PK: String = + "id INTEGER, name VARCHAR(200) NOT NULL, updated_at DATE NOT NULL" + protected var COLUMN_CLAUSE_WITH_COMPOSITE_PK: String = + "first_name VARCHAR(200) NOT NULL, last_name VARCHAR(200) NOT NULL, updated_at DATE NOT NULL" + + @JvmField + var CREATE_TABLE_WITHOUT_CURSOR_TYPE_QUERY: String = "CREATE TABLE %s (%s bit NOT NULL);" + @JvmField var INSERT_TABLE_WITHOUT_CURSOR_TYPE_QUERY: String = "INSERT INTO %s VALUES(0);" + protected var CREATE_TABLE_WITH_NULLABLE_CURSOR_TYPE_QUERY: String = + "CREATE TABLE %s (%s VARCHAR(20));" + protected var INSERT_TABLE_WITH_NULLABLE_CURSOR_TYPE_QUERY: String = + "INSERT INTO %s VALUES('Hello world :)');" + protected var INSERT_TABLE_NAME_AND_TIMESTAMP_QUERY: String = + "INSERT INTO %s (name, timestamp) VALUES ('%s', '%s')" protected fun setEmittedAtToNull(messages: Iterable) { for (actualMessage in messages) { diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/source/jdbc/test/JdbcStressTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/source/jdbc/test/JdbcStressTest.kt index 5395e454a7621..c6b3e45735f28 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/source/jdbc/test/JdbcStressTest.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/source/jdbc/test/JdbcStressTest.kt @@ -19,14 +19,14 @@ import io.airbyte.commons.string.Strings import io.airbyte.protocol.models.Field import io.airbyte.protocol.models.JsonSchemaType import io.airbyte.protocol.models.v0.* -import org.junit.jupiter.api.Assertions -import org.junit.jupiter.api.Test -import org.slf4j.Logger -import org.slf4j.LoggerFactory import java.math.BigDecimal import java.nio.ByteBuffer import java.sql.Connection import java.util.* +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test +import org.slf4j.Logger +import org.slf4j.LoggerFactory /** * Runs a "large" amount of data through a JdbcSource to ensure that it streams / chunks records. @@ -34,21 +34,25 @@ import java.util.* // todo (cgardens) - this needs more love and thought. we should be able to test this without having // to rewrite so much data. it is enough for now to sanity check that our JdbcSources can actually // handle more data than fits in memory. -@SuppressFBWarnings(value = ["MS_SHOULD_BE_FINAL"], justification = "The static variables are updated in sub classes for convenience, and cannot be final.") +@SuppressFBWarnings( + value = ["MS_SHOULD_BE_FINAL"], + justification = + "The static variables are updated in sub classes for convenience, and cannot be final." +) abstract class JdbcStressTest { private var bitSet: BitSet? = null - private var config: JsonNode? = null + private lateinit var config: JsonNode private var source: AbstractJdbcSource<*>? = null /** - * These tests write records without specifying a namespace (schema name). They will be written into - * whatever the default schema is for the database. When they are discovered they will be namespaced - * by the schema name (e.g. .). Thus the source needs to tell the - * tests what that default schema name is. If the database does not support schemas, then database - * name should used instead. + * These tests write records without specifying a namespace (schema name). They will be written + * into whatever the default schema is for the database. When they are discovered they will be + * namespaced by the schema name (e.g. .). Thus the source + * needs to tell the tests what that default schema name is. If the database does not support + * schemas, then database name should used instead. * - * @return name that will be used to namespace the record. - */ + * @return name that will be used to namespace the record. + */ abstract val defaultSchemaName: Optional /** @@ -56,7 +60,7 @@ abstract class JdbcStressTest { * * @return config */ - abstract fun getConfig(): JsonNode? + abstract fun getConfig(): JsonNode /** * Full qualified class name of the JDBC driver for the database. @@ -73,8 +77,7 @@ abstract class JdbcStressTest { abstract fun getSource(): AbstractJdbcSource<*>? protected fun createTableQuery(tableName: String?, columnClause: String?): String { - return String.format("CREATE TABLE %s(%s)", - tableName, columnClause) + return String.format("CREATE TABLE %s(%s)", tableName, columnClause) } @Throws(Exception::class) @@ -83,25 +86,39 @@ abstract class JdbcStressTest { bitSet = BitSet(TOTAL_RECORDS.toInt()) source = getSource() - streamName = defaultSchemaName.map { `val`: String -> `val` + "." + TABLE_NAME }.orElse(TABLE_NAME) + streamName = + defaultSchemaName.map { `val`: String -> `val` + "." + TABLE_NAME }.orElse(TABLE_NAME) config = getConfig() val jdbcConfig = source!!.toDatabaseConfig(config) - val database: JdbcDatabase = DefaultJdbcDatabase( + val database: JdbcDatabase = + DefaultJdbcDatabase( create( - jdbcConfig[JdbcUtils.USERNAME_KEY].asText(), - if (jdbcConfig.has(JdbcUtils.PASSWORD_KEY)) jdbcConfig[JdbcUtils.PASSWORD_KEY].asText() else null, - driverClass, - jdbcConfig[JdbcUtils.JDBC_URL_KEY].asText())) - - database.execute(CheckedConsumer { connection: Connection -> - connection.createStatement().execute( - createTableQuery("id_and_name", String.format("id %s, name VARCHAR(200)", COL_ID_TYPE))) - }) + jdbcConfig[JdbcUtils.USERNAME_KEY].asText(), + if (jdbcConfig.has(JdbcUtils.PASSWORD_KEY)) + jdbcConfig[JdbcUtils.PASSWORD_KEY].asText() + else null, + driverClass, + jdbcConfig[JdbcUtils.JDBC_URL_KEY].asText() + ) + ) + + database.execute( + CheckedConsumer { connection: Connection -> + connection + .createStatement() + .execute( + createTableQuery( + "id_and_name", + String.format("id %s, name VARCHAR(200)", COL_ID_TYPE) + ) + ) + } + ) val batchCount = TOTAL_RECORDS / BATCH_SIZE LOGGER.info("writing {} batches of {}", batchCount, BATCH_SIZE) for (i in 0 until batchCount) { - if (i % 1000 == 0) LOGGER.info("writing batch: $i") + if (i % 1000 == 0L) LOGGER.info("writing batch: $i") val insert: MutableList = ArrayList() for (j in 0 until BATCH_SIZE) { val recordNumber = (i * BATCH_SIZE + j).toInt() @@ -109,12 +126,18 @@ abstract class JdbcStressTest { } val sql = prepareInsertStatement(insert) - database.execute(CheckedConsumer { connection: Connection -> connection.createStatement().execute(sql) }) + database.execute( + CheckedConsumer { connection: Connection -> + connection.createStatement().execute(sql) + } + ) } } - // todo (cgardens) - restructure these tests so that testFullRefresh() and testIncremental() can be - // separate tests. current constrained by only wanting to setup the fixture in the database once, + // todo (cgardens) - restructure these tests so that testFullRefresh() and testIncremental() can + // be + // separate tests. current constrained by only wanting to setup the fixture in the database + // once, // but it is not trivial to move them to @BeforeAll because it is static and we are doing // inheritance. Not impossible, just needs to be done thoughtfully and for all JdbcSources. @Test @@ -137,8 +160,10 @@ abstract class JdbcStressTest { @Throws(Exception::class) private fun runTest(configuredCatalog: ConfiguredAirbyteCatalog, testName: String) { LOGGER.info("running stress test for: $testName") - val read: Iterator = source!!.read(config!!, configuredCatalog, Jsons.jsonNode(emptyMap())) - val actualCount = MoreStreams.toStream(read) + val read: Iterator = + source!!.read(config!!, configuredCatalog, Jsons.jsonNode(emptyMap())) + val actualCount = + MoreStreams.toStream(read) .filter { m: AirbyteMessage -> m.type == AirbyteMessage.Type.RECORD } .peek { m: AirbyteMessage -> if (m.record.data[COL_ID].asLong() % 100000 == 0L) { @@ -152,7 +177,11 @@ abstract class JdbcStressTest { LOGGER.info("expected records count: " + TOTAL_RECORDS) LOGGER.info("actual records count: $actualCount") Assertions.assertEquals(expectedRoundedRecordsCount, actualCount, "testing: $testName") - Assertions.assertEquals(expectedRoundedRecordsCount, bitSet!!.cardinality().toLong(), "testing: $testName") + Assertions.assertEquals( + expectedRoundedRecordsCount, + bitSet!!.cardinality().toLong(), + "testing: $testName" + ) } // each is roughly 106 bytes. @@ -162,13 +191,27 @@ abstract class JdbcStressTest { actualMessage.record.emittedAt = null val expectedRecordNumber: Number = - if (driverClass.lowercase(Locale.getDefault()).contains("oracle")) BigDecimal(recordNumber) - else recordNumber + if (driverClass.lowercase(Locale.getDefault()).contains("oracle")) + BigDecimal(recordNumber) + else recordNumber - val expectedMessage = AirbyteMessage().withType(AirbyteMessage.Type.RECORD) - .withRecord(AirbyteRecordMessage().withStream(streamName) - .withData(Jsons.jsonNode( - ImmutableMap.of(COL_ID, expectedRecordNumber, COL_NAME, "picard-$recordNumber")))) + val expectedMessage = + AirbyteMessage() + .withType(AirbyteMessage.Type.RECORD) + .withRecord( + AirbyteRecordMessage() + .withStream(streamName) + .withData( + Jsons.jsonNode( + ImmutableMap.of( + COL_ID, + expectedRecordNumber, + COL_NAME, + "picard-$recordNumber" + ) + ) + ) + ) Assertions.assertEquals(expectedMessage, actualMessage) } @@ -176,7 +219,10 @@ abstract class JdbcStressTest { if (driverClass.lowercase(Locale.getDefault()).contains("oracle")) { return String.format("INSERT ALL %s SELECT * FROM dual", Strings.join(inserts, " ")) } - return String.format("INSERT INTO id_and_name (id, name) VALUES %s", Strings.join(inserts, ", ")) + return String.format( + "INSERT INTO id_and_name (id, name) VALUES %s", + Strings.join(inserts, ", ") + ) } companion object { @@ -197,17 +243,32 @@ abstract class JdbcStressTest { get() = CatalogHelpers.toDefaultConfiguredCatalog(catalog) private val configuredCatalogIncremental: ConfiguredAirbyteCatalog - get() = ConfiguredAirbyteCatalog() - .withStreams(listOf(ConfiguredAirbyteStream().withStream(catalog.streams[0]) - .withCursorField(listOf(COL_ID)) - .withSyncMode(SyncMode.INCREMENTAL) - .withDestinationSyncMode(DestinationSyncMode.APPEND))) + get() = + ConfiguredAirbyteCatalog() + .withStreams( + listOf( + ConfiguredAirbyteStream() + .withStream(catalog.streams[0]) + .withCursorField(listOf(COL_ID)) + .withSyncMode(SyncMode.INCREMENTAL) + .withDestinationSyncMode(DestinationSyncMode.APPEND) + ) + ) private val catalog: AirbyteCatalog - get() = AirbyteCatalog().withStreams(Lists.newArrayList(CatalogHelpers.createAirbyteStream( - streamName, - Field.of(COL_ID, JsonSchemaType.NUMBER), - Field.of(COL_NAME, JsonSchemaType.STRING)) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL)))) + get() = + AirbyteCatalog() + .withStreams( + Lists.newArrayList( + CatalogHelpers.createAirbyteStream( + streamName, + Field.of(COL_ID, JsonSchemaType.NUMBER), + Field.of(COL_NAME, JsonSchemaType.STRING) + ) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL) + ) + ) + ) } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/AbstractSourceConnectorTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/AbstractSourceConnectorTest.kt index 28f1fc794ae2b..136596b3dd43b 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/AbstractSourceConnectorTest.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/AbstractSourceConnectorTest.kt @@ -28,6 +28,9 @@ import io.airbyte.workers.internal.DefaultAirbyteSource import io.airbyte.workers.process.AirbyteIntegrationLauncher import io.airbyte.workers.process.DockerProcessFactory import io.airbyte.workers.process.ProcessFactory +import java.nio.file.Files +import java.nio.file.Path +import java.util.* import org.junit.jupiter.api.AfterEach import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.BeforeEach @@ -36,9 +39,6 @@ import org.mockito.ArgumentMatchers import org.mockito.Mockito import org.slf4j.Logger import org.slf4j.LoggerFactory -import java.nio.file.Files -import java.nio.file.Path -import java.util.* /** * This abstract class contains helpful functionality and boilerplate for testing a source @@ -50,56 +50,58 @@ abstract class AbstractSourceConnectorTest { protected var localRoot: Path? = null private var processFactory: ProcessFactory? = null - protected abstract val imageName: String? - /** - * Name of the docker image that the tests will run against. - * - * @return docker image name - */ - get + /** Name of the docker image that the tests will run against. */ + protected abstract val imageName: String @get:Throws(Exception::class) protected abstract val config: JsonNode? /** - * Configuration specific to the integration. Will be passed to integration where appropriate in - * each test. Should be valid. + * Configuration specific to the integration. Will be passed to integration where + * appropriate in each test. Should be valid. * * @return integration-specific configuration */ get /** - * Function that performs any setup of external resources required for the test. e.g. instantiate a - * postgres database. This function will be called before EACH test. + * Function that performs any setup of external resources required for the test. e.g. + * instantiate a postgres database. This function will be called before EACH test. * - * @param environment - information about the test environment. - * @throws Exception - can throw any exception, test framework will handle. + * @param environment + * - information about the test environment. + * @throws Exception + * - can throw any exception, test framework will handle. */ @Throws(Exception::class) protected abstract fun setupEnvironment(environment: TestDestinationEnv?) /** - * Function that performs any clean up of external resources required for the test. e.g. delete a - * postgres database. This function will be called after EACH test. It MUST remove all data in the - * destination so that there is no contamination across tests. + * Function that performs any clean up of external resources required for the test. e.g. delete + * a postgres database. This function will be called after EACH test. It MUST remove all data in + * the destination so that there is no contamination across tests. * - * @param testEnv - information about the test environment. - * @throws Exception - can throw any exception, test framework will handle. + * @param testEnv + * - information about the test environment. + * @throws Exception + * - can throw any exception, test framework will handle. */ - @Throws(Exception::class) - protected abstract fun tearDown(testEnv: TestDestinationEnv?) + @Throws(Exception::class) protected abstract fun tearDown(testEnv: TestDestinationEnv?) - private var mAirbyteApiClient: AirbyteApiClient? = null + private lateinit var mAirbyteApiClient: AirbyteApiClient - private var mSourceApi: SourceApi? = null + private lateinit var mSourceApi: SourceApi private var mConnectorConfigUpdater: ConnectorConfigUpdater? = null protected val lastPersistedCatalog: AirbyteCatalog - get() = convertProtocolObject( - CatalogClientConverters.toAirbyteProtocol(discoverWriteRequest.value.catalog), AirbyteCatalog::class.java) + get() = + convertProtocolObject( + CatalogClientConverters.toAirbyteProtocol(discoverWriteRequest.value.catalog), + AirbyteCatalog::class.java + ) - private val discoverWriteRequest: ArgumentCaptor = ArgumentCaptor.forClass(SourceDiscoverSchemaWriteRequestBody::class.java) + private val discoverWriteRequest: ArgumentCaptor = + ArgumentCaptor.forClass(SourceDiscoverSchemaWriteRequestBody::class.java) @BeforeEach @Throws(Exception::class) @@ -115,16 +117,18 @@ abstract class AbstractSourceConnectorTest { mSourceApi = Mockito.mock(SourceApi::class.java) Mockito.`when`(mAirbyteApiClient.getSourceApi()).thenReturn(mSourceApi) Mockito.`when`(mSourceApi.writeDiscoverCatalogResult(ArgumentMatchers.any())) - .thenReturn(DiscoverCatalogResult().catalogId(CATALOG_ID)) + .thenReturn(DiscoverCatalogResult().catalogId(CATALOG_ID)) mConnectorConfigUpdater = Mockito.mock(ConnectorConfigUpdater::class.java) val envMap = HashMap(TestEnvConfigs().jobDefaultEnvMap) envMap[EnvVariableFeatureFlags.DEPLOYMENT_MODE] = featureFlags().deploymentMode() - processFactory = DockerProcessFactory( + processFactory = + DockerProcessFactory( workspaceRoot, workspaceRoot.toString(), localRoot.toString(), "host", - envMap) + envMap + ) postSetup() } @@ -133,9 +137,7 @@ abstract class AbstractSourceConnectorTest { * Override this method if you want to do any per-test setup that depends on being able to e.g. * [.runRead]. */ - @Throws(Exception::class) - protected fun postSetup() { - } + @Throws(Exception::class) protected fun postSetup() {} @AfterEach @Throws(Exception::class) @@ -149,39 +151,87 @@ abstract class AbstractSourceConnectorTest { @Throws(TestHarnessException::class) protected fun runSpec(): ConnectorSpecification { - val spec = DefaultGetSpecTestHarness( - AirbyteIntegrationLauncher(JOB_ID, JOB_ATTEMPT, imageName, processFactory, null, null, false, - featureFlags())) - .run(JobGetSpecConfig().withDockerImage(imageName), jobRoot).spec + val spec = + DefaultGetSpecTestHarness( + AirbyteIntegrationLauncher( + JOB_ID, + JOB_ATTEMPT, + imageName, + processFactory, + null, + null, + false, + featureFlags() + ) + ) + .run(JobGetSpecConfig().withDockerImage(imageName), jobRoot) + .spec return convertProtocolObject(spec, ConnectorSpecification::class.java) } @Throws(Exception::class) protected fun runCheck(): StandardCheckConnectionOutput { return DefaultCheckConnectionTestHarness( - AirbyteIntegrationLauncher(JOB_ID, JOB_ATTEMPT, imageName, processFactory, null, null, false, - featureFlags()), - mConnectorConfigUpdater) - .run(StandardCheckConnectionInput().withConnectionConfiguration(config), jobRoot).checkConnection + AirbyteIntegrationLauncher( + JOB_ID, + JOB_ATTEMPT, + imageName, + processFactory, + null, + null, + false, + featureFlags() + ), + mConnectorConfigUpdater + ) + .run(StandardCheckConnectionInput().withConnectionConfiguration(config), jobRoot) + .checkConnection } @Throws(Exception::class) protected fun runCheckAndGetStatusAsString(config: JsonNode?): String { return DefaultCheckConnectionTestHarness( - AirbyteIntegrationLauncher(JOB_ID, JOB_ATTEMPT, imageName, processFactory, null, null, false, - featureFlags()), - mConnectorConfigUpdater) - .run(StandardCheckConnectionInput().withConnectionConfiguration(config), jobRoot).checkConnection.status.toString() + AirbyteIntegrationLauncher( + JOB_ID, + JOB_ATTEMPT, + imageName, + processFactory, + null, + null, + false, + featureFlags() + ), + mConnectorConfigUpdater + ) + .run(StandardCheckConnectionInput().withConnectionConfiguration(config), jobRoot) + .checkConnection + .status + .toString() } @Throws(Exception::class) protected fun runDiscover(): UUID { - val toReturn = DefaultDiscoverCatalogTestHarness( - mAirbyteApiClient, - AirbyteIntegrationLauncher(JOB_ID, JOB_ATTEMPT, imageName, processFactory, null, null, false, - featureFlags()), - mConnectorConfigUpdater) - .run(StandardDiscoverCatalogInput().withSourceId(SOURCE_ID.toString()).withConnectionConfiguration(config), jobRoot) + val toReturn = + DefaultDiscoverCatalogTestHarness( + mAirbyteApiClient, + AirbyteIntegrationLauncher( + JOB_ID, + JOB_ATTEMPT, + imageName, + processFactory, + null, + null, + false, + featureFlags() + ), + mConnectorConfigUpdater + ) + .run( + StandardDiscoverCatalogInput() + .withSourceId(SOURCE_ID.toString()) + .withConnectionConfiguration(config), + jobRoot + ) .discoverCatalogId Mockito.verify(mSourceApi).writeDiscoverCatalogResult(discoverWriteRequest.capture()) return toReturn @@ -189,12 +239,14 @@ abstract class AbstractSourceConnectorTest { @Throws(Exception::class) protected fun checkEntrypointEnvVariable() { - val entrypoint = EntrypointEnvChecker.getEntrypointEnvVariable( + val entrypoint = + EntrypointEnvChecker.getEntrypointEnvVariable( processFactory, JOB_ID, JOB_ATTEMPT, jobRoot, - imageName) + imageName + ) Assertions.assertNotNull(entrypoint) Assertions.assertFalse(entrypoint.isBlank()) @@ -207,20 +259,41 @@ abstract class AbstractSourceConnectorTest { // todo (cgardens) - assume no state since we are all full refresh right now. @Throws(Exception::class) - protected fun runRead(catalog: ConfiguredAirbyteCatalog?, state: JsonNode?): List { - val sourceConfig = WorkerSourceConfig() + protected fun runRead( + catalog: ConfiguredAirbyteCatalog?, + state: JsonNode? + ): List { + val sourceConfig = + WorkerSourceConfig() .withSourceConnectionConfiguration(config) .withState(if (state == null) null else State().withState(state)) - .withCatalog(convertProtocolObject(catalog, io.airbyte.protocol.models.ConfiguredAirbyteCatalog::class.java)) - - val source: AirbyteSource = DefaultAirbyteSource( - AirbyteIntegrationLauncher(JOB_ID, JOB_ATTEMPT, imageName, processFactory, null, null, false, - featureFlags()), - featureFlags()) + .withCatalog( + convertProtocolObject( + catalog, + io.airbyte.protocol.models.ConfiguredAirbyteCatalog::class.java + ) + ) + + val source: AirbyteSource = + DefaultAirbyteSource( + AirbyteIntegrationLauncher( + JOB_ID, + JOB_ATTEMPT, + imageName, + processFactory, + null, + null, + false, + featureFlags() + ), + featureFlags() + ) val messages: MutableList = ArrayList() source.start(sourceConfig, jobRoot) while (!source.isFinished) { - source.attemptRead().ifPresent { m: io.airbyte.protocol.models.AirbyteMessage -> messages.add(convertProtocolObject(m, AirbyteMessage::class.java)) } + source.attemptRead().ifPresent { m: io.airbyte.protocol.models.AirbyteMessage -> + messages.add(convertProtocolObject(m, AirbyteMessage::class.java)) + } } source.close() @@ -228,20 +301,34 @@ abstract class AbstractSourceConnectorTest { } @Throws(Exception::class) - protected fun runReadVerifyNumberOfReceivedMsgs(catalog: ConfiguredAirbyteCatalog, - state: JsonNode?, - mapOfExpectedRecordsCount: MutableMap): Map { - val sourceConfig = WorkerSourceConfig() + protected fun runReadVerifyNumberOfReceivedMsgs( + catalog: ConfiguredAirbyteCatalog, + state: JsonNode?, + mapOfExpectedRecordsCount: MutableMap + ): Map { + val sourceConfig = + WorkerSourceConfig() .withSourceConnectionConfiguration(config) .withState(if (state == null) null else State().withState(state)) - .withCatalog(convertProtocolObject(catalog, io.airbyte.protocol.models.ConfiguredAirbyteCatalog::class.java)) + .withCatalog( + convertProtocolObject( + catalog, + io.airbyte.protocol.models.ConfiguredAirbyteCatalog::class.java + ) + ) val source = prepareAirbyteSource() source.start(sourceConfig, jobRoot) while (!source.isFinished) { - val airbyteMessageOptional = source.attemptRead().map { m: io.airbyte.protocol.models.AirbyteMessage -> convertProtocolObject(m, AirbyteMessage::class.java) } - if (airbyteMessageOptional.isPresent && airbyteMessageOptional.get().type == AirbyteMessage.Type.RECORD) { + val airbyteMessageOptional = + source.attemptRead().map { m: io.airbyte.protocol.models.AirbyteMessage -> + convertProtocolObject(m, AirbyteMessage::class.java) + } + if ( + airbyteMessageOptional.isPresent && + airbyteMessageOptional.get().type == AirbyteMessage.Type.RECORD + ) { val airbyteMessage = airbyteMessageOptional.get() val record = airbyteMessage.record @@ -254,7 +341,8 @@ abstract class AbstractSourceConnectorTest { } private fun prepareAirbyteSource(): AirbyteSource { - val integrationLauncher = AirbyteIntegrationLauncher( + val integrationLauncher = + AirbyteIntegrationLauncher( JOB_ID, JOB_ATTEMPT, imageName, @@ -262,12 +350,14 @@ abstract class AbstractSourceConnectorTest { null, null, false, - featureFlags()) + featureFlags() + ) return DefaultAirbyteSource(integrationLauncher, featureFlags()) } companion object { - protected val LOGGER: Logger = LoggerFactory.getLogger(AbstractSourceConnectorTest::class.java) + protected val LOGGER: Logger = + LoggerFactory.getLogger(AbstractSourceConnectorTest::class.java) private const val JOB_ID = 0L.toString() private const val JOB_ATTEMPT = 0 diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/AbstractSourceDatabaseTypeTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/AbstractSourceDatabaseTypeTest.kt index c00a381877141..573b8852d4fa1 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/AbstractSourceDatabaseTypeTest.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/AbstractSourceDatabaseTypeTest.kt @@ -10,17 +10,17 @@ import io.airbyte.commons.json.Jsons import io.airbyte.protocol.models.Field import io.airbyte.protocol.models.JsonSchemaType import io.airbyte.protocol.models.v0.* +import java.io.IOException +import java.sql.SQLException +import java.util.function.Consumer +import java.util.function.Function +import java.util.stream.Collectors import org.apache.commons.lang3.StringUtils import org.jooq.DSLContext import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.Test import org.slf4j.Logger import org.slf4j.LoggerFactory -import java.io.IOException -import java.sql.SQLException -import java.util.function.Consumer -import java.util.function.Function -import java.util.stream.Collectors /** * This abstract class contains common helpers and boilerplate for comprehensively testing that all @@ -33,8 +33,8 @@ abstract class AbstractSourceDatabaseTypeTest : AbstractSourceConnectorTest() { protected val idColumnName: String /** - * The column name will be used for a PK column in the test tables. Override it if default name is - * not valid for your source. + * The column name will be used for a PK column in the test tables. Override it if default + * name is not valid for your source. * * @return Id column name */ @@ -42,25 +42,24 @@ abstract class AbstractSourceDatabaseTypeTest : AbstractSourceConnectorTest() { protected val testColumnName: String /** - * The column name will be used for a test column in the test tables. Override it if default name is - * not valid for your source. + * The column name will be used for a test column in the test tables. Override it if default + * name is not valid for your source. * * @return Test column name */ get() = "test_column" /** - * Setup the test database. All tables and data described in the registered tests will be put there. + * Setup the test database. All tables and data described in the registered tests will be put + * there. * * @return configured test database - * @throws Exception - might throw any exception during initialization. + * @throws Exception + * - might throw any exception during initialization. */ - @Throws(Exception::class) - protected abstract fun setupDatabase(): Database? + @Throws(Exception::class) protected abstract fun setupDatabase(): Database? - /** - * Put all required tests here using method [.addDataTypeTestData] - */ + /** Put all required tests here using method [.addDataTypeTestData] */ protected abstract fun initTests() @Throws(Exception::class) @@ -73,78 +72,103 @@ abstract class AbstractSourceDatabaseTypeTest : AbstractSourceConnectorTest() { protected abstract val nameSpace: String /** - * Provide a source namespace. It's allocated place for table creation. It also known ask "Database - * Schema" or "Dataset" + * Provide a source namespace. It's allocated place for table creation. It also known ask + * "Database Schema" or "Dataset" * * @return source name space */ get /** - * Test the 'discover' command. TODO (liren): Some existing databases may fail testDataTypes(), so - * it is turned off by default. It should be enabled for all databases eventually. + * Test the 'discover' command. TODO (liren): Some existing databases may fail testDataTypes(), + * so it is turned off by default. It should be enabled for all databases eventually. */ protected fun testCatalog(): Boolean { return false } /** - * The test checks that the types from the catalog matches the ones discovered from the source. This - * test is disabled by default. To enable it you need to overwrite testCatalog() function. + * The test checks that the types from the catalog matches the ones discovered from the source. + * This test is disabled by default. To enable it you need to overwrite testCatalog() function. */ @Test @Throws(Exception::class) fun testDataTypes() { if (testCatalog()) { runDiscover() - val streams = lastPersistedCatalog.streams.stream() - .collect(Collectors.toMap(Function { obj: AirbyteStream -> obj.name }, Function { s: AirbyteStream? -> s })) + val streams = + lastPersistedCatalog.streams + .stream() + .collect( + Collectors.toMap( + Function { obj: AirbyteStream -> obj.name }, + Function { s: AirbyteStream? -> s } + ) + ) // testDataHolders should be initialized using the `addDataTypeTestData` function - testDataHolders.forEach(Consumer { testDataHolder: TestDataHolder -> - val airbyteStream = streams[testDataHolder.nameWithTestPrefix] - val jsonSchemaTypeMap = Jsons.deserialize>( - airbyteStream!!.jsonSchema["properties"][testColumnName].toString(), MutableMap::class.java) as Map - Assertions.assertEquals(testDataHolder.airbyteType.jsonSchemaTypeMap, jsonSchemaTypeMap, - "Expected column type for " + testDataHolder.nameWithTestPrefix) - }) + testDataHolders.forEach( + Consumer { testDataHolder: TestDataHolder -> + val airbyteStream = streams[testDataHolder.nameWithTestPrefix] + val jsonSchemaTypeMap = + Jsons.deserialize( + airbyteStream!!.jsonSchema["properties"][testColumnName].toString(), + MutableMap::class.java + ) as Map + Assertions.assertEquals( + testDataHolder.airbyteType.jsonSchemaTypeMap, + jsonSchemaTypeMap, + "Expected column type for " + testDataHolder.nameWithTestPrefix + ) + } + ) } } /** * The test checks that connector can fetch prepared data without failure. It uses a prepared - * catalog and read the source using that catalog. Then makes sure that the expected values are the - * ones inserted in the source. + * catalog and read the source using that catalog. Then makes sure that the expected values are + * the ones inserted in the source. */ @Test @Throws(Exception::class) fun testDataContent() { // Class used to make easier the error reporting - class MissedRecords(// Stream that is missing any value - var streamName: String?, // Which are the values that has not being gathered from the source - var missedValues: List?) + class MissedRecords( // Stream that is missing any value + var streamName: + String?, // Which are the values that has not being gathered from the source + var missedValues: List? + ) class UnexpectedRecord(val streamName: String, val unexpectedValue: String?) val catalog = configuredCatalog val allMessages = runRead(catalog) - val recordMessages = allMessages!!.stream().filter { m: AirbyteMessage? -> m!!.type == AirbyteMessage.Type.RECORD }.toList() + val recordMessages = + allMessages!! + .stream() + .filter { m: AirbyteMessage? -> m!!.type == AirbyteMessage.Type.RECORD } + .toList() val expectedValues: MutableMap?> = HashMap() val missedValuesByStream: MutableMap> = HashMap() val unexpectedValuesByStream: MutableMap> = HashMap() val testByName: MutableMap = HashMap() - // If there is no expected value in the test set we don't include it in the list to be asserted + // If there is no expected value in the test set we don't include it in the list to be + // asserted // (even if the table contains records) - testDataHolders.forEach(Consumer { testDataHolder: TestDataHolder -> - if (!testDataHolder.expectedValues.isEmpty()) { - expectedValues[testDataHolder.nameWithTestPrefix] = testDataHolder.expectedValues - testByName[testDataHolder.nameWithTestPrefix] = testDataHolder - } else { - LOGGER.warn("Missing expected values for type: " + testDataHolder.sourceType) + testDataHolders.forEach( + Consumer { testDataHolder: TestDataHolder -> + if (!testDataHolder.expectedValues.isEmpty()) { + expectedValues[testDataHolder.nameWithTestPrefix] = + testDataHolder.expectedValues + testByName[testDataHolder.nameWithTestPrefix] = testDataHolder + } else { + LOGGER.warn("Missing expected values for type: " + testDataHolder.sourceType) + } } - }) + ) for (message in recordMessages) { val streamName = message!!.record.stream @@ -171,23 +195,33 @@ abstract class AbstractSourceDatabaseTypeTest : AbstractSourceConnectorTest() { val errorsByStream: MutableMap> = HashMap() for (streamName in unexpectedValuesByStream.keys) { errorsByStream.putIfAbsent(streamName, ArrayList()) - val test = testByName[streamName] + val test = testByName.getValue(streamName) val unexpectedValues: List = unexpectedValuesByStream[streamName]!! for (unexpectedValue in unexpectedValues) { errorsByStream[streamName]!!.add( - "The stream '%s' checking type '%s' initialized at %s got unexpected values: %s".formatted(streamName, test.getSourceType(), - test!!.declarationLocation, unexpectedValue)) + "The stream '%s' checking type '%s' initialized at %s got unexpected values: %s".formatted( + streamName, + test.sourceType, + test!!.declarationLocation, + unexpectedValue + ) + ) } } for (streamName in missedValuesByStream.keys) { errorsByStream.putIfAbsent(streamName, ArrayList()) - val test = testByName[streamName] + val test = testByName.getValue(streamName) val missedValues: List = missedValuesByStream[streamName]!! for (missedValue in missedValues) { errorsByStream[streamName]!!.add( - "The stream '%s' checking type '%s' initialized at %s is missing values: %s".formatted(streamName, test.getSourceType(), - test!!.declarationLocation, missedValue)) + "The stream '%s' checking type '%s' initialized at %s is missing values: %s".formatted( + streamName, + test.sourceType, + test!!.declarationLocation, + missedValue + ) + ) } } @@ -206,7 +240,9 @@ abstract class AbstractSourceDatabaseTypeTest : AbstractSourceConnectorTest() { return jsonNode.toString() } - var value = (if (jsonNode.isBinary) jsonNode.binaryValue().contentToString() else jsonNode.asText()) + var value = + (if (jsonNode.isBinary) jsonNode.binaryValue().contentToString() + else jsonNode.asText()) value = (if (value != null && value == "null") null else value) return value } @@ -223,7 +259,7 @@ abstract class AbstractSourceDatabaseTypeTest : AbstractSourceConnectorTest() { protected fun createTables() { for (test in testDataHolders) { database!!.query { ctx: DSLContext? -> - ctx.fetch(test.createSqlQuery) + ctx!!.fetch(test.createSqlQuery) LOGGER.info("Table {} is created.", test.nameWithTestPrefix) null } @@ -235,7 +271,11 @@ abstract class AbstractSourceDatabaseTypeTest : AbstractSourceConnectorTest() { for (test in testDataHolders) { database!!.query { ctx: DSLContext? -> test.insertSqlQueries.forEach(Consumer { sql: String? -> ctx!!.fetch(sql) }) - LOGGER.info("Inserted {} rows in Ttable {}", test.insertSqlQueries.size, test.nameWithTestPrefix) + LOGGER.info( + "Inserted {} rows in Ttable {}", + test.insertSqlQueries.size, + test.nameWithTestPrefix + ) null } } @@ -247,36 +287,53 @@ abstract class AbstractSourceDatabaseTypeTest : AbstractSourceConnectorTest() { * * @return configured catalog */ - get() = ConfiguredAirbyteCatalog().withStreams( - testDataHolders + get() = + ConfiguredAirbyteCatalog() + .withStreams( + testDataHolders .stream() .map { test: TestDataHolder -> ConfiguredAirbyteStream() - .withSyncMode(SyncMode.INCREMENTAL) - .withCursorField(Lists.newArrayList(idColumnName)) - .withDestinationSyncMode(DestinationSyncMode.APPEND) - .withStream(CatalogHelpers.createAirbyteStream( + .withSyncMode(SyncMode.INCREMENTAL) + .withCursorField(Lists.newArrayList(idColumnName)) + .withDestinationSyncMode(DestinationSyncMode.APPEND) + .withStream( + CatalogHelpers.createAirbyteStream( String.format("%s", test.nameWithTestPrefix), String.format("%s", nameSpace), Field.of(idColumnName, JsonSchemaType.INTEGER), - Field.of(testColumnName, test.airbyteType)) - .withSourceDefinedCursor(true) - .withSourceDefinedPrimaryKey(java.util.List.of(java.util.List.of(idColumnName))) - .withSupportedSyncModes( - Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL))) + Field.of(testColumnName, test.airbyteType) + ) + .withSourceDefinedCursor(true) + .withSourceDefinedPrimaryKey( + java.util.List.of(java.util.List.of(idColumnName)) + ) + .withSupportedSyncModes( + Lists.newArrayList( + SyncMode.FULL_REFRESH, + SyncMode.INCREMENTAL + ) + ) + ) } - .collect(Collectors.toList())) + .collect(Collectors.toList()) + ) /** * Register your test in the run scope. For each test will be created a table with one column of - * specified type. Note! If you register more than one test with the same type name, they will be - * run as independent tests with own streams. + * specified type. Note! If you register more than one test with the same type name, they will + * be run as independent tests with own streams. * * @param test comprehensive data type test */ fun addDataTypeTestData(test: TestDataHolder) { testDataHolders.add(test) - test.setTestNumber(testDataHolders.stream().filter { t: TestDataHolder -> t.sourceType == test.sourceType }.count()) + test.setTestNumber( + testDataHolders + .stream() + .filter { t: TestDataHolder -> t.sourceType == test.sourceType } + .count() + ) test.nameSpace = nameSpace test.setIdColumnName(idColumnName) test.setTestColumnName(testColumnName) @@ -289,24 +346,33 @@ abstract class AbstractSourceDatabaseTypeTest : AbstractSourceConnectorTest() { val markdownTestTable: String /** - * Builds a table with all registered test cases with values using Markdown syntax (can be used in - * the github). + * Builds a table with all registered test cases with values using Markdown syntax (can be + * used in the github). * * @return formatted list of test cases */ get() { - val table = StringBuilder() - .append("|**Data Type**|**Insert values**|**Expected values**|**Comment**|**Common test result**|\n") + val table = + StringBuilder() + .append( + "|**Data Type**|**Insert values**|**Expected values**|**Comment**|**Common test result**|\n" + ) .append("|----|----|----|----|----|\n") - testDataHolders.forEach(Consumer { test: TestDataHolder -> - table.append(String.format("| %s | %s | %s | %s | %s |\n", - test.sourceType, - formatCollection(test.values), - formatCollection(test.expectedValues), - "", - "Ok")) - }) + testDataHolders.forEach( + Consumer { test: TestDataHolder -> + table.append( + String.format( + "| %s | %s | %s | %s | %s |\n", + test.sourceType, + formatCollection(test.values), + formatCollection(test.expectedValues), + "", + "Ok" + ) + ) + } + ) return table.toString() } @@ -317,30 +383,44 @@ abstract class AbstractSourceDatabaseTypeTest : AbstractSourceConnectorTest() { @Throws(SQLException::class) protected fun createDummyTableWithData(database: Database): ConfiguredAirbyteStream { database.query { ctx: DSLContext? -> - ctx!!.fetch("CREATE TABLE " + nameSpace + ".random_dummy_table(id INTEGER PRIMARY KEY, test_column VARCHAR(63));") + ctx!!.fetch( + "CREATE TABLE " + + nameSpace + + ".random_dummy_table(id INTEGER PRIMARY KEY, test_column VARCHAR(63));" + ) ctx.fetch("INSERT INTO " + nameSpace + ".random_dummy_table VALUES (2, 'Random Data');") null } - return ConfiguredAirbyteStream().withSyncMode(SyncMode.INCREMENTAL) - .withCursorField(Lists.newArrayList("id")) - .withDestinationSyncMode(DestinationSyncMode.APPEND) - .withStream(CatalogHelpers.createAirbyteStream( + return ConfiguredAirbyteStream() + .withSyncMode(SyncMode.INCREMENTAL) + .withCursorField(Lists.newArrayList("id")) + .withDestinationSyncMode(DestinationSyncMode.APPEND) + .withStream( + CatalogHelpers.createAirbyteStream( "random_dummy_table", nameSpace, Field.of("id", JsonSchemaType.INTEGER), - Field.of("test_column", JsonSchemaType.STRING)) - .withSourceDefinedCursor(true) - .withSupportedSyncModes(Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL)) - .withSourceDefinedPrimaryKey(java.util.List.of(listOf("id")))) + Field.of("test_column", JsonSchemaType.STRING) + ) + .withSourceDefinedCursor(true) + .withSupportedSyncModes( + Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL) + ) + .withSourceDefinedPrimaryKey(java.util.List.of(listOf("id"))) + ) } protected fun extractStateMessages(messages: List): List { - return messages.stream().filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.STATE }.map { obj: AirbyteMessage -> obj.state } - .collect(Collectors.toList()) + return messages + .stream() + .filter { r: AirbyteMessage -> r.type == AirbyteMessage.Type.STATE } + .map { obj: AirbyteMessage -> obj.state } + .collect(Collectors.toList()) } companion object { - private val LOGGER: Logger = LoggerFactory.getLogger(AbstractSourceDatabaseTypeTest::class.java) + private val LOGGER: Logger = + LoggerFactory.getLogger(AbstractSourceDatabaseTypeTest::class.java) } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/PythonSourceAcceptanceTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/PythonSourceAcceptanceTest.kt index c0d7764c6334e..0ebdc2addeabf 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/PythonSourceAcceptanceTest.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/PythonSourceAcceptanceTest.kt @@ -13,15 +13,15 @@ import io.airbyte.protocol.models.v0.AirbyteMessage import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog import io.airbyte.protocol.models.v0.ConnectorSpecification import io.airbyte.workers.TestHarnessUtils -import org.junit.jupiter.api.Assertions -import org.slf4j.Logger -import org.slf4j.LoggerFactory import java.io.IOException import java.nio.file.Files import java.nio.file.Path import java.util.* import java.util.concurrent.TimeUnit import java.util.function.Consumer +import org.junit.jupiter.api.Assertions +import org.slf4j.Logger +import org.slf4j.LoggerFactory /** * Extends TestSource such that it can be called using resources pulled from the file system. Will @@ -48,19 +48,41 @@ class PythonSourceAcceptanceTest : SourceAcceptanceTest() { @Throws(IOException::class) override fun assertFullRefreshMessages(allMessages: List?) { - val regexTests = Streams.stream(runExecutable(Command.GET_REGEX_TESTS).withArray("tests").elements()) - .map { obj: JsonNode -> obj.textValue() }.toList() - val stringMessages = allMessages!!.stream().map { `object`: AirbyteMessage? -> Jsons.serialize(`object`) }.toList() + val regexTests = + Streams.stream( + runExecutable(Command.GET_REGEX_TESTS).withArray("tests").elements() + ) + .map { obj: JsonNode -> obj.textValue() } + .toList() + val stringMessages = + allMessages!! + .stream() + .map { `object`: AirbyteMessage? -> Jsons.serialize(`object`) } + .toList() LOGGER.info("Running " + regexTests.size + " regex tests...") - regexTests.forEach(Consumer { regex: String -> - LOGGER.info("Looking for [$regex]") - Assertions.assertTrue(stringMessages.stream().anyMatch { line: String -> line.matches(regex.toRegex()) }, "Failed to find regex: $regex") - }) + regexTests.forEach( + Consumer { regex: String -> + LOGGER.info("Looking for [$regex]") + Assertions.assertTrue( + stringMessages.stream().anyMatch { line: String -> + line.matches(regex.toRegex()) + }, + "Failed to find regex: $regex" + ) + } + ) } + override val imageName: String + get() = IMAGE_NAME + @Throws(Exception::class) override fun setupEnvironment(environment: TestDestinationEnv?) { - testRoot = Files.createTempDirectory(Files.createDirectories(Path.of("/tmp/standard_test")), "pytest") + testRoot = + Files.createTempDirectory( + Files.createDirectories(Path.of("/tmp/standard_test")), + "pytest" + ) runExecutableVoid(Command.SETUP) } @@ -98,21 +120,22 @@ class PythonSourceAcceptanceTest : SourceAcceptanceTest() { private fun runExecutableInternal(cmd: Command): Path? { LOGGER.info("testRoot = $testRoot") val dockerCmd: List = - Lists.newArrayList( - "docker", - "run", - "--rm", - "-i", - "-v", - String.format("%s:%s", testRoot, "/test_root"), - "-w", - testRoot.toString(), - "--network", - "host", - PYTHON_CONTAINER_NAME, - cmd.toString().lowercase(Locale.getDefault()), - "--out", - "/test_root") + Lists.newArrayList( + "docker", + "run", + "--rm", + "-i", + "-v", + String.format("%s:%s", testRoot, "/test_root"), + "-w", + testRoot.toString(), + "--network", + "host", + PYTHON_CONTAINER_NAME, + cmd.toString().lowercase(Locale.getDefault()), + "--out", + "/test_root" + ) val process = ProcessBuilder(dockerCmd).start() LineGobbler.gobble(process.errorStream) { msg: String? -> LOGGER.error(msg) } @@ -132,8 +155,7 @@ class PythonSourceAcceptanceTest : SourceAcceptanceTest() { private val LOGGER: Logger = LoggerFactory.getLogger(PythonSourceAcceptanceTest::class.java) private const val OUTPUT_FILENAME = "output.json" - protected var imageName: String? = null - get() = Companion.field + lateinit var IMAGE_NAME: String var PYTHON_CONTAINER_NAME: String? = null } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/SourceAcceptanceTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/SourceAcceptanceTest.kt index 9f512f590ae0b..8045d5377a097 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/SourceAcceptanceTest.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/SourceAcceptanceTest.kt @@ -10,24 +10,25 @@ import com.google.common.collect.Sets import io.airbyte.commons.json.Jsons import io.airbyte.configoss.StandardCheckConnectionOutput import io.airbyte.protocol.models.v0.* +import java.util.* +import java.util.stream.Collectors import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.Test import org.slf4j.Logger import org.slf4j.LoggerFactory -import java.util.* -import java.util.stream.Collectors abstract class SourceAcceptanceTest : AbstractSourceConnectorTest() { /** * TODO hack: Various Singer integrations use cursor fields inclusively i.e: they output records * whose cursor field >= the provided cursor value. This leads to the last record in a sync to - * always be the first record in the next sync. This is a fine assumption from a product POV since - * we offer at-least-once delivery. But for simplicity, the incremental test suite currently assumes - * that the second incremental read should output no records when provided the state from the first - * sync. This works for many integrations but not some Singer ones, so we hardcode the list of - * integrations to skip over when performing those tests. + * always be the first record in the next sync. This is a fine assumption from a product POV + * since we offer at-least-once delivery. But for simplicity, the incremental test suite + * currently assumes that the second incremental read should output no records when provided the + * state from the first sync. This works for many integrations but not some Singer ones, so we + * hardcode the list of integrations to skip over when performing those tests. */ - private val IMAGES_TO_SKIP_SECOND_INCREMENTAL_READ: Set = Sets.newHashSet( + private val IMAGES_TO_SKIP_SECOND_INCREMENTAL_READ: Set = + Sets.newHashSet( "airbyte/source-intercom-singer", "airbyte/source-exchangeratesapi-singer", "airbyte/source-hubspot", @@ -46,19 +47,20 @@ abstract class SourceAcceptanceTest : AbstractSourceConnectorTest() { "airbyte/source-zendesk-talk", "airbyte/source-zendesk-support-singer", "airbyte/source-quickbooks-singer", - "airbyte/source-jira") + "airbyte/source-jira" + ) /** * FIXME: Some sources can't guarantee that there will be no events between two sequential sync */ - private val IMAGES_TO_SKIP_IDENTICAL_FULL_REFRESHES: Set = Sets.newHashSet( - "airbyte/source-google-workspace-admin-reports", "airbyte/source-kafka") + private val IMAGES_TO_SKIP_IDENTICAL_FULL_REFRESHES: Set = + Sets.newHashSet("airbyte/source-google-workspace-admin-reports", "airbyte/source-kafka") @get:Throws(Exception::class) protected abstract val spec: ConnectorSpecification /** - * Specification for integration. Will be passed to integration where appropriate in each test. - * Should be valid. + * Specification for integration. Will be passed to integration where appropriate in each + * test. Should be valid. * * @return integration-specific configuration */ @@ -67,13 +69,14 @@ abstract class SourceAcceptanceTest : AbstractSourceConnectorTest() { @get:Throws(Exception::class) protected abstract val configuredCatalog: ConfiguredAirbyteCatalog /** - * The catalog to use to validate the output of read operations. This will be used as follows: - * + * The catalog to use to validate the output of read operations. This will be used as + * follows: * - * Full Refresh syncs will be tested on all the input streams which support it Incremental syncs: - - * if the stream declares a source-defined cursor, it will be tested with an incremental sync using - * the default cursor. - if the stream requires a user-defined cursor, it will be tested with the - * input cursor in both cases, the input [.getState] will be used as the input state. + * Full Refresh syncs will be tested on all the input streams which support it Incremental + * syncs: - if the stream declares a source-defined cursor, it will be tested with an + * incremental sync using the default cursor. - if the stream requires a user-defined + * cursor, it will be tested with the input cursor in both cases, the input [.getState] will + * be used as the input state. * * @return * @throws Exception @@ -82,18 +85,18 @@ abstract class SourceAcceptanceTest : AbstractSourceConnectorTest() { @get:Throws(Exception::class) protected abstract val state: JsonNode? - /** - * @return a JSON file representing the state file to use when testing incremental syncs - */ + /** @return a JSON file representing the state file to use when testing incremental syncs */ get - /** - * Verify that a spec operation issued to the connector returns a valid spec. - */ + /** Verify that a spec operation issued to the connector returns a valid spec. */ @Test @Throws(Exception::class) fun testGetSpec() { - Assertions.assertEquals(spec, runSpec(), "Expected spec output by integration to be equal to spec provided by test runner") + Assertions.assertEquals( + spec, + runSpec(), + "Expected spec output by integration to be equal to spec provided by test runner" + ) } /** @@ -103,11 +106,16 @@ abstract class SourceAcceptanceTest : AbstractSourceConnectorTest() { @Test @Throws(Exception::class) fun testCheckConnection() { - Assertions.assertEquals(StandardCheckConnectionOutput.Status.SUCCEEDED, runCheck().status, "Expected check connection operation to succeed") + Assertions.assertEquals( + StandardCheckConnectionOutput.Status.SUCCEEDED, + runCheck().status, + "Expected check connection operation to succeed" + ) } // /** - // * Verify that when given invalid credentials, that check connection returns a failed response. + // * Verify that when given invalid credentials, that check connection returns a failed + // response. // * Assume that the {@link TestSource#getFailCheckConfig()} is invalid. // */ // @Test @@ -117,8 +125,8 @@ abstract class SourceAcceptanceTest : AbstractSourceConnectorTest() { // assertEquals(Status.FAILED, output.getOutput().get().getStatus()); // } /** - * Verifies when a discover operation is run on the connector using the given config file, a valid - * catalog is output by the connector. + * Verifies when a discover operation is run on the connector using the given config file, a + * valid catalog is output by the connector. */ @Test @Throws(Exception::class) @@ -129,17 +137,15 @@ abstract class SourceAcceptanceTest : AbstractSourceConnectorTest() { verifyCatalog(discoveredCatalog) } - /** - * Override this method to check the actual catalog. - */ + /** Override this method to check the actual catalog. */ @Throws(Exception::class) protected fun verifyCatalog(catalog: AirbyteCatalog?) { // do nothing by default } /** - * Configuring all streams in the input catalog to full refresh mode, verifies that a read operation - * produces some RECORD messages. + * Configuring all streams in the input catalog to full refresh mode, verifies that a read + * operation produces some RECORD messages. */ @Test @Throws(Exception::class) @@ -152,22 +158,23 @@ abstract class SourceAcceptanceTest : AbstractSourceConnectorTest() { val catalog = withFullRefreshSyncModes(configuredCatalog) val allMessages = runRead(catalog) - Assertions.assertFalse(filterRecords(allMessages).isEmpty(), "Expected a full refresh sync to produce records") + Assertions.assertFalse( + filterRecords(allMessages).isEmpty(), + "Expected a full refresh sync to produce records" + ) assertFullRefreshMessages(allMessages) } - /** - * Override this method to perform more specific assertion on the messages. - */ + /** Override this method to perform more specific assertion on the messages. */ @Throws(Exception::class) protected open fun assertFullRefreshMessages(allMessages: List?) { // do nothing by default } /** - * Configuring all streams in the input catalog to full refresh mode, performs two read operations - * on all streams which support full refresh syncs. It then verifies that the RECORD messages output - * from both were identical. + * Configuring all streams in the input catalog to full refresh mode, performs two read + * operations on all streams which support full refresh syncs. It then verifies that the RECORD + * messages output from both were identical. */ @Test @Throws(Exception::class) @@ -177,36 +184,49 @@ abstract class SourceAcceptanceTest : AbstractSourceConnectorTest() { return } - if (IMAGES_TO_SKIP_IDENTICAL_FULL_REFRESHES.contains(imageName.split(":".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray()[0])) { + if ( + IMAGES_TO_SKIP_IDENTICAL_FULL_REFRESHES.contains( + imageName.split(":".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray()[0] + ) + ) { return } val configuredCatalog = withFullRefreshSyncModes(configuredCatalog) val recordMessagesFirstRun = filterRecords(runRead(configuredCatalog)) val recordMessagesSecondRun = filterRecords(runRead(configuredCatalog)) - // the worker validates the messages, so we just validate the message, so we do not need to validate + // the worker validates the messages, so we just validate the message, so we do not need to + // validate // again (as long as we use the worker, which we will not want to do long term). - Assertions.assertFalse(recordMessagesFirstRun.isEmpty(), "Expected first full refresh to produce records") - Assertions.assertFalse(recordMessagesSecondRun.isEmpty(), "Expected second full refresh to produce records") - - assertSameRecords(recordMessagesFirstRun, recordMessagesSecondRun, "Expected two full refresh syncs to produce the same records") + Assertions.assertFalse( + recordMessagesFirstRun.isEmpty(), + "Expected first full refresh to produce records" + ) + Assertions.assertFalse( + recordMessagesSecondRun.isEmpty(), + "Expected second full refresh to produce records" + ) + + assertSameRecords( + recordMessagesFirstRun, + recordMessagesSecondRun, + "Expected two full refresh syncs to produce the same records" + ) } /** - * This test verifies that all streams in the input catalog which support incremental sync can do so - * correctly. It does this by running two read operations on the connector's Docker image: the first - * takes the configured catalog and config provided to this test as input. It then verifies that the - * sync produced a non-zero number of RECORD and STATE messages. - * + * This test verifies that all streams in the input catalog which support incremental sync can + * do so correctly. It does this by running two read operations on the connector's Docker image: + * the first takes the configured catalog and config provided to this test as input. It then + * verifies that the sync produced a non-zero number of RECORD and STATE messages. * * The second read takes the same catalog and config used in the first test, plus the last STATE - * message output by the first read operation as the input state file. It verifies that no records - * are produced (since we read all records in the first sync). + * message output by the first read operation as the input state file. It verifies that no + * records are produced (since we read all records in the first sync). * - * - * This test is performed only for streams which support incremental. Streams which do not support - * incremental sync are ignored. If no streams in the input catalog support incremental sync, this - * test is skipped. + * This test is performed only for streams which support incremental. Streams which do not + * support incremental sync are ignored. If no streams in the input catalog support incremental + * sync, this test is skipped. */ @Test @Throws(Exception::class) @@ -217,24 +237,40 @@ abstract class SourceAcceptanceTest : AbstractSourceConnectorTest() { val configuredCatalog = withSourceDefinedCursors(configuredCatalog) // only sync incremental streams - configuredCatalog.streams = configuredCatalog.streams.stream().filter { s: ConfiguredAirbyteStream -> s.syncMode == SyncMode.INCREMENTAL }.collect(Collectors.toList()) + configuredCatalog.streams = + configuredCatalog.streams + .stream() + .filter { s: ConfiguredAirbyteStream -> s.syncMode == SyncMode.INCREMENTAL } + .collect(Collectors.toList()) val airbyteMessages = runRead(configuredCatalog, state) val recordMessages = filterRecords(airbyteMessages) - val stateMessages = airbyteMessages + val stateMessages = + airbyteMessages .stream() .filter { m: AirbyteMessage? -> m!!.type == AirbyteMessage.Type.STATE } .map { obj: AirbyteMessage? -> obj!!.state } .collect(Collectors.toList()) - Assertions.assertFalse(recordMessages.isEmpty(), "Expected the first incremental sync to produce records") - Assertions.assertFalse(stateMessages.isEmpty(), "Expected incremental sync to produce STATE messages") + Assertions.assertFalse( + recordMessages.isEmpty(), + "Expected the first incremental sync to produce records" + ) + Assertions.assertFalse( + stateMessages.isEmpty(), + "Expected incremental sync to produce STATE messages" + ) // TODO validate exact records - if (IMAGES_TO_SKIP_SECOND_INCREMENTAL_READ.contains(imageName.split(":".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray()[0])) { + if ( + IMAGES_TO_SKIP_SECOND_INCREMENTAL_READ.contains( + imageName.split(":".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray()[0] + ) + ) { return } - // when we run incremental sync again there should be no new records. Run a sync with the latest + // when we run incremental sync again there should be no new records. Run a sync with the + // latest // state message and assert no records were emitted. var latestState: JsonNode? = null for (stateMessage in stateMessages) { @@ -252,19 +288,19 @@ abstract class SourceAcceptanceTest : AbstractSourceConnectorTest() { assert(Objects.nonNull(latestState)) val secondSyncRecords = filterRecords(runRead(configuredCatalog, latestState)) Assertions.assertTrue( - secondSyncRecords.isEmpty(), - "Expected the second incremental sync to produce no records when given the first sync's output state.") + secondSyncRecords.isEmpty(), + "Expected the second incremental sync to produce no records when given the first sync's output state." + ) } /** * If the source does not support incremental sync, this test is skipped. * - * - * Otherwise, this test runs two syncs: one where all streams provided in the input catalog sync in - * full refresh mode, and another where all the streams which in the input catalog which support - * incremental, sync in incremental mode (streams which don't support incremental sync in full - * refresh mode). Then, the test asserts that the two syncs produced the same RECORD messages. Any - * other type of message is disregarded. + * Otherwise, this test runs two syncs: one where all streams provided in the input catalog sync + * in full refresh mode, and another where all the streams which in the input catalog which + * support incremental, sync in incremental mode (streams which don't support incremental sync + * in full refresh mode). Then, the test asserts that the two syncs produced the same RECORD + * messages. Any other type of message is disregarded. */ @Test @Throws(Exception::class) @@ -282,17 +318,27 @@ abstract class SourceAcceptanceTest : AbstractSourceConnectorTest() { val fullRefreshCatalog = withFullRefreshSyncModes(configuredCatalog) val fullRefreshRecords = filterRecords(runRead(fullRefreshCatalog)) - val emptyStateRecords = filterRecords(runRead(configuredCatalog, Jsons.jsonNode(HashMap()))) - Assertions.assertFalse(fullRefreshRecords.isEmpty(), "Expected a full refresh sync to produce records") - Assertions.assertFalse(emptyStateRecords.isEmpty(), "Expected state records to not be empty") - assertSameRecords(fullRefreshRecords, emptyStateRecords, - "Expected a full refresh sync and incremental sync with no input state to produce identical records") + val emptyStateRecords = + filterRecords(runRead(configuredCatalog, Jsons.jsonNode(HashMap()))) + Assertions.assertFalse( + fullRefreshRecords.isEmpty(), + "Expected a full refresh sync to produce records" + ) + Assertions.assertFalse( + emptyStateRecords.isEmpty(), + "Expected state records to not be empty" + ) + assertSameRecords( + fullRefreshRecords, + emptyStateRecords, + "Expected a full refresh sync and incremental sync with no input state to produce identical records" + ) } /** - * In order to launch a source on Kubernetes in a pod, we need to be able to wrap the entrypoint. - * The source connector must specify its entrypoint in the AIRBYTE_ENTRYPOINT variable. This test - * ensures that the entrypoint environment variable is set. + * In order to launch a source on Kubernetes in a pod, we need to be able to wrap the + * entrypoint. The source connector must specify its entrypoint in the AIRBYTE_ENTRYPOINT + * variable. This test ensures that the entrypoint environment variable is set. */ @Test @Throws(Exception::class) @@ -300,17 +346,25 @@ abstract class SourceAcceptanceTest : AbstractSourceConnectorTest() { checkEntrypointEnvVariable() } - protected fun withSourceDefinedCursors(catalog: ConfiguredAirbyteCatalog): ConfiguredAirbyteCatalog { + protected fun withSourceDefinedCursors( + catalog: ConfiguredAirbyteCatalog + ): ConfiguredAirbyteCatalog { val clone = Jsons.clone(catalog) for (configuredStream in clone.streams) { - if (configuredStream.syncMode == SyncMode.INCREMENTAL && configuredStream.stream.sourceDefinedCursor != null && configuredStream.stream.sourceDefinedCursor) { + if ( + configuredStream.syncMode == SyncMode.INCREMENTAL && + configuredStream.stream.sourceDefinedCursor != null && + configuredStream.stream.sourceDefinedCursor + ) { configuredStream.cursorField = configuredStream.stream.defaultCursorField } } return clone } - protected fun withFullRefreshSyncModes(catalog: ConfiguredAirbyteCatalog): ConfiguredAirbyteCatalog { + protected fun withFullRefreshSyncModes( + catalog: ConfiguredAirbyteCatalog + ): ConfiguredAirbyteCatalog { val clone = Jsons.clone(catalog) for (configuredStream in clone.streams) { if (configuredStream.stream.supportedSyncModes.contains(SyncMode.FULL_REFRESH)) { @@ -342,9 +396,18 @@ abstract class SourceAcceptanceTest : AbstractSourceConnectorTest() { return false } - private fun assertSameRecords(expected: List, actual: List, message: String) { - val prunedExpected = expected.stream().map { m: AirbyteRecordMessage -> this.pruneEmittedAt(m) }.collect(Collectors.toList()) - val prunedActual = actual + private fun assertSameRecords( + expected: List, + actual: List, + message: String + ) { + val prunedExpected = + expected + .stream() + .map { m: AirbyteRecordMessage -> this.pruneEmittedAt(m) } + .collect(Collectors.toList()) + val prunedActual = + actual .stream() .map { m: AirbyteRecordMessage -> this.pruneEmittedAt(m) } .map { m: AirbyteRecordMessage -> this.pruneCdcMetadata(m) } @@ -381,11 +444,14 @@ abstract class SourceAcceptanceTest : AbstractSourceConnectorTest() { private val LOGGER: Logger = LoggerFactory.getLogger(SourceAcceptanceTest::class.java) - protected fun filterRecords(messages: Collection?): List { - return messages!!.stream() - .filter { m: AirbyteMessage? -> m!!.type == AirbyteMessage.Type.RECORD } - .map { obj: AirbyteMessage? -> obj!!.record } - .collect(Collectors.toList()) + protected fun filterRecords( + messages: Collection? + ): List { + return messages!! + .stream() + .filter { m: AirbyteMessage? -> m!!.type == AirbyteMessage.Type.RECORD } + .map { obj: AirbyteMessage? -> obj!!.record } + .collect(Collectors.toList()) } } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/TestDataHolder.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/TestDataHolder.kt index 57d6be2758191..c14e9f7e33a47 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/TestDataHolder.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/TestDataHolder.kt @@ -6,23 +6,27 @@ package io.airbyte.cdk.integrations.standardtest.source import io.airbyte.protocol.models.JsonSchemaType import java.util.* -class TestDataHolder internal constructor(val sourceType: String?, - val airbyteType: JsonSchemaType?, - val values: List, - val expectedValues: MutableList, - private val createTablePatternSql: String, - private val insertPatternSql: String, - private val fullSourceDataType: String?) { +class TestDataHolder +internal constructor( + val sourceType: String?, + val airbyteType: JsonSchemaType, + val values: List, + val expectedValues: MutableList, + private val createTablePatternSql: String, + private val insertPatternSql: String, + private val fullSourceDataType: String? +) { var nameSpace: String? = null private var testNumber: Long = 0 private var idColumnName: String? = null private var testColumnName: String? = null - private var declarationLocation: Array + var declarationLocation: String = "" + private set class TestDataHolderBuilder internal constructor() { private var sourceType: String? = null - private var airbyteType: JsonSchemaType? = null + private lateinit var airbyteType: JsonSchemaType private val values: MutableList = ArrayList() private val expectedValues: MutableList = ArrayList() private var createTablePatternSql: String @@ -35,10 +39,11 @@ class TestDataHolder internal constructor(val sourceType: String?, } /** - * The name of the source data type. Duplicates by name will be tested independently from each - * others. Note that this name will be used for connector setup and table creation. If source syntax - * requires more details (E.g. "varchar" type requires length "varchar(50)"), you can additionally - * set custom data type syntax by [TestDataHolderBuilder.fullSourceDataType] method. + * The name of the source data type. Duplicates by name will be tested independently from + * each others. Note that this name will be used for connector setup and table creation. If + * source syntax requires more details (E.g. "varchar" type requires length "varchar(50)"), + * you can additionally set custom data type syntax by + * [TestDataHolderBuilder.fullSourceDataType] method. * * @param sourceType source data type name * @return builder @@ -56,16 +61,16 @@ class TestDataHolder internal constructor(val sourceType: String?, * @param airbyteType Airbyte data type * @return builder */ - fun airbyteType(airbyteType: JsonSchemaType?): TestDataHolderBuilder { + fun airbyteType(airbyteType: JsonSchemaType): TestDataHolderBuilder { this.airbyteType = airbyteType return this } /** - * Set custom the create table script pattern. Use it if you source uses untypical table creation - * sql. Default patter described [.DEFAULT_CREATE_TABLE_SQL] Note! The patter should contain - * four String place holders for the: - namespace.table name (as one placeholder together) - id - * column name - test column name - test column data type + * Set custom the create table script pattern. Use it if you source uses untypical table + * creation sql. Default patter described [.DEFAULT_CREATE_TABLE_SQL] Note! The patter + * should contain four String place holders for the: - namespace.table name (as one + * placeholder together) - id column name - test column name - test column data type * * @param createTablePatternSql creation table sql pattern * @return builder @@ -76,9 +81,9 @@ class TestDataHolder internal constructor(val sourceType: String?, } /** - * Set custom the insert record script pattern. Use it if you source uses untypical insert record - * sql. Default patter described [.DEFAULT_INSERT_SQL] Note! The patter should contains two - * String place holders for the table name and value. + * Set custom the insert record script pattern. Use it if you source uses untypical insert + * record sql. Default patter described [.DEFAULT_INSERT_SQL] Note! The patter should + * contains two String place holders for the table name and value. * * @param insertPatternSql creation table sql pattern * @return builder @@ -89,8 +94,8 @@ class TestDataHolder internal constructor(val sourceType: String?, } /** - * Allows to set extended data type for the table creation. E.g. The "varchar" type requires in - * MySQL requires length. In this case fullSourceDataType will be "varchar(50)". + * Allows to set extended data type for the table creation. E.g. The "varchar" type requires + * in MySQL requires length. In this case fullSourceDataType will be "varchar(50)". * * @param fullSourceDataType actual string for the column data type description * @return builder @@ -101,21 +106,21 @@ class TestDataHolder internal constructor(val sourceType: String?, } /** - * Adds value(s) to the scope of a corresponding test. The values will be inserted into the created - * table. Note! The value will be inserted into the insert script without any transformations. Make - * sure that the value is in line with the source syntax. + * Adds value(s) to the scope of a corresponding test. The values will be inserted into the + * created table. Note! The value will be inserted into the insert script without any + * transformations. Make sure that the value is in line with the source syntax. * * @param insertValue test value * @return builder */ - fun addInsertValues(vararg insertValue: String?): TestDataHolderBuilder { + fun addInsertValues(vararg insertValue: String): TestDataHolderBuilder { values.addAll(Arrays.asList(*insertValue)) return this } /** - * Adds expected value(s) to the test scope. If you add at least one value, it will check that all - * values are provided by corresponding streamer. + * Adds expected value(s) to the test scope. If you add at least one value, it will check + * that all values are provided by corresponding streamer. * * @param expectedValue value which should be provided by a streamer * @return builder @@ -126,8 +131,8 @@ class TestDataHolder internal constructor(val sourceType: String?, } /** - * Add NULL value to the expected value list. If you need to add only one value and it's NULL, you - * have to use this method instead of [.addExpectedValues] + * Add NULL value to the expected value list. If you need to add only one value and it's + * NULL, you have to use this method instead of [.addExpectedValues] * * @return builder */ @@ -137,7 +142,15 @@ class TestDataHolder internal constructor(val sourceType: String?, } fun build(): TestDataHolder { - return TestDataHolder(sourceType, airbyteType, values, expectedValues, createTablePatternSql, insertPatternSql, fullSourceDataType) + return TestDataHolder( + sourceType, + airbyteType, + values, + expectedValues, + createTablePatternSql, + insertPatternSql, + fullSourceDataType + ) } } @@ -154,19 +167,21 @@ class TestDataHolder internal constructor(val sourceType: String?, } val nameWithTestPrefix: String - get() =// source type may include space (e.g. "character varying") - nameSpace + "_" + testNumber + "_" + sourceType!!.replace("\\s".toRegex(), "_") + get() = // source type may include space (e.g. "character varying") + nameSpace + "_" + testNumber + "_" + sourceType!!.replace("\\s".toRegex(), "_") val createSqlQuery: String - get() = String.format(createTablePatternSql, (if (nameSpace != null) "$nameSpace." else "") + this.nameWithTestPrefix, idColumnName, testColumnName, - fullSourceDataType) + get() = + String.format( + createTablePatternSql, + (if (nameSpace != null) "$nameSpace." else "") + this.nameWithTestPrefix, + idColumnName, + testColumnName, + fullSourceDataType + ) fun setDeclarationLocation(declarationLocation: Array) { - this.declarationLocation = declarationLocation - } - - fun getDeclarationLocation(): String { - return Arrays.asList(*declarationLocation).subList(2, 3).toString() + this.declarationLocation = Arrays.asList(*declarationLocation).subList(2, 3).toString() } val insertSqlQueries: List @@ -174,13 +189,21 @@ class TestDataHolder internal constructor(val sourceType: String?, val insertSqls: MutableList = ArrayList() var rowId = 1 for (value in values) { - insertSqls.add(String.format(insertPatternSql, (if (nameSpace != null) "$nameSpace." else "") + this.nameWithTestPrefix, rowId++, value)) + insertSqls.add( + String.format( + insertPatternSql, + (if (nameSpace != null) "$nameSpace." else "") + this.nameWithTestPrefix, + rowId++, + value + ) + ) } return insertSqls } companion object { - private const val DEFAULT_CREATE_TABLE_SQL = "CREATE TABLE %1\$s(%2\$s INTEGER PRIMARY KEY, %3\$s %4\$s)" + private const val DEFAULT_CREATE_TABLE_SQL = + "CREATE TABLE %1\$s(%2\$s INTEGER PRIMARY KEY, %3\$s %4\$s)" private const val DEFAULT_INSERT_SQL = "INSERT INTO %1\$s VALUES (%2\$s, %3\$s)" /** diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/TestEnvConfigs.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/TestEnvConfigs.kt index 9fa926a2539bf..73c05019e02a5 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/TestEnvConfigs.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/TestEnvConfigs.kt @@ -7,12 +7,12 @@ import com.google.common.base.Preconditions import io.airbyte.commons.lang.Exceptions import io.airbyte.commons.map.MoreMaps import io.airbyte.commons.version.AirbyteVersion -import org.slf4j.Logger -import org.slf4j.LoggerFactory import java.util.* import java.util.function.Function import java.util.function.Supplier import java.util.stream.Collectors +import org.slf4j.Logger +import org.slf4j.LoggerFactory /** * This class passes environment variable to the DockerProcessFactory that runs the source in the @@ -43,36 +43,61 @@ class TestEnvConfigs private constructor(envMap: Map) { get() = AirbyteVersion(getEnsureEnv(AIRBYTE_VERSION)) val deploymentMode: DeploymentMode - get() = getEnvOrDefault(DEPLOYMENT_MODE, DeploymentMode.OSS) { s: String -> - try { - return@getEnvOrDefault DeploymentMode.valueOf(s) - } catch (e: IllegalArgumentException) { - LOGGER.info(s + " not recognized, defaulting to " + DeploymentMode.OSS) - return@getEnvOrDefault DeploymentMode.OSS + get() = + getEnvOrDefault(DEPLOYMENT_MODE, DeploymentMode.OSS) { s: String -> + try { + return@getEnvOrDefault DeploymentMode.valueOf(s) + } catch (e: IllegalArgumentException) { + LOGGER.info(s + " not recognized, defaulting to " + DeploymentMode.OSS) + return@getEnvOrDefault DeploymentMode.OSS + } } - } val workerEnvironment: WorkerEnvironment - get() = getEnvOrDefault(WORKER_ENVIRONMENT, WorkerEnvironment.DOCKER) { s: String -> WorkerEnvironment.valueOf(s.uppercase(Locale.getDefault())) } + get() = + getEnvOrDefault(WORKER_ENVIRONMENT, WorkerEnvironment.DOCKER) { s: String -> + WorkerEnvironment.valueOf(s.uppercase(Locale.getDefault())) + } val jobDefaultEnvMap: Map /** * There are two types of environment variables available to the job container: * - * * Exclusive variables prefixed with JOB_DEFAULT_ENV_PREFIX - * * Shared variables defined in JOB_SHARED_ENVS - * + * * Exclusive variables prefixed with JOB_DEFAULT_ENV_PREFIX + * * Shared variables defined in JOB_SHARED_ENVS */ get() { - val jobPrefixedEnvMap = getAllEnvKeys.get().stream() + val jobPrefixedEnvMap = + getAllEnvKeys + .get() + .stream() .filter { key: String -> key.startsWith(JOB_DEFAULT_ENV_PREFIX) } - .collect(Collectors.toMap(Function { key: String -> key.replace(JOB_DEFAULT_ENV_PREFIX, "") }, getEnv)) + .collect( + Collectors.toMap( + Function { key: String -> key.replace(JOB_DEFAULT_ENV_PREFIX, "") }, + getEnv + ) + ) // This method assumes that these shared env variables are not critical to the execution // of the jobs, and only serve as metadata. So any exception is swallowed and default to // an empty string. Change this logic if this assumption no longer holds. - val jobSharedEnvMap = JOB_SHARED_ENVS.entries.stream().collect(Collectors.toMap( - Function { obj: Map.Entry> -> obj.key }, - Function { entry: Map.Entry> -> Exceptions.swallowWithDefault({ Objects.requireNonNullElse(entry.value.apply(this), "") }, "") })) + val jobSharedEnvMap = + JOB_SHARED_ENVS.entries + .stream() + .collect( + Collectors.toMap( + Function { obj: Map.Entry> -> + obj.key + }, + Function { entry: Map.Entry> + -> + Exceptions.swallowWithDefault( + { Objects.requireNonNullElse(entry.value.apply(this), "") }, + "" + ) + } + ) + ) return MoreMaps.merge(jobPrefixedEnvMap, jobSharedEnvMap) } @@ -80,12 +105,21 @@ class TestEnvConfigs private constructor(envMap: Map) { return getEnvOrDefault(key, defaultValue, parser, false) } - fun getEnvOrDefault(key: String, defaultValue: T, parser: Function, isSecret: Boolean): T { + fun getEnvOrDefault( + key: String, + defaultValue: T, + parser: Function, + isSecret: Boolean + ): T { val value = getEnv.apply(key) if (value != null && !value.isEmpty()) { return parser.apply(value) } else { - LOGGER.info("Using default value for environment variable {}: '{}'", key, if (isSecret) "*****" else defaultValue) + LOGGER.info( + "Using default value for environment variable {}: '{}'", + key, + if (isSecret) "*****" else defaultValue + ) return defaultValue } } @@ -111,10 +145,16 @@ class TestEnvConfigs private constructor(envMap: Map) { const val DEPLOYMENT_MODE: String = "DEPLOYMENT_MODE" const val JOB_DEFAULT_ENV_PREFIX: String = "JOB_DEFAULT_ENV_" - val JOB_SHARED_ENVS: Map> = java.util.Map.of( - AIRBYTE_VERSION, Function { instance: TestEnvConfigs -> instance.airbyteVersion.serialize() }, - AIRBYTE_ROLE, Function { obj: TestEnvConfigs -> obj.airbyteRole }, - DEPLOYMENT_MODE, Function { instance: TestEnvConfigs -> instance.deploymentMode.name }, - WORKER_ENVIRONMENT, Function { instance: TestEnvConfigs -> instance.workerEnvironment.name }) + val JOB_SHARED_ENVS: Map> = + java.util.Map.of( + AIRBYTE_VERSION, + Function { instance: TestEnvConfigs -> instance.airbyteVersion.serialize() }, + AIRBYTE_ROLE, + Function { obj: TestEnvConfigs -> obj.airbyteRole }, + DEPLOYMENT_MODE, + Function { instance: TestEnvConfigs -> instance.deploymentMode.name }, + WORKER_ENVIRONMENT, + Function { instance: TestEnvConfigs -> instance.workerEnvironment.name } + ) } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/TestPythonSourceMain.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/TestPythonSourceMain.kt index 357e8aa2a8e93..9a9c0b90ee52e 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/TestPythonSourceMain.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/TestPythonSourceMain.kt @@ -14,15 +14,15 @@ import net.sourceforge.argparse4j.inf.Namespace object TestPythonSourceMain { @JvmStatic fun main(args: Array) { - val parser = ArgumentParsers.newFor(TestPythonSourceMain::class.java.name).build() + val parser = + ArgumentParsers.newFor(TestPythonSourceMain::class.java.name) + .build() .defaultHelp(true) .description("Run standard source tests") - parser.addArgument("--imageName") - .help("Name of the integration image") + parser.addArgument("--imageName").help("Name of the integration image") - parser.addArgument("--pythonContainerName") - .help("Name of the python integration image") + parser.addArgument("--pythonContainerName").help("Name of the python integration image") var ns: Namespace? = null try { diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/TestRunner.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/TestRunner.kt index 79ca70620cfde..c28e8c07b99e2 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/TestRunner.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/TestRunner.kt @@ -3,16 +3,17 @@ */ package io.airbyte.cdk.integrations.standardtest.source +import java.io.PrintWriter +import java.nio.charset.StandardCharsets import org.junit.platform.engine.discovery.DiscoverySelectors import org.junit.platform.launcher.core.LauncherDiscoveryRequestBuilder import org.junit.platform.launcher.core.LauncherFactory import org.junit.platform.launcher.listeners.SummaryGeneratingListener -import java.io.PrintWriter -import java.nio.charset.StandardCharsets object TestRunner { fun runTestClass(testClass: Class<*>?) { - val request = LauncherDiscoveryRequestBuilder.request() + val request = + LauncherDiscoveryRequestBuilder.request() .selectors(DiscoverySelectors.selectClass(testClass)) .build() @@ -29,8 +30,9 @@ object TestRunner { if (listener.summary.testsFailedCount > 0) { println( - "There are failing tests. See https://docs.airbyte.io/contributing-to-airbyte/building-new-connector/standard-source-tests " + - "for more information about the standard source test suite.") + "There are failing tests. See https://docs.airbyte.io/contributing-to-airbyte/building-new-connector/standard-source-tests " + + "for more information about the standard source test suite." + ) System.exit(1) } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/fs/ExecutableTestSource.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/fs/ExecutableTestSource.kt index 1d1b538e7646d..6b79f0863bc56 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/fs/ExecutableTestSource.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/fs/ExecutableTestSource.kt @@ -17,26 +17,41 @@ import java.nio.file.Path * also add the ability to execute arbitrary scripts in the next version. */ class ExecutableTestSource : SourceAcceptanceTest() { - class TestConfig(val imageName: String, val specPath: Path, val configPath: Path, val catalogPath: Path, val statePath: Path?) + class TestConfig( + val imageName: String, + val specPath: Path, + val configPath: Path, + val catalogPath: Path, + val statePath: Path? + ) override val spec: ConnectorSpecification - get() = Jsons.deserialize(IOs.readFile(TEST_CONFIG!!.specPath), ConnectorSpecification::class.java) + get() = + Jsons.deserialize( + IOs.readFile(TEST_CONFIG!!.specPath), + ConnectorSpecification::class.java + ) - override val imageName: String? + override val imageName: String get() = TEST_CONFIG!!.imageName override val config: JsonNode? get() = Jsons.deserialize(IOs.readFile(TEST_CONFIG!!.configPath)) override val configuredCatalog: ConfiguredAirbyteCatalog - get() = Jsons.deserialize(IOs.readFile(TEST_CONFIG!!.catalogPath), ConfiguredAirbyteCatalog::class.java) + get() = + Jsons.deserialize( + IOs.readFile(TEST_CONFIG!!.catalogPath), + ConfiguredAirbyteCatalog::class.java + ) override val state: JsonNode? - get() = if (TEST_CONFIG!!.statePath != null) { - Jsons.deserialize(IOs.readFile(TEST_CONFIG!!.statePath)) - } else { - Jsons.deserialize("{}") - } + get() = + if (TEST_CONFIG!!.statePath != null) { + Jsons.deserialize(IOs.readFile(TEST_CONFIG!!.statePath)) + } else { + Jsons.deserialize("{}") + } @Throws(Exception::class) override fun setupEnvironment(environment: TestDestinationEnv?) { diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/fs/TestSourceMain.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/fs/TestSourceMain.kt index c9898037c89a4..b1552e38d7c2c 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/fs/TestSourceMain.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/fs/TestSourceMain.kt @@ -4,12 +4,12 @@ package io.airbyte.cdk.integrations.standardtest.source.fs import io.airbyte.cdk.integrations.standardtest.source.TestRunner +import java.nio.file.Path import net.sourceforge.argparse4j.ArgumentParsers import net.sourceforge.argparse4j.inf.ArgumentParserException import net.sourceforge.argparse4j.inf.Namespace import org.slf4j.Logger import org.slf4j.LoggerFactory -import java.nio.file.Path /** * Parse command line arguments and inject them into the test class before running the test. Then @@ -20,29 +20,27 @@ object TestSourceMain { @JvmStatic fun main(args: Array) { - val parser = ArgumentParsers.newFor(TestSourceMain::class.java.name).build() + val parser = + ArgumentParsers.newFor(TestSourceMain::class.java.name) + .build() .defaultHelp(true) .description("Run standard source tests") - parser.addArgument("--imageName") - .required(true) - .help("Name of the source connector image e.g: airbyte/source-mailchimp") + parser + .addArgument("--imageName") + .required(true) + .help("Name of the source connector image e.g: airbyte/source-mailchimp") - parser.addArgument("--spec") - .required(true) - .help("Path to file that contains spec json") + parser.addArgument("--spec").required(true).help("Path to file that contains spec json") - parser.addArgument("--config") - .required(true) - .help("Path to file that contains config json") + parser.addArgument("--config").required(true).help("Path to file that contains config json") - parser.addArgument("--catalog") - .required(true) - .help("Path to file that contains catalog json") + parser + .addArgument("--catalog") + .required(true) + .help("Path to file that contains catalog json") - parser.addArgument("--state") - .required(false) - .help("Path to the file containing state") + parser.addArgument("--state").required(false).help("Path to the file containing state") var ns: Namespace? = null try { @@ -58,12 +56,14 @@ object TestSourceMain { val catalogFile = ns.getString("catalog") val stateFile = ns.getString("state") - ExecutableTestSource.Companion.TEST_CONFIG = ExecutableTestSource.TestConfig( + ExecutableTestSource.Companion.TEST_CONFIG = + ExecutableTestSource.TestConfig( imageName, Path.of(specFile), Path.of(configFile), Path.of(catalogFile), - if (stateFile != null) Path.of(stateFile) else null) + if (stateFile != null) Path.of(stateFile) else null + ) TestRunner.runTestClass(ExecutableTestSource::class.java) } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/performancetest/AbstractSourceBasePerformanceTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/performancetest/AbstractSourceBasePerformanceTest.kt index a89538094b006..6f94ccff21e84 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/performancetest/AbstractSourceBasePerformanceTest.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/performancetest/AbstractSourceBasePerformanceTest.kt @@ -11,27 +11,25 @@ import io.airbyte.cdk.integrations.standardtest.source.TestDestinationEnv * tests. */ abstract class AbstractSourceBasePerformanceTest : AbstractSourceConnectorTest() { + /** + * The column name will be used for a test column in the test tables. Override it if default + * name is not valid for your source. + */ + protected val testColumnName + get() = TEST_COLUMN_NAME + /** + * The stream name template will be used for a test tables. Override it if default name is not + * valid for your source. + */ + protected val testStreamNameTemplate + get() = TEST_STREAM_NAME_TEMPLATE @Throws(Exception::class) override fun setupEnvironment(environment: TestDestinationEnv?) { // DO NOTHING. Mandatory to override. DB will be setup as part of each test } companion object { - protected val testColumnName: String = "test_column" - /** - * The column name will be used for a test column in the test tables. Override it if default name is - * not valid for your source. - * - * @return Test column name - */ - get() = Companion.field - protected val testStreamNameTemplate: String = "test_%S" - /** - * The stream name template will be used for a test tables. Override it if default name is not valid - * for your source. - * - * @return Test steam name template - */ - get() = Companion.field + protected const val TEST_COLUMN_NAME: String = "test_column" + protected const val TEST_STREAM_NAME_TEMPLATE: String = "test_%S" } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/performancetest/AbstractSourceFillDbWithTestData.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/performancetest/AbstractSourceFillDbWithTestData.kt index 5fdf8418dd767..37e88b9e02693 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/performancetest/AbstractSourceFillDbWithTestData.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/performancetest/AbstractSourceFillDbWithTestData.kt @@ -4,43 +4,46 @@ package io.airbyte.cdk.integrations.standardtest.source.performancetest import io.airbyte.cdk.db.Database +import java.util.* +import java.util.stream.Stream import org.jooq.DSLContext import org.junit.jupiter.api.Disabled +import org.junit.jupiter.api.TestInstance import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.Arguments import org.junit.jupiter.params.provider.MethodSource import org.slf4j.Logger import org.slf4j.LoggerFactory -import java.util.* -import java.util.stream.Stream -/** - * This abstract class contains common methods for Fill Db scripts. - */ +/** This abstract class contains common methods for Fill Db scripts. */ +@TestInstance(TestInstance.Lifecycle.PER_CLASS) abstract class AbstractSourceFillDbWithTestData : AbstractSourceBasePerformanceTest() { /** - * Setup the test database. All tables and data described in the registered tests will be put there. + * Setup the test database. All tables and data described in the registered tests will be put + * there. * * @return configured test database - * @throws Exception - might throw any exception during initialization. + * @throws Exception + * - might throw any exception during initialization. */ - @Throws(Exception::class) - protected abstract fun setupDatabase(dbName: String?): Database + @Throws(Exception::class) protected abstract fun setupDatabase(dbName: String?): Database /** - * The test added test data to a new DB. 1. Set DB creds in static variables above 2. Set desired - * number for streams, coolumns and records 3. Run the test + * The test added test data to a new DB. 1. Set DB creds in static variables above 2. Set + * desired number for streams, coolumns and records 3. Run the test */ @Disabled @ParameterizedTest @MethodSource("provideParameters") @Throws(Exception::class) - fun addTestData(dbName: String?, - schemaName: String?, - numberOfDummyRecords: Int, - numberOfBatches: Int, - numberOfColumns: Int, - numberOfStreams: Int) { + fun addTestData( + dbName: String?, + schemaName: String?, + numberOfDummyRecords: Int, + numberOfBatches: Int, + numberOfColumns: Int, + numberOfStreams: Int + ) { val database = setupDatabase(dbName) database.query { ctx: DSLContext? -> @@ -49,9 +52,13 @@ abstract class AbstractSourceFillDbWithTestData : AbstractSourceBasePerformanceT ctx!!.fetch(prepareCreateTableQuery(schemaName, numberOfColumns, currentTableName)) for (i in 0 until numberOfBatches) { - val insertQueryTemplate = prepareInsertQueryTemplate(schemaName, i, + val insertQueryTemplate = + prepareInsertQueryTemplate( + schemaName, + i, numberOfColumns, - numberOfDummyRecords) + numberOfDummyRecords + ) ctx.fetch(String.format(insertQueryTemplate, currentTableName)) } @@ -65,31 +72,41 @@ abstract class AbstractSourceFillDbWithTestData : AbstractSourceBasePerformanceT * This is a data provider for fill DB script,, Each argument's group would be ran as a separate * test. Set the "testArgs" in test class of your DB in @BeforeTest method. * - * 1st arg - a name of DB that will be used in jdbc connection string. 2nd arg - a schemaName that - * will be ised as a NameSpace in Configured Airbyte Catalog. 3rd arg - a number of expected records - * retrieved in each stream. 4th arg - a number of columns in each stream\table that will be use for - * Airbyte Cataloq configuration 5th arg - a number of streams to read in configured airbyte - * Catalog. Each stream\table in DB should be names like "test_0", "test_1",..., test_n. + * 1st arg - a name of DB that will be used in jdbc connection string. 2nd arg - a schemaName + * that will be ised as a NameSpace in Configured Airbyte Catalog. 3rd arg - a number of + * expected records retrieved in each stream. 4th arg - a number of columns in each stream\table + * that will be use for Airbyte Cataloq configuration 5th arg - a number of streams to read in + * configured airbyte Catalog. Each stream\table in DB should be names like "test_0", + * "test_1",..., test_n. * * Stream.of( Arguments.of("your_db_name", "your_schema_name", 100, 2, 240, 1000) ); */ protected abstract fun provideParameters(): Stream? - protected fun prepareCreateTableQuery(dbSchemaName: String?, - numberOfColumns: Int, - currentTableName: String?): String { + protected fun prepareCreateTableQuery( + dbSchemaName: String?, + numberOfColumns: Int, + currentTableName: String? + ): String { val sj = StringJoiner(",") for (i in 0 until numberOfColumns) { sj.add(String.format(" %s%s %s", testColumnName, i, TEST_DB_FIELD_TYPE)) } - return String.format(CREATE_DB_TABLE_TEMPLATE, dbSchemaName, currentTableName, sj.toString()) + return String.format( + CREATE_DB_TABLE_TEMPLATE, + dbSchemaName, + currentTableName, + sj.toString() + ) } - protected fun prepareInsertQueryTemplate(dbSchemaName: String?, - batchNumber: Int, - numberOfColumns: Int, - recordsNumber: Int): String { + protected fun prepareInsertQueryTemplate( + dbSchemaName: String?, + batchNumber: Int, + numberOfColumns: Int, + recordsNumber: Int + ): String { val fieldsNames = StringJoiner(",") fieldsNames.add("id") @@ -106,21 +123,32 @@ abstract class AbstractSourceFillDbWithTestData : AbstractSourceBasePerformanceT val batchMessages = batchNumber * 100 for (currentRecordNumber in batchMessages until recordsNumber + batchMessages) { - insertGroupValuesJoiner - .add("(" + baseInsertQuery.toString() - .replace("id_placeholder".toRegex(), currentRecordNumber.toString()) + ")") + insertGroupValuesJoiner.add( + "(" + + baseInsertQuery + .toString() + .replace("id_placeholder".toRegex(), currentRecordNumber.toString()) + + ")" + ) } - return String.format(INSERT_INTO_DB_TABLE_QUERY_TEMPLATE, dbSchemaName, "%s", fieldsNames.toString(), - insertGroupValuesJoiner.toString()) + return String.format( + INSERT_INTO_DB_TABLE_QUERY_TEMPLATE, + dbSchemaName, + "%s", + fieldsNames.toString(), + insertGroupValuesJoiner.toString() + ) } companion object { - private const val CREATE_DB_TABLE_TEMPLATE = "CREATE TABLE %s.%s(id INTEGER PRIMARY KEY, %s)" + private const val CREATE_DB_TABLE_TEMPLATE = + "CREATE TABLE %s.%s(id INTEGER PRIMARY KEY, %s)" private const val INSERT_INTO_DB_TABLE_QUERY_TEMPLATE = "INSERT INTO %s.%s (%s) VALUES %s" private const val TEST_DB_FIELD_TYPE = "varchar(10)" - protected val c: Logger = LoggerFactory.getLogger(AbstractSourceFillDbWithTestData::class.java) + protected val c: Logger = + LoggerFactory.getLogger(AbstractSourceFillDbWithTestData::class.java) private const val TEST_VALUE_TEMPLATE_POSTGRES = "\'Value id_placeholder\'" } } diff --git a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/performancetest/AbstractSourcePerformanceTest.kt b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/performancetest/AbstractSourcePerformanceTest.kt index cde09b7315210..4980d3ced7cb3 100644 --- a/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/performancetest/AbstractSourcePerformanceTest.kt +++ b/airbyte-cdk/java/airbyte-cdk/db-sources/src/testFixtures/kotlin/io/airbyte/cdk/integrations/standardtest/source/performancetest/AbstractSourcePerformanceTest.kt @@ -9,44 +9,51 @@ import io.airbyte.cdk.integrations.standardtest.source.TestDestinationEnv import io.airbyte.protocol.models.Field import io.airbyte.protocol.models.JsonSchemaType import io.airbyte.protocol.models.v0.* +import java.util.function.Function +import java.util.stream.Collectors +import java.util.stream.Stream import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.TestInstance import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.Arguments import org.junit.jupiter.params.provider.MethodSource import org.slf4j.Logger import org.slf4j.LoggerFactory -import java.util.function.Function -import java.util.stream.Collectors -import java.util.stream.Stream -/** - * This abstract class contains common methods for Performance tests. - */ +/** This abstract class contains common methods for Performance tests. */ +@TestInstance(TestInstance.Lifecycle.PER_CLASS) abstract class AbstractSourcePerformanceTest : AbstractSourceBasePerformanceTest() { override var config: JsonNode? = null + /** + * The column name will be used for a PK column in the test tables. Override it if default name + * is not valid for your source. + */ + protected val idColumnName: String = "id" /** - * Setup the test database. All tables and data described in the registered tests will be put there. + * Setup the test database. All tables and data described in the registered tests will be put + * there. * - * @throws Exception - might throw any exception during initialization. + * @throws Exception + * - might throw any exception during initialization. */ - @Throws(Exception::class) - protected abstract fun setupDatabase(dbName: String?) + @Throws(Exception::class) protected abstract fun setupDatabase(dbName: String?) override fun tearDown(testEnv: TestDestinationEnv?) {} /** - * This is a data provider for performance tests, Each argument's group would be ran as a separate - * test. Set the "testArgs" in test class of your DB in @BeforeTest method. + * This is a data provider for performance tests, Each argument's group would be ran as a + * separate test. Set the "testArgs" in test class of your DB in @BeforeTest method. * - * 1st arg - a name of DB that will be used in jdbc connection string. 2nd arg - a schemaName that - * will be ised as a NameSpace in Configured Airbyte Catalog. 3rd arg - a number of expected records - * retrieved in each stream. 4th arg - a number of columns in each stream\table that will be used - * for Airbyte Cataloq configuration 5th arg - a number of streams to read in configured airbyte - * Catalog. Each stream\table in DB should be names like "test_0", "test_1",..., test_n. + * 1st arg - a name of DB that will be used in jdbc connection string. 2nd arg - a schemaName + * that will be ised as a NameSpace in Configured Airbyte Catalog. 3rd arg - a number of + * expected records retrieved in each stream. 4th arg - a number of columns in each stream\table + * that will be used for Airbyte Cataloq configuration 5th arg - a number of streams to read in + * configured airbyte Catalog. Each stream\table in DB should be names like "test_0", + * "test_1",..., test_n. * - * Example: Stream.of( Arguments.of("test1000tables240columns200recordsDb", "dbo", 200, 240, 1000), - * Arguments.of("test5000tables240columns200recordsDb", "dbo", 200, 240, 1000), + * Example: Stream.of( Arguments.of("test1000tables240columns200recordsDb", "dbo", 200, 240, + * 1000), Arguments.of("test5000tables240columns200recordsDb", "dbo", 200, 240, 1000), * Arguments.of("newregular25tables50000records", "dbo", 50052, 8, 25), * Arguments.of("newsmall1000tableswith10000rows", "dbo", 10011, 8, 1000) ); */ @@ -55,26 +62,36 @@ abstract class AbstractSourcePerformanceTest : AbstractSourceBasePerformanceTest @ParameterizedTest @MethodSource("provideParameters") @Throws(Exception::class) - fun testPerformance(dbName: String?, - schemaName: String?, - numberOfDummyRecords: Int, - numberOfColumns: Int, - numberOfStreams: Int) { + fun testPerformance( + dbName: String?, + schemaName: String?, + numberOfDummyRecords: Int, + numberOfColumns: Int, + numberOfStreams: Int + ) { setupDatabase(dbName) - val catalog = getConfiguredCatalog(schemaName, numberOfStreams, - numberOfColumns) - val mapOfExpectedRecordsCount = prepareMapWithExpectedRecords( - numberOfStreams, numberOfDummyRecords) - val checkStatusMap = runReadVerifyNumberOfReceivedMsgs(catalog, null, - mapOfExpectedRecordsCount) + val catalog = getConfiguredCatalog(schemaName, numberOfStreams, numberOfColumns) + val mapOfExpectedRecordsCount = + prepareMapWithExpectedRecords(numberOfStreams, numberOfDummyRecords) + val checkStatusMap = + runReadVerifyNumberOfReceivedMsgs(catalog, null, mapOfExpectedRecordsCount) validateNumberOfReceivedMsgs(checkStatusMap) } protected fun validateNumberOfReceivedMsgs(checkStatusMap: Map?) { // Iterate through all streams map and check for streams where - val failedStreamsMap = checkStatusMap!!.entries.stream() - .filter { el: Map.Entry -> el.value != 0 }.collect(Collectors.toMap(Function { obj: Map.Entry -> obj.key }, Function { obj: Map.Entry -> obj.value })) + val failedStreamsMap = + checkStatusMap!! + .entries + .stream() + .filter { el: Map.Entry -> el.value != 0 } + .collect( + Collectors.toMap( + Function { obj: Map.Entry -> obj.key }, + Function { obj: Map.Entry -> obj.value } + ) + ) if (!failedStreamsMap.isEmpty()) { Assertions.fail("Non all messages were delivered. $failedStreamsMap") @@ -82,9 +99,11 @@ abstract class AbstractSourcePerformanceTest : AbstractSourceBasePerformanceTest c.info("Finished all checks, no issues found for {} of streams", checkStatusMap.size) } - protected fun prepareMapWithExpectedRecords(streamNumber: Int, - expectedRecordsNumberInEachStream: Int): Map { - val resultMap: MutableMap = HashMap() // streamName&expected records in stream + protected fun prepareMapWithExpectedRecords( + streamNumber: Int, + expectedRecordsNumberInEachStream: Int + ): MutableMap { + val resultMap: MutableMap = HashMap() // streamName&expected records in stream for (currentStream in 0 until streamNumber) { val streamName = String.format(testStreamNameTemplate, currentStream) @@ -98,9 +117,11 @@ abstract class AbstractSourcePerformanceTest : AbstractSourceBasePerformanceTest * * @return configured catalog */ - protected fun getConfiguredCatalog(nameSpace: String?, - numberOfStreams: Int, - numberOfColumns: Int): ConfiguredAirbyteCatalog { + protected fun getConfiguredCatalog( + nameSpace: String?, + numberOfStreams: Int, + numberOfColumns: Int + ): ConfiguredAirbyteCatalog { val streams: MutableList = ArrayList() for (currentStream in 0 until numberOfStreams) { @@ -113,15 +134,24 @@ abstract class AbstractSourcePerformanceTest : AbstractSourceBasePerformanceTest fields.add(Field.of(testColumnName + currentColumnNumber, JsonSchemaType.STRING)) } - val airbyteStream = CatalogHelpers - .createAirbyteStream(String.format(testStreamNameTemplate, currentStream), - nameSpace, fields) + val airbyteStream = + CatalogHelpers.createAirbyteStream( + String.format(testStreamNameTemplate, currentStream), + nameSpace, + fields + ) .withSourceDefinedCursor(true) - .withSourceDefinedPrimaryKey(java.util.List.of>(java.util.List.of(this.idColumnName))) + .withSourceDefinedPrimaryKey( + java.util.List.of>( + java.util.List.of(this.idColumnName) + ) + ) .withSupportedSyncModes( - Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL)) + Lists.newArrayList(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL) + ) - val configuredAirbyteStream = ConfiguredAirbyteStream() + val configuredAirbyteStream = + ConfiguredAirbyteStream() .withSyncMode(SyncMode.INCREMENTAL) .withCursorField(Lists.newArrayList(this.idColumnName)) .withDestinationSyncMode(DestinationSyncMode.APPEND) @@ -135,13 +165,5 @@ abstract class AbstractSourcePerformanceTest : AbstractSourceBasePerformanceTest companion object { protected val c: Logger = LoggerFactory.getLogger(AbstractSourcePerformanceTest::class.java) - protected val idColumnName: String = "id" - /** - * The column name will be used for a PK column in the test tables. Override it if default name is - * not valid for your source. - * - * @return Id column name - */ - get() = Companion.field } }