Skip to content

Commit

Permalink
Wrap default namespace in optional to avoid NPE (#36207)
Browse files Browse the repository at this point in the history
  • Loading branch information
jdpgrailsdev committed Mar 15, 2024
1 parent 0755321 commit cec938f
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 31 deletions.
Expand Up @@ -31,6 +31,7 @@ import java.util.concurrent.Executors
import java.util.concurrent.atomic.AtomicLong
import java.util.function.Consumer
import java.util.stream.Collectors
import kotlin.jvm.optionals.getOrNull
import org.slf4j.Logger
import org.slf4j.LoggerFactory

Expand All @@ -52,7 +53,7 @@ constructor(
private val catalog: ConfiguredAirbyteCatalog,
private val bufferManager: BufferManager,
private val flushFailure: FlushFailure,
private val defaultNamespace: String,
private val defaultNamespace: Optional<String>,
workerPool: ExecutorService,
private val dataTransformer: StreamAwareDataTransformer,
private val deserializationUtil: DeserializationUtil,
Expand Down Expand Up @@ -94,7 +95,7 @@ constructor(
flusher: DestinationFlushFunction,
catalog: ConfiguredAirbyteCatalog,
bufferManager: BufferManager,
defaultNamespace: String,
defaultNamespace: Optional<String>,
) : this(
outputRecordCollector,
onStart,
Expand All @@ -113,7 +114,7 @@ constructor(
flusher: DestinationFlushFunction,
catalog: ConfiguredAirbyteCatalog,
bufferManager: BufferManager,
defaultNamespace: String,
defaultNamespace: Optional<String>,
dataTransformer: StreamAwareDataTransformer,
) : this(
outputRecordCollector,
Expand All @@ -136,7 +137,7 @@ constructor(
flusher: DestinationFlushFunction,
catalog: ConfiguredAirbyteCatalog,
bufferManager: BufferManager,
defaultNamespace: String,
defaultNamespace: Optional<String>,
workerPool: ExecutorService,
) : this(
outputRecordCollector,
Expand All @@ -161,7 +162,7 @@ constructor(
catalog: ConfiguredAirbyteCatalog,
bufferManager: BufferManager,
flushFailure: FlushFailure,
defaultNamespace: String,
defaultNamespace: Optional<String>,
) : this(
outputRecordCollector,
onStart,
Expand Down Expand Up @@ -206,7 +207,7 @@ constructor(
)
if (AirbyteMessage.Type.RECORD == message.type) {
if (Strings.isNullOrEmpty(message.record?.namespace)) {
message.record?.namespace = defaultNamespace
message.record?.namespace = defaultNamespace.getOrNull()
}
validateRecord(message)

Expand Down
Expand Up @@ -9,6 +9,7 @@ import io.airbyte.cdk.integrations.destination.async.partial_messages.PartialAir
import io.airbyte.cdk.integrations.destination.async.state.GlobalAsyncStateManager
import io.airbyte.protocol.models.v0.AirbyteMessage
import io.airbyte.protocol.models.v0.StreamDescriptor
import java.util.Optional
import java.util.concurrent.ConcurrentMap

/**
Expand All @@ -30,12 +31,12 @@ class BufferEnqueue(
fun addRecord(
message: PartialAirbyteMessage,
sizeInBytes: Int,
defaultNamespace: String,
defaultNamespace: Optional<String>,
) {
if (message.type == AirbyteMessage.Type.RECORD) {
handleRecord(message, sizeInBytes)
} else if (message.type == AirbyteMessage.Type.STATE) {
stateManager.trackState(message, sizeInBytes.toLong(), defaultNamespace)
stateManager.trackState(message, sizeInBytes.toLong(), defaultNamespace.orElse(""))
}
}

Expand Down
Expand Up @@ -34,6 +34,7 @@ import io.airbyte.protocol.models.v0.StreamDescriptor
import java.io.IOException
import java.math.BigDecimal
import java.time.Instant
import java.util.Optional
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
import java.util.concurrent.TimeoutException
Expand Down Expand Up @@ -148,7 +149,7 @@ class AsyncStreamConsumerTest {
catalog = CATALOG,
bufferManager = BufferManager(),
flushFailure = flushFailure,
defaultNamespace = "default_ns",
defaultNamespace = Optional.of("default_ns"),
dataTransformer = streamAwareDataTransformer,
deserializationUtil = deserializationUtil,
workerPool = Executors.newFixedThreadPool(5),
Expand Down Expand Up @@ -268,7 +269,7 @@ class AsyncStreamConsumerTest {
CATALOG,
BufferManager((1024 * 10).toLong()),
flushFailure,
"default_ns",
Optional.of("default_ns"),
)
Mockito.`when`(flushFunction.optimalBatchSizeBytes).thenReturn(0L)

Expand Down
Expand Up @@ -12,6 +12,7 @@ import io.airbyte.protocol.models.v0.AirbyteMessage
import io.airbyte.protocol.models.v0.StreamDescriptor
import java.time.Instant
import java.time.temporal.ChronoUnit
import java.util.Optional
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Nested
import org.junit.jupiter.api.Test
Expand All @@ -36,10 +37,26 @@ class BufferDequeueTest {
val enqueue = bufferManager.bufferEnqueue
val dequeue = bufferManager.bufferDequeue

enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES, DEFAULT_NAMESPACE)
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES, DEFAULT_NAMESPACE)
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES, DEFAULT_NAMESPACE)
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES, DEFAULT_NAMESPACE)
enqueue.addRecord(
RECORD_MSG_20_BYTES,
RECORD_SIZE_20_BYTES,
Optional.of(DEFAULT_NAMESPACE)
)
enqueue.addRecord(
RECORD_MSG_20_BYTES,
RECORD_SIZE_20_BYTES,
Optional.of(DEFAULT_NAMESPACE)
)
enqueue.addRecord(
RECORD_MSG_20_BYTES,
RECORD_SIZE_20_BYTES,
Optional.of(DEFAULT_NAMESPACE)
)
enqueue.addRecord(
RECORD_MSG_20_BYTES,
RECORD_SIZE_20_BYTES,
Optional.of(DEFAULT_NAMESPACE)
)

// total size of records is 80, so we expect 50 to get us 2 records (prefer to
// under-pull records
Expand All @@ -64,9 +81,21 @@ class BufferDequeueTest {
val enqueue = bufferManager.bufferEnqueue
val dequeue = bufferManager.bufferDequeue

enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES, DEFAULT_NAMESPACE)
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES, DEFAULT_NAMESPACE)
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES, DEFAULT_NAMESPACE)
enqueue.addRecord(
RECORD_MSG_20_BYTES,
RECORD_SIZE_20_BYTES,
Optional.of(DEFAULT_NAMESPACE)
)
enqueue.addRecord(
RECORD_MSG_20_BYTES,
RECORD_SIZE_20_BYTES,
Optional.of(DEFAULT_NAMESPACE)
)
enqueue.addRecord(
RECORD_MSG_20_BYTES,
RECORD_SIZE_20_BYTES,
Optional.of(DEFAULT_NAMESPACE)
)

try {
dequeue.take(STREAM_DESC, 60).use { take ->
Expand All @@ -83,8 +112,16 @@ class BufferDequeueTest {
val enqueue = bufferManager.bufferEnqueue
val dequeue = bufferManager.bufferDequeue

enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES, DEFAULT_NAMESPACE)
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES, DEFAULT_NAMESPACE)
enqueue.addRecord(
RECORD_MSG_20_BYTES,
RECORD_SIZE_20_BYTES,
Optional.of(DEFAULT_NAMESPACE)
)
enqueue.addRecord(
RECORD_MSG_20_BYTES,
RECORD_SIZE_20_BYTES,
Optional.of(DEFAULT_NAMESPACE)
)

try {
dequeue.take(STREAM_DESC, Long.MAX_VALUE).use { take ->
Expand All @@ -102,13 +139,17 @@ class BufferDequeueTest {
val enqueue = bufferManager.bufferEnqueue
val dequeue = bufferManager.bufferDequeue

enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES, DEFAULT_NAMESPACE)
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES, DEFAULT_NAMESPACE)
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES, Optional.of(DEFAULT_NAMESPACE))
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES, Optional.of(DEFAULT_NAMESPACE))

val secondStream = StreamDescriptor().withName("stream_2")
val recordFromSecondStream = Jsons.clone(RECORD_MSG_20_BYTES)
recordFromSecondStream.record?.withStream(secondStream.name)
enqueue.addRecord(recordFromSecondStream, RECORD_SIZE_20_BYTES, DEFAULT_NAMESPACE)
enqueue.addRecord(
recordFromSecondStream,
RECORD_SIZE_20_BYTES,
Optional.of(DEFAULT_NAMESPACE)
)

Assertions.assertEquals(60, dequeue.totalGlobalQueueSizeBytes)

Expand Down Expand Up @@ -157,15 +198,15 @@ class BufferDequeueTest {
)

// allocate a block for new stream
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES, DEFAULT_NAMESPACE)
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES, Optional.of(DEFAULT_NAMESPACE))
Assertions.assertEquals(
2 * GlobalMemoryManager.BLOCK_SIZE_BYTES,
memoryManager.getCurrentMemoryBytes(),
)

enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES, DEFAULT_NAMESPACE)
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES, DEFAULT_NAMESPACE)
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES, DEFAULT_NAMESPACE)
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES, Optional.of(DEFAULT_NAMESPACE))
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES, Optional.of(DEFAULT_NAMESPACE))
enqueue.addRecord(RECORD_MSG_20_BYTES, RECORD_SIZE_20_BYTES, Optional.of(DEFAULT_NAMESPACE))

// no re-allocates as we haven't breached block size
Assertions.assertEquals(
Expand Down
Expand Up @@ -10,6 +10,7 @@ import io.airbyte.cdk.integrations.destination.async.partial_messages.PartialAir
import io.airbyte.cdk.integrations.destination.async.state.GlobalAsyncStateManager
import io.airbyte.protocol.models.v0.AirbyteMessage
import io.airbyte.protocol.models.v0.StreamDescriptor
import java.util.Optional
import java.util.concurrent.ConcurrentHashMap
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Test
Expand Down Expand Up @@ -41,7 +42,7 @@ class BufferEnqueueTest {
PartialAirbyteRecordMessage().withStream(streamName),
)

enqueue.addRecord(record, RECORD_SIZE_20_BYTES, DEFAULT_NAMESPACE)
enqueue.addRecord(record, RECORD_SIZE_20_BYTES, Optional.of(DEFAULT_NAMESPACE))
Assertions.assertEquals(1, streamToBuffer[stream]!!.size())
Assertions.assertEquals(20L, streamToBuffer[stream]!!.currentMemoryUsage)
}
Expand All @@ -68,8 +69,8 @@ class BufferEnqueueTest {
PartialAirbyteRecordMessage().withStream(streamName),
)

enqueue.addRecord(record, RECORD_SIZE_20_BYTES, DEFAULT_NAMESPACE)
enqueue.addRecord(record, RECORD_SIZE_20_BYTES, DEFAULT_NAMESPACE)
enqueue.addRecord(record, RECORD_SIZE_20_BYTES, Optional.of(DEFAULT_NAMESPACE))
enqueue.addRecord(record, RECORD_SIZE_20_BYTES, Optional.of(DEFAULT_NAMESPACE))
Assertions.assertEquals(2, streamToBuffer[stream]!!.size())
Assertions.assertEquals(40, streamToBuffer[stream]!!.currentMemoryUsage)
}
Expand Down
Expand Up @@ -79,7 +79,7 @@ public static SerializedAirbyteMessageConsumer createAsync(final Consumer<Airbyt
new JdbcInsertFlushFunction(recordWriterFunction(database, sqlOperations, writeConfigs, catalog)),
catalog,
new BufferManager((long) (Runtime.getRuntime().maxMemory() * 0.2)),
defaultNamespace,
Optional.ofNullable(defaultNamespace),
Executors.newFixedThreadPool(2));
}

Expand Down
Expand Up @@ -211,7 +211,7 @@ public SerializedAirbyteMessageConsumer createAsync() {
flusher,
catalog,
new BufferManager(getMemoryLimit(bufferMemoryLimit)),
defaultNamespace);
Optional.ofNullable(defaultNamespace));
}

private static long getMemoryLimit(final Optional<Long> bufferMemoryLimit) {
Expand Down

0 comments on commit cec938f

Please sign in to comment.