Skip to content

Commit

Permalink
debuggability improvements to the CDK
Browse files Browse the repository at this point in the history
  • Loading branch information
stephane-airbyte committed May 2, 2024
1 parent 914f044 commit 961d513
Show file tree
Hide file tree
Showing 9 changed files with 144 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ import io.airbyte.protocol.models.v0.AirbyteMessage
import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog
import io.airbyte.validation.json.JsonSchemaValidator
import java.io.*
import java.lang.reflect.Method
import java.nio.charset.StandardCharsets
import java.nio.file.Path
import java.time.Instant
import java.util.*
import java.util.concurrent.*
import java.util.function.Consumer
import java.util.function.Predicate
import java.util.stream.Collectors
import org.apache.commons.lang3.ThreadUtils
import org.apache.commons.lang3.concurrent.BasicThreadFactory
import org.slf4j.Logger
Expand Down Expand Up @@ -84,6 +84,7 @@ internal constructor(
(destination != null) xor (source != null),
"can only pass in a destination or a source"
)
threadCreationInfo.set(ThreadCreationInfo())
this.cliParser = cliParser
this.outputRecordCollector = outputRecordCollector
// integration iface covers the commands that are the same for both source and destination.
Expand Down Expand Up @@ -189,17 +190,20 @@ internal constructor(
}
}
Command.WRITE -> {
val config = parseConfig(parsed.getConfigPath())
validateConfig(integration.spec().connectionSpecification, config, "WRITE")
// save config to singleton
DestinationConfig.Companion.initialize(
config,
(integration as Destination).isV2Destination
)
val catalog =
parseConfig(parsed.getCatalogPath(), ConfiguredAirbyteCatalog::class.java)!!

try {
val config = parseConfig(parsed.getConfigPath())
validateConfig(integration.spec().connectionSpecification, config, "WRITE")
// save config to singleton
DestinationConfig.Companion.initialize(
config,
(integration as Destination).isV2Destination
)
val catalog =
parseConfig(
parsed.getCatalogPath(),
ConfiguredAirbyteCatalog::class.java
)!!

destination!!
.getSerializedMessageConsumer(config, catalog, outputRecordCollector)
.use { consumer -> consumeWriteStream(consumer!!) }
Expand Down Expand Up @@ -339,11 +343,37 @@ internal constructor(
}
}

class ThreadCreationInfo {
val stack: List<StackTraceElement> = Thread.currentThread().stackTrace.asList()
val time: Instant = Instant.now()
override fun toString(): String {
return "creationStack=${stack.joinToString("\n ")}\ncreationTime=$time"
}
}

companion object {
private val LOGGER: Logger = LoggerFactory.getLogger(IntegrationRunner::class.java)
private val threadCreationInfo: InheritableThreadLocal<ThreadCreationInfo> =
object : InheritableThreadLocal<ThreadCreationInfo>() {
override fun childValue(parentValue: ThreadCreationInfo): ThreadCreationInfo {
return ThreadCreationInfo()
}
}

const val TYPE_AND_DEDUPE_THREAD_NAME: String = "type-and-dedupe"

// ThreadLocal.get(Thread) is private. So we open it and keep a reference to the
// opened method
private val getMethod: Method =
ThreadLocal::class.java.getDeclaredMethod("get", Thread::class.java).also {
it.isAccessible = true
}

@JvmStatic
fun getThreadCreationInfo(thread: Thread): ThreadCreationInfo {
return getMethod.invoke(threadCreationInfo, thread) as ThreadCreationInfo
}

/**
* Filters threads that should not be considered when looking for orphaned threads at
* shutdown of the integration runner.
Expand All @@ -353,11 +383,12 @@ internal constructor(
* active so long as the database connection pool is open.
*/
@VisibleForTesting
val ORPHANED_THREAD_FILTER: Predicate<Thread> = Predicate { runningThread: Thread ->
(runningThread.name != Thread.currentThread().name &&
!runningThread.isDaemon &&
TYPE_AND_DEDUPE_THREAD_NAME != runningThread.name)
}
private val orphanedThreadPredicates: MutableList<(Thread) -> Boolean> =
mutableListOf({ runningThread: Thread ->
(runningThread.name != Thread.currentThread().name &&
!runningThread.isDaemon &&
TYPE_AND_DEDUPE_THREAD_NAME != runningThread.name)
})

const val INTERRUPT_THREAD_DELAY_MINUTES: Int = 1
const val EXIT_THREAD_DELAY_MINUTES: Int = 2
Expand Down Expand Up @@ -398,6 +429,15 @@ internal constructor(
LOGGER.info("Finished buffered read of input stream")
}

@JvmStatic
fun addOrphanedThreadFilter(predicate: (Thread) -> (Boolean)) {
orphanedThreadPredicates.add(predicate)
}

fun filterOrphanedThread(thread: Thread): Boolean {
return orphanedThreadPredicates.all { it(thread) }
}

/**
* Stops any non-daemon threads that could block the JVM from exiting when the main thread
* is done.
Expand Down Expand Up @@ -425,11 +465,7 @@ internal constructor(
) {
val currentThread = Thread.currentThread()

val runningThreads =
ThreadUtils.getAllThreads()
.stream()
.filter(ORPHANED_THREAD_FILTER)
.collect(Collectors.toList())
val runningThreads = ThreadUtils.getAllThreads().filter(::filterOrphanedThread).toList()
if (runningThreads.isNotEmpty()) {
LOGGER.warn(
"""
Expand All @@ -450,7 +486,10 @@ internal constructor(
.build()
)
for (runningThread in runningThreads) {
val str = "Active non-daemon thread: " + dumpThread(runningThread)
val str =
"Active non-daemon thread: " +
dumpThread(runningThread) +
"\ncreationStack=${getThreadCreationInfo(runningThread)}"
LOGGER.warn(str)
// even though the main thread is already shutting down, we still leave some
// chances to the children
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version=0.31.3
version=0.31.4
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ ${Jsons.serialize(message2)}""".toByteArray(
val runningThreads =
ThreadUtils.getAllThreads()
.stream()
.filter(IntegrationRunner.ORPHANED_THREAD_FILTER)
.filter(IntegrationRunner::filterOrphanedThread)
.collect(Collectors.toList())
// all threads should be interrupted
Assertions.assertEquals(listOf<Any>(), runningThreads)
Expand Down Expand Up @@ -505,7 +505,7 @@ ${Jsons.serialize(message2)}""".toByteArray(
val runningThreads =
ThreadUtils.getAllThreads()
.stream()
.filter(IntegrationRunner.ORPHANED_THREAD_FILTER)
.filter(IntegrationRunner::filterOrphanedThread)
.collect(Collectors.toList())
// a thread that refuses to be interrupted should remain
Assertions.assertEquals(1, runningThreads.size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,18 @@ class LoggingInvocationInterceptor : InvocationInterceptor {
ExtensionContext::class.java
) == null
) {
LOGGER!!.error(
LOGGER.error(
"Junit LoggingInvocationInterceptor executing unknown interception point {}",
method.name
)
return method.invoke(proxy, *(args!!))
return method.invoke(proxy, *(args))
}
val invocation = args!![0] as InvocationInterceptor.Invocation<*>?
val invocationContext = args[1] as ReflectiveInvocationContext<*>?
val invocation = args[0] as InvocationInterceptor.Invocation<*>?
val invocationContext = args[1] as ReflectiveInvocationContext<*>
val extensionContext = args[2] as ExtensionContext?
val methodName = method.name
val logLineSuffix: String?
val methodMatcher = methodPattern!!.matcher(methodName)
val logLineSuffix: String
val methodMatcher = methodPattern.matcher(methodName)
if (methodName == "interceptDynamicTest") {
logLineSuffix =
"execution of DynamicTest %s".formatted(extensionContext!!.displayName)
Expand All @@ -66,12 +66,19 @@ class LoggingInvocationInterceptor : InvocationInterceptor {
"instance creation for %s".formatted(invocationContext!!.targetClass)
} else if (methodMatcher.matches()) {
val interceptedEvent = methodMatcher.group(1)
val methodRealClassName = invocationContext!!.executable!!.declaringClass.simpleName
val methodName = invocationContext.executable!!.name
val targetClassName = invocationContext!!.targetClass.simpleName
val methodDisplayName =
if (targetClassName == methodRealClassName) methodName
else "$methodName($methodRealClassName)"
logLineSuffix =
"execution of @%s method %s.%s".formatted(
interceptedEvent,
invocationContext!!.executable!!.declaringClass.simpleName,
invocationContext.executable!!.name
targetClassName,
methodDisplayName
)
TestContext.CURRENT_TEST_NAME.set("$targetClassName.$methodName")
} else {
logLineSuffix = "execution of unknown intercepted call %s".formatted(methodName)
}
Expand All @@ -81,15 +88,15 @@ class LoggingInvocationInterceptor : InvocationInterceptor {
try {
val timeout = getTimeout(invocationContext)
if (timeout != null) {
LOGGER!!.info(
LOGGER.info(
"Junit starting {} with timeout of {}",
logLineSuffix,
DurationFormatUtils.formatDurationWords(timeout.toMillis(), true, true)
)
Timer("TimeoutTimer-" + currentThread.name, true)
.schedule(timeoutTask, timeout.toMillis())
} else {
LOGGER!!.warn("Junit starting {} with no timeout", logLineSuffix)
LOGGER.warn("Junit starting {} with no timeout", logLineSuffix)
}
val retVal = invocation!!.proceed()
val elapsedMs = Duration.between(start, Instant.now()).toMillis()
Expand Down Expand Up @@ -136,7 +143,7 @@ class LoggingInvocationInterceptor : InvocationInterceptor {
}
}
val stackTrace = StringUtils.join(stackToDisplay, "\n ")
LOGGER!!.error(
LOGGER.error(
"Junit exception throw during {} after {}:\n{}",
logLineSuffix,
DurationFormatUtils.formatDurationWords(elapsedMs, true, true),
Expand All @@ -145,24 +152,29 @@ class LoggingInvocationInterceptor : InvocationInterceptor {
throw t1
} finally {
timeoutTask.cancel()
TestContext.CURRENT_TEST_NAME.set(null)
}
}

private class TimeoutInteruptor(private val parentThread: Thread?) : TimerTask() {
private class TimeoutInteruptor(private val parentThread: Thread) : TimerTask() {
@Volatile var wasTriggered: Boolean = false

override fun run() {
LOGGER.info(
"interrupting running task on ${parentThread.name}. Current Stacktrace is ${parentThread.stackTrace.asList()}"
)
wasTriggered = true
parentThread!!.interrupt()
parentThread.interrupt()
}

override fun cancel(): Boolean {
LOGGER.info("cancelling timer task on ${parentThread.name}")
return super.cancel()
}
}

companion object {
private val methodPattern: Pattern? = Pattern.compile("intercept(.*)Method")
private val methodPattern: Pattern = Pattern.compile("intercept(.*)Method")

private val PATTERN: Pattern =
Pattern.compile(
Expand Down Expand Up @@ -201,11 +213,11 @@ class LoggingInvocationInterceptor : InvocationInterceptor {
)
}

private fun getTimeout(invocationContext: ReflectiveInvocationContext<*>?): Duration? {
private fun getTimeout(invocationContext: ReflectiveInvocationContext<*>): Duration {
var timeout: Duration? = null
var m = invocationContext!!.executable
var m = invocationContext.executable
if (m is Method) {
var timeoutAnnotation: Timeout? = m.getAnnotation<Timeout?>(Timeout::class.java)
var timeoutAnnotation: Timeout? = m.getAnnotation(Timeout::class.java)
if (timeoutAnnotation == null) {
timeoutAnnotation =
invocationContext.targetClass.getAnnotation(Timeout::class.java)
Expand Down Expand Up @@ -328,9 +340,9 @@ class LoggingInvocationInterceptor : InvocationInterceptor {
}

companion object {
private val LOGGER: Logger? =
private val LOGGER: Logger =
LoggerFactory.getLogger(LoggingInvocationInterceptor::class.java)
private val JUNIT_METHOD_EXECUTION_TIMEOUT_PROPERTY_NAME: String? =
private val JUNIT_METHOD_EXECUTION_TIMEOUT_PROPERTY_NAME: String =
"JunitMethodExecutionTimeout"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
/*
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.cdk.extensions

object TestContext {
val CURRENT_TEST_NAME: ThreadLocal<String?> = ThreadLocal()
}
Original file line number Diff line number Diff line change
Expand Up @@ -1469,7 +1469,7 @@ abstract class DestinationAcceptanceTest {
}

/** Whether the destination should be tested against different namespaces. */
protected open fun supportNamespaceTest(): Boolean {
open protected fun supportNamespaceTest(): Boolean {
return false
}

Expand Down Expand Up @@ -1571,19 +1571,21 @@ abstract class DestinationAcceptanceTest {
}

protected val destination: AirbyteDestination
get() =
DefaultAirbyteDestination(
AirbyteIntegrationLauncher(
JOB_ID,
JOB_ATTEMPT,
imageName,
processFactory,
null,
null,
false,
EnvVariableFeatureFlags()
)
get() {
return DefaultAirbyteDestination(
integrationLauncher =
AirbyteIntegrationLauncher(
JOB_ID,
JOB_ATTEMPT,
imageName,
processFactory,
null,
null,
false,
EnvVariableFeatureFlags()
)
)
}

@Throws(Exception::class)
protected fun runSyncAndVerifyStateOutput(
Expand Down
Loading

0 comments on commit 961d513

Please sign in to comment.