From 6488bb1d589c4b37323637afa3f15c6790e55e1b Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Wed, 9 Jun 2021 16:24:36 +0100 Subject: [PATCH 1/2] implement == for thunks --- Project.toml | 2 +- src/differentials/thunks.jl | 4 ++++ test/differentials/thunks.jl | 6 ++++++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index f5273df1d..b46577795 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.10.2" +version = "0.10.3" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/src/differentials/thunks.jl b/src/differentials/thunks.jl index efaebb0c7..0f54d8d23 100644 --- a/src/differentials/thunks.jl +++ b/src/differentials/thunks.jl @@ -13,6 +13,10 @@ end return element, (val, new_state) end +Base.:(==)(a::AbstractThunk, b::AbstractThunk) = unthunk(a) == unthunk(b) +Base.:(==)(a::AbstractThunk, b) = unthunk(a) == b +Base.:(==)(a, b::AbstractThunk) = a == unthunk(b) + """ @thunk expr diff --git a/test/differentials/thunks.jl b/test/differentials/thunks.jl index 2fb532fb0..28340287a 100644 --- a/test/differentials/thunks.jl +++ b/test/differentials/thunks.jl @@ -1,6 +1,12 @@ @testset "Thunk" begin @test @thunk(3) isa Thunk + @testset "==" begin + @test @thunk(3.2) == @thunk(3.2) + @test @thunk(3.2) == 3.2 + @test 3.2 == InplaceableThunk(@thunk(3.2), x -> x + 3.2) + end + @testset "show" begin rep = repr(Thunk(rand)) @test occursin(r"Thunk\(.*rand.*\)", rep) From f6d6bddb1983f03ffc2c8345574466b24a5c8196 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Wed, 9 Jun 2021 16:28:27 +0100 Subject: [PATCH 2/2] make test more explicit --- test/differentials/thunks.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/differentials/thunks.jl b/test/differentials/thunks.jl index 28340287a..8cd83a92f 100644 --- a/test/differentials/thunks.jl +++ b/test/differentials/thunks.jl @@ -2,7 +2,7 @@ @test @thunk(3) isa Thunk @testset "==" begin - @test @thunk(3.2) == @thunk(3.2) + @test @thunk(3.2) == InplaceableThunk(@thunk(3.2), x -> x + 3.2) @test @thunk(3.2) == 3.2 @test 3.2 == InplaceableThunk(@thunk(3.2), x -> x + 3.2) end