From cfc1bb2a4b43ba583af8e864381267ef25155851 Mon Sep 17 00:00:00 2001 From: Nick Robinson Date: Sun, 4 Apr 2021 12:08:21 +0100 Subject: [PATCH 1/4] Add entry about `typeof` gotcha --- docs/src/writing_good_rules.md | 55 ++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/docs/src/writing_good_rules.md b/docs/src/writing_good_rules.md index 9a64d14f4..13aa0be10 100644 --- a/docs/src/writing_good_rules.md +++ b/docs/src/writing_good_rules.md @@ -1,5 +1,60 @@ # On writing good `rrule` / `frule` methods +## Use `Type{T}`, not `typeof(T)`, to define rules for constructors + +To define an `frule` or `rrule` for a _function_ `foo` we dispatch on the type of `foo`, which is `typeof(foo)`. +For example, the `rrule` signature would be like: + +```julia +function rrule(::typeof(foo), args...; kwargs...) + ... + return y, foo_pullback +end +``` + +But to define an `rrule` for a constructor for a _type_ `T` we need to be careful to dispatch only on `Type{T}`. + +For example, the `rrule` signature for a constructor would be like: + +```julia +function rrule(::Type{T}, args...; kwargs...) + ... + return y, T_pullback +end +``` + +In particular, be careful not to use `typeof(T)` here. +Because `typeof(T)` is `DataType`, using this to define an `rrule`/`frule` will define an `rrule`/`frule` for all constructors. + +You can check which to use with `Core.Typeof`: + +```julia +julia> function foob end +foob (generic function with 0 methods) + +julia> typeof(foob) +typeof(foob) + +julia> Core.Typeof(foob) +typeof(foob) + +julia> abstract type AbstractT end + +julia> struct ExampleT <: AbstractT end + +julia> typeof(AbstractT) +DataType + +julia> typeof(ExampleT) +DataType + +julia> Core.Typeof(AbstractT) +Type{AbstractT} + +julia> Core.Typeof(ExampleT) +Type{ExampleT} +``` + ## Use `ZeroTangent()` as the return value The `ZeroTangent()` object exists as an alternative to directly returning `0` or `zeros(n)`. From c9f9168b7abb568f4674492751f6793cd137cd93 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Mon, 7 Jun 2021 12:47:17 +0100 Subject: [PATCH 2/4] include functor and move code style to the top --- docs/src/writing_good_rules.md | 166 ++++++++++++++++++--------------- 1 file changed, 91 insertions(+), 75 deletions(-) diff --git a/docs/src/writing_good_rules.md b/docs/src/writing_good_rules.md index 13aa0be10..eccb98b3c 100644 --- a/docs/src/writing_good_rules.md +++ b/docs/src/writing_good_rules.md @@ -1,58 +1,31 @@ # On writing good `rrule` / `frule` methods -## Use `Type{T}`, not `typeof(T)`, to define rules for constructors +## Code Style -To define an `frule` or `rrule` for a _function_ `foo` we dispatch on the type of `foo`, which is `typeof(foo)`. -For example, the `rrule` signature would be like: +Use named local functions for the `pullback` in an `rrule`. ```julia -function rrule(::typeof(foo), args...; kwargs...) - ... - return y, foo_pullback +# good: +function rrule(::typeof(foo), x) + Y = foo(x) + function foo_pullback(Ȳ) + return NoTangent(), bar(Ȳ) + end + return Y, foo_pullback end -``` - -But to define an `rrule` for a constructor for a _type_ `T` we need to be careful to dispatch only on `Type{T}`. - -For example, the `rrule` signature for a constructor would be like: +#== output +julia> rrule(foo, 2) +(4, var"#foo_pullback#11"()) +==# -```julia -function rrule(::Type{T}, args...; kwargs...) - ... - return y, T_pullback +# bad: +function rrule(::typeof(foo), x) + return foo(x), x̄ -> (NoTangent(), bar(x̄)) end -``` - -In particular, be careful not to use `typeof(T)` here. -Because `typeof(T)` is `DataType`, using this to define an `rrule`/`frule` will define an `rrule`/`frule` for all constructors. - -You can check which to use with `Core.Typeof`: - -```julia -julia> function foob end -foob (generic function with 0 methods) - -julia> typeof(foob) -typeof(foob) - -julia> Core.Typeof(foob) -typeof(foob) - -julia> abstract type AbstractT end - -julia> struct ExampleT <: AbstractT end - -julia> typeof(AbstractT) -DataType - -julia> typeof(ExampleT) -DataType - -julia> Core.Typeof(AbstractT) -Type{AbstractT} - -julia> Core.Typeof(ExampleT) -Type{ExampleT} +#== output: +julia> rrule(foo, 2) +(4, var"##9#10"()) +==# ``` ## Use `ZeroTangent()` as the return value @@ -90,6 +63,77 @@ Examples being: - There is only one derivative being returned, so from the fact that the user called `frule`/`rrule` they clearly will want to use that one. +## Structs: constructors and functors + +To define an `frule` or `rrule` for a _function_ `foo` we dispatch on the type of `foo`, which is `typeof(foo)`. +For example, the `rrule` signature would be like: + +```julia +function rrule(::typeof(foo), args...; kwargs...) + ... + return y, foo_pullback +end +``` + +For a struct `Bar`, +```julia +struct Bar + a::Float64 +end + +(bar::Bar)(x, y) = return bar.a + x + y # functor +``` +we can define an `frule`/`rrule` for the `Bar` constructor(s), as well as any `Bar` [functors](https://docs.julialang.org/en/v1/manual/methods/#Function-like-objects). + +To define an `rrule` for a constructor for a _type_ `Bar` we need to be careful to dispatch only on `Type{Bar}`. +For example, the `rrule` signature for a `Bar` constructor would be like: +```julia +function ChainRulesCore.rrule(::Type{Bar}, a) + ... + return Bar(a), Bar_pullback +end +``` + +In particular, be careful not to use `typeof(Bar)` here. +Because `typeof(Bar)` is `DataType`, using this to define an `rrule`/`frule` will define an `rrule`/`frule` for all constructors. + +You can check which to use with `Core.Typeof`: + +```julia +julia> function foo end +foo (generic function with 0 methods) + +julia> typeof(foo) +typeof(foo) + +julia> Core.Typeof(foob) +typeof(foo) + +julia> typeof(Bar) +DataType + +julia> Core.Typeof(Bar) +Type{Bar} + +julia> abstract type AbstractT end + +julia> typeof(AbstractT) +DataType + +julia> Core.Typeof(AbstractT) +Type{AbstractT} +``` + +For the functor, use `bar::Bar`, i.e. + +```julia +function ChainRulesCore.rrule(bar::Bar, x, y) + ... + return bar(x, y), Bar_pullback +end +``` + + ## Use `@not_implemented` appropriately One can use [`@not_implemented`](@ref) to mark missing differentials. @@ -107,34 +151,6 @@ https://github.com/JuliaMath/SpecialFunctions.jl/issues/160 Do not use `@not_implemented` if the differential does not exist mathematically (use `NoTangent()` instead). -## Code Style - -Use named local functions for the `pullback` in an `rrule`. - -```julia -# good: -function rrule(::typeof(foo), x) - Y = foo(x) - function foo_pullback(Ȳ) - return NoTangent(), bar(Ȳ) - end - return Y, foo_pullback -end -#== output -julia> rrule(foo, 2) -(4, var"#foo_pullback#11"()) -==# - -# bad: -function rrule(::typeof(foo), x) - return foo(x), x̄ -> (NoTangent(), bar(x̄)) -end -#== output: -julia> rrule(foo, 2) -(4, var"##9#10"()) -==# -``` - While this is more verbose, it ensures that if an error is thrown during the `pullback` the [`gensym`](https://docs.julialang.org/en/v1/base/base/#Base.gensym) name of the local function will include the name you gave it. This makes it a lot simpler to debug from the stacktrace. From dc5c3ae6a395bd4d0f81f72e808abef456292caf Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Mon, 7 Jun 2021 13:36:09 +0100 Subject: [PATCH 3/4] code review --- docs/src/writing_good_rules.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/src/writing_good_rules.md b/docs/src/writing_good_rules.md index eccb98b3c..855bcbfec 100644 --- a/docs/src/writing_good_rules.md +++ b/docs/src/writing_good_rules.md @@ -81,7 +81,7 @@ struct Bar a::Float64 end -(bar::Bar)(x, y) = return bar.a + x + y # functor +(bar::Bar)(x, y) = return bar.a + x + y # functor (i.e. callable object, overloading the call action) ``` we can define an `frule`/`rrule` for the `Bar` constructor(s), as well as any `Bar` [functors](https://docs.julialang.org/en/v1/manual/methods/#Function-like-objects). @@ -89,11 +89,12 @@ To define an `rrule` for a constructor for a _type_ `Bar` we need to be careful For example, the `rrule` signature for a `Bar` constructor would be like: ```julia function ChainRulesCore.rrule(::Type{Bar}, a) - ... + Bar_pullback(Δbar) = NoTangent(), Δbar.a return Bar(a), Bar_pullback end ``` +Use `Type{<:Bar}` (with the `<:`) for non-concrete types, such that the `rrule` is defined for all subtypes. In particular, be careful not to use `typeof(Bar)` here. Because `typeof(Bar)` is `DataType`, using this to define an `rrule`/`frule` will define an `rrule`/`frule` for all constructors. @@ -128,7 +129,8 @@ For the functor, use `bar::Bar`, i.e. ```julia function ChainRulesCore.rrule(bar::Bar, x, y) - ... + # Notice the first return is not `NoTangent()` + Bar_pullback(Δy) = Tangent{Bar}(;a=Δy), Δy, Δy return bar(x, y), Bar_pullback end ``` From 9813ec2e9fbadba7ee3f666c5ccfd81eafbe68d5 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Mon, 7 Jun 2021 13:40:27 +0100 Subject: [PATCH 4/4] improve wording --- docs/src/writing_good_rules.md | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/docs/src/writing_good_rules.md b/docs/src/writing_good_rules.md index 855bcbfec..872a160b4 100644 --- a/docs/src/writing_good_rules.md +++ b/docs/src/writing_good_rules.md @@ -85,6 +85,8 @@ end ``` we can define an `frule`/`rrule` for the `Bar` constructor(s), as well as any `Bar` [functors](https://docs.julialang.org/en/v1/manual/methods/#Function-like-objects). +### Constructors + To define an `rrule` for a constructor for a _type_ `Bar` we need to be careful to dispatch only on `Type{Bar}`. For example, the `rrule` signature for a `Bar` constructor would be like: ```julia @@ -125,7 +127,10 @@ julia> Core.Typeof(AbstractT) Type{AbstractT} ``` -For the functor, use `bar::Bar`, i.e. +### Functors (callable objects) + +In contrast to defining a rule for a constructor, it is possible to define rules for calling an instance of an object. +In that case, use `bar::Bar`, i.e. ```julia function ChainRulesCore.rrule(bar::Bar, x, y) @@ -134,7 +139,7 @@ function ChainRulesCore.rrule(bar::Bar, x, y) return bar(x, y), Bar_pullback end ``` - +to define the rules. ## Use `@not_implemented` appropriately