Skip to content

Commit

Permalink
Destination CDK: Simplify AsyncStreamConsumer constructors (#37106)
Browse files Browse the repository at this point in the history
  • Loading branch information
gisripa committed Apr 12, 2024
1 parent faad484 commit 6d5ecca
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 146 deletions.
1 change: 1 addition & 0 deletions airbyte-cdk/java/airbyte-cdk/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ Maven and Gradle will automatically reference the correct (pinned) version of th

| Version | Date | Pull Request | Subject |
|:--------|:-----------|:-----------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------|
| 0.30.3 | 2024-04-12 | [\#37106](https://github.com/airbytehq/airbyte/pull/37106) | Destinations: Simplify constructors in `AsyncStreamConsumer` |
| 0.30.2 | 2024-04-12 | [\#36926](https://github.com/airbytehq/airbyte/pull/36926) | Destinations: Remove `JdbcSqlOperations#formatData`; misc changes for java interop |
| 0.30.1 | 2024-04-11 | [\#36919](https://github.com/airbytehq/airbyte/pull/36919) | Fix regression in sources conversion of null values |
| 0.30.0 | 2024-04-11 | [\#36974](https://github.com/airbytehq/airbyte/pull/36974) | Destinations: Pass config to jdbc sqlgenerator; allow cascade drop |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,13 @@

package io.airbyte.cdk.integrations.destination.async

import com.google.common.annotations.VisibleForTesting
import com.google.common.base.Preconditions
import com.google.common.base.Strings
import io.airbyte.cdk.integrations.base.SerializedAirbyteMessageConsumer
import io.airbyte.cdk.integrations.destination.StreamSyncSummary
import io.airbyte.cdk.integrations.destination.async.buffers.BufferEnqueue
import io.airbyte.cdk.integrations.destination.async.buffers.BufferManager
import io.airbyte.cdk.integrations.destination.async.deser.DeserializationUtil
import io.airbyte.cdk.integrations.destination.async.deser.IdentityDataTransformer
import io.airbyte.cdk.integrations.destination.async.deser.StreamAwareDataTransformer
import io.airbyte.cdk.integrations.destination.async.deser.AirbyteMessageDeserializer
import io.airbyte.cdk.integrations.destination.async.function.DestinationFlushFunction
import io.airbyte.cdk.integrations.destination.async.model.PartialAirbyteMessage
import io.airbyte.cdk.integrations.destination.async.state.FlushFailure
Expand Down Expand Up @@ -44,26 +41,23 @@ private val logger = KotlinLogging.logger {}
* memory limit governed by [GlobalMemoryManager]. Record writing is decoupled via [FlushWorkers].
* See the other linked class for more detail.
*/
class AsyncStreamConsumer
@VisibleForTesting
constructor(
class AsyncStreamConsumer(
outputRecordCollector: Consumer<AirbyteMessage>,
private val onStart: OnStartFunction,
private val onClose: OnCloseFunction,
flusher: DestinationFlushFunction,
onFlush: DestinationFlushFunction,
private val catalog: ConfiguredAirbyteCatalog,
private val bufferManager: BufferManager,
private val flushFailure: FlushFailure,
private val defaultNamespace: Optional<String>,
workerPool: ExecutorService,
private val dataTransformer: StreamAwareDataTransformer,
private val deserializationUtil: DeserializationUtil,
private val flushFailure: FlushFailure = FlushFailure(),
workerPool: ExecutorService = Executors.newFixedThreadPool(5),
private val airbyteMessageDeserializer: AirbyteMessageDeserializer,
) : SerializedAirbyteMessageConsumer {
private val bufferEnqueue: BufferEnqueue = bufferManager.bufferEnqueue
private val flushWorkers: FlushWorkers =
FlushWorkers(
bufferManager.bufferDequeue,
flusher,
onFlush,
outputRecordCollector,
flushFailure,
bufferManager.stateManager,
Expand All @@ -81,73 +75,7 @@ constructor(
private var hasClosed = false
private var hasFailed = false

constructor(
outputRecordCollector: Consumer<AirbyteMessage>,
onStart: OnStartFunction,
onClose: OnCloseFunction,
flusher: DestinationFlushFunction,
catalog: ConfiguredAirbyteCatalog,
bufferManager: BufferManager,
defaultNamespace: Optional<String>,
) : this(
outputRecordCollector,
onStart,
onClose,
flusher,
catalog,
bufferManager,
FlushFailure(),
defaultNamespace,
)

constructor(
outputRecordCollector: Consumer<AirbyteMessage>,
onStart: OnStartFunction,
onClose: OnCloseFunction,
flusher: DestinationFlushFunction,
catalog: ConfiguredAirbyteCatalog,
bufferManager: BufferManager,
defaultNamespace: Optional<String>,
dataTransformer: StreamAwareDataTransformer,
) : this(
outputRecordCollector,
onStart,
onClose,
flusher,
catalog,
bufferManager,
FlushFailure(),
defaultNamespace,
Executors.newFixedThreadPool(5),
dataTransformer,
DeserializationUtil(),
)

constructor(
outputRecordCollector: Consumer<AirbyteMessage>,
onStart: OnStartFunction,
onClose: OnCloseFunction,
flusher: DestinationFlushFunction,
catalog: ConfiguredAirbyteCatalog,
bufferManager: BufferManager,
defaultNamespace: Optional<String>,
workerPool: ExecutorService,
) : this(
outputRecordCollector,
onStart,
onClose,
flusher,
catalog,
bufferManager,
FlushFailure(),
defaultNamespace,
workerPool,
IdentityDataTransformer(),
DeserializationUtil(),
)

@VisibleForTesting
constructor(
internal constructor(
outputRecordCollector: Consumer<AirbyteMessage>,
onStart: OnStartFunction,
onClose: OnCloseFunction,
Expand All @@ -163,11 +91,10 @@ constructor(
flusher,
catalog,
bufferManager,
flushFailure,
defaultNamespace,
flushFailure,
Executors.newFixedThreadPool(5),
IdentityDataTransformer(),
DeserializationUtil(),
AirbyteMessageDeserializer(),
)

@Throws(Exception::class)
Expand All @@ -183,7 +110,7 @@ constructor(

@Throws(Exception::class)
override fun accept(
messageString: String,
message: String,
sizeInBytes: Int,
) {
Preconditions.checkState(hasStarted, "Cannot accept records until consumer has started")
Expand All @@ -193,21 +120,22 @@ constructor(
* to try to use a thread pool to partially deserialize to get record type and stream name, we can
* do it without touching buffer manager.
*/
val message =
deserializationUtil.deserializeAirbyteMessage(
messageString,
dataTransformer,
val partialAirbyteMessage =
airbyteMessageDeserializer.deserializeAirbyteMessage(
message,
)
if (AirbyteMessage.Type.RECORD == message.type) {
if (Strings.isNullOrEmpty(message.record?.namespace)) {
message.record?.namespace = defaultNamespace.getOrNull()
if (AirbyteMessage.Type.RECORD == partialAirbyteMessage.type) {
if (Strings.isNullOrEmpty(partialAirbyteMessage.record?.namespace)) {
partialAirbyteMessage.record?.namespace = defaultNamespace.getOrNull()
}
validateRecord(message)
validateRecord(partialAirbyteMessage)

message.record?.streamDescriptor?.let { getRecordCounter(it).incrementAndGet() }
partialAirbyteMessage.record?.streamDescriptor?.let {
getRecordCounter(it).incrementAndGet()
}
}
bufferEnqueue.addRecord(
message,
partialAirbyteMessage,
sizeInBytes + PARTIAL_DESERIALIZE_REF_BYTES,
defaultNamespace,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ import io.airbyte.cdk.integrations.destination.async.model.PartialAirbyteMessage
import io.airbyte.commons.json.Jsons
import io.airbyte.protocol.models.v0.AirbyteMessage

class DeserializationUtil {
class AirbyteMessageDeserializer(
private val dataTransformer: StreamAwareDataTransformer = IdentityDataTransformer(),
) {
/**
* Deserializes to a [PartialAirbyteMessage] which can represent both a Record or a State
* Message
Expand All @@ -16,20 +18,20 @@ class DeserializationUtil {
* * entire serialized message string when message is a valid State Message
* * serialized AirbyteRecordMessage when message is a valid Record Message
*
* @param messageString the string to deserialize
* @param message the string to deserialize
* @return PartialAirbyteMessage if the message is valid, empty otherwise
*/
fun deserializeAirbyteMessage(
messageString: String?,
dataTransformer: StreamAwareDataTransformer,
message: String?,
): PartialAirbyteMessage {
// TODO: This is doing some sketchy assumptions by deserializing either the whole or the
// partial based on type.
// Use JsonSubTypes and extend StdDeserializer to properly handle this.
// Make immutability a first class citizen in the PartialAirbyteMessage class.
val partial =
Jsons.tryDeserializeExact(messageString, PartialAirbyteMessage::class.java)
.orElseThrow { RuntimeException("Unable to deserialize PartialAirbyteMessage.") }
Jsons.tryDeserializeExact(message, PartialAirbyteMessage::class.java).orElseThrow {
RuntimeException("Unable to deserialize PartialAirbyteMessage.")
}

val msgType = partial.type
if (AirbyteMessage.Type.RECORD == msgType && partial.record?.data != null) {
Expand All @@ -50,7 +52,7 @@ class DeserializationUtil {
// usage.
partial.record?.data = null
} else if (AirbyteMessage.Type.STATE == msgType) {
partial.withSerialized(messageString)
partial.withSerialized(message)
} else {
throw RuntimeException(String.format("Unsupported message type: %s", msgType))
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version=0.30.2
version=0.30.3
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ package io.airbyte.cdk.integrations.destination.async
import com.fasterxml.jackson.databind.JsonNode
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings
import io.airbyte.cdk.integrations.destination.async.buffers.BufferManager
import io.airbyte.cdk.integrations.destination.async.deser.DeserializationUtil
import io.airbyte.cdk.integrations.destination.async.deser.IdentityDataTransformer
import io.airbyte.cdk.integrations.destination.async.deser.AirbyteMessageDeserializer
import io.airbyte.cdk.integrations.destination.async.deser.StreamAwareDataTransformer
import io.airbyte.cdk.integrations.destination.async.function.DestinationFlushFunction
import io.airbyte.cdk.integrations.destination.async.model.PartialAirbyteMessage
Expand Down Expand Up @@ -125,7 +124,7 @@ class AsyncStreamConsumerTest {
private lateinit var outputRecordCollector: Consumer<AirbyteMessage>
private lateinit var flushFailure: FlushFailure
private lateinit var streamAwareDataTransformer: StreamAwareDataTransformer
private lateinit var deserializationUtil: DeserializationUtil
private lateinit var airbyteMessageDeserializer: AirbyteMessageDeserializer

@BeforeEach
@Suppress("UNCHECKED_CAST")
Expand All @@ -139,20 +138,18 @@ class AsyncStreamConsumerTest {
flushFunction = Mockito.mock(DestinationFlushFunction::class.java)
outputRecordCollector = Mockito.mock(Consumer::class.java) as Consumer<AirbyteMessage>
flushFailure = Mockito.mock(FlushFailure::class.java)
deserializationUtil = DeserializationUtil()
streamAwareDataTransformer = IdentityDataTransformer()
airbyteMessageDeserializer = AirbyteMessageDeserializer()
consumer =
AsyncStreamConsumer(
outputRecordCollector = outputRecordCollector,
onStart = onStart,
onClose = onClose,
flusher = flushFunction,
onFlush = flushFunction,
catalog = CATALOG,
bufferManager = BufferManager(),
flushFailure = flushFailure,
defaultNamespace = Optional.of("default_ns"),
dataTransformer = streamAwareDataTransformer,
deserializationUtil = deserializationUtil,
airbyteMessageDeserializer = airbyteMessageDeserializer,
workerPool = Executors.newFixedThreadPool(5),
)

Expand Down Expand Up @@ -330,9 +327,8 @@ class AsyncStreamConsumerTest {
val serializedAirbyteMessage = Jsons.serialize(airbyteMessage)
val airbyteRecordString = Jsons.serialize(PAYLOAD)
val partial =
deserializationUtil.deserializeAirbyteMessage(
airbyteMessageDeserializer.deserializeAirbyteMessage(
serializedAirbyteMessage,
streamAwareDataTransformer,
)
assertEquals(airbyteRecordString, partial.serialized)
}
Expand All @@ -357,9 +353,8 @@ class AsyncStreamConsumerTest {
val serializedAirbyteMessage = Jsons.serialize(airbyteMessage)
val airbyteRecordString = Jsons.serialize(payload)
val partial =
deserializationUtil.deserializeAirbyteMessage(
airbyteMessageDeserializer.deserializeAirbyteMessage(
serializedAirbyteMessage,
streamAwareDataTransformer,
)
assertEquals(airbyteRecordString, partial.serialized)
}
Expand All @@ -378,9 +373,8 @@ class AsyncStreamConsumerTest {
)
val serializedAirbyteMessage = Jsons.serialize(airbyteMessage)
val partial =
deserializationUtil.deserializeAirbyteMessage(
airbyteMessageDeserializer.deserializeAirbyteMessage(
serializedAirbyteMessage,
streamAwareDataTransformer,
)
assertEquals(emptyMap.toString(), partial.serialized)
}
Expand All @@ -393,9 +387,8 @@ class AsyncStreamConsumerTest {
assertThrows(
RuntimeException::class.java,
) {
deserializationUtil.deserializeAirbyteMessage(
airbyteMessageDeserializer.deserializeAirbyteMessage(
serializedAirbyteMessage,
streamAwareDataTransformer,
)
}
}
Expand All @@ -404,9 +397,8 @@ class AsyncStreamConsumerTest {
internal fun deserializeAirbyteMessageWithAirbyteState() {
val serializedAirbyteMessage = Jsons.serialize(STATE_MESSAGE1)
val partial =
deserializationUtil.deserializeAirbyteMessage(
airbyteMessageDeserializer.deserializeAirbyteMessage(
serializedAirbyteMessage,
streamAwareDataTransformer,
)
assertEquals(serializedAirbyteMessage, partial.serialized)
}
Expand All @@ -430,9 +422,8 @@ class AsyncStreamConsumerTest {
assertThrows(
RuntimeException::class.java,
) {
deserializationUtil.deserializeAirbyteMessage(
airbyteMessageDeserializer.deserializeAirbyteMessage(
serializedAirbyteMessage,
streamAwareDataTransformer,
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import io.airbyte.cdk.integrations.destination.NamingConventionTransformer
import io.airbyte.cdk.integrations.destination.StreamSyncSummary
import io.airbyte.cdk.integrations.destination.async.AsyncStreamConsumer
import io.airbyte.cdk.integrations.destination.async.buffers.BufferManager
import io.airbyte.cdk.integrations.destination.async.deser.DeserializationUtil
import io.airbyte.cdk.integrations.destination.async.deser.AirbyteMessageDeserializer
import io.airbyte.cdk.integrations.destination.async.deser.IdentityDataTransformer
import io.airbyte.cdk.integrations.destination.async.deser.StreamAwareDataTransformer
import io.airbyte.cdk.integrations.destination.async.model.PartialAirbyteMessage
Expand Down Expand Up @@ -78,11 +78,10 @@ object JdbcBufferedConsumerFactory {
),
catalog,
BufferManager((Runtime.getRuntime().maxMemory() * 0.2).toLong()),
FlushFailure(),
Optional.ofNullable(defaultNamespace),
FlushFailure(),
Executors.newFixedThreadPool(2),
dataTransformer,
DeserializationUtil()
AirbyteMessageDeserializer(dataTransformer)
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.commons.io.FileUtils
private val logger = KotlinLogging.logger {}

internal class AsyncFlush(
streamDescToWriteConfig: Map<StreamDescriptor, WriteConfig>,
private val streamDescToWriteConfig: Map<StreamDescriptor, WriteConfig>,
private val stagingOperations: StagingOperations?,
private val database: JdbcDatabase?,
private val catalog: ConfiguredAirbyteCatalog?,
Expand All @@ -41,8 +41,6 @@ internal class AsyncFlush(
override val optimalBatchSizeBytes: Long,
private val useDestinationsV2Columns: Boolean
) : DestinationFlushFunction {
private val streamDescToWriteConfig: Map<StreamDescriptor, WriteConfig> =
streamDescToWriteConfig

@Throws(Exception::class)
override fun flush(decs: StreamDescriptor, stream: Stream<PartialAirbyteMessage>) {
Expand Down
Loading

0 comments on commit 6d5ecca

Please sign in to comment.