Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 106 additions & 28 deletions docs/src/writing_good_rules.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,33 @@
# On writing good `rrule` / `frule` methods

## Code Style

Use named local functions for the `pullback` in an `rrule`.

```julia
# good:
function rrule(::typeof(foo), x)
Y = foo(x)
function foo_pullback(Ȳ)
return NoTangent(), bar(Ȳ)
end
return Y, foo_pullback
end
#== output
julia> rrule(foo, 2)
(4, var"#foo_pullback#11"())
==#

# bad:
function rrule(::typeof(foo), x)
return foo(x), x̄ -> (NoTangent(), bar(x̄))
end
#== output:
julia> rrule(foo, 2)
(4, var"##9#10"())
==#
```

## Use `ZeroTangent()` as the return value

The `ZeroTangent()` object exists as an alternative to directly returning `0` or `zeros(n)`.
Expand Down Expand Up @@ -35,6 +63,84 @@ Examples being:
- There is only one derivative being returned, so from the fact that the user called
`frule`/`rrule` they clearly will want to use that one.

## Structs: constructors and functors

To define an `frule` or `rrule` for a _function_ `foo` we dispatch on the type of `foo`, which is `typeof(foo)`.
For example, the `rrule` signature would be like:

```julia
function rrule(::typeof(foo), args...; kwargs...)
...
return y, foo_pullback
end
```

For a struct `Bar`,
```julia
struct Bar
a::Float64
end

(bar::Bar)(x, y) = return bar.a + x + y # functor (i.e. callable object, overloading the call action)
```
we can define an `frule`/`rrule` for the `Bar` constructor(s), as well as any `Bar` [functors](https://docs.julialang.org/en/v1/manual/methods/#Function-like-objects).

### Constructors

To define an `rrule` for a constructor for a _type_ `Bar` we need to be careful to dispatch only on `Type{Bar}`.
For example, the `rrule` signature for a `Bar` constructor would be like:
```julia
function ChainRulesCore.rrule(::Type{Bar}, a)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Seth's point about <: (for non-concrete types) is worth making too
#331 (comment)

Copy link
Member

@oxinabox oxinabox Jun 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in general probably should always use <: inside Type.
Since it means your code doesn't have to change if it gains type-parameters, and thus is no longer concrete.
Plus saves thinking
(possibly even this should be added to BlueStyle)

Bar_pullback(Δbar) = NoTangent(), Δbar.a
return Bar(a), Bar_pullback
end
```

Use `Type{<:Bar}` (with the `<:`) for non-concrete types, such that the `rrule` is defined for all subtypes.
In particular, be careful not to use `typeof(Bar)` here.
Because `typeof(Bar)` is `DataType`, using this to define an `rrule`/`frule` will define an `rrule`/`frule` for all constructors.

You can check which to use with `Core.Typeof`:

```julia
julia> function foo end
foo (generic function with 0 methods)

julia> typeof(foo)
typeof(foo)

julia> Core.Typeof(foob)
typeof(foo)

julia> typeof(Bar)
DataType

julia> Core.Typeof(Bar)
Type{Bar}

julia> abstract type AbstractT end

julia> typeof(AbstractT)
DataType

julia> Core.Typeof(AbstractT)
Type{AbstractT}
```

### Functors (callable objects)

In contrast to defining a rule for a constructor, it is possible to define rules for calling an instance of an object.
In that case, use `bar::Bar`, i.e.

```julia
function ChainRulesCore.rrule(bar::Bar, x, y)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@oxinabox just double checking this is right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is right.
Though I think that seperating constructors into a seperate section to Functors might be clearer.
and have the functor section say somehting like "In constrast to defined the rule for constructing an object, one can overload the rule for calling an instance of an object (if that object is a functor)"
or somehting like that.

Still even with them together this is a clear enhancement to the docs, so I am happy to have it merged and then to a follow up if you like

# Notice the first return is not `NoTangent()`
Bar_pullback(Δy) = Tangent{Bar}(;a=Δy), Δy, Δy
return bar(x, y), Bar_pullback
end
```
to define the rules.

## Use `@not_implemented` appropriately

One can use [`@not_implemented`](@ref) to mark missing differentials.
Expand All @@ -52,34 +158,6 @@ https://github.com/JuliaMath/SpecialFunctions.jl/issues/160

Do not use `@not_implemented` if the differential does not exist mathematically (use `NoTangent()` instead).

## Code Style
Copy link
Member

@oxinabox oxinabox Jun 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving code in the same PR as edits somehting else makes reviewing harder.
It's not terrible in this case since the PR is small, but still like 1/2 the changed lines are moving this)


Use named local functions for the `pullback` in an `rrule`.

```julia
# good:
function rrule(::typeof(foo), x)
Y = foo(x)
function foo_pullback(Ȳ)
return NoTangent(), bar(Ȳ)
end
return Y, foo_pullback
end
#== output
julia> rrule(foo, 2)
(4, var"#foo_pullback#11"())
==#

# bad:
function rrule(::typeof(foo), x)
return foo(x), x̄ -> (NoTangent(), bar(x̄))
end
#== output:
julia> rrule(foo, 2)
(4, var"##9#10"())
==#
```

While this is more verbose, it ensures that if an error is thrown during the `pullback` the [`gensym`](https://docs.julialang.org/en/v1/base/base/#Base.gensym) name of the local function will include the name you gave it.
This makes it a lot simpler to debug from the stacktrace.

Expand Down