Skip to content

Commit

Permalink
Possible fix to problem with suspended transaction functions and spri…
Browse files Browse the repository at this point in the history
…ng transaction manager
  • Loading branch information
Tapac committed May 10, 2020
1 parent 0ea4082 commit 7762500
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,18 @@ class ThreadLocalTransactionManager(private val db: Database,
outerTransaction = outerTransaction
)
)).apply {
threadLocal.set(this)
bindTransactionToThread(this)
}

override fun currentOrNull(): Transaction? = threadLocal.get()

override fun bindTransactionToThread(transaction: Transaction?) {
if (transaction != null)
threadLocal.set(transaction)
else
threadLocal.remove()
}

private class ThreadLocalTransaction(
override val db: Database,
override val transactionIsolation: Int,
Expand Down Expand Up @@ -206,11 +213,11 @@ fun <T> inTopLevelTransaction(
}

internal fun <T> keepAndRestoreTransactionRefAfterRun(db: Database? = null, block: () -> T): T {
val manager = db.transactionManager as? ThreadLocalTransactionManager
val currentTransaction = manager?.currentOrNull()
val manager = db.transactionManager
val currentTransaction = manager.currentOrNull()
return try {
block()
} finally {
manager?.threadLocal?.set(currentTransaction)
manager.bindTransactionToThread(currentTransaction)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ private object NotInitializedManager : TransactionManager {
override fun newTransaction(isolation: Int, outerTransaction: Transaction?): Transaction = error("Please call Database.connect() before using this code")

override fun currentOrNull(): Transaction? = error("Please call Database.connect() before using this code")

override fun bindTransactionToThread(transaction: Transaction?) {
error("Please call Database.connect() before using this code")
}
}

interface TransactionManager {
Expand All @@ -48,6 +52,8 @@ interface TransactionManager {

fun currentOrNull(): Transaction?

fun bindTransactionToThread(transaction: Transaction?)

companion object {

private val managers = ConcurrentLinkedDeque<TransactionManager>().apply {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,19 @@ internal class TransactionScope(internal val tx: Transaction, parent: CoroutineC
companion object : CoroutineContext.Key<TransactionScope>
}

internal class TransactionCoroutineElement(val newTransaction: Transaction, manager: TransactionManager) : ThreadContextElement<TransactionContext> {
internal class TransactionCoroutineElement(val newTransaction: Transaction, val manager: TransactionManager) : ThreadContextElement<TransactionContext> {
override val key: CoroutineContext.Key<TransactionCoroutineElement> = Companion
private val tlManager = manager as? ThreadLocalTransactionManager

override fun updateThreadContext(context: CoroutineContext): TransactionContext {
val currentTransaction = TransactionManager.currentOrNull()
val currentManager = currentTransaction?.db?.transactionManager
tlManager?.let {
it.threadLocal.set(newTransaction)
TransactionManager.resetCurrent(it)
}
manager.bindTransactionToThread(newTransaction)
TransactionManager.resetCurrent(manager)
return TransactionContext(currentManager, currentTransaction)
}

override fun restoreThreadContext(context: CoroutineContext, oldState: TransactionContext) {

if (oldState.transaction == null)
tlManager?.threadLocal?.remove()
else
tlManager?.threadLocal?.set(oldState.transaction)
manager.bindTransactionToThread(oldState.transaction)
TransactionManager.resetCurrent(oldState.manager)
}

Expand Down Expand Up @@ -75,7 +68,7 @@ private fun Transaction.commitInAsync() {
val currentTransaction = TransactionManager.currentOrNull()
try {
val temporaryManager = this.db.transactionManager
(temporaryManager as? ThreadLocalTransactionManager)?.threadLocal?.set(this)
temporaryManager.bindTransactionToThread(this)
TransactionManager.resetCurrent(temporaryManager)
try {
commit()
Expand All @@ -95,7 +88,7 @@ private fun Transaction.commitInAsync() {
}
} finally {
val transactionManager = currentTransaction?.db?.transactionManager
(transactionManager as? ThreadLocalTransactionManager)?.threadLocal?.set(currentTransaction)
transactionManager?.bindTransactionToThread(currentTransaction)
TransactionManager.resetCurrent(transactionManager)
}
}
Expand Down
2 changes: 2 additions & 0 deletions spring-transaction/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ dependencies {
implementation("org.jetbrains.kotlinx", "kotlinx-coroutines-core", Versions.kotlinCoroutines)

testImplementation(project(":exposed-dao"))
testImplementation(project(":exposed-tests"))
testImplementation(kotlin("test-junit"))
testImplementation("org.jetbrains.kotlinx","kotlinx-coroutines-debug", Versions.kotlinCoroutines)
testImplementation("org.springframework", "spring-test", Versions.springFramework)
testImplementation("org.slf4j", "slf4j-log4j12", "1.7.26")
testImplementation("log4j", "log4j", "1.2.17")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,18 @@ class SpringTransactionManager(private val _dataSource: DataSource,
}

override fun currentOrNull(): Transaction? = TransactionSynchronizationManager.getResource(this) as Transaction?
override fun bindTransactionToThread(transaction: Transaction?) {
if (transaction != null) {
bindResourceForSure(this, transaction)
} else {
TransactionSynchronizationManager.unbindResourceIfPossible(this)
}
}

private fun bindResourceForSure(key: Any, value: Any) {
TransactionSynchronizationManager.unbindResourceIfPossible(key)
TransactionSynchronizationManager.bindResource(key, value)
}

private inner class SpringTransaction(
override val connection: ExposedConnection<*>,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package org.jetbrains.exposed.spring

import kotlinx.coroutines.*
import kotlinx.coroutines.debug.junit4.CoroutinesTimeout
import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.tests.utils.RepeatableTest
import org.jetbrains.exposed.sql.transactions.experimental.suspendedTransactionAsync
import org.jetbrains.exposed.sql.transactions.transaction
import org.junit.Rule
import org.junit.Test
import org.springframework.test.annotation.Commit
import org.springframework.transaction.annotation.Transactional
import kotlin.test.assertEquals


open class SpringCoroutineTest : SpringTransactionTestBase() {

@Rule
@JvmField
val timeout = CoroutinesTimeout.seconds(60)

object Testing : Table("COROUTINE_TESTING") {
val id = integer("id").autoIncrement() // Column<Int>

override val primaryKey = PrimaryKey(id)
}

@RepeatableTest(times = 5)
@Test @Transactional @Commit
open fun testNestedCoroutineTransaction() {
try {
SchemaUtils.create(Testing)

val mainJob = GlobalScope.async {

val results = (1..5).map { indx ->
suspendedTransactionAsync(Dispatchers.IO) {
Testing.insert { }
indx
}
}.awaitAll()

assertEquals(15, results.sum())
}

while (!mainJob.isCompleted) Thread.sleep(100)
mainJob.getCompletionExceptionOrNull()?.let { throw it }

transaction {
assertEquals(5L, Testing.selectAll().count())
}
} finally {
SchemaUtils.drop(Testing)
}
/*withTables(Testing) {
val mainJob = GlobalScope.async {
val job = launch(Dispatchers.IO) {
newSuspendedTransaction(db = db) {
Testing.insert {}
suspendedTransaction {
assertEquals(1, Testing.select { Testing.id.eq(1) }.singleOrNull()?.getOrNull(Testing.id))
}
}
}
job.join()
val result = newSuspendedTransaction(Dispatchers.Default, db = db) {
Testing.select { Testing.id.eq(1) }.single()[Testing.id]
}
kotlin.test.assertEquals(1, result)
}
while (!mainJob.isCompleted) Thread.sleep(100)
mainJob.getCompletionExceptionOrNull()?.let { throw it }
}*/
}
}

0 comments on commit 7762500

Please sign in to comment.