-
Notifications
You must be signed in to change notification settings - Fork 64
fix some typos in docs #310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,7 +25,7 @@ We will call this function `pullback_at`, as it pulls back the sensitivity at a | |
| To make this concrete: | ||
| ```julia | ||
| y = f(x) # primal program | ||
| x̄ = pullback_at(f, x, y, ȳ) | ||
| x̄ = pullback_at(f, x, y, ȳ) | ||
| ``` | ||
| Let's illustrate this with examples for `sin` and for the [logistic sigmoid](https://en.wikipedia.org/wiki/Logistic_function#Derivative). | ||
|
|
||
|
|
@@ -34,9 +34,9 @@ Let's illustrate this with examples for `sin` and for the [logistic sigmoid](htt | |
| ``` | ||
| ```julia | ||
| y = sin(x) | ||
| pullback_at(::typeof(sin), x, y, ȳ) = ȳ * cos(x) | ||
| pullback_at(::typeof(sin), x, y, ȳ) = ȳ * cos(x) | ||
| ``` | ||
| `pullback_at` uses the primal input `x`, and the sensitivity being pulled back `ȳ`. | ||
| `pullback_at` uses the primal input `x`, and the sensitivity being pulled back `ȳ`. | ||
|
|
||
| ```@raw html | ||
| </details> | ||
|
|
@@ -48,7 +48,7 @@ pullback_at(::typeof(sin), x, y, ȳ) = ȳ * cos(x) | |
| ```julia | ||
| σ(x) = 1/(1 + exp(-x)) # = exp(x) / (1 + exp(x)) | ||
| y = σ(x) | ||
| pullback_at(::typeof(σ), x, y, ȳ) = ȳ * y * σ(-x) # = ȳ * σ(x) * σ(-x) | ||
| pullback_at(::typeof(σ), x, y, ȳ) = ȳ * y * σ(-x) # = ȳ * σ(x) * σ(-x) | ||
| ``` | ||
| Notice that in `pullback_at` we are not only using input `x` but also using the primal output `y` . | ||
| This is a nice bit of symmetry that shows up around `exp`. | ||
|
|
@@ -130,7 +130,7 @@ So we are talking about a 30-40% speed-up from these optimizations.[^4] | |
| It is faster to compute `sin` and `cos` at the same time via `sincos` than it is to compute them one after the other. | ||
| And it is faster to reuse the `exp(x)` in computing `σ(x)` and `σ(-x)`. | ||
| How can we incorporate this insight into our system? | ||
| We know we can compute both of these in the primal — because they only depend on `x` and not on `ȳ` — but there is nowhere to put them that is accessible both to the primal pass and the gradient pass code. | ||
| We know we can compute both of these in the primal — because they only depend on `x` and not on `ȳ` — but there is nowhere to put them that is accessible both to the primal pass and the gradient pass code. | ||
|
|
||
| What if we introduced some variable called `intermediates` that is also recorded onto the tape during the primal pass? | ||
| We would need to be able to modify the primal pass to do this, so that we can actually put the data into the `intermediates`. | ||
|
|
@@ -140,7 +140,7 @@ So that would look like: | |
| ```julia | ||
| y = f(x) # primal program | ||
| y, intermediates = augmented_primal(f, x) | ||
| x̄ = pullback_at(f, x, y, ȳ, intermediates) | ||
| x̄ = pullback_at(f, x, y, ȳ, intermediates) | ||
| ``` | ||
|
|
||
| ```@raw html | ||
|
|
@@ -152,7 +152,7 @@ function augmented_primal(::typeof(sin), x) | |
| return y, (; cx=cx) # use a NamedTuple for the intermediates | ||
| end | ||
|
|
||
| pullback_at(::typeof(sin), x, y, ȳ, intermediates) = ȳ * intermediates.cx | ||
| pullback_at(::typeof(sin), x, y, ȳ, intermediates) = ȳ * intermediates.cx | ||
| ``` | ||
| ```@raw html | ||
| </details> | ||
|
|
@@ -168,7 +168,7 @@ function augmented_primal(::typeof(σ), x) | |
| return y, (; ex=ex) # use a NamedTuple for the intermediates | ||
| end | ||
|
|
||
| pullback_at(::typeof(σ), x, y, ȳ, intermediates) = ȳ * y / (1 + intermediates.ex) | ||
| pullback_at(::typeof(σ), x, y, ȳ, intermediates) = ȳ * y / (1 + intermediates.ex) | ||
| ``` | ||
| ```@raw html | ||
| </details> | ||
|
|
@@ -202,7 +202,7 @@ So changing our API we have: | |
| ```julia | ||
| y = f(x) # primal program | ||
| y, pb = augmented_primal(f, x) | ||
| x̄ = pullback_at(pb, ȳ) | ||
| x̄ = pullback_at(pb, ȳ) | ||
| ``` | ||
| which is much cleaner. | ||
|
|
||
|
|
@@ -215,7 +215,7 @@ function augmented_primal(::typeof(sin), x) | |
| return y, PullbackMemory(sin; cx=cx) | ||
| end | ||
|
|
||
| pullback_at(pb::PullbackMemory{typeof(sin)}, ȳ) = ȳ * pb.cx | ||
| pullback_at(pb::PullbackMemory{typeof(sin)}, ȳ) = ȳ * pb.cx | ||
| ``` | ||
| ```@raw html | ||
| </details> | ||
|
|
@@ -231,7 +231,7 @@ function augmented_primal(::typeof(σ), x) | |
| return y, PullbackMemory(σ; y=y, ex=ex) | ||
| end | ||
|
|
||
| pullback_at(pb::PullbackMemory{typeof(σ)}, ȳ) = ȳ * pb.y / (1 + pb.ex) | ||
| pullback_at(pb::PullbackMemory{typeof(σ)}, ȳ) = ȳ * pb.y / (1 + pb.ex) | ||
| ``` | ||
| ```@raw html | ||
| </details> | ||
|
|
@@ -242,13 +242,13 @@ That now looks much simpler; `pullback_at` only ever has 2 arguments. | |
| One way we could make it nicer to use is by making `PullbackMemory` a callable object. | ||
| Conceptually the `PullbackMemory` is a fixed thing it the contents of the tape for a particular operation. | ||
| It is fully determined by the end of the primal pass. | ||
| The during the gradient (reverse) pass the `PullbackMemory` is used to successively compute the `ȳ` argument. | ||
| The during the gradient (reverse) pass the `PullbackMemory` is used to successively compute the `ȳ` argument. | ||
| So it makes sense to make `PullbackMemory` a callable object that acts on the sensitivity. | ||
| We can do that via call overloading: | ||
| ```julia | ||
| y = f(x) # primal program | ||
| y, pb = augmented_primal(f, x) | ||
| x̄ = pb(ȳ) | ||
| x̄ = pb(ȳ) | ||
| ``` | ||
|
|
||
| ```@raw html | ||
|
|
@@ -259,7 +259,7 @@ function augmented_primal(::typeof(sin), x) | |
| y, cx = sincos(x) | ||
| return y, PullbackMemory(sin; cx=cx) | ||
| end | ||
| (pb::PullbackMemory)(ȳ) = ȳ * pb.cx | ||
| (pb::PullbackMemory{typeof(sin)})(ȳ) = ȳ * pb.cx | ||
| ``` | ||
|
|
||
| ```@raw html | ||
|
|
@@ -276,14 +276,14 @@ function augmented_primal(::typeof(σ), x) | |
| return y, PullbackMemory(σ; y=y, ex=ex) | ||
| end | ||
|
|
||
| (pb::PullbackMemory{typeof(σ)})(ȳ) = ȳ * pb.y / (1 + pb.ex) | ||
| (pb::PullbackMemory{typeof(σ)})(ȳ) = ȳ * pb.y / (1 + pb.ex) | ||
| ``` | ||
| ```@raw html | ||
| </details> | ||
| ``` | ||
|
|
||
| Let's recap what we have done here. | ||
| We now have an object `pb` that acts on the cotangent of the output of the primal `ȳ` to give us the cotangent of the input of the primal function `x̄`. | ||
| We now have an object `pb` that acts on the cotangent of the output of the primal `ȳ` to give us the cotangent of the input of the primal function `x̄`. | ||
| _`pb` is not just the **memory** of state required for the `pullback`, it **is** the pullback._ | ||
|
|
||
| We have one final thing to do, which is to think about how we make the code easy to modify. | ||
|
|
@@ -298,15 +298,15 @@ function augmented_primal(::typeof(sin), x) | |
| y = sin(x) | ||
| return y, PullbackMemory(sin; x=x) | ||
| end | ||
| (pb::PullbackMemory)(ȳ) = ȳ * cos(pb.x) | ||
| (pb::PullbackMemory)(ȳ) = ȳ * cos(pb.x) | ||
| ``` | ||
| To go from that to: | ||
| ```julia | ||
| function augmented_primal(::typeof(sin), x) | ||
| y, cx = sincos(x) | ||
| return y, PullbackMemory(sin; cx=cx) | ||
| end | ||
| (pb::PullbackMemory)(ȳ) = ȳ * pb.cx | ||
| (pb::PullbackMemory)(ȳ) = ȳ * pb.cx | ||
| ``` | ||
| ```@raw html | ||
| </details> | ||
|
|
@@ -320,7 +320,7 @@ function augmented_primal(::typeof(σ), x) | |
| y = σ(x) | ||
| return y, PullbackMemory(σ; y=y, x=x) | ||
| end | ||
| (pb::PullbackMemory{typeof(σ)})(ȳ) = ȳ * pb.y * σ(-pb.x) | ||
| (pb::PullbackMemory{typeof(σ)})(ȳ) = ȳ * pb.y * σ(-pb.x) | ||
| ``` | ||
| to get to: | ||
| ```julia | ||
|
|
@@ -329,7 +329,7 @@ function augmented_primal(::typeof(σ), x) | |
| y = ex/(1 + ex) | ||
| return y, PullbackMemory(σ; y=y, ex=ex) | ||
| end | ||
| (pb::PullbackMemory{typeof(σ)})(ȳ) = ȳ * pb.y/(1 + pb.ex) | ||
| (pb::PullbackMemory{typeof(σ)})(ȳ) = ȳ * pb.y/(1 + pb.ex) | ||
| ``` | ||
| ```@raw html | ||
| </details> | ||
|
|
@@ -344,7 +344,7 @@ We need to make a series of changes: | |
| It's important these parts all stay in sync. | ||
| It's not too bad for this simple example with just one or two things to remember. | ||
| For more complicated multi-argument functions, which we will show below, you often end up needing to remember half a dozen things, like sizes and indices relating to each input/output, so it gets a little more fiddly to make sure you remember all the things you need to and give them the same name in both places. | ||
| _Is there a way we can automatically just have all the things we use remembered for us?_ | ||
| _Is there a way we can automatically just have all the things we use remembered for us?_ | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and added a space here |
||
| Surprisingly for such a specific request, there actually is: a closure. | ||
|
|
||
| A closure in Julia is a callable structure that automatically contains a field for every object from its parent scope that is used in its body. | ||
|
|
@@ -357,7 +357,7 @@ Replacing `PullbackMemory` with a closure that works the same way lets us avoid | |
| ```julia | ||
| function augmented_primal(::typeof(sin), x) | ||
| y, cx = sincos(x) | ||
| pb = ȳ -> cx * ȳ # pullback closure. closes over `cx` | ||
| pb = ȳ -> cx * ȳ # pullback closure. closes over `cx` | ||
| return y, pb | ||
| end | ||
| ``` | ||
|
|
@@ -372,7 +372,7 @@ end | |
| function augmented_primal(::typeof(σ), x) | ||
| ex = exp(x) | ||
| y = ex / (1 + ex) | ||
| pb = ȳ -> ȳ * y / (1 + ex) # pullback closure. closes over `y` and `ex` | ||
| pb = ȳ -> ȳ * y / (1 + ex) # pullback closure. closes over `y` and `ex` | ||
| return y, pb | ||
| end | ||
| ``` | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just changed this