Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 23 additions & 23 deletions docs/src/design/changing_the_primal.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ We will call this function `pullback_at`, as it pulls back the sensitivity at a
To make this concrete:
```julia
y = f(x) # primal program
x̄ = pullback_at(f, x, y, )
x̄ = pullback_at(f, x, y, ȳ)
```
Let's illustrate this with examples for `sin` and for the [logistic sigmoid](https://en.wikipedia.org/wiki/Logistic_function#Derivative).

Expand All @@ -34,9 +34,9 @@ Let's illustrate this with examples for `sin` and for the [logistic sigmoid](htt
```
```julia
y = sin(x)
pullback_at(::typeof(sin), x, y, ) = * cos(x)
pullback_at(::typeof(sin), x, y, ȳ) = ȳ * cos(x)
```
`pullback_at` uses the primal input `x`, and the sensitivity being pulled back ``.
`pullback_at` uses the primal input `x`, and the sensitivity being pulled back `ȳ`.

```@raw html
</details>
Expand All @@ -48,7 +48,7 @@ pullback_at(::typeof(sin), x, y, ȳ) = ȳ * cos(x)
```julia
σ(x) = 1/(1 + exp(-x)) # = exp(x) / (1 + exp(x))
y = σ(x)
pullback_at(::typeof(σ), x, y, ) = * y * σ(-x) # = * σ(x) * σ(-x)
pullback_at(::typeof(σ), x, y, ȳ) = ȳ * y * σ(-x) # = ȳ * σ(x) * σ(-x)
```
Notice that in `pullback_at` we are not only using input `x` but also using the primal output `y` .
This is a nice bit of symmetry that shows up around `exp`.
Expand Down Expand Up @@ -130,7 +130,7 @@ So we are talking about a 30-40% speed-up from these optimizations.[^4]
It is faster to compute `sin` and `cos` at the same time via `sincos` than it is to compute them one after the other.
And it is faster to reuse the `exp(x)` in computing `σ(x)` and `σ(-x)`.
How can we incorporate this insight into our system?
We know we can compute both of these in the primal — because they only depend on `x` and not on `` — but there is nowhere to put them that is accessible both to the primal pass and the gradient pass code.
We know we can compute both of these in the primal — because they only depend on `x` and not on `ȳ` — but there is nowhere to put them that is accessible both to the primal pass and the gradient pass code.

What if we introduced some variable called `intermediates` that is also recorded onto the tape during the primal pass?
We would need to be able to modify the primal pass to do this, so that we can actually put the data into the `intermediates`.
Expand All @@ -140,7 +140,7 @@ So that would look like:
```julia
y = f(x) # primal program
y, intermediates = augmented_primal(f, x)
x̄ = pullback_at(f, x, y, , intermediates)
x̄ = pullback_at(f, x, y, ȳ, intermediates)
```

```@raw html
Expand All @@ -152,7 +152,7 @@ function augmented_primal(::typeof(sin), x)
return y, (; cx=cx) # use a NamedTuple for the intermediates
end

pullback_at(::typeof(sin), x, y, , intermediates) = * intermediates.cx
pullback_at(::typeof(sin), x, y, ȳ, intermediates) = ȳ * intermediates.cx
```
```@raw html
</details>
Expand All @@ -168,7 +168,7 @@ function augmented_primal(::typeof(σ), x)
return y, (; ex=ex) # use a NamedTuple for the intermediates
end

pullback_at(::typeof(σ), x, y, , intermediates) = * y / (1 + intermediates.ex)
pullback_at(::typeof(σ), x, y, ȳ, intermediates) = ȳ * y / (1 + intermediates.ex)
```
```@raw html
</details>
Expand Down Expand Up @@ -202,7 +202,7 @@ So changing our API we have:
```julia
y = f(x) # primal program
y, pb = augmented_primal(f, x)
x̄ = pullback_at(pb, )
x̄ = pullback_at(pb, ȳ)
```
which is much cleaner.

Expand All @@ -215,7 +215,7 @@ function augmented_primal(::typeof(sin), x)
return y, PullbackMemory(sin; cx=cx)
end

pullback_at(pb::PullbackMemory{typeof(sin)}, ) = * pb.cx
pullback_at(pb::PullbackMemory{typeof(sin)}, ȳ) = ȳ * pb.cx
```
```@raw html
</details>
Expand All @@ -231,7 +231,7 @@ function augmented_primal(::typeof(σ), x)
return y, PullbackMemory(σ; y=y, ex=ex)
end

pullback_at(pb::PullbackMemory{typeof(σ)}, ) = * pb.y / (1 + pb.ex)
pullback_at(pb::PullbackMemory{typeof(σ)}, ȳ) = ȳ * pb.y / (1 + pb.ex)
```
```@raw html
</details>
Expand All @@ -242,13 +242,13 @@ That now looks much simpler; `pullback_at` only ever has 2 arguments.
One way we could make it nicer to use is by making `PullbackMemory` a callable object.
Conceptually the `PullbackMemory` is a fixed thing it the contents of the tape for a particular operation.
It is fully determined by the end of the primal pass.
The during the gradient (reverse) pass the `PullbackMemory` is used to successively compute the `` argument.
The during the gradient (reverse) pass the `PullbackMemory` is used to successively compute the `ȳ` argument.
So it makes sense to make `PullbackMemory` a callable object that acts on the sensitivity.
We can do that via call overloading:
```julia
y = f(x) # primal program
y, pb = augmented_primal(f, x)
x̄ = pb()
x̄ = pb(ȳ)
```

```@raw html
Expand All @@ -259,7 +259,7 @@ function augmented_primal(::typeof(sin), x)
y, cx = sincos(x)
return y, PullbackMemory(sin; cx=cx)
end
(pb::PullbackMemory)(ȳ) = * pb.cx
(pb::PullbackMemory{typeof(sin)})(ȳ) = ȳ * pb.cx
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just changed this

```

```@raw html
Expand All @@ -276,14 +276,14 @@ function augmented_primal(::typeof(σ), x)
return y, PullbackMemory(σ; y=y, ex=ex)
end

(pb::PullbackMemory{typeof(σ)})() = * pb.y / (1 + pb.ex)
(pb::PullbackMemory{typeof(σ)})(ȳ) = ȳ * pb.y / (1 + pb.ex)
```
```@raw html
</details>
```

Let's recap what we have done here.
We now have an object `pb` that acts on the cotangent of the output of the primal `` to give us the cotangent of the input of the primal function `x̄`.
We now have an object `pb` that acts on the cotangent of the output of the primal `ȳ` to give us the cotangent of the input of the primal function `x̄`.
_`pb` is not just the **memory** of state required for the `pullback`, it **is** the pullback._

We have one final thing to do, which is to think about how we make the code easy to modify.
Expand All @@ -298,15 +298,15 @@ function augmented_primal(::typeof(sin), x)
y = sin(x)
return y, PullbackMemory(sin; x=x)
end
(pb::PullbackMemory)() = * cos(pb.x)
(pb::PullbackMemory)(ȳ) = ȳ * cos(pb.x)
```
To go from that to:
```julia
function augmented_primal(::typeof(sin), x)
y, cx = sincos(x)
return y, PullbackMemory(sin; cx=cx)
end
(pb::PullbackMemory)() = * pb.cx
(pb::PullbackMemory)(ȳ) = ȳ * pb.cx
```
```@raw html
</details>
Expand All @@ -320,7 +320,7 @@ function augmented_primal(::typeof(σ), x)
y = σ(x)
return y, PullbackMemory(σ; y=y, x=x)
end
(pb::PullbackMemory{typeof(σ)})() = * pb.y * σ(-pb.x)
(pb::PullbackMemory{typeof(σ)})(ȳ) = ȳ * pb.y * σ(-pb.x)
```
to get to:
```julia
Expand All @@ -329,7 +329,7 @@ function augmented_primal(::typeof(σ), x)
y = ex/(1 + ex)
return y, PullbackMemory(σ; y=y, ex=ex)
end
(pb::PullbackMemory{typeof(σ)})() = * pb.y/(1 + pb.ex)
(pb::PullbackMemory{typeof(σ)})(ȳ) = ȳ * pb.y/(1 + pb.ex)
```
```@raw html
</details>
Expand All @@ -344,7 +344,7 @@ We need to make a series of changes:
It's important these parts all stay in sync.
It's not too bad for this simple example with just one or two things to remember.
For more complicated multi-argument functions, which we will show below, you often end up needing to remember half a dozen things, like sizes and indices relating to each input/output, so it gets a little more fiddly to make sure you remember all the things you need to and give them the same name in both places.
_Is there a way we can automatically just have all the things we use remembered for us?_
_Is there a way we can automatically just have all the things we use remembered for us?_
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and added a space here

Surprisingly for such a specific request, there actually is: a closure.

A closure in Julia is a callable structure that automatically contains a field for every object from its parent scope that is used in its body.
Expand All @@ -357,7 +357,7 @@ Replacing `PullbackMemory` with a closure that works the same way lets us avoid
```julia
function augmented_primal(::typeof(sin), x)
y, cx = sincos(x)
pb = -> cx * # pullback closure. closes over `cx`
pb = ȳ -> cx * ȳ # pullback closure. closes over `cx`
return y, pb
end
```
Expand All @@ -372,7 +372,7 @@ end
function augmented_primal(::typeof(σ), x)
ex = exp(x)
y = ex / (1 + ex)
pb = -> * y / (1 + ex) # pullback closure. closes over `y` and `ex`
pb = ȳ -> ȳ * y / (1 + ex) # pullback closure. closes over `y` and `ex`
return y, pb
end
```
Expand Down