Skip to content

Commit

Permalink
Merge b55fe43 into d7be84a
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Jan 20, 2020
2 parents d7be84a + b55fe43 commit 2d6dc10
Showing 1 changed file with 124 additions and 83 deletions.
207 changes: 124 additions & 83 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,23 @@ Knowing rules for more complicated functions speeds up the autodiff process as i

**ChainRules is an AD-independent collection of rules to use in a differentiation system.**

### Introduction

!!! note "The whole field is a mess for terminology"
It isn't just ChainRules, it is everyone.
Internally ChainRules tries to be consistent.
Help with that is always welcomed.

!!! terminology "Primal"
Often we will talk about something as _primal_.
That means it is related to the original problem, not its derivative.
For example for `y = foo(x)`
`foo` is the _primal_ function,
computing `foo(x)` is doing the _primal_ computation.
`y` is the _primal_ return, and `x` is a _primal_ argument.
`typeof(y)` and `typeof(x)` are both _primal_ types.


### `frule` and `rrule`

!!! terminology "`frule` and `rrule`"
Expand All @@ -32,35 +44,50 @@ Knowing rules for more complicated functions speeds up the autodiff process as i

The rules are encoded as `frule`s and `rrule`s, for use in forward-mode and reverse-mode differentiation respectively.

The `frule` is written:
The `rrule` for some function `foo`, which takes the positional arguments `args` and keyword arguments `kwargs`, is written:

```julia
function frule(::typeof(foo), args; kwargs...)
function rrule(::typeof(foo), args...; kwargs...)
...
return y, pushforward
return y, pullback
end
```
where `y = foo(args; kwargs...)`, and `pushforward` is a function to propagate the derivative information forwards at that point (more later).
where `y` (the primal result) must be equal to `foo(args...; kwargs...)`.
`pullback` is a function to propagate the derivative information backwards at that point.
That pullback function is used like:
`∂self, ∂args... = pullback(Δy)`


Almost always the _pullback_ will be declared locally within the `rrule`, and will be a _closure_ over some of the other arguments, and potentially over the primal result too.

The `rrule` for some function `foo`, which takes the positional argument `args` and keyword argument `kwargs`, is written:

The `frule` is written:
```julia
function rrule(::typeof(foo), args; kwargs...)
function frule(::typeof(foo), args..., Δself, Δargs...; kwargs...)
...
return y, pullback
return y, ∂Y
end
```
again `y` must be equal to `foo(args; kwargs...)`, and `pullback` is a function to propagate the derivative information backwards at that point (more later).
where again `y = foo(args; kwargs...)`,
and `∂Y` is the result of propagating the derivative information forwards at that point.
This propagation is call the pushforward.
One could think of writing `∂Y = pushforward(Δself, Δargs)`, and often we will think of the `frule` as having the primal computation `y = foo(args...; kwargs...)`, and the push-forward `∂Y = pushforward(Δself, Δargs...)`


!!! note "Why `rrule` returns a pullback but `frule` doesn't return a pushforward"
While `rrule` takes only the arguments to the original function (the primal arguments) and returns a function (the pullback) that operates with the derivative information, the `frule` does it all at once.
This is because the `frule` fuses the primal computation and the pushforward.
This is an optimization that allows `frule`s to contain single large operations that perform both the primal computation and the pushforward at the same time (for example solving an ODE).
This operation is only possible in forward mode (where `frule` is used) because the derivative information needed by the pushforward available with the `frule` is invoked -- it is about the primal function's inputs.
In contrast, in reverse mode the derivative information needed by the pullback is about the primal function's output.
Thus the reverse mode returns the pullback function which the caller (usually an AD system) keeps hold of until derivative information about the output is available.

Almost always the _pushforward_/_pullback_ will be declared locally within the `frule`/`rrule`, and will be a _closure_ over some of the other arguments.

### The propagators: pushforward and pullback


!!! terminology "pushforward and pullback"

_Pushforward_ and _pullback_ are fancy words that the autodiff community adopted from Differential Geometry.
_Pushforward_ and _pullback_ are fancy words that the autodiff community recently adopted from Differential Geometry.
The are broadly in agreement with the use of [pullback](https://en.wikipedia.org/wiki/Pullback_(differential_geometry)) and [pushforward](https://en.wikipedia.org/wiki/Pushforward_(differential)) in differential geometry.
But any geometer will tell you these are the super-boring flat cases. Some will also frown at you.
They are also sometimes described in terms of the jacobian:
Expand Down Expand Up @@ -114,60 +141,69 @@ pushforward to find ``\dfrac{∂f}{∂x}``:
``\dfrac{∂f}{∂x}=\mathrm{pushforward}_{h(b)|b=g(x)}\left(\left.\dfrac{∂g}{∂a}\right|_{a=x}\right)``


#### The anatomy of pushforward and pullback
#### The anatomy of pullback and pushforward

For our function `foo(args...; kwargs) = Y`:
For our function `foo(args...; kwargs...) = y`:

The pushforward is a function:

```julia
function pushforward(Δself, Δargs...)
function pullback(Δy)
...
returnY
returnself, ∂args...
end
```

The input to the pushforward is often called the _perturbation_.
If the function is `y = f(x)` often the pushforward will be written `ẏ = pushforward(ṡelf, ẋ)`.
(`` is commonly used to represent the perturbation for `y`)
The input to the pullback is often called the _seed_.
If the function is `y = f(x)` often the pullback will be written `s̄elf, x̄ = pullback(ȳ)`.

!!! note

There is one `Δarg` per `arg` to the original function.
The `Δargs` are similar in type/structure to the corresponding inputs `args` (`Δself` is explained below).
The `∂Y` are similar in type/structure to the original function's output `Y`.
In particular if that function returned a tuple then `∂Y` will be a tuple of same size.
The pullback returns one `∂arg` per `arg` to the original function, plus one `∂self` for the fields of the function itself (explained below).

!!! terminology "perturbation, seed, sensitivity"
Sometimes _perturbation_, _seed_, and even _sensitivity_ will be used interchangeably.
They are not generally synonymous, and ChainRules shouldn't mix them up.
One must be careful when reading literature.
At the end of the day, they are all _wiggles_ or _wobbles_.


The pullback is a function:
The pushforward is a part of the `frule` function.
Considered alone it would look like:

```julia
function pullback(ΔY)
function pushforward(Δself, Δargs...)
...
return ∂self, ∂args...
return ∂y
end
```
But because it is fused into frule we see it as part of:
```julia
function frule(::typeof(foo), args..., Δself, Δargs...; kwargs...)
...
return y, ∂y
end
```

The input to the pullback is often called the _seed_.
If the function is `y = f(x)` often the pullback will be written `s̄elf, x̄ = pullback(ȳ)`.

!!! note

The pullback returns one `∂arg` per `arg` to the original function, plus one for the fields of the function itself (explained below).
The input to the pushforward is often called the _perturbation_.
If the function is `y = f(x)` often the pushforward will be written `ẏ = last(frule(f, x, ṡelf, ẋ))`.
`` is commonly used to represent the perturbation for `y`.

!!! terminology
Sometimes _perturbation_, _seed_, and even _sensitivity_ will be used interchangeably.
They are not generally synonymous, and ChainRules shouldn't mix them up.
One must be careful when reading literature.
At the end of the day, they are all _wiggles_ or _wobbles_.
!!! note

In the `frule`/pushforward,
there is one `Δarg` per `arg` to the original function.
The `Δargs` are similar in type/structure to the corresponding inputs `args` (`Δself` is explained below).
The `∂y` are similar in type/structure to the original function's output `Y`.
In particular if that function returned a tuple then `∂y` will be a tuple of the same size.

### Self derivative `Δself`, `∂self`, `s̄elf`, `ṡelf` etc.
### Self derivative `Δself`, `∂self`, `s̄elf`, `ṡelf` etc.

!!! terminology `Δself`, `∂self`, `s̄elf`, `ṡelf`
!!! terminology "Δself, ∂self, s̄elf, ṡelf"
It is the derivatives with respect to the internal fields of the function.
To the best of our knowledge there is no standard terminology for this.
Other good names might be `Δinternal`/`∂internal`.


From the mathematical perspective, one may have been wondering what all this `Δself`, `∂self` is.
Given that a function with two inputs, say `f(a, b)`, only has two partial derivatives:
``\dfrac{∂f}{∂a}``, ``\dfrac{∂f}{∂b}``.
Expand All @@ -181,59 +217,68 @@ So every `pushforward` takes in an extra argument, which is ignored unless the o
It is common to write `function foo_pushforward(_, Δargs...)` in the case when `foo` does not have fields.
Similarly every `pullback` returns an extra `∂self`, which for things without fields is the constant `NO_FIELDS`, indicating there are no fields within the function itself.

#### Pushforward / Pullback summary

- **Pushforward:**
- returned by `frule`
- takes input space wiggles, gives output space wobbles
- 1 argument per original function argument + 1 for the function itself
- 1 return per original function return
#### Pushforward / Pullback summary

- **Pullback**
- returned by `rrule`
- takes output space wobbles, gives input space wiggles
- 1 argument per original function return
- 1 return per original function argument + 1 for the function itself

#### Pushforward/Pullback and Total Derivative/Gradient
- **Pushforward:**
- part of `frule`
- takes input space wiggles, gives output space wobbles
- 1 argument per original function argument + 1 for the function itself
- 1 return per original function return


#### Pullback/Pushforward and Directional Derivative/Gradient

The most trivial use of the `pushforward` from within `frule` is to calculate the directional derivative:

The most trivial use of `frule` and returned `pushforward` is to calculate the [Total Derivative](https://en.wikipedia.org/wiki/Total_derivative):
If we would like to know the the directional derivative of `f` for an input change of `(1.5, 0.4, -1)`

```julia
y, f_pushforward = frule(f, a, b, c)
= f_pushforward(1, 1, 1, 1) # for appropriate `1`-like perturbation.
direction = (1.5, 0.4, -1) # (ȧ, ḃ, ċ)
y, ẏ = frule(f, a, b, c, Zero(), direction)
```

Then we have that `` is the _total derivative_ of `f` at `(a, b, c)`, written mathematically as ``df_{(a,b,c)}``
On the basis directions one gets the partial derivatives of `y`:
```julia
y, ∂y_∂a = frule(f, a, b, c, Zero(), 1, 0, 0)
y, ∂y_∂b = frule(f, a, b, c, Zero(), 0, 1, 0)
y, ∂y_∂c = frule(f, a, b, c, Zero(), 0, 0, 1)
```

Similarly, the most trivial use of `rrule` and returned `pullback` is to calculate the [Gradient](https://en.wikipedia.org/wiki/Gradient):

```julia
y, f_pullback = rrule(f, a, b, c)
∇f = f_pullback(1) # for appropriate `1`-like seed.
s̄elf, , b̄, c̄ = ∇f
s̄elf, ā, b̄, c̄ = ∇f
```
Then we have that `∇f` is the _gradient_ of `f` at `(a, b, c)`.
And we thus have the partial derivatives ``\overline{\mathrm{self}}, = \dfrac{∂f}{∂\mathrm{self}}``, ``\overline{a} = \dfrac{∂f}{∂a}``, ``\overline{b} = \dfrac{∂f}{∂b}``, ``\overline{c} = \dfrac{∂f}{∂c}``, including the and the self-partial derivative, ``\overline{\mathrm{self}}``.

### Differentials

The values that come back from pullbacks or pushforwards are not always the same type as the input/outputs of the original function.
They are differentials, which correspond roughly to something able to represent the difference between two values of the original types.
The values that come back from pullbacks or pushforwards are not always the same type as the input/outputs of the primal function.
They are differentials, which correspond roughly to something able to represent the difference between two values of the primal types.
A differential might be such a regular type, like a `Number`, or a `Matrix`, matching to the original type;
or it might be one of the `AbstractDifferential` subtypes.

Differentials support a number of operations.
Most importantly: `+` and `*`, which let them act as mathematical objects.
And `extern` which converts `AbstractDifferential` types into a conventional non-ChainRules type.

The most important `AbstractDifferential`s when getting started are the ones about avoiding work:

- `Thunk`: this is a deferred computation. A thunk is a [word for a zero argument closure](https://en.wikipedia.org/wiki/Thunk). A computation wrapped in a `@thunk` doesn't get evaluated until `extern` is called on the `Thunk`. More on thunks later.
- `Thunk`: this is a deferred computation. A thunk is a [word for a zero argument closure](https://en.wikipedia.org/wiki/Thunk). A computation wrapped in a `@thunk` doesn't get evaluated until `unthunk` is called on the thunk. `unthunk` is a no-op on non-thunked inputs.
- `One`, `Zero`: There are special representations of `1` and `0`. They do great things around avoiding expanding `Thunks` in multiplication and (for `Zero`) addition.

#### Other `AbstractDifferential`s: don't worry about them right now
- `Casted`: it implements broadcasting mechanics. See [#10](https://github.com/JuliaDiff/ChainRulesCore.jl/issues/10)
#### Other `AbstractDifferential`s:
- `Composite{P}`: this is the differential for tuples and structs. Use it like a `Tuple` or `NamedTuple`. The type parameter `P` is for the primal type.
- `DoesNotExist`: Zero-like, represents that the operation on this input is not differentiable. Its primal type is normally `Integer` or `Bool`.
- `InplaceableThunk`: it is like a Thunk but it can do `store!` and `accumulate!` in-place.

-------------------------------
Expand Down Expand Up @@ -267,27 +312,22 @@ c, c_pullback = rrule(asin, b)
# Then the backward pass calculating gradients
= 1; # ∂c/∂c
_, b̄ = c_pullback(extern(c̄)); # ∂c/∂b
_, _, = b_pullback(extern(b̄)); # ∂c/∂a
_, x̄ = a_pullback(extern()); # ∂c/∂x = ∂f/∂x
_, _, ā = b_pullback(extern(b̄)); # ∂c/∂a
_, x̄ = a_pullback(extern(ā)); # ∂c/∂x = ∂f/∂x
extern(x̄)
# -2.0638950738662625

#### Find dfoo/dx via frules

# Unlike with rrule, we can interleave evaluation and derivative evaluation
x = 3;
= 1; # ∂x/∂x
nofields = NamedTuple(); # ∂self/∂self

a, a_pushforward = frule(sin, x);
= a_pushforward(nofields, extern(ẋ)); # ∂a/∂x
= 1; # ∂x/∂x
nofields = Zero(); # ∂self/∂self

b, b_pushforward = frule(*, 2, a);
= b_pushforward(nofields, 0, extern(ȧ)); # ∂b/∂x = ∂b/∂a⋅∂a/∂x
a, ȧ = frule(sin, x, nofields, ẋ); # ∂a/∂x
b, ḃ = frule(*, 2, nofields, unthunk)); # ∂b/∂x = ∂b/∂a⋅∂a/∂x

c, c_pushforward = frule(asin, b);
= c_pushforward(nofields, extern(ḃ)); # ∂c/∂x = ∂c/∂b⋅∂b/∂x = ∂f/∂x
extern(ċ)
c, ċ = frule(asin, b, unthunk(ḃ)); # ∂c/∂x = ∂c/∂b⋅∂b/∂x = ∂f/∂x
unthunk(ċ)
# -2.0638950738662625

#### Find dfoo/dx via finite-differences
Expand Down Expand Up @@ -350,8 +390,8 @@ Use named local functions for the `pushforward`/`pullback`:
# good:
function frule(::typeof(foo), x)
Y = foo(x)
function foo_pushforward(_, )
return bar()
function foo_pushforward(_, )
return bar()
end
return Y, foo_pushforward
end
Expand All @@ -362,7 +402,7 @@ julia> frule(foo, 2)

# bad:
function frule(::typeof(foo), x)
return foo(x), (_, ) -> bar()
return foo(x), (_, ) -> bar()
end
#== output:
julia> frule(foo, 2)
Expand Down Expand Up @@ -397,14 +437,14 @@ It is very easy to check gradients or derivatives with a computer algebra system

#### `Δx`, `∂x`, `dx`
ChainRules uses these perhaps atyptically.
As a notation that is the same across propagators, regardless of direction. (Incontrast see `` and `` below)
As a notation that is the same across propagators, regardless of direction (incontrast see `` and `` below).

- `Δx` is the input to a propagator, (i.e a _seed_ for a _pullback_; or a _perturbation_ for a _pushforward_)
- `∂x` is the output of a propagator
- `dx` could be anything, including a pullback/pushforward. It really should not show up outside of tests.
- `dx` could be either


#### ``\dot{y} = \dfrac{∂y}{∂x} = \overline{x}``
#### dots and bars: ``\dot{y} = \dfrac{∂y}{∂x} = \overline{x}``
- `` is a derivative of the input moving forward: ``v̇ = \frac{∂v}{∂x}`` for input ``x``, intermediate value ``v``.
- `` is a derivative of the output moving backward: ``v̄ = \frac{∂y}{∂v}`` for output ``y``, intermediate value ``v``.

Expand All @@ -414,16 +454,17 @@ As a notation that is the same across propagators, regardless of direction. (Inc
- `∂Ω` is thus the output of a pushforward.


### Why does `frule` and `rrule` return the function evaluation?
You might wonder why `frule(f, x)` returns `f(x)` and the pushforward for `f` at `x`, and similarly for `rrule` returning `f(x)` and the pullback for `f` at `x`.
### Why does `rrule` return the primal function evaluation?
You might wonder why `frule(f, x)` returns `f(x)` and the derivative of `f` at `x`, and similarly for `rrule` returning `f(x)` and the pullback for `f` at `x`.
Why not just return the pushforward/pullback, and let the user call `f(x)` to get the answer separately?

There are two reasons the rules also calculate the `f(x)`.
1. For some rules the output value is used in the definition of its propagator. For example `tan`.
2. For some rules an alternative way of calculating `f(x)` can give the same answer while also generating intermediate values that can be used in the calculations within the propagator.
There are three reasons the rules also calculate the `f(x)`.
1. For some rules an alternative way of calculating `f(x)` can give the same answer while also generating intermediate values that can be used in the calculations required to propagate the derivative.
2. For many `rrule`s the output value is used in the definition of the pullback. For example `tan`, `sigmoid` etc.
3. For some `frule`s there exists a single, non-separable operation that will compute both derivative and primal result. For example many of the methods for [differential equation sensitivity analysis](https://docs.juliadiffeq.org/latest/analysis/sensitivity/#sensitivity-1).

### Where are the gradients for keyword arguments?
_pullbacks_ do not return a gradient for keyword arguments;
### Where are the derivatives for keyword arguments?
_pullbacks_ do not return a sensitivity for keyword arguments;
similarly _pushfowards_ do not accept a perturbation for keyword arguments.
This is because in practice functions are very rarely differentiable with respect to keyword arguments.
As a rule keyword arguments tend to control side-effects, like logging verbosity,
Expand Down

0 comments on commit 2d6dc10

Please sign in to comment.