From 6767d67c74ee887e4e72fb209b647b78a35daa93 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Wed, 17 Mar 2021 11:59:33 +0000 Subject: [PATCH] fix some typos in docs --- docs/src/design/changing_the_primal.md | 46 +++++++++++++------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/docs/src/design/changing_the_primal.md b/docs/src/design/changing_the_primal.md index 5ffa1e2b1..9d875329f 100644 --- a/docs/src/design/changing_the_primal.md +++ b/docs/src/design/changing_the_primal.md @@ -25,7 +25,7 @@ We will call this function `pullback_at`, as it pulls back the sensitivity at a To make this concrete: ```julia y = f(x) # primal program -x̄ = pullback_at(f, x, y, ȳ) +x̄ = pullback_at(f, x, y, ȳ) ``` Let's illustrate this with examples for `sin` and for the [logistic sigmoid](https://en.wikipedia.org/wiki/Logistic_function#Derivative). @@ -34,9 +34,9 @@ Let's illustrate this with examples for `sin` and for the [logistic sigmoid](htt ``` ```julia y = sin(x) -pullback_at(::typeof(sin), x, y, ȳ) = ȳ * cos(x) +pullback_at(::typeof(sin), x, y, ȳ) = ȳ * cos(x) ``` -`pullback_at` uses the primal input `x`, and the sensitivity being pulled back `ȳ`. +`pullback_at` uses the primal input `x`, and the sensitivity being pulled back `ȳ`. ```@raw html @@ -48,7 +48,7 @@ pullback_at(::typeof(sin), x, y, ȳ) = ȳ * cos(x) ```julia σ(x) = 1/(1 + exp(-x)) # = exp(x) / (1 + exp(x)) y = σ(x) -pullback_at(::typeof(σ), x, y, ȳ) = ȳ * y * σ(-x) # = ȳ * σ(x) * σ(-x) +pullback_at(::typeof(σ), x, y, ȳ) = ȳ * y * σ(-x) # = ȳ * σ(x) * σ(-x) ``` Notice that in `pullback_at` we are not only using input `x` but also using the primal output `y` . This is a nice bit of symmetry that shows up around `exp`. @@ -130,7 +130,7 @@ So we are talking about a 30-40% speed-up from these optimizations.[^4] It is faster to compute `sin` and `cos` at the same time via `sincos` than it is to compute them one after the other. And it is faster to reuse the `exp(x)` in computing `σ(x)` and `σ(-x)`. How can we incorporate this insight into our system? -We know we can compute both of these in the primal — because they only depend on `x` and not on `ȳ` — but there is nowhere to put them that is accessible both to the primal pass and the gradient pass code. +We know we can compute both of these in the primal — because they only depend on `x` and not on `ȳ` — but there is nowhere to put them that is accessible both to the primal pass and the gradient pass code. What if we introduced some variable called `intermediates` that is also recorded onto the tape during the primal pass? We would need to be able to modify the primal pass to do this, so that we can actually put the data into the `intermediates`. @@ -140,7 +140,7 @@ So that would look like: ```julia y = f(x) # primal program y, intermediates = augmented_primal(f, x) -x̄ = pullback_at(f, x, y, ȳ, intermediates) +x̄ = pullback_at(f, x, y, ȳ, intermediates) ``` ```@raw html @@ -152,7 +152,7 @@ function augmented_primal(::typeof(sin), x) return y, (; cx=cx) # use a NamedTuple for the intermediates end -pullback_at(::typeof(sin), x, y, ȳ, intermediates) = ȳ * intermediates.cx +pullback_at(::typeof(sin), x, y, ȳ, intermediates) = ȳ * intermediates.cx ``` ```@raw html @@ -168,7 +168,7 @@ function augmented_primal(::typeof(σ), x) return y, (; ex=ex) # use a NamedTuple for the intermediates end -pullback_at(::typeof(σ), x, y, ȳ, intermediates) = ȳ * y / (1 + intermediates.ex) +pullback_at(::typeof(σ), x, y, ȳ, intermediates) = ȳ * y / (1 + intermediates.ex) ``` ```@raw html @@ -202,7 +202,7 @@ So changing our API we have: ```julia y = f(x) # primal program y, pb = augmented_primal(f, x) -x̄ = pullback_at(pb, ȳ) +x̄ = pullback_at(pb, ȳ) ``` which is much cleaner. @@ -215,7 +215,7 @@ function augmented_primal(::typeof(sin), x) return y, PullbackMemory(sin; cx=cx) end -pullback_at(pb::PullbackMemory{typeof(sin)}, ȳ) = ȳ * pb.cx +pullback_at(pb::PullbackMemory{typeof(sin)}, ȳ) = ȳ * pb.cx ``` ```@raw html @@ -231,7 +231,7 @@ function augmented_primal(::typeof(σ), x) return y, PullbackMemory(σ; y=y, ex=ex) end -pullback_at(pb::PullbackMemory{typeof(σ)}, ȳ) = ȳ * pb.y / (1 + pb.ex) +pullback_at(pb::PullbackMemory{typeof(σ)}, ȳ) = ȳ * pb.y / (1 + pb.ex) ``` ```@raw html @@ -242,13 +242,13 @@ That now looks much simpler; `pullback_at` only ever has 2 arguments. One way we could make it nicer to use is by making `PullbackMemory` a callable object. Conceptually the `PullbackMemory` is a fixed thing it the contents of the tape for a particular operation. It is fully determined by the end of the primal pass. -The during the gradient (reverse) pass the `PullbackMemory` is used to successively compute the `ȳ` argument. +The during the gradient (reverse) pass the `PullbackMemory` is used to successively compute the `ȳ` argument. So it makes sense to make `PullbackMemory` a callable object that acts on the sensitivity. We can do that via call overloading: ```julia y = f(x) # primal program y, pb = augmented_primal(f, x) -x̄ = pb(ȳ) +x̄ = pb(ȳ) ``` ```@raw html @@ -259,7 +259,7 @@ function augmented_primal(::typeof(sin), x) y, cx = sincos(x) return y, PullbackMemory(sin; cx=cx) end -(pb::PullbackMemory)(ȳ) = ȳ * pb.cx +(pb::PullbackMemory{typeof(sin)})(ȳ) = ȳ * pb.cx ``` ```@raw html @@ -276,14 +276,14 @@ function augmented_primal(::typeof(σ), x) return y, PullbackMemory(σ; y=y, ex=ex) end -(pb::PullbackMemory{typeof(σ)})(ȳ) = ȳ * pb.y / (1 + pb.ex) +(pb::PullbackMemory{typeof(σ)})(ȳ) = ȳ * pb.y / (1 + pb.ex) ``` ```@raw html ``` Let's recap what we have done here. -We now have an object `pb` that acts on the cotangent of the output of the primal `ȳ` to give us the cotangent of the input of the primal function `x̄`. +We now have an object `pb` that acts on the cotangent of the output of the primal `ȳ` to give us the cotangent of the input of the primal function `x̄`. _`pb` is not just the **memory** of state required for the `pullback`, it **is** the pullback._ We have one final thing to do, which is to think about how we make the code easy to modify. @@ -298,7 +298,7 @@ function augmented_primal(::typeof(sin), x) y = sin(x) return y, PullbackMemory(sin; x=x) end -(pb::PullbackMemory)(ȳ) = ȳ * cos(pb.x) +(pb::PullbackMemory)(ȳ) = ȳ * cos(pb.x) ``` To go from that to: ```julia @@ -306,7 +306,7 @@ function augmented_primal(::typeof(sin), x) y, cx = sincos(x) return y, PullbackMemory(sin; cx=cx) end -(pb::PullbackMemory)(ȳ) = ȳ * pb.cx +(pb::PullbackMemory)(ȳ) = ȳ * pb.cx ``` ```@raw html @@ -320,7 +320,7 @@ function augmented_primal(::typeof(σ), x) y = σ(x) return y, PullbackMemory(σ; y=y, x=x) end -(pb::PullbackMemory{typeof(σ)})(ȳ) = ȳ * pb.y * σ(-pb.x) +(pb::PullbackMemory{typeof(σ)})(ȳ) = ȳ * pb.y * σ(-pb.x) ``` to get to: ```julia @@ -329,7 +329,7 @@ function augmented_primal(::typeof(σ), x) y = ex/(1 + ex) return y, PullbackMemory(σ; y=y, ex=ex) end -(pb::PullbackMemory{typeof(σ)})(ȳ) = ȳ * pb.y/(1 + pb.ex) +(pb::PullbackMemory{typeof(σ)})(ȳ) = ȳ * pb.y/(1 + pb.ex) ``` ```@raw html @@ -344,7 +344,7 @@ We need to make a series of changes: It's important these parts all stay in sync. It's not too bad for this simple example with just one or two things to remember. For more complicated multi-argument functions, which we will show below, you often end up needing to remember half a dozen things, like sizes and indices relating to each input/output, so it gets a little more fiddly to make sure you remember all the things you need to and give them the same name in both places. -_Is there a way we can automatically just have all the things we use remembered for us?_ +_Is there a way we can automatically just have all the things we use remembered for us?_ Surprisingly for such a specific request, there actually is: a closure. A closure in Julia is a callable structure that automatically contains a field for every object from its parent scope that is used in its body. @@ -357,7 +357,7 @@ Replacing `PullbackMemory` with a closure that works the same way lets us avoid ```julia function augmented_primal(::typeof(sin), x) y, cx = sincos(x) - pb = ȳ -> cx * ȳ # pullback closure. closes over `cx` + pb = ȳ -> cx * ȳ # pullback closure. closes over `cx` return y, pb end ``` @@ -372,7 +372,7 @@ end function augmented_primal(::typeof(σ), x) ex = exp(x) y = ex / (1 + ex) - pb = ȳ -> ȳ * y / (1 + ex) # pullback closure. closes over `y` and `ex` + pb = ȳ -> ȳ * y / (1 + ex) # pullback closure. closes over `y` and `ex` return y, pb end ```