Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug cancellation parMap #2579

Merged
merged 2 commits into from
Nov 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ public final class arrow/fx/coroutines/NamedThreadFactory : java/util/concurrent

public final class arrow/fx/coroutines/Predef_testKt {
public static final fun assertThrowable (Lkotlin/jvm/functions/Function0;)Ljava/lang/Throwable;
public static final fun awaitExitCase (Lkotlinx/coroutines/CompletableDeferred;Lkotlinx/coroutines/CompletableDeferred;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static final fun awaitExitCase (Lkotlinx/coroutines/channels/Channel;Lkotlinx/coroutines/CompletableDeferred;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static final fun charRange (Lio/kotest/property/Arb$Companion;)Lio/kotest/property/Arb;
public static final fun either (Larrow/core/Either;)Lio/kotest/matchers/Matcher;
public static final fun either (Lio/kotest/property/Arb$Companion;Lio/kotest/property/Arb;Lio/kotest/property/Arb;)Lio/kotest/property/Arb;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,19 @@ import io.kotest.property.arbitrary.list
import io.kotest.property.arbitrary.long
import io.kotest.property.arbitrary.map
import io.kotest.property.arbitrary.string
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.asFlow
import kotlin.coroutines.Continuation
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.intrinsics.COROUTINE_SUSPENDED
import kotlin.coroutines.intrinsics.intercepted
import kotlin.coroutines.intrinsics.suspendCoroutineUninterceptedOrReturn
import kotlin.coroutines.resume
import kotlin.coroutines.startCoroutine
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.awaitCancellation
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.asFlow
import kotlinx.coroutines.flow.buffer
import kotlinx.coroutines.flow.channelFlow
import kotlinx.coroutines.flow.emptyFlow
Expand Down Expand Up @@ -217,3 +219,15 @@ public fun <A> either(e: Either<Throwable, A>): Matcher<Either<Throwable, A>> =
is Either.Right -> equalityMatcher(e).test(value)
}
}

public suspend fun <A> awaitExitCase(send: Channel<Unit>, exit: CompletableDeferred<ExitCase>): A =
guaranteeCase({
send.receive()
awaitCancellation()
}) { ex -> exit.complete(ex) }

public suspend fun <A> awaitExitCase(start: CompletableDeferred<Unit>, exit: CompletableDeferred<ExitCase>): A =
guaranteeCase({
start.complete(Unit)
awaitCancellation()
}) { ex -> exit.complete(ex) }
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import kotlinx.coroutines.coroutineScope
import kotlin.coroutines.ContinuationInterceptor
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.EmptyCoroutineContext
import kotlinx.coroutines.awaitAll

/**
* Runs [fa], [fb] in parallel on [Dispatchers.Default] and combines their results using the provided function.
Expand Down Expand Up @@ -76,9 +77,10 @@ public suspend inline fun <A, B, C> parZip(
crossinline fb: suspend CoroutineScope.() -> B,
crossinline f: suspend CoroutineScope.(A, B) -> C
): C = coroutineScope {
val a = async(ctx) { fa() }
val b = async(ctx) { fb() }
f(a.await(), b.await())
val faa = async(ctx) { fa() }
val fbb = async(ctx) { fb() }
val (a, b) = awaitAll(faa, fbb)
f(a as A, b as B)
}

/**
Expand Down Expand Up @@ -155,10 +157,11 @@ public suspend inline fun <A, B, C, D> parZip(
crossinline fc: suspend CoroutineScope.() -> C,
crossinline f: suspend CoroutineScope.(A, B, C) -> D
): D = coroutineScope {
val a = async(ctx) { fa() }
val b = async(ctx) { fb() }
val c = async(ctx) { fc() }
f(a.await(), b.await(), c.await())
val faa = async(ctx) { fa() }
val fbb = async(ctx) { fb() }
val fcc = async(ctx) { fc() }
val (a, b, c) = awaitAll(faa, fbb, fcc)
f(a as A, b as B, c as C)
}

/**
Expand Down Expand Up @@ -242,11 +245,12 @@ public suspend inline fun <A, B, C, D, E> parZip(
crossinline fd: suspend CoroutineScope.() -> D,
crossinline f: suspend CoroutineScope.(A, B, C, D) -> E
): E = coroutineScope {
val a = async(ctx) { fa() }
val b = async(ctx) { fb() }
val c = async(ctx) { fc() }
val d = async(ctx) { fd() }
f(a.await(), b.await(), c.await(), d.await())
val faa = async(ctx) { fa() }
val fbb = async(ctx) { fb() }
val fcc = async(ctx) { fc() }
val fdd = async(ctx) { fd() }
val (a, b, c, d) = awaitAll(faa, fbb, fcc, fdd)
f(a as A, b as B, c as C, d as D)
}

/**
Expand Down Expand Up @@ -337,12 +341,13 @@ public suspend inline fun <A, B, C, D, E, F> parZip(
crossinline fe: suspend CoroutineScope.() -> E,
crossinline f: suspend CoroutineScope.(A, B, C, D, E) -> F
): F = coroutineScope {
val a = async(ctx) { fa() }
val b = async(ctx) { fb() }
val c = async(ctx) { fc() }
val d = async(ctx) { fd() }
val e = async(ctx) { fe() }
f(a.await(), b.await(), c.await(), d.await(), e.await())
val faa = async(ctx) { fa() }
val fbb = async(ctx) { fb() }
val fcc = async(ctx) { fc() }
val fdd = async(ctx) { fd() }
val fee = async(ctx) { fe() }
val (a, b, c, d, e) = awaitAll(faa, fbb, fcc, fdd, fee)
f(a as A, b as B, c as C, d as D, e as E)
}

/**
Expand Down Expand Up @@ -439,13 +444,14 @@ public suspend inline fun <A, B, C, D, E, F, G> parZip(
crossinline ff: suspend CoroutineScope.() -> F,
crossinline f: suspend CoroutineScope.(A, B, C, D, E, F) -> G
): G = coroutineScope {
val a = async(ctx) { fa() }
val b = async(ctx) { fb() }
val c = async(ctx) { fc() }
val d = async(ctx) { fd() }
val e = async(ctx) { fe() }
val g = async(ctx) { ff() }
f(a.await(), b.await(), c.await(), d.await(), e.await(), g.await())
val faa = async(ctx) { fa() }
val fbb = async(ctx) { fb() }
val fcc = async(ctx) { fc() }
val fdd = async(ctx) { fd() }
val fee = async(ctx) { fe() }
val fgg = async(ctx) { ff() }
val res = awaitAll(faa, fbb, fcc, fdd, fee, fgg)
f(res[0] as A, res[1] as B, res[2] as C, res[3] as D, res[4] as E, res[5] as F)
}

/**
Expand Down Expand Up @@ -548,14 +554,15 @@ public suspend inline fun <A, B, C, D, E, F, G, H> parZip(
crossinline fg: suspend CoroutineScope.() -> G,
crossinline f: suspend CoroutineScope.(A, B, C, D, E, F, G) -> H
): H = coroutineScope {
val a = async(ctx) { fa() }
val b = async(ctx) { fb() }
val c = async(ctx) { fc() }
val d = async(ctx) { fd() }
val e = async(ctx) { fe() }
val faa = async(ctx) { fa() }
val fbb = async(ctx) { fb() }
val fcc = async(ctx) { fc() }
val fdd = async(ctx) { fd() }
val fee = async(ctx) { fe() }
val fDef = async(ctx) { ff() }
val g = async(ctx) { fg() }
f(a.await(), b.await(), c.await(), d.await(), e.await(), fDef.await(), g.await())
val fgg = async(ctx) { fg() }
val res = awaitAll(faa, fbb, fcc, fdd, fee, fDef, fgg)
f(res[0] as A, res[1] as B, res[2] as C, res[3] as D, res[4] as E, res[5] as F, res[6] as G)
}

/**
Expand Down Expand Up @@ -663,13 +670,14 @@ public suspend inline fun <A, B, C, D, E, F, G, H, I> parZip(
crossinline fh: suspend CoroutineScope.() -> H,
crossinline f: suspend CoroutineScope.(A, B, C, D, E, F, G, H) -> I
): I = coroutineScope {
val a = async(ctx) { fa() }
val b = async(ctx) { fb() }
val c = async(ctx) { fc() }
val d = async(ctx) { fd() }
val e = async(ctx) { fe() }
val faa = async(ctx) { fa() }
val fbb = async(ctx) { fb() }
val fcc = async(ctx) { fc() }
val fdd = async(ctx) { fd() }
val fee = async(ctx) { fe() }
val fDef = async(ctx) { ff() }
val g = async(ctx) { fg() }
val h = async(ctx) { fh() }
f(a.await(), b.await(), c.await(), d.await(), e.await(), fDef.await(), g.await(), h.await())
val fgg = async(ctx) { fg() }
val fhh = async(ctx) { fh() }
val res = awaitAll(faa, fbb, fcc, fdd, fee, fDef, fgg, fhh)
f(res[0] as A, res[1] as B, res[2] as C, res[3] as D, res[4] as E, res[5] as F, res[6] as G, res[7] as H)
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,28 @@ class BracketCaseTest : ArrowFxSpec(
spec = {
"Immediate acquire bracketCase finishes successfully" {
checkAll(Arb.int(), Arb.int()) { a, b ->
var once = true
bracketCase(
acquire = { a },
use = { aa -> Pair(aa, b) },
release = { _, _ -> Unit }
release = { _, _ ->
require(once)
once = false
}
) shouldBe Pair(a, b)
}
}

"Suspended acquire bracketCase finishes successfully" {
checkAll(Arb.int(), Arb.int()) { a, b ->
var once = true
bracketCase(
acquire = { a.suspend() },
use = { aa -> Pair(aa, b) },
release = { _, _ -> Unit }
release = { _, _ ->
require(once)
once = false
}
) shouldBe Pair(a, b)
}
}
Expand Down Expand Up @@ -60,20 +68,28 @@ class BracketCaseTest : ArrowFxSpec(

"Immediate use bracketCase finishes successfully" {
checkAll(Arb.int(), Arb.int()) { a, b ->
var once = true
bracketCase(
acquire = { a },
use = { aa -> Pair(aa, b).suspend() },
release = { _, _ -> Unit }
release = { _, _ ->
require(once)
once = false
}
) shouldBe Pair(a, b)
}
}

"Suspended use bracketCase finishes successfully" {
checkAll(Arb.int(), Arb.int()) { a, b ->
var once = true
bracketCase(
acquire = { a },
use = { aa -> Pair(aa, b).suspend() },
release = { _, _ -> Unit }
release = { _, _ ->
require(once)
once = false
}
) shouldBe Pair(a, b)
}
}
Expand Down Expand Up @@ -309,7 +325,7 @@ class BracketCaseTest : ArrowFxSpec(
mVar.send(y)
},
use = { never<Unit>() },
release = { _, exitCase -> p.complete(exitCase) }
release = { _, exitCase -> require(p.complete(exitCase)) }
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class GuaranteeCaseTest : ArrowFxSpec(

val res = guaranteeCase(
fa = { i },
finalizer = { ex -> p.complete(ex) }
finalizer = { ex -> require(p.complete(ex)) }
)

p.await() shouldBe ExitCase.Completed
Expand All @@ -31,7 +31,7 @@ class GuaranteeCaseTest : ArrowFxSpec(
val attempted = Either.catch {
guaranteeCase<Int>(
fa = { throw e },
finalizer = { ex -> p.complete(ex) }
finalizer = { ex -> require(p.complete(ex)) }
)
}

Expand All @@ -50,7 +50,7 @@ class GuaranteeCaseTest : ArrowFxSpec(
start.complete(Unit)
never<Unit>()
},
finalizer = { ex -> p.complete(ex) }
finalizer = { ex -> require(p.complete(ex)) }
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,28 @@ import arrow.core.Either
import arrow.fx.coroutines.ArrowFxSpec
import arrow.fx.coroutines.Atomic
import arrow.fx.coroutines.ExitCase
import arrow.fx.coroutines.awaitExitCase
import arrow.fx.coroutines.guaranteeCase
import arrow.fx.coroutines.leftException
import arrow.fx.coroutines.never
import arrow.fx.coroutines.parZip
import arrow.fx.coroutines.throwable
import io.kotest.matchers.should
import io.kotest.matchers.shouldBe
import io.kotest.matchers.types.shouldBeInstanceOf
import io.kotest.matchers.types.shouldBeTypeOf
import io.kotest.property.Arb
import io.kotest.property.arbitrary.boolean
import io.kotest.property.arbitrary.int
import kotlinx.coroutines.CoroutineScope
import io.kotest.property.arbitrary.string
import kotlin.time.ExperimentalTime
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitCancellation
import kotlinx.coroutines.channels.Channel

@OptIn(ExperimentalTime::class)
class ParMap2Test : ArrowFxSpec(
spec = {

Expand Down Expand Up @@ -51,8 +57,10 @@ class ParMap2Test : ArrowFxSpec(
val pa = CompletableDeferred<Pair<Int, ExitCase>>()
val pb = CompletableDeferred<Pair<Int, ExitCase>>()

val loserA: suspend CoroutineScope.() -> Int = { guaranteeCase({ s.receive(); never<Int>() }) { ex -> pa.complete(Pair(a, ex)) } }
val loserB: suspend CoroutineScope.() -> Int = { guaranteeCase({ s.receive(); never<Int>() }) { ex -> pb.complete(Pair(b, ex)) } }
val loserA: suspend CoroutineScope.() -> Int =
{ guaranteeCase({ s.receive(); never<Int>() }) { ex -> pa.complete(Pair(a, ex)) } }
val loserB: suspend CoroutineScope.() -> Int =
{ guaranteeCase({ s.receive(); never<Int>() }) { ex -> pb.complete(Pair(b, ex)) } }

val f = async { parZip(loserA, loserB) { _a, _b -> Pair(_a, _b) } }

Expand All @@ -62,38 +70,50 @@ class ParMap2Test : ArrowFxSpec(

pa.await().let { (res, exit) ->
res shouldBe a
exit.shouldBeInstanceOf<ExitCase.Cancelled>()
exit.shouldBeTypeOf<ExitCase.Cancelled>()
}
pb.await().let { (res, exit) ->
res shouldBe b
exit.shouldBeInstanceOf<ExitCase.Cancelled>()
exit.shouldBeTypeOf<ExitCase.Cancelled>()
}
}
}

"parMapN 2 cancels losers if a failure occurs in one of the tasks" {
checkAll(
Arb.throwable(),
Arb.boolean(),
Arb.int()
) { e, leftWinner, a ->
checkAll(Arb.throwable(), Arb.boolean()) { e, leftWinner ->
val s = Channel<Unit>()
val pa = CompletableDeferred<Pair<Int, ExitCase>>()
val pa = CompletableDeferred<ExitCase>()

val winner: suspend CoroutineScope.() -> Unit = { s.send(Unit); throw e }
val loserA: suspend CoroutineScope.() -> Int = { guaranteeCase({ s.receive(); never<Int>() }) { ex -> pa.complete(Pair(a, ex)) } }
val loserA: suspend CoroutineScope.() -> Int =
{ guaranteeCase({ s.receive(); awaitCancellation() }) { ex -> pa.complete(ex) } }

val r = Either.catch {
if (leftWinner) parZip(winner, loserA) { _, _ -> Unit }
else parZip(loserA, winner) { _, _ -> Unit }
}

pa.await().let { (res, exit) ->
res shouldBe a
exit.shouldBeInstanceOf<ExitCase.Cancelled>()
}
pa.await().shouldBeTypeOf<ExitCase.Cancelled>()
r should leftException(e)
}
}

"parMapN CancellationException on right can cancel rest" {
checkAll(Arb.string()) { msg ->
val exit = CompletableDeferred<ExitCase>()
val start = CompletableDeferred<Unit>()
try {
parZip<Unit, Unit, Unit>({
awaitExitCase(start, exit)
}, {
start.await()
throw CancellationException(msg)
}) { _, _ -> }
} catch (e: CancellationException) {
e.message shouldBe msg
}
exit.await().shouldBeTypeOf<ExitCase.Cancelled>()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙌🏾

}
}
}
)
Loading