Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix null bug in Validated#zip & Ior#zip #2338

Merged
merged 1 commit into from
Mar 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
91 changes: 53 additions & 38 deletions arrow-libs/core/arrow-core-data/src/main/kotlin/arrow/core/Ior.kt
Original file line number Diff line number Diff line change
Expand Up @@ -908,59 +908,74 @@ inline fun <A, B, C, D, E, F, G, H, I, J, K, L> Ior<A, B>.zip(
k: Ior<A, K>,
map: (B, C, D, E, F, G, H, I, J, K) -> L
): Ior<A, L> {
val rightValue: L? = Nullable.zip(
(this@zip as? Right)?.value ?: (this@zip as? Both)?.rightValue,
(c as? Right)?.value ?: (c as? Both)?.rightValue,
(d as? Right)?.value ?: (d as? Both)?.rightValue,
(e as? Right)?.value ?: (e as? Both)?.rightValue,
(f as? Right)?.value ?: (f as? Both)?.rightValue,
(g as? Right)?.value ?: (g as? Both)?.rightValue,
(h as? Right)?.value ?: (h as? Both)?.rightValue,
(i as? Right)?.value ?: (i as? Both)?.rightValue,
(j as? Right)?.value ?: (j as? Both)?.rightValue,
(k as? Right)?.value ?: (k as? Both)?.rightValue,
map
)
// If any of the values is Right or Both then we can calculate L otherwise it results in MY_NULL
val rightValue: Any? = if (
(this@zip.isRight || this@zip.isBoth) &&
(c.isRight || c.isBoth) &&
(d.isRight || d.isBoth) &&
(e.isRight || e.isBoth) &&
(f.isRight || f.isBoth) &&
(g.isRight || g.isBoth) &&
(h.isRight || h.isBoth) &&
(i.isRight || i.isBoth) &&
(j.isRight || j.isBoth) &&
(k.isRight || k.isBoth)
) {
map(
this@zip.orNull() as B,
c.orNull() as C,
d.orNull() as D,
e.orNull() as E,
f.orNull() as F,
g.orNull() as G,
h.orNull() as H,
i.orNull() as I,
j.orNull() as J,
k.orNull() as K
)
} else EmptyValue

val leftValue: Any? = SA.run {
var accumulatedLeft: Any? = EmptyValue

val leftValue: A? = SA.run {
var accumulatedLeft: A? = null
if (this@zip is Left) value.maybeCombine(accumulatedLeft) else accumulatedLeft
accumulatedLeft = if (this@zip is Both) leftValue.maybeCombine(accumulatedLeft) else accumulatedLeft
if (this@zip is Left) return@zip Left(this@zip.value)
accumulatedLeft =
if (this@zip is Both) this@zip.leftValue else accumulatedLeft

if (c is Left) return Left(c.value.maybeCombine(accumulatedLeft))
accumulatedLeft = if (c is Both) c.leftValue.maybeCombine(accumulatedLeft) else accumulatedLeft
if (c is Left) return@zip Left(emptyCombine(accumulatedLeft, c.value))
accumulatedLeft = if (c is Both) emptyCombine(accumulatedLeft, c.leftValue) else accumulatedLeft

if (d is Left) return Left(d.value.maybeCombine(accumulatedLeft))
accumulatedLeft = if (d is Both) d.leftValue.maybeCombine(accumulatedLeft) else accumulatedLeft
if (d is Left) return@zip Left(emptyCombine(accumulatedLeft, d.value))
accumulatedLeft = if (d is Both) emptyCombine(accumulatedLeft, d.leftValue) else accumulatedLeft

if (e is Left) return Left(e.value.maybeCombine(accumulatedLeft))
accumulatedLeft = if (e is Both) e.leftValue.maybeCombine(accumulatedLeft) else accumulatedLeft
if (e is Left) return@zip Left(emptyCombine(accumulatedLeft, e.value))
accumulatedLeft = if (e is Both) emptyCombine(accumulatedLeft, e.leftValue) else accumulatedLeft

if (f is Left) return Left(f.value.maybeCombine(accumulatedLeft))
accumulatedLeft = if (f is Both) f.leftValue.maybeCombine(accumulatedLeft) else accumulatedLeft
if (f is Left) return@zip Left(emptyCombine(accumulatedLeft, f.value))
accumulatedLeft = if (f is Both) emptyCombine(accumulatedLeft, f.leftValue) else accumulatedLeft

if (g is Left) return Left(g.value.maybeCombine(accumulatedLeft))
accumulatedLeft = if (g is Both) g.leftValue.maybeCombine(accumulatedLeft) else accumulatedLeft
if (g is Left) return@zip Left(emptyCombine(accumulatedLeft, g.value))
accumulatedLeft = if (g is Both) emptyCombine(accumulatedLeft, g.leftValue) else accumulatedLeft

if (h is Left) return Left(h.value.maybeCombine(accumulatedLeft))
accumulatedLeft = if (h is Both) h.leftValue.maybeCombine(accumulatedLeft) else accumulatedLeft
if (h is Left) return@zip Left(emptyCombine(accumulatedLeft, h.value))
accumulatedLeft = if (h is Both) emptyCombine(accumulatedLeft, h.leftValue) else accumulatedLeft

if (i is Left) return Left(i.value.maybeCombine(accumulatedLeft))
accumulatedLeft = if (i is Both) i.leftValue.maybeCombine(accumulatedLeft) else accumulatedLeft
if (i is Left) return@zip Left(emptyCombine(accumulatedLeft, i.value))
accumulatedLeft = if (i is Both) emptyCombine(accumulatedLeft, i.leftValue) else accumulatedLeft

if (j is Left) return Left(j.value.maybeCombine(accumulatedLeft))
accumulatedLeft = if (j is Both) j.leftValue.maybeCombine(accumulatedLeft) else accumulatedLeft
if (j is Left) return@zip Left(emptyCombine(accumulatedLeft, j.value))
accumulatedLeft = if (j is Both) emptyCombine(accumulatedLeft, j.leftValue) else accumulatedLeft

if (k is Left) return Left(k.value.maybeCombine(accumulatedLeft))
accumulatedLeft = if (k is Both) k.leftValue.maybeCombine(accumulatedLeft) else accumulatedLeft
if (k is Left) return@zip Left(emptyCombine(accumulatedLeft, k.value))
accumulatedLeft = if (k is Both) emptyCombine(accumulatedLeft, k.leftValue) else accumulatedLeft

accumulatedLeft
}

return when {
rightValue != null && leftValue == null -> Right(rightValue)
rightValue != null && leftValue != null -> Both(leftValue, rightValue)
rightValue == null && leftValue != null -> Left(leftValue)
rightValue != EmptyValue && leftValue == EmptyValue -> Right(rightValue as L)
rightValue != EmptyValue && leftValue != EmptyValue -> Both(leftValue as A, rightValue as L)
rightValue == EmptyValue && leftValue != EmptyValue -> Left(leftValue as A)
else -> throw ArrowCoreInternalException
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -942,28 +942,29 @@ inline fun <E, A, B, C, D, EE, F, G, H, I, J, Z> Validated<E, A>.zip(
if (this is Validated.Valid && b is Validated.Valid && c is Validated.Valid && d is Validated.Valid && e is Validated.Valid && ff is Validated.Valid && g is Validated.Valid && h is Validated.Valid && i is Validated.Valid && j is Validated.Valid) {
Validated.Valid(f(this.a, b.a, c.a, d.a, e.a, ff.a, g.a, h.a, i.a, j.a))
} else SE.run {
var accumulatedError: E? = null
var accumulatedError: Any? = EmptyValue
accumulatedError =
if (this@zip is Validated.Invalid) this@zip.e.maybeCombine(accumulatedError) else accumulatedError
if (this@zip is Validated.Invalid) this@zip.e else accumulatedError
accumulatedError =
if (b is Validated.Invalid) accumulatedError?.let { it.combine(b.e) } ?: b.e else accumulatedError
if (b is Validated.Invalid) emptyCombine(accumulatedError, b.e) else accumulatedError
accumulatedError =
if (c is Validated.Invalid) accumulatedError?.let { it.combine(c.e) } ?: c.e else accumulatedError
if (c is Validated.Invalid) emptyCombine(accumulatedError, c.e) else accumulatedError
accumulatedError =
if (d is Validated.Invalid) accumulatedError?.let { it.combine(d.e) } ?: d.e else accumulatedError
if (d is Validated.Invalid) emptyCombine(accumulatedError, d.e) else accumulatedError
accumulatedError =
if (e is Validated.Invalid) accumulatedError?.let { it.combine(e.e) } ?: e.e else accumulatedError
if (e is Validated.Invalid) emptyCombine(accumulatedError, e.e) else accumulatedError
accumulatedError =
if (ff is Validated.Invalid) accumulatedError?.let { it.combine(ff.e) } ?: ff.e else accumulatedError
if (ff is Validated.Invalid) emptyCombine(accumulatedError, ff.e) else accumulatedError
accumulatedError =
if (g is Validated.Invalid) accumulatedError?.let { it.combine(g.e) } ?: g.e else accumulatedError
if (g is Validated.Invalid) emptyCombine(accumulatedError, g.e)else accumulatedError
accumulatedError =
if (h is Validated.Invalid) accumulatedError?.let { it.combine(h.e) } ?: h.e else accumulatedError
if (h is Validated.Invalid) emptyCombine(accumulatedError, h.e) else accumulatedError
accumulatedError =
if (i is Validated.Invalid) accumulatedError?.let { it.combine(i.e) } ?: i.e else accumulatedError
if (i is Validated.Invalid) emptyCombine(accumulatedError, i.e) else accumulatedError
accumulatedError =
if (j is Validated.Invalid) accumulatedError?.let { it.combine(j.e) } ?: j.e else accumulatedError
Validated.Invalid(accumulatedError!!)
if (j is Validated.Invalid) emptyCombine(accumulatedError, j.e) else accumulatedError

Validated.Invalid(accumulatedError as E)
}

inline fun <E, A, B, Z> ValidatedNel<E, A>.zip(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package arrow.core

import arrow.typeclasses.Semigroup

inline fun <A> identity(a: A): A = a

inline fun <A, B, Z> ((A, B) -> Z).curry(): (A) -> (B) -> Z = { p1: A -> { p2: B -> this(p1, p2) } }
Expand All @@ -20,3 +22,23 @@ internal object ArrowCoreInternalException : RuntimeException(

const val TailRecMDeprecation: String =
"tailRecM is deprecated together with the Kind type classes since it's meant for writing kind-based polymorphic stack-safe programs."

/**
* This is a work-around for having nested nulls in generic code.
* This allows for writing faster generic code instead of using `Option`.
* This is only used as an optimisation technique in low-level code,
* always prefer to use `Option` in actual business code when needed in generic code.
*/
@PublishedApi
internal object EmptyValue {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👌

@Suppress("UNCHECKED_CAST", "NOTHING_TO_INLINE")
inline fun <A> unbox(value: Any?): A =
if (value === this) null as A else value as A
}

/**
* Like [Semigroup.maybeCombine] but for using with [EmptyValue]
*/
@PublishedApi
internal fun <T> Semigroup<T>.emptyCombine(first: Any?, second: T): T =
if (first == EmptyValue) second else (first as T).combine(second)
Original file line number Diff line number Diff line change
Expand Up @@ -79,21 +79,67 @@ class IorTest : UnitSpec() {
BicrosswalkLaws.laws(Ior.bicrosswalk(), Ior.genK2(), Ior.eqK2())
)

val nullableLongSemigroup = object : Semigroup<Long?> {
override fun Long?.combine(b: Long?): Long? =
Nullable.zip(this, b) { a, bb -> a + bb }
}

"zip identity" {
forAll(Gen.ior(Gen.long(), Gen.int())) { ior ->
val res = ior.zip(Semigroup.long(), Ior.Right(Unit)) { a, _ -> a }
forAll(Gen.ior(Gen.long().orNull(), Gen.int().orNull())) { ior ->
val res = ior.zip(nullableLongSemigroup, Ior.Right(Unit)) { a, _ -> a }
res == ior
}
}

"zip short-circuits on left & accumulates with both" {
forAll(Gen.ior(Gen.long(), Gen.int()).filterNot(Ior<Long, Int>::isLeft), Gen.long()) { ior, l ->
val res = ior.zip(Semigroup.long(), Ior.Left(l)) { a, _ -> a }
val expected = ior.leftOrNull()?.let { Semigroup.long().run { Ior.Left(it.combine(l)) } } ?: Ior.Left(l)
"zip is derived from flatMap" {
forAll(
Gen.ior(Gen.long().orNull(), Gen.int().orNull()),
Gen.ior(Gen.long().orNull(), Gen.int().orNull()),
Gen.ior(Gen.long().orNull(), Gen.int().orNull()),
Gen.ior(Gen.long().orNull(), Gen.int().orNull()),
Gen.ior(Gen.long().orNull(), Gen.int().orNull()),
Gen.ior(Gen.long().orNull(), Gen.int().orNull()),
Gen.ior(Gen.long().orNull(), Gen.int().orNull()),
Gen.ior(Gen.long().orNull(), Gen.int().orNull()),
Gen.ior(Gen.long().orNull(), Gen.int().orNull()),
Gen.ior(Gen.long().orNull(), Gen.int().orNull())
) { a, b, c, d, e, f, g, h, i, j ->
val res = a.zip(
nullableLongSemigroup,
b, c, d, e, f, g, h, i, j
) { a, b, c, d, e, f, g, h, i, j ->
Nullable.zip(
a,
b,
c,
d,
e,
f,
g,
h,
i,
j
) { a, b, c, d, e, f, g, h, i, j -> a + b + c + d + e + f + g + h + i + j }
}

val expected = listOf(a, b, c, d, e, f, g, h, i, j)
.fold<Ior<Long?, Int?>, Ior<Long?, Int?>>(Ior.Right(0)) { acc, ior ->
val mid = acc.flatMap(nullableLongSemigroup) { a -> ior.map { b -> Nullable.zip(a, b) { a, b -> a + b } } }
mid
}

res == expected
}
}

"zip should combine left values in correct order" {
Ior.Both("fail1", -1).zip(
Semigroup.string(),
Ior.Left("fail2"),
Ior.Right(-1)
) { _, _, _ -> "success!" } shouldBe Ior.Left("fail1fail2")
}

"bimap() should allow modify both value" {
forAll { a: Int, b: String ->
Ior.Right(b).bimap({ "5" }, { a * 2 }) == Ior.Right(a * 2) &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ import arrow.core.test.laws.SemigroupKLaws
import arrow.core.test.laws.ShowLaws
import arrow.core.test.laws.TraverseLaws
import arrow.typeclasses.Eq
import arrow.typeclasses.Monoid
import arrow.typeclasses.Semigroup
import io.kotlintest.fail
import io.kotlintest.properties.Gen
import io.kotlintest.properties.forAll
import io.kotlintest.shouldBe

@Suppress("RedundantSuspendModifier")
Expand Down Expand Up @@ -201,6 +203,60 @@ class ValidatedTest : UnitSpec() {
Invalid(10).findValid(Semigroup.int()) { Invalid(5) } shouldBe Invalid(15)
}

val nullableLongSemigroup = object : Monoid<Long?> {
override fun empty(): Long? = 0
override fun Long?.combine(b: Long?): Long? =
Nullable.zip(this@combine, b) { a, bb -> a + bb }
}

"zip identity" {
forAll(Gen.validated(Gen.long().orNull(), Gen.int().orNull())) { validated ->
val res = validated.zip(nullableLongSemigroup, Valid(Unit)) { a, _ -> a }
res == validated
}
}

"zip is derived from flatMap" {
forAll(
Gen.validated(Gen.long().orNull(), Gen.int().orNull()),
Gen.validated(Gen.long().orNull(), Gen.int().orNull()),
Gen.validated(Gen.long().orNull(), Gen.int().orNull()),
Gen.validated(Gen.long().orNull(), Gen.int().orNull()),
Gen.validated(Gen.long().orNull(), Gen.int().orNull()),
Gen.validated(Gen.long().orNull(), Gen.int().orNull()),
Gen.validated(Gen.long().orNull(), Gen.int().orNull()),
Gen.validated(Gen.long().orNull(), Gen.int().orNull()),
Gen.validated(Gen.long().orNull(), Gen.int().orNull()),
Gen.validated(Gen.long().orNull(), Gen.int().orNull())
) { a, b, c, d, e, f, g, h, i, j ->
val res = a.zip(
nullableLongSemigroup,
b, c, d, e, f, g, h, i, j
) { a, b, c, d, e, f, g, h, i, j ->
Nullable.zip(
a,
b,
c,
d,
e,
f,
g,
h,
i,
j
) { a, b, c, d, e, f, g, h, i, j -> a + b + c + d + e + f + g + h + i + j }
}

val all = listOf(a, b, c, d, e, f, g, h, i, j)
val isValid = all.all(Validated<Long?, Int?>::isValid)
val expected: Validated<Long?, Int?> =
if (isValid) Valid(all.fold<Validated<Long?, Int?>, Int?>(0) { acc, validated -> Nullable.zip(acc, validated.orNull()) { a, b -> a + b } })
else Invalid(all.filterIsInstance<Invalid<Long?>>().map(Invalid<Long?>::value).combineAll(nullableLongSemigroup))

res == expected
}
}

"zip should return Valid(f(a)) if both are Valid" {
Valid(10).zip(Semigroup.int(), Valid { a: Int -> a + 5 }) { a, ff -> ff(a) } shouldBe Valid(15)
}
Expand Down