Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Destination CDK: Simplify AsyncStreamConsumer constructors #37106

Merged
merged 1 commit into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions airbyte-cdk/java/airbyte-cdk/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ Maven and Gradle will automatically reference the correct (pinned) version of th

| Version | Date | Pull Request | Subject |
|:--------|:-----------|:-----------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------|
| 0.30.3 | 2024-04-12 | [\#37106](https://github.com/airbytehq/airbyte/pull/37106) | Destinations: Simplify constructors in `AsyncStreamConsumer` |
| 0.30.2 | 2024-04-12 | [\#36926](https://github.com/airbytehq/airbyte/pull/36926) | Destinations: Remove `JdbcSqlOperations#formatData`; misc changes for java interop |
| 0.30.1 | 2024-04-11 | [\#36919](https://github.com/airbytehq/airbyte/pull/36919) | Fix regression in sources conversion of null values |
| 0.30.0 | 2024-04-11 | [\#36974](https://github.com/airbytehq/airbyte/pull/36974) | Destinations: Pass config to jdbc sqlgenerator; allow cascade drop |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,13 @@

package io.airbyte.cdk.integrations.destination.async

import com.google.common.annotations.VisibleForTesting
import com.google.common.base.Preconditions
import com.google.common.base.Strings
import io.airbyte.cdk.integrations.base.SerializedAirbyteMessageConsumer
import io.airbyte.cdk.integrations.destination.StreamSyncSummary
import io.airbyte.cdk.integrations.destination.async.buffers.BufferEnqueue
import io.airbyte.cdk.integrations.destination.async.buffers.BufferManager
import io.airbyte.cdk.integrations.destination.async.deser.DeserializationUtil
import io.airbyte.cdk.integrations.destination.async.deser.IdentityDataTransformer
import io.airbyte.cdk.integrations.destination.async.deser.StreamAwareDataTransformer
import io.airbyte.cdk.integrations.destination.async.deser.AirbyteMessageDeserializer
import io.airbyte.cdk.integrations.destination.async.function.DestinationFlushFunction
import io.airbyte.cdk.integrations.destination.async.model.PartialAirbyteMessage
import io.airbyte.cdk.integrations.destination.async.state.FlushFailure
Expand Down Expand Up @@ -44,26 +41,23 @@ private val logger = KotlinLogging.logger {}
* memory limit governed by [GlobalMemoryManager]. Record writing is decoupled via [FlushWorkers].
* See the other linked class for more detail.
*/
class AsyncStreamConsumer
@VisibleForTesting
constructor(
class AsyncStreamConsumer(
outputRecordCollector: Consumer<AirbyteMessage>,
private val onStart: OnStartFunction,
private val onClose: OnCloseFunction,
flusher: DestinationFlushFunction,
onFlush: DestinationFlushFunction,
private val catalog: ConfiguredAirbyteCatalog,
private val bufferManager: BufferManager,
private val flushFailure: FlushFailure,
private val defaultNamespace: Optional<String>,
workerPool: ExecutorService,
private val dataTransformer: StreamAwareDataTransformer,
private val deserializationUtil: DeserializationUtil,
private val flushFailure: FlushFailure = FlushFailure(),
workerPool: ExecutorService = Executors.newFixedThreadPool(5),
private val airbyteMessageDeserializer: AirbyteMessageDeserializer,
) : SerializedAirbyteMessageConsumer {
private val bufferEnqueue: BufferEnqueue = bufferManager.bufferEnqueue
private val flushWorkers: FlushWorkers =
FlushWorkers(
bufferManager.bufferDequeue,
flusher,
onFlush,
outputRecordCollector,
flushFailure,
bufferManager.stateManager,
Expand All @@ -81,73 +75,7 @@ constructor(
private var hasClosed = false
private var hasFailed = false

constructor(
outputRecordCollector: Consumer<AirbyteMessage>,
onStart: OnStartFunction,
onClose: OnCloseFunction,
flusher: DestinationFlushFunction,
catalog: ConfiguredAirbyteCatalog,
bufferManager: BufferManager,
defaultNamespace: Optional<String>,
) : this(
outputRecordCollector,
onStart,
onClose,
flusher,
catalog,
bufferManager,
FlushFailure(),
defaultNamespace,
)

constructor(
outputRecordCollector: Consumer<AirbyteMessage>,
onStart: OnStartFunction,
onClose: OnCloseFunction,
flusher: DestinationFlushFunction,
catalog: ConfiguredAirbyteCatalog,
bufferManager: BufferManager,
defaultNamespace: Optional<String>,
dataTransformer: StreamAwareDataTransformer,
) : this(
outputRecordCollector,
onStart,
onClose,
flusher,
catalog,
bufferManager,
FlushFailure(),
defaultNamespace,
Executors.newFixedThreadPool(5),
dataTransformer,
DeserializationUtil(),
)

constructor(
outputRecordCollector: Consumer<AirbyteMessage>,
onStart: OnStartFunction,
onClose: OnCloseFunction,
flusher: DestinationFlushFunction,
catalog: ConfiguredAirbyteCatalog,
bufferManager: BufferManager,
defaultNamespace: Optional<String>,
workerPool: ExecutorService,
) : this(
outputRecordCollector,
onStart,
onClose,
flusher,
catalog,
bufferManager,
FlushFailure(),
defaultNamespace,
workerPool,
IdentityDataTransformer(),
DeserializationUtil(),
)

@VisibleForTesting
constructor(
internal constructor(
outputRecordCollector: Consumer<AirbyteMessage>,
onStart: OnStartFunction,
onClose: OnCloseFunction,
Expand All @@ -163,11 +91,10 @@ constructor(
flusher,
catalog,
bufferManager,
flushFailure,
defaultNamespace,
flushFailure,
Executors.newFixedThreadPool(5),
IdentityDataTransformer(),
DeserializationUtil(),
AirbyteMessageDeserializer(),
)

@Throws(Exception::class)
Expand All @@ -183,7 +110,7 @@ constructor(

@Throws(Exception::class)
override fun accept(
messageString: String,
message: String,
sizeInBytes: Int,
) {
Preconditions.checkState(hasStarted, "Cannot accept records until consumer has started")
Expand All @@ -193,21 +120,22 @@ constructor(
* to try to use a thread pool to partially deserialize to get record type and stream name, we can
* do it without touching buffer manager.
*/
val message =
deserializationUtil.deserializeAirbyteMessage(
messageString,
dataTransformer,
val partialAirbyteMessage =
airbyteMessageDeserializer.deserializeAirbyteMessage(
message,
)
if (AirbyteMessage.Type.RECORD == message.type) {
if (Strings.isNullOrEmpty(message.record?.namespace)) {
message.record?.namespace = defaultNamespace.getOrNull()
if (AirbyteMessage.Type.RECORD == partialAirbyteMessage.type) {
if (Strings.isNullOrEmpty(partialAirbyteMessage.record?.namespace)) {
partialAirbyteMessage.record?.namespace = defaultNamespace.getOrNull()
}
validateRecord(message)
validateRecord(partialAirbyteMessage)

message.record?.streamDescriptor?.let { getRecordCounter(it).incrementAndGet() }
partialAirbyteMessage.record?.streamDescriptor?.let {
getRecordCounter(it).incrementAndGet()
}
}
bufferEnqueue.addRecord(
message,
partialAirbyteMessage,
sizeInBytes + PARTIAL_DESERIALIZE_REF_BYTES,
defaultNamespace,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ import io.airbyte.cdk.integrations.destination.async.model.PartialAirbyteMessage
import io.airbyte.commons.json.Jsons
import io.airbyte.protocol.models.v0.AirbyteMessage

class DeserializationUtil {
class AirbyteMessageDeserializer(
private val dataTransformer: StreamAwareDataTransformer = IdentityDataTransformer(),
) {
/**
* Deserializes to a [PartialAirbyteMessage] which can represent both a Record or a State
* Message
Expand All @@ -16,20 +18,20 @@ class DeserializationUtil {
* * entire serialized message string when message is a valid State Message
* * serialized AirbyteRecordMessage when message is a valid Record Message
*
* @param messageString the string to deserialize
* @param message the string to deserialize
* @return PartialAirbyteMessage if the message is valid, empty otherwise
*/
fun deserializeAirbyteMessage(
messageString: String?,
dataTransformer: StreamAwareDataTransformer,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉 love to see this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

love how the doc string didn't even match the parameters...

message: String?,
): PartialAirbyteMessage {
// TODO: This is doing some sketchy assumptions by deserializing either the whole or the
// partial based on type.
// Use JsonSubTypes and extend StdDeserializer to properly handle this.
// Make immutability a first class citizen in the PartialAirbyteMessage class.
val partial =
Jsons.tryDeserializeExact(messageString, PartialAirbyteMessage::class.java)
.orElseThrow { RuntimeException("Unable to deserialize PartialAirbyteMessage.") }
Jsons.tryDeserializeExact(message, PartialAirbyteMessage::class.java).orElseThrow {
RuntimeException("Unable to deserialize PartialAirbyteMessage.")
}

val msgType = partial.type
if (AirbyteMessage.Type.RECORD == msgType && partial.record?.data != null) {
Expand All @@ -50,7 +52,7 @@ class DeserializationUtil {
// usage.
partial.record?.data = null
} else if (AirbyteMessage.Type.STATE == msgType) {
partial.withSerialized(messageString)
partial.withSerialized(message)
} else {
throw RuntimeException(String.format("Unsupported message type: %s", msgType))
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version=0.30.2
version=0.30.3
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ package io.airbyte.cdk.integrations.destination.async
import com.fasterxml.jackson.databind.JsonNode
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings
import io.airbyte.cdk.integrations.destination.async.buffers.BufferManager
import io.airbyte.cdk.integrations.destination.async.deser.DeserializationUtil
import io.airbyte.cdk.integrations.destination.async.deser.IdentityDataTransformer
import io.airbyte.cdk.integrations.destination.async.deser.AirbyteMessageDeserializer
import io.airbyte.cdk.integrations.destination.async.deser.StreamAwareDataTransformer
import io.airbyte.cdk.integrations.destination.async.function.DestinationFlushFunction
import io.airbyte.cdk.integrations.destination.async.model.PartialAirbyteMessage
Expand Down Expand Up @@ -125,7 +124,7 @@ class AsyncStreamConsumerTest {
private lateinit var outputRecordCollector: Consumer<AirbyteMessage>
private lateinit var flushFailure: FlushFailure
private lateinit var streamAwareDataTransformer: StreamAwareDataTransformer
private lateinit var deserializationUtil: DeserializationUtil
private lateinit var airbyteMessageDeserializer: AirbyteMessageDeserializer

@BeforeEach
@Suppress("UNCHECKED_CAST")
Expand All @@ -139,20 +138,18 @@ class AsyncStreamConsumerTest {
flushFunction = Mockito.mock(DestinationFlushFunction::class.java)
outputRecordCollector = Mockito.mock(Consumer::class.java) as Consumer<AirbyteMessage>
flushFailure = Mockito.mock(FlushFailure::class.java)
deserializationUtil = DeserializationUtil()
streamAwareDataTransformer = IdentityDataTransformer()
airbyteMessageDeserializer = AirbyteMessageDeserializer()
consumer =
AsyncStreamConsumer(
outputRecordCollector = outputRecordCollector,
onStart = onStart,
onClose = onClose,
flusher = flushFunction,
onFlush = flushFunction,
catalog = CATALOG,
bufferManager = BufferManager(),
flushFailure = flushFailure,
defaultNamespace = Optional.of("default_ns"),
dataTransformer = streamAwareDataTransformer,
deserializationUtil = deserializationUtil,
airbyteMessageDeserializer = airbyteMessageDeserializer,
workerPool = Executors.newFixedThreadPool(5),
)

Expand Down Expand Up @@ -330,9 +327,8 @@ class AsyncStreamConsumerTest {
val serializedAirbyteMessage = Jsons.serialize(airbyteMessage)
val airbyteRecordString = Jsons.serialize(PAYLOAD)
val partial =
deserializationUtil.deserializeAirbyteMessage(
airbyteMessageDeserializer.deserializeAirbyteMessage(
serializedAirbyteMessage,
streamAwareDataTransformer,
)
assertEquals(airbyteRecordString, partial.serialized)
}
Expand All @@ -357,9 +353,8 @@ class AsyncStreamConsumerTest {
val serializedAirbyteMessage = Jsons.serialize(airbyteMessage)
val airbyteRecordString = Jsons.serialize(payload)
val partial =
deserializationUtil.deserializeAirbyteMessage(
airbyteMessageDeserializer.deserializeAirbyteMessage(
serializedAirbyteMessage,
streamAwareDataTransformer,
)
assertEquals(airbyteRecordString, partial.serialized)
}
Expand All @@ -378,9 +373,8 @@ class AsyncStreamConsumerTest {
)
val serializedAirbyteMessage = Jsons.serialize(airbyteMessage)
val partial =
deserializationUtil.deserializeAirbyteMessage(
airbyteMessageDeserializer.deserializeAirbyteMessage(
serializedAirbyteMessage,
streamAwareDataTransformer,
)
assertEquals(emptyMap.toString(), partial.serialized)
}
Expand All @@ -393,9 +387,8 @@ class AsyncStreamConsumerTest {
assertThrows(
RuntimeException::class.java,
) {
deserializationUtil.deserializeAirbyteMessage(
airbyteMessageDeserializer.deserializeAirbyteMessage(
serializedAirbyteMessage,
streamAwareDataTransformer,
)
}
}
Expand All @@ -404,9 +397,8 @@ class AsyncStreamConsumerTest {
internal fun deserializeAirbyteMessageWithAirbyteState() {
val serializedAirbyteMessage = Jsons.serialize(STATE_MESSAGE1)
val partial =
deserializationUtil.deserializeAirbyteMessage(
airbyteMessageDeserializer.deserializeAirbyteMessage(
serializedAirbyteMessage,
streamAwareDataTransformer,
)
assertEquals(serializedAirbyteMessage, partial.serialized)
}
Expand All @@ -430,9 +422,8 @@ class AsyncStreamConsumerTest {
assertThrows(
RuntimeException::class.java,
) {
deserializationUtil.deserializeAirbyteMessage(
airbyteMessageDeserializer.deserializeAirbyteMessage(
serializedAirbyteMessage,
streamAwareDataTransformer,
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import io.airbyte.cdk.integrations.destination.NamingConventionTransformer
import io.airbyte.cdk.integrations.destination.StreamSyncSummary
import io.airbyte.cdk.integrations.destination.async.AsyncStreamConsumer
import io.airbyte.cdk.integrations.destination.async.buffers.BufferManager
import io.airbyte.cdk.integrations.destination.async.deser.DeserializationUtil
import io.airbyte.cdk.integrations.destination.async.deser.AirbyteMessageDeserializer
import io.airbyte.cdk.integrations.destination.async.deser.IdentityDataTransformer
import io.airbyte.cdk.integrations.destination.async.deser.StreamAwareDataTransformer
import io.airbyte.cdk.integrations.destination.async.model.PartialAirbyteMessage
Expand Down Expand Up @@ -78,11 +78,10 @@ object JdbcBufferedConsumerFactory {
),
catalog,
BufferManager((Runtime.getRuntime().maxMemory() * 0.2).toLong()),
FlushFailure(),
Optional.ofNullable(defaultNamespace),
FlushFailure(),
Executors.newFixedThreadPool(2),
dataTransformer,
DeserializationUtil()
AirbyteMessageDeserializer(dataTransformer)
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.commons.io.FileUtils
private val logger = KotlinLogging.logger {}

internal class AsyncFlush(
streamDescToWriteConfig: Map<StreamDescriptor, WriteConfig>,
private val streamDescToWriteConfig: Map<StreamDescriptor, WriteConfig>,
private val stagingOperations: StagingOperations?,
private val database: JdbcDatabase?,
private val catalog: ConfiguredAirbyteCatalog?,
Expand All @@ -41,8 +41,6 @@ internal class AsyncFlush(
override val optimalBatchSizeBytes: Long,
private val useDestinationsV2Columns: Boolean
) : DestinationFlushFunction {
private val streamDescToWriteConfig: Map<StreamDescriptor, WriteConfig> =
streamDescToWriteConfig

@Throws(Exception::class)
override fun flush(decs: StreamDescriptor, stream: Stream<PartialAirbyteMessage>) {
Expand Down
Loading
Loading