From 3037ae23753c4fdbdc80b56e1e0873dbdc1df455 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 27 May 2021 11:09:47 +0100 Subject: [PATCH 1/9] add depwarn message --- src/differentials/one.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/differentials/one.jl b/src/differentials/one.jl index 14390a3e5..6cbbeeeb7 100644 --- a/src/differentials/one.jl +++ b/src/differentials/one.jl @@ -3,7 +3,9 @@ The Differential which is the multiplicative identity. Basically, this represents `1`. """ -struct One <: AbstractTangent end +struct One <: AbstractTangent + One() = (Base.depwarn("`One()` is deprecated; use `true` instead", :One); return new()) +end extern(x::One) = true # true is a strong 1. From ecbe9530a53ef8949062e981eb65ba15aa7b658f Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 27 May 2021 11:12:04 +0100 Subject: [PATCH 2/9] update docs --- docs/src/index.md | 4 ++-- docs/src/writing_good_rules.md | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index ca868886b..62c681039 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -316,7 +316,7 @@ Most importantly: `+` and `*`, which let them act as mathematical objects. The most important `AbstractTangent`s when getting started are the ones about avoiding work: - [`Thunk`](@ref): this is a deferred computation. A thunk is a [word for a zero argument closure](https://en.wikipedia.org/wiki/Thunk). A computation wrapped in a `@thunk` doesn't get evaluated until [`unthunk`](@ref) is called on the thunk. `unthunk` is a no-op on non-thunked inputs. - - [`One`](@ref), [`ZeroTangent`](@ref): There are special representations of `1` and `0`. They do great things around avoiding expanding `Thunks` in multiplication and (for `ZeroTangent`) addition. + - [`ZeroTangent`](@ref): It is a special representation of `0`. It does great things around avoiding expanding `Thunks` in addition. ### Other `AbstractTangent`s: - [`Tangent{P}`](@ref Tangent): this is the differential for tuples and structs. Use it like a `Tuple` or `NamedTuple`. The type parameter `P` is for the primal type. @@ -345,7 +345,7 @@ end # Define rules (alternatively get them for free via `using ChainRules`) @scalar_rule(sin(x), cos(x)) -@scalar_rule(+(x, y), (One(), One())) +@scalar_rule(+(x, y), (1.0, 1.0)) @scalar_rule(asin(x), inv(sqrt(1 - x^2))) # output diff --git a/docs/src/writing_good_rules.md b/docs/src/writing_good_rules.md index 8c505ca39..afce46e37 100644 --- a/docs/src/writing_good_rules.md +++ b/docs/src/writing_good_rules.md @@ -1,10 +1,9 @@ # On writing good `rrule` / `frule` methods -## Use `ZeroTangent()` or `One()` as return value +## Use `ZeroTangent()` as the return value -The `ZeroTangent()` and `One()` differential objects exist as an alternative to directly returning -`0` or `zeros(n)`, and `1` or `I`. -They allow more optimal computation when chaining pullbacks/pushforwards, to avoid work. +The `ZeroTangent()` object exists as an alternative to directly returning `0` or `zeros(n)`. +It allows more optimal computation when chaining pullbacks/pushforwards, to avoid work. They should be used where possible. ## Use `Thunk`s appropriately From 0449df06337219da357f866cff3218b216cd05f6 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 27 May 2021 11:12:52 +0100 Subject: [PATCH 3/9] update tests --- test/deprecated.jl | 66 ++++++++++++++++++++++++++++ test/differentials/abstract_zero.jl | 3 +- test/differentials/composite.jl | 4 +- test/differentials/notimplemented.jl | 14 +++--- test/differentials/one.jl | 25 ----------- test/rules.jl | 8 ++-- 6 files changed, 80 insertions(+), 40 deletions(-) delete mode 100644 test/differentials/one.jl diff --git a/test/deprecated.jl b/test/deprecated.jl index 6f8403725..2aa785e63 100644 --- a/test/deprecated.jl +++ b/test/deprecated.jl @@ -1,6 +1,72 @@ +# Define some rules to test One on +dummy_identity(x) = x +@scalar_rule(dummy_identity(x), One()) + +very_nice(x, y) = x + y +@scalar_rule(very_nice(x, y), (One(), One())) + @testset "deprecations" begin @test ChainRulesCore.AbstractDifferential === ChainRulesCore.AbstractTangent @test Zero === ZeroTangent @test DoesNotExist === NoTangent @test Composite === Tangent end + +@testset "One()" begin + + o = One() + @test extern(o) === true + @test o + o == 2 + @test o + 1 == 2 + @test 1 + o == 2 + @test o * o == o + @test o * 17 == 17 + @test 6 * o == 6 + @test dot(2 + im, o) == 2 - im + @test dot(o, 2 + im) == 2 + im + for x in o + @test x === o + end + @test broadcastable(o) isa Ref{One} + @test conj(o) == o + + @test reim(o) === (One(), ZeroTangent()) + @test real(o) === One() + @test imag(o) === ZeroTangent() + + @test complex(o) === o + @test complex(o, ZeroTangent()) === o + @test complex(ZeroTangent(), o) === im + + @test frule((nothing, nothing, 5.0), Core._apply, dummy_identity, 4.0) == (4.0, 5.0) + + @testset "broadcasting One" begin + sx = @SVector [1, 2] + sy = @SVector [3, 4] + + # Test that @scalar_rule and `One()` play nice together, w.r.t broadcasting + @inferred frule((ZeroTangent(), sx, sy), very_nice, 1, 2) + end + + @testset "interaction with other types" begin + c = Tangent{Foo}(y=1.5, x=2.5) + @test One() * c === c + @test c * One() === c + + z = ZeroTangent() + @test zero(One()) === z + @test zero(One) === z + + ni = ChainRulesCore.NotImplemented( + @__MODULE__, LineNumberNode(@__LINE__, @__FILE__), "error" + ) + @test ni + One() === ni + @test One() + ni === ni + E = ChainRulesCore.NotImplementedException + @test_throws E ni - One() + @test_throws E One() - ni + @test_throws E One() * ni + @test_throws E dot(ni, One()) + @test_throws E dot(One(), ni) + end +end diff --git a/test/differentials/abstract_zero.jl b/test/differentials/abstract_zero.jl index 1ae19663c..5948efb15 100644 --- a/test/differentials/abstract_zero.jl +++ b/test/differentials/abstract_zero.jl @@ -27,9 +27,8 @@ end @test broadcastable(z) isa Ref{ZeroTangent} @test zero(@thunk(3)) === z - @test zero(One()) === z + @test zero(true) === z @test zero(NoTangent()) === z - @test zero(One) === z @test zero(ZeroTangent) === z @test zero(NoTangent) === z @test zero(Tangent{Tuple{Int,Int}}((1, 2))) === z diff --git a/test/differentials/composite.jl b/test/differentials/composite.jl index 0c66e1684..9a64e7f1c 100644 --- a/test/differentials/composite.jl +++ b/test/differentials/composite.jl @@ -304,8 +304,8 @@ end @test dot(ZeroTangent(), c) == ZeroTangent() @test dot(c, ZeroTangent()) == ZeroTangent() - @test One() * c === c - @test c * One() === c + @test true * c === c + @test c * true === c t = @thunk 2 @test t * c == 2 * c diff --git a/test/differentials/notimplemented.jl b/test/differentials/notimplemented.jl index ffd3320c5..e80a1461e 100644 --- a/test/differentials/notimplemented.jl +++ b/test/differentials/notimplemented.jl @@ -31,12 +31,12 @@ @test ni + rand() === ni @test ni + ZeroTangent() === ni @test ni + NoTangent() === ni - @test ni + One() === ni + @test ni + true === ni @test ni + @thunk(x^2) === ni @test rand() + ni === ni @test ZeroTangent() + ni === ni @test NoTangent() + ni === ni - @test One() + ni === ni + @test true + ni === ni @test @thunk(x^2) + ni === ni @test ni + ni2 === ni @test ni * rand() === ni @@ -55,26 +55,26 @@ @test_throws E ni - rand() @test_throws E ni - ZeroTangent() @test_throws E ni - NoTangent() - @test_throws E ni - One() + @test_throws E ni - true @test_throws E ni - @thunk(x^2) @test_throws E rand() - ni @test_throws E ZeroTangent() - ni @test_throws E NoTangent() - ni - @test_throws E One() - ni + @test_throws E true - ni @test_throws E @thunk(x^2) - ni @test_throws E ni - ni2 @test_throws E rand() * ni @test_throws E NoTangent() * ni - @test_throws E One() * ni + @test_throws E true * ni @test_throws E @thunk(x^2) * ni @test_throws E ni * ni2 @test_throws E dot(ni, rand()) @test_throws E dot(ni, NoTangent()) - @test_throws E dot(ni, One()) + @test_throws E dot(ni, true) @test_throws E dot(ni, @thunk(x^2)) @test_throws E dot(rand(), ni) @test_throws E dot(NoTangent(), ni) - @test_throws E dot(One(), ni) + @test_throws E dot(true, ni) @test_throws E dot(@thunk(x^2), ni) @test_throws E dot(ni, ni2) @test_throws E ni / rand() diff --git a/test/differentials/one.jl b/test/differentials/one.jl deleted file mode 100644 index f9a51b2c1..000000000 --- a/test/differentials/one.jl +++ /dev/null @@ -1,25 +0,0 @@ -@testset "One" begin - o = One() - @test extern(o) === true - @test o + o == 2 - @test o + 1 == 2 - @test 1 + o == 2 - @test o * o == o - @test o * 17 == 17 - @test 6 * o == 6 - @test dot(2 + im, o) == 2 - im - @test dot(o, 2 + im) == 2 + im - for x in o - @test x === o - end - @test broadcastable(o) isa Ref{One} - @test conj(o) == o - - @test reim(o) === (One(), ZeroTangent()) - @test real(o) === One() - @test imag(o) === ZeroTangent() - - @test complex(o) === o - @test complex(o, ZeroTangent()) === o - @test complex(ZeroTangent(), o) === im -end diff --git a/test/rules.jl b/test/rules.jl index 31a76f982..533a4cd8d 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -6,13 +6,13 @@ cool(x, y) = x + y + 1 # a rule we define so we can test rules dummy_identity(x) = x -@scalar_rule(dummy_identity(x), One()) +@scalar_rule(dummy_identity(x), true) nice(x) = 1 @scalar_rule(nice(x), ZeroTangent()) very_nice(x, y) = x + y -@scalar_rule(very_nice(x, y), (One(), One())) +@scalar_rule(very_nice(x, y), (true, true)) complex_times(x) = (1 + 2im) * x @scalar_rule(complex_times(x), 1 + 2im) @@ -116,11 +116,11 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @test frule((nothing, nothing, 5.0), Core._apply, dummy_identity, 4.0) == (4.0, 5.0) - @testset "broadcasting One" begin + @testset "broadcasting true" begin sx = @SVector [1, 2] sy = @SVector [3, 4] - # Test that @scalar_rule and `One()` play nice together, w.r.t broadcasting + # Test that @scalar_rule and `true` play nice together, w.r.t broadcasting @inferred frule((ZeroTangent(), sx, sy), very_nice, 1, 2) end From 32143e2c47395c8ff6f5aa5cd87ce6de5f6de743 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 27 May 2021 11:14:26 +0100 Subject: [PATCH 4/9] fix tests --- test/differentials/abstract_zero.jl | 1 - test/runtests.jl | 1 - 2 files changed, 2 deletions(-) diff --git a/test/differentials/abstract_zero.jl b/test/differentials/abstract_zero.jl index 5948efb15..7a1adb577 100644 --- a/test/differentials/abstract_zero.jl +++ b/test/differentials/abstract_zero.jl @@ -27,7 +27,6 @@ end @test broadcastable(z) isa Ref{ZeroTangent} @test zero(@thunk(3)) === z - @test zero(true) === z @test zero(NoTangent()) === z @test zero(ZeroTangent) === z @test zero(NoTangent) === z diff --git a/test/runtests.jl b/test/runtests.jl index 36d1fa55c..080258c46 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,7 +9,6 @@ using Test @testset "ChainRulesCore" begin @testset "differentials" begin include("differentials/abstract_zero.jl") - include("differentials/one.jl") include("differentials/thunks.jl") include("differentials/composite.jl") include("differentials/notimplemented.jl") From 61ac743caebec42f6cee3c24e4c87d877e73096c Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 27 May 2021 11:15:04 +0100 Subject: [PATCH 5/9] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index db0b66414..c222e0394 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.44" +version = "0.9.45" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" From 821ad3a6352ff7b3ee5fb0202b85dce17b480546 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 27 May 2021 11:17:30 +0100 Subject: [PATCH 6/9] test deprecation of One() --- test/deprecated.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/deprecated.jl b/test/deprecated.jl index 2aa785e63..4035a8989 100644 --- a/test/deprecated.jl +++ b/test/deprecated.jl @@ -10,6 +10,7 @@ very_nice(x, y) = x + y @test Zero === ZeroTangent @test DoesNotExist === NoTangent @test Composite === Tangent + @test_deprecated One() end @testset "One()" begin From 7de64f4ee4d1e829551f1ad1f130f768fd40ff36 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 27 May 2021 12:26:20 +0100 Subject: [PATCH 7/9] rename function --- test/rules.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/rules.jl b/test/rules.jl index 533a4cd8d..78334e229 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -11,8 +11,8 @@ dummy_identity(x) = x nice(x) = 1 @scalar_rule(nice(x), ZeroTangent()) -very_nice(x, y) = x + y -@scalar_rule(very_nice(x, y), (true, true)) +sum_two(x, y) = x + y +@scalar_rule(sum_two(x, y), (true, true)) complex_times(x) = (1 + 2im) * x @scalar_rule(complex_times(x), 1 + 2im) @@ -121,7 +121,7 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) sy = @SVector [3, 4] # Test that @scalar_rule and `true` play nice together, w.r.t broadcasting - @inferred frule((ZeroTangent(), sx, sy), very_nice, 1, 2) + @inferred frule((ZeroTangent(), sx, sy), sum_two, 1, 2) end @testset "complex inputs" begin From 4ca0f15bde4c9013895dba7459f7b5b1d07218c3 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 27 May 2021 17:58:32 +0100 Subject: [PATCH 8/9] comment on using true --- src/rule_definition_tools.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index a9a0aaf67..ffa90d7cd 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -56,6 +56,8 @@ derivative/setup expressions. This macro assumes complex functions are holomorphic. In general, for non-holomorphic functions, the `frule` and `rrule` must be defined manually. +Prefer using `true` over `1.0` to express multiplicative identity. + The `@setup` argument can be elided if no setup code is need. In other words: From 08453d5cc939fa1030dadb050007f443080ff64b Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 27 May 2021 19:35:18 +0100 Subject: [PATCH 9/9] Update src/rule_definition_tools.jl Co-authored-by: Lyndon White --- src/rule_definition_tools.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index ffa90d7cd..95f4f338a 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -56,7 +56,8 @@ derivative/setup expressions. This macro assumes complex functions are holomorphic. In general, for non-holomorphic functions, the `frule` and `rrule` must be defined manually. -Prefer using `true` over `1.0` to express multiplicative identity. +If the derivative is one, (e.g. for identity functions) `true` can be used as the most +general multiplicative identity. The `@setup` argument can be elided if no setup code is need. In other words: