Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.9.44"
version = "0.9.45"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
4 changes: 2 additions & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ Most importantly: `+` and `*`, which let them act as mathematical objects.
The most important `AbstractTangent`s when getting started are the ones about avoiding work:

- [`Thunk`](@ref): this is a deferred computation. A thunk is a [word for a zero argument closure](https://en.wikipedia.org/wiki/Thunk). A computation wrapped in a `@thunk` doesn't get evaluated until [`unthunk`](@ref) is called on the thunk. `unthunk` is a no-op on non-thunked inputs.
- [`One`](@ref), [`ZeroTangent`](@ref): There are special representations of `1` and `0`. They do great things around avoiding expanding `Thunks` in multiplication and (for `ZeroTangent`) addition.
- [`ZeroTangent`](@ref): It is a special representation of `0`. It does great things around avoiding expanding `Thunks` in addition.

### Other `AbstractTangent`s:
- [`Tangent{P}`](@ref Tangent): this is the differential for tuples and structs. Use it like a `Tuple` or `NamedTuple`. The type parameter `P` is for the primal type.
Expand Down Expand Up @@ -345,7 +345,7 @@ end

# Define rules (alternatively get them for free via `using ChainRules`)
@scalar_rule(sin(x), cos(x))
@scalar_rule(+(x, y), (One(), One()))
@scalar_rule(+(x, y), (1.0, 1.0))
@scalar_rule(asin(x), inv(sqrt(1 - x^2)))
# output

Expand Down
7 changes: 3 additions & 4 deletions docs/src/writing_good_rules.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# On writing good `rrule` / `frule` methods

## Use `ZeroTangent()` or `One()` as return value
## Use `ZeroTangent()` as the return value

The `ZeroTangent()` and `One()` differential objects exist as an alternative to directly returning
`0` or `zeros(n)`, and `1` or `I`.
They allow more optimal computation when chaining pullbacks/pushforwards, to avoid work.
The `ZeroTangent()` object exists as an alternative to directly returning `0` or `zeros(n)`.
It allows more optimal computation when chaining pullbacks/pushforwards, to avoid work.
They should be used where possible.

## Use `Thunk`s appropriately
Expand Down
4 changes: 3 additions & 1 deletion src/differentials/one.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
The Differential which is the multiplicative identity.
Basically, this represents `1`.
"""
struct One <: AbstractTangent end
struct One <: AbstractTangent
One() = (Base.depwarn("`One()` is deprecated; use `true` instead", :One); return new())
end

extern(x::One) = true # true is a strong 1.

Expand Down
3 changes: 3 additions & 0 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ derivative/setup expressions.
This macro assumes complex functions are holomorphic. In general, for non-holomorphic
functions, the `frule` and `rrule` must be defined manually.

If the derivative is one, (e.g. for identity functions) `true` can be used as the most
general multiplicative identity.

The `@setup` argument can be elided if no setup code is need. In other
words:

Expand Down
67 changes: 67 additions & 0 deletions test/deprecated.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,73 @@
# Define some rules to test One on
dummy_identity(x) = x
@scalar_rule(dummy_identity(x), One())

very_nice(x, y) = x + y
@scalar_rule(very_nice(x, y), (One(), One()))

@testset "deprecations" begin
@test ChainRulesCore.AbstractDifferential === ChainRulesCore.AbstractTangent
@test Zero === ZeroTangent
@test DoesNotExist === NoTangent
@test Composite === Tangent
@test_deprecated One()
end

@testset "One()" begin

o = One()
@test extern(o) === true
@test o + o == 2
@test o + 1 == 2
@test 1 + o == 2
@test o * o == o
@test o * 17 == 17
@test 6 * o == 6
@test dot(2 + im, o) == 2 - im
@test dot(o, 2 + im) == 2 + im
for x in o
@test x === o
end
@test broadcastable(o) isa Ref{One}
@test conj(o) == o

@test reim(o) === (One(), ZeroTangent())
@test real(o) === One()
@test imag(o) === ZeroTangent()

@test complex(o) === o
@test complex(o, ZeroTangent()) === o
@test complex(ZeroTangent(), o) === im

@test frule((nothing, nothing, 5.0), Core._apply, dummy_identity, 4.0) == (4.0, 5.0)

@testset "broadcasting One" begin
sx = @SVector [1, 2]
sy = @SVector [3, 4]

# Test that @scalar_rule and `One()` play nice together, w.r.t broadcasting
@inferred frule((ZeroTangent(), sx, sy), very_nice, 1, 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this testing? it doesn't appear to be using One anywhere

Copy link
Member Author

@mzgubic mzgubic May 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is testing that One() * sx is type stable?

end

@testset "interaction with other types" begin
c = Tangent{Foo}(y=1.5, x=2.5)
@test One() * c === c
@test c * One() === c

z = ZeroTangent()
@test zero(One()) === z
@test zero(One) === z

ni = ChainRulesCore.NotImplemented(
@__MODULE__, LineNumberNode(@__LINE__, @__FILE__), "error"
)
@test ni + One() === ni
@test One() + ni === ni
E = ChainRulesCore.NotImplementedException
@test_throws E ni - One()
@test_throws E One() - ni
@test_throws E One() * ni
@test_throws E dot(ni, One())
@test_throws E dot(One(), ni)
end
end
2 changes: 0 additions & 2 deletions test/differentials/abstract_zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@
end
@test broadcastable(z) isa Ref{ZeroTangent}
@test zero(@thunk(3)) === z
@test zero(One()) === z
@test zero(NoTangent()) === z
@test zero(One) === z
@test zero(ZeroTangent) === z
@test zero(NoTangent) === z
@test zero(Tangent{Tuple{Int,Int}}((1, 2))) === z
Expand Down
4 changes: 2 additions & 2 deletions test/differentials/composite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,8 @@ end
@test dot(ZeroTangent(), c) == ZeroTangent()
@test dot(c, ZeroTangent()) == ZeroTangent()

@test One() * c === c
@test c * One() === c
@test true * c === c
@test c * true === c

t = @thunk 2
@test t * c == 2 * c
Expand Down
14 changes: 7 additions & 7 deletions test/differentials/notimplemented.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@
@test ni + rand() === ni
@test ni + ZeroTangent() === ni
@test ni + NoTangent() === ni
@test ni + One() === ni
@test ni + true === ni
@test ni + @thunk(x^2) === ni
@test rand() + ni === ni
@test ZeroTangent() + ni === ni
@test NoTangent() + ni === ni
@test One() + ni === ni
@test true + ni === ni
@test @thunk(x^2) + ni === ni
@test ni + ni2 === ni
@test ni * rand() === ni
Expand All @@ -55,26 +55,26 @@
@test_throws E ni - rand()
@test_throws E ni - ZeroTangent()
@test_throws E ni - NoTangent()
@test_throws E ni - One()
@test_throws E ni - true
@test_throws E ni - @thunk(x^2)
@test_throws E rand() - ni
@test_throws E ZeroTangent() - ni
@test_throws E NoTangent() - ni
@test_throws E One() - ni
@test_throws E true - ni
@test_throws E @thunk(x^2) - ni
@test_throws E ni - ni2
@test_throws E rand() * ni
@test_throws E NoTangent() * ni
@test_throws E One() * ni
@test_throws E true * ni
@test_throws E @thunk(x^2) * ni
@test_throws E ni * ni2
@test_throws E dot(ni, rand())
@test_throws E dot(ni, NoTangent())
@test_throws E dot(ni, One())
@test_throws E dot(ni, true)
@test_throws E dot(ni, @thunk(x^2))
@test_throws E dot(rand(), ni)
@test_throws E dot(NoTangent(), ni)
@test_throws E dot(One(), ni)
@test_throws E dot(true, ni)
@test_throws E dot(@thunk(x^2), ni)
@test_throws E dot(ni, ni2)
@test_throws E ni / rand()
Expand Down
25 changes: 0 additions & 25 deletions test/differentials/one.jl

This file was deleted.

12 changes: 6 additions & 6 deletions test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ cool(x, y) = x + y + 1

# a rule we define so we can test rules
dummy_identity(x) = x
@scalar_rule(dummy_identity(x), One())
@scalar_rule(dummy_identity(x), true)

nice(x) = 1
@scalar_rule(nice(x), ZeroTangent())

very_nice(x, y) = x + y
@scalar_rule(very_nice(x, y), (One(), One()))
sum_two(x, y) = x + y
@scalar_rule(sum_two(x, y), (true, true))

complex_times(x) = (1 + 2im) * x
@scalar_rule(complex_times(x), 1 + 2im)
Expand Down Expand Up @@ -116,12 +116,12 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))

@test frule((nothing, nothing, 5.0), Core._apply, dummy_identity, 4.0) == (4.0, 5.0)

@testset "broadcasting One" begin
@testset "broadcasting true" begin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this name and comments relate to the @inferred test here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we don't need this test anymore?

sx = @SVector [1, 2]
sy = @SVector [3, 4]

# Test that @scalar_rule and `One()` play nice together, w.r.t broadcasting
@inferred frule((ZeroTangent(), sx, sy), very_nice, 1, 2)
# Test that @scalar_rule and `true` play nice together, w.r.t broadcasting
@inferred frule((ZeroTangent(), sx, sy), sum_two, 1, 2)
end

@testset "complex inputs" begin
Expand Down
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ using Test
@testset "ChainRulesCore" begin
@testset "differentials" begin
include("differentials/abstract_zero.jl")
include("differentials/one.jl")
include("differentials/thunks.jl")
include("differentials/composite.jl")
include("differentials/notimplemented.jl")
Expand Down