Skip to content

Commit

Permalink
minor changes for destination-snowflake
Browse files Browse the repository at this point in the history
  • Loading branch information
stephane-airbyte committed May 1, 2024
1 parent 0f51cbd commit ac8e4ba
Show file tree
Hide file tree
Showing 12 changed files with 165 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import java.util.function.Consumer
import java.util.function.Function
import java.util.stream.Stream
import java.util.stream.StreamSupport
import org.slf4j.Logger
import org.slf4j.LoggerFactory

/** Database object for interacting with a JDBC connection. */
abstract class JdbcDatabase(protected val sourceOperations: JdbcCompatibleSourceOperations<*>?) :
Expand Down Expand Up @@ -211,6 +213,7 @@ abstract class JdbcDatabase(protected val sourceOperations: JdbcCompatibleSource
abstract fun <T> executeMetadataQuery(query: Function<DatabaseMetaData?, T>): T

companion object {
private val LOGGER: Logger = LoggerFactory.getLogger(JdbcDatabase::class.java)
/**
* Map records returned in a result set. It is an "unsafe" stream because the stream must be
* manually closed. Otherwise, there will be a database connection leak.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ internal constructor(
(destination != null) xor (source != null),
"can only pass in a destination or a source"
)
threadCreationStack.set("main")
this.cliParser = cliParser
this.outputRecordCollector = outputRecordCollector
// integration iface covers the commands that are the same for both source and destination.
Expand Down Expand Up @@ -199,13 +200,9 @@ internal constructor(
val catalog =
parseConfig(parsed.getCatalogPath(), ConfiguredAirbyteCatalog::class.java)!!

try {
destination!!
.getSerializedMessageConsumer(config, catalog, outputRecordCollector)
.use { consumer -> consumeWriteStream(consumer!!) }
} finally {
stopOrphanedThreads()
}
destination!!
.getSerializedMessageConsumer(config, catalog, outputRecordCollector)
.use { consumer -> consumeWriteStream(consumer!!) }
}
}
} catch (e: Exception) {
Expand Down Expand Up @@ -242,6 +239,8 @@ internal constructor(
return
}
throw e
} finally {
stopOrphanedThreads()
}

LOGGER.info("Completed integration: {}", integration.javaClass.name)
Expand Down Expand Up @@ -341,6 +340,12 @@ internal constructor(

companion object {
private val LOGGER: Logger = LoggerFactory.getLogger(IntegrationRunner::class.java)
private val threadCreationStack: InheritableThreadLocal<String> =
object : InheritableThreadLocal<String>() {
override fun childValue(parentValue: String): String {
return Thread.currentThread().stackTrace.joinToString("\n at ")
}
}

const val TYPE_AND_DEDUPE_THREAD_NAME: String = "type-and-dedupe"

Expand Down Expand Up @@ -449,8 +454,13 @@ internal constructor(
.daemon(true)
.build()
)
val getMethod = ThreadLocal::class.java.getDeclaredMethod("get", Thread::class.java)
getMethod.isAccessible = true
for (runningThread in runningThreads) {
val str = "Active non-daemon thread: " + dumpThread(runningThread)
val str =
"Active non-daemon thread: " +
dumpThread(runningThread) +
"\ncreationStack=${getMethod.invoke(threadCreationStack, runningThread)}"
LOGGER.warn(str)
// even though the main thread is already shutting down, we still leave some
// chances to the children
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,18 @@ class LoggingInvocationInterceptor : InvocationInterceptor {
ExtensionContext::class.java
) == null
) {
LOGGER!!.error(
LOGGER.error(
"Junit LoggingInvocationInterceptor executing unknown interception point {}",
method.name
)
return method.invoke(proxy, *(args!!))
return method.invoke(proxy, *(args))
}
val invocation = args!![0] as InvocationInterceptor.Invocation<*>?
val invocationContext = args[1] as ReflectiveInvocationContext<*>?
val invocation = args[0] as InvocationInterceptor.Invocation<*>?
val invocationContext = args[1] as ReflectiveInvocationContext<*>
val extensionContext = args[2] as ExtensionContext?
val methodName = method.name
val logLineSuffix: String?
val methodMatcher = methodPattern!!.matcher(methodName)
val logLineSuffix: String
val methodMatcher = methodPattern.matcher(methodName)
if (methodName == "interceptDynamicTest") {
logLineSuffix =
"execution of DynamicTest %s".formatted(extensionContext!!.displayName)
Expand All @@ -66,12 +66,19 @@ class LoggingInvocationInterceptor : InvocationInterceptor {
"instance creation for %s".formatted(invocationContext!!.targetClass)
} else if (methodMatcher.matches()) {
val interceptedEvent = methodMatcher.group(1)
val methodRealClassName = invocationContext!!.executable!!.declaringClass.simpleName
val methodName = invocationContext.executable!!.name
val targetClassName = invocationContext!!.targetClass.simpleName
val methodDisplayName =
if (targetClassName == methodRealClassName) methodName
else "$methodName($methodRealClassName)"
logLineSuffix =
"execution of @%s method %s.%s".formatted(
interceptedEvent,
invocationContext!!.executable!!.declaringClass.simpleName,
invocationContext.executable!!.name
targetClassName,
methodDisplayName
)
TestContext.CURRENT_TEST_NAME.set("$targetClassName.$methodName")
} else {
logLineSuffix = "execution of unknown intercepted call %s".formatted(methodName)
}
Expand All @@ -81,15 +88,16 @@ class LoggingInvocationInterceptor : InvocationInterceptor {
try {
val timeout = getTimeout(invocationContext)
if (timeout != null) {
LOGGER!!.info(
LOGGER.info(
"Junit starting {} with timeout of {}",
logLineSuffix,
DurationFormatUtils.formatDurationWords(timeout.toMillis(), true, true)
)
Timer("TimeoutTimer-" + currentThread.name, true)
.schedule(timeoutTask, timeout.toMillis())
var timer =
Timer("TimeoutTimer-" + currentThread.name, true)
.schedule(timeoutTask, timeout.toMillis())
} else {
LOGGER!!.warn("Junit starting {} with no timeout", logLineSuffix)
LOGGER.warn("Junit starting {} with no timeout", logLineSuffix)
}
val retVal = invocation!!.proceed()
val elapsedMs = Duration.between(start, Instant.now()).toMillis()
Expand Down Expand Up @@ -136,7 +144,7 @@ class LoggingInvocationInterceptor : InvocationInterceptor {
}
}
val stackTrace = StringUtils.join(stackToDisplay, "\n ")
LOGGER!!.error(
LOGGER.error(
"Junit exception throw during {} after {}:\n{}",
logLineSuffix,
DurationFormatUtils.formatDurationWords(elapsedMs, true, true),
Expand All @@ -145,24 +153,29 @@ class LoggingInvocationInterceptor : InvocationInterceptor {
throw t1
} finally {
timeoutTask.cancel()
TestContext.CURRENT_TEST_NAME.set(null)
}
}

private class TimeoutInteruptor(private val parentThread: Thread?) : TimerTask() {
private class TimeoutInteruptor(private val parentThread: Thread) : TimerTask() {
@Volatile var wasTriggered: Boolean = false

override fun run() {
LOGGER.info(
"interrupting running task on ${parentThread.name}. Current Stacktrace is ${parentThread.stackTrace.asList()}"
)
wasTriggered = true
parentThread!!.interrupt()
parentThread.interrupt()
}

override fun cancel(): Boolean {
LOGGER.info("cancelling timer task on ${parentThread.name}")
return super.cancel()
}
}

companion object {
private val methodPattern: Pattern? = Pattern.compile("intercept(.*)Method")
private val methodPattern: Pattern = Pattern.compile("intercept(.*)Method")

private val PATTERN: Pattern =
Pattern.compile(
Expand Down Expand Up @@ -201,11 +214,11 @@ class LoggingInvocationInterceptor : InvocationInterceptor {
)
}

private fun getTimeout(invocationContext: ReflectiveInvocationContext<*>?): Duration? {
private fun getTimeout(invocationContext: ReflectiveInvocationContext<*>): Duration {
var timeout: Duration? = null
var m = invocationContext!!.executable
var m = invocationContext.executable
if (m is Method) {
var timeoutAnnotation: Timeout? = m.getAnnotation<Timeout?>(Timeout::class.java)
var timeoutAnnotation: Timeout? = m.getAnnotation(Timeout::class.java)
if (timeoutAnnotation == null) {
timeoutAnnotation =
invocationContext.targetClass.getAnnotation(Timeout::class.java)
Expand Down Expand Up @@ -328,9 +341,9 @@ class LoggingInvocationInterceptor : InvocationInterceptor {
}

companion object {
private val LOGGER: Logger? =
private val LOGGER: Logger =
LoggerFactory.getLogger(LoggingInvocationInterceptor::class.java)
private val JUNIT_METHOD_EXECUTION_TIMEOUT_PROPERTY_NAME: String? =
private val JUNIT_METHOD_EXECUTION_TIMEOUT_PROPERTY_NAME: String =
"JunitMethodExecutionTimeout"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
/*
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.cdk.extensions

class TestContext {
companion object {
val CURRENT_TEST_NAME: ThreadLocal<String?> = ThreadLocal()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,29 @@ abstract class JdbcDestinationHandler<DestinationState>(
return actualColumns == intendedColumns
}

protected open fun getDeleteStatesSql(
destinationStates: Map<StreamId, DestinationState>
): String {
return dslContext
.deleteFrom(table(quotedName(rawTableSchemaName, DESTINATION_STATE_TABLE_NAME)))
.where(
destinationStates.keys
.stream()
.map { streamId: StreamId ->
field(quotedName(DESTINATION_STATE_TABLE_COLUMN_NAME))
.eq(streamId.originalName)
.and(
field(quotedName(DESTINATION_STATE_TABLE_COLUMN_NAMESPACE))
.eq(streamId.originalNamespace)
)
}
.reduce(DSL.falseCondition()) { obj: Condition, arg2: Condition? ->
obj.or(arg2)
}
)
.getSQL(ParamType.INLINED)
}

@Throws(Exception::class)
override fun commitDestinationStates(destinationStates: Map<StreamId, DestinationState>) {
try {
Expand All @@ -408,25 +431,7 @@ abstract class JdbcDestinationHandler<DestinationState>(
}

// Delete all state records where the stream name+namespace match one of our states
val deleteStates =
dslContext
.deleteFrom(table(quotedName(rawTableSchemaName, DESTINATION_STATE_TABLE_NAME)))
.where(
destinationStates.keys
.stream()
.map { streamId: StreamId ->
field(quotedName(DESTINATION_STATE_TABLE_COLUMN_NAME))
.eq(streamId.originalName)
.and(
field(quotedName(DESTINATION_STATE_TABLE_COLUMN_NAMESPACE))
.eq(streamId.originalNamespace)
)
}
.reduce(DSL.falseCondition()) { obj: Condition, arg2: Condition? ->
obj.or(arg2)
}
)
.getSQL(ParamType.INLINED)
var deleteStates = getDeleteStatesSql(destinationStates)

// Reinsert all of our states
var insertStatesStep =
Expand Down Expand Up @@ -461,12 +466,17 @@ abstract class JdbcDestinationHandler<DestinationState>(
}
val insertStates = insertStatesStep.getSQL(ParamType.INLINED)

jdbcDatabase.executeWithinTransaction(listOf(deleteStates, insertStates))
executeWithinTransaction(listOf(deleteStates, insertStates))
} catch (e: Exception) {
LOGGER.warn("Failed to commit destination states", e)
}
}

@Throws(Exception::class)
protected open fun executeWithinTransaction(statements: List<String>) {
jdbcDatabase.executeWithinTransaction(statements)
}

/**
* Convert to the TYPE_NAME retrieved from [java.sql.DatabaseMetaData.getColumns]
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import com.fasterxml.jackson.databind.node.ObjectNode
import com.google.common.collect.ImmutableMap
import com.google.common.collect.Lists
import com.google.common.collect.Sets
import io.airbyte.cdk.extensions.TestContext
import io.airbyte.cdk.integrations.destination.NamingConventionTransformer
import io.airbyte.cdk.integrations.standardtest.destination.*
import io.airbyte.cdk.integrations.standardtest.destination.argproviders.DataArgumentsProvider
Expand Down Expand Up @@ -47,6 +48,7 @@ import io.airbyte.workers.helper.ConnectorConfigUpdater
import io.airbyte.workers.helper.EntrypointEnvChecker
import io.airbyte.workers.internal.AirbyteDestination
import io.airbyte.workers.internal.DefaultAirbyteDestination
import io.airbyte.workers.internal.DefaultAirbyteStreamFactory
import io.airbyte.workers.normalization.DefaultNormalizationRunner
import io.airbyte.workers.normalization.NormalizationRunner
import io.airbyte.workers.process.AirbyteIntegrationLauncher
Expand Down Expand Up @@ -1573,16 +1575,22 @@ abstract class DestinationAcceptanceTest {
protected val destination: AirbyteDestination
get() =
DefaultAirbyteDestination(
AirbyteIntegrationLauncher(
JOB_ID,
JOB_ATTEMPT,
imageName,
processFactory,
null,
null,
false,
EnvVariableFeatureFlags()
)
integrationLauncher =
AirbyteIntegrationLauncher(
JOB_ID,
JOB_ATTEMPT,
imageName,
processFactory,
null,
null,
false,
EnvVariableFeatureFlags()
),
streamFactory =
DefaultAirbyteStreamFactory(
DefaultAirbyteDestination.createContainerLogMdcBuilder()
.setLogPrefix("destination-(${TestContext.CURRENT_TEST_NAME.get()})")
)
)

@Throws(Exception::class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class DefaultAirbyteDestination
constructor(
private val integrationLauncher: IntegrationLauncher,
private val streamFactory: AirbyteStreamFactory =
DefaultAirbyteStreamFactory(CONTAINER_LOG_MDC_BUILDER),
DefaultAirbyteStreamFactory(createContainerLogMdcBuilder()),
private val messageWriterFactory: AirbyteMessageBufferedWriterFactory =
DefaultAirbyteMessageBufferedWriterFactory(),
private val protocolSerializer: ProtocolSerializer = DefaultProtocolSerializer()
Expand Down Expand Up @@ -87,7 +87,7 @@ constructor(
destinationProcess!!.errorStream,
{ msg: String? -> LOGGER.error(msg) },
"airbyte-destination",
CONTAINER_LOG_MDC_BUILDER
createContainerLogMdcBuilder()
)

writer =
Expand Down Expand Up @@ -179,7 +179,7 @@ constructor(

companion object {
private val LOGGER: Logger = LoggerFactory.getLogger(DefaultAirbyteDestination::class.java)
val CONTAINER_LOG_MDC_BUILDER: MdcScope.Builder =
fun createContainerLogMdcBuilder(): MdcScope.Builder =
MdcScope.Builder()
.setLogPrefix("destination")
.setPrefixColor(LoggingHelper.Color.YELLOW_BACKGROUND)
Expand Down
Loading

0 comments on commit ac8e4ba

Please sign in to comment.