### What

In this notebook, we build a simple prototype of source-to-source reverse-mode automatic differentiation. This is the same strategy used by `Zygote.jl`.

We opt to implement **reverse mode** automatic differentiation because it efficiently computes the gradient of functions with a (very) large number of inputs and a single output. This kind of problem shows up in all kinds of optimization, most notably machine learning.

We also choose to do **source-to-source** transformation. This is substantially more complex to implement than dynamic systems, which build up a computation graph and then differentiate through it on every program run. The main benefit is that our gradient computation can be compiled ahead of time and optimized down by LLVM, instead of (expensively) re-built and (poorly) optimized every time.

This is also better than trace-based systems, which usually end up with a huge number of traced operations. Since there are so many operations, it's infeasible to analyze with optimization algorithms requiring `O(N^2)` time, and we're unable to optimize effectively anyway. (source: Zygote paper)

### Version 1: Basics

We begin by building a program to automatically differentiate `simple_math`.

In [1]:
function simple_math(x, y)
    x * y + sin(x)
end

simple_math (generic function with 1 method)

Our automatic differentiator operates at the SSA-level: it deals with a simplified structure, called static single-assignment form, where variables can only be assigned once, and control flow is simplified into only simple branch/goto statements.

Julia does this pass [internally](https://docs.julialang.org/en/v1/devdocs/ssair/), but doesn't expose the details, so we use the `IRTools.jl` library instead. This is the same one used by Zygote.

Converting `simple_math` to SSA form using `IRTools.jl`, we have:

In [2]:
using IRTools

In [3]:
IRTools.@code_ir simple_math(3, 5)

1: (%1, %2, %3)
  %4 = %2 * %3
  %5 = Main.sin(%2)
  %6 = %4 + %5
  return %6

For our first version, our algorithm is as follows:
* Find the return value (`%6` here), and set its adjoint to 1.
* Iterate through the program backwards and compute the adjoints `gN` for each `%N`.

For example, given `%6 = %4 + %5`, we want to convert it (using the sum rule for derivatives) into
```
g4 += g6
g5 += g6
```
which possible involves also creating `g4` and `g5`.

As we iterate through computations, we also keep a map from variables to the single static assignment which contains their most recently adjoint (`x_to_gx` in the code).

In [4]:
include("autodiff.jl")

gradient_1 (generic function with 1 method)

If you want to see it explicitly, the output of this transformation on `simple_math` is
```
1: (%1, %2, %3)
  %4 = %2 * %3
  %5 = Main.sin(%2)
  %6 = %4 + %5
  %7 = 1
  %8 = 0
  %9 = 0
  %10 = %8 + %7
  %11 = %9 + %7
  %12 = 0
  %13 = cos(%2)
  %14 = %11 * %13
  %15 = %12 + %14
  %16 = 0
  %17 = 0
  %18 = %10 * %3
  %19 = %15 + %18
  %20 = %10 * %2
  %21 = %17 + %20
  %22 = Core.tuple(%6, %19, %21)
  return %22
```

In [5]:
gradient_1(simple_math, 3, 5)

(4.010007503399555, 3)

We check the same result with Zygote:

In [6]:
using Zygote

In [7]:
Zygote.gradient(simple_math, 3, 5)

(4.010007503399555, 3.0)

Hooray! This is as the same as what Zygote reports for our function.

### Improvement 2: Generality & Extensibility

Here's what our addition rule currently looks like:

In [8]:
fn = ""

""

In [9]:
if fn == "Main.:+"
    # where is the adjoint of the assignee?
    g_assignee = x_to_gx[var(assignee)]

    # either get the variable that we've accumulated adjoints into so far, 
    # or if it doesn't exist, initialize it to zero.
    g_arg1_old = get(x_to_gx, ex.args[2], IRTools.push!(ir, :(0)))
    g_arg2_old = get(x_to_gx, ex.args[3], IRTools.push!(ir, :(0)))

    # finally, add the new bit (g_assignee) and accumulate it into the new variable.
    g_arg1_new = IRTools.push!(ir, :($(g_arg1_old) + $(g_assignee)))
    g_arg2_new = IRTools.push!(ir, :($(g_arg2_old) + $(g_assignee)))

    x_to_gx[ex.args[2]] = g_arg1_new
    x_to_gx[ex.args[3]] = g_arg2_new
end

And we've repeated nearly the exact same thing for our multiplication rule:

In [10]:
if fn == "Main.:*"
    g_assignee = x_to_gx[var(assignee)]

    g_arg1_old = get(x_to_gx, ex.args[2], IRTools.push!(ir, :(0)))
    g_arg2_old = get(x_to_gx, ex.args[3], IRTools.push!(ir, :(0)))

    g_arg1_new = IRTools.push!(ir, :($(g_arg1_old) + $(g_assignee) * $(ex.args[3])))
    g_arg2_new = IRTools.push!(ir, :($(g_arg2_old) + $(g_assignee) * $(ex.args[2])))

    x_to_gx[ex.args[2]] = g_arg1_new
    x_to_gx[ex.args[3]] = g_arg2_new
end

Notice the similarity in the `g_arg1_new` line. They're almost exactly the same for `+` and `*`, with the only exception being the factor that `g_assignee` is multiplied by.

We take advantage of this (general, it turns out) property to express all these rules more succintly with the following snippet.

In [36]:
if ex.head == :call
    g_assignee = x_to_gx[var(assignee)]
    for (i, a) in Iterators.enumerate(ex.args[2:length(ex.args)])
        g_old = get(x_to_gx, a, IRTools.push!(ir, :(0)))
        factor = derivative_rule(eval(ex.args[1]), ex.args[2: length(ex.args)], i)
        g_new = IRTools.push!(ir, :($(g_old) + $(g_assignee) * $(factor)))
        x_to_gx[a] = g_new
    end
end

LoadError: UndefVarError: ex not defined

The magic happens in `derivative_rule` via multiple dispatch. For each operation, we implement the derivative rule with respect to the `i`th argument. Here are the three functions we implemented in version 1, more succinctly this time.

In [12]:
function derivative_rule(::typeof(sin), args, i)
    return :(cos($(args[1])))
end

function derivative_rule(::typeof(+), args, i)
    return 1
end

# not quite the * rule in full generality, but good enough for our two-argument purposes.
function derivative_rule(::typeof(*), args, i)
    if i == 1
        return args[2]
    elseif i == 2
        return args[1]
    end
end

derivative_rule (generic function with 3 methods)

And it still works: 

In [13]:
include("autodiff_v2.jl")
gradient_2(simple_math, 3, 5)

(4.010007503399555, 3)

Great! Now we have a basic reverse mode automatic differentiator, capable of handling long sequences of manipulations of primitive operations. And if the user wants to use a new function, then can define derivative_rule for it.

### Improvement 3: Recursive decent

Consider the following series of functions.

In [14]:
function mult(a, b)
    a * b
end

mult (generic function with 1 method)

In [15]:
function add(a, b)
    a + b
end

add (generic function with 1 method)

In [16]:
function composite(a, b)
    mult(a, add(a, b))
end

composite (generic function with 1 method)

In [34]:
#function derivative_rule(::typeof(mult), args, i)
#    return gradient_2(mult, args...)[i]
#end

Our current implementation will fail to differentiate `composite` (with `MethodError: no method matching derivative_rule(::typeof(mult), ::Vector{Any}, ::Int64)
`) since `add` and `mult` aren't functions it knows how to differentiate. 

In [35]:
gradient_2(composite, 3, 5)

LoadError: MethodError: no method matching *(::IRTools.Inner.Variable, ::IRTools.Inner.Variable)
[0mClosest candidates are:
[0m  *(::Any, ::Any, [91m::Any[39m, [91m::Any...[39m) at operators.jl:591
[0m  *([91m::SpecialFunctions.SimplePoly[39m, ::Any) at ~/.julia/packages/SpecialFunctions/hefUc/src/expint.jl:8
[0m  *([91m::ChainRulesCore.AbstractThunk[39m, ::Any) at ~/.julia/packages/ChainRulesCore/Z4Jry/src/tangent_arithmetic.jl:125
[0m  ...

But we can teach it to recursively descend into `add` and `mult`, it'll discover that they're made of functions we recognize, so we should be able to differentiate the entire thing.

That seems simple: we'll just introduce a simple fallback function, like this, which catches derivatives of previously unseen functions:

In [18]:
function derivative_rule(unknown_function, args, i)
    return gradient_2(unknown_function, args...)[i]
end

derivative_rule (generic function with 4 methods)

Then, instead of computing `derivative_rule` at compile time, we'll insert the `derivative_rule` symbolically and compute it at runtime. This should allow us to do things like infinite recursion without issue. Here's some SSA code generated by `gradient_3`.

```
1: (%1, %2, %3)
  %4 = %2 * %3
  %5 = 1
  %6 = 0
  %7 = derivative_rule(*, Any[%2, %3], i)
  %8 = %5 * %7
  %9 = %6 + %8
  %10 = 0
  %11 = derivative_rule(*, Any[%2, %3], i)
  %12 = %5 * %11
  %13 = %10 + %12
  %14 = Core.tuple(%4, %9, %13)
  return %14
```

In [30]:
include("autodiff_v3.jl")

derivative_rule (generic function with 4 methods)

Now we go to call our new SSA code:

In [31]:
gradient_3(composite, 3, 5)

LoadError: MethodError: no method matching *(::IRTools.Inner.Variable, ::IRTools.Inner.Variable)
[0mClosest candidates are:
[0m  *(::Any, ::Any, [91m::Any[39m, [91m::Any...[39m) at operators.jl:591
[0m  *([91m::SpecialFunctions.SimplePoly[39m, ::Any) at ~/.julia/packages/SpecialFunctions/hefUc/src/expint.jl:8
[0m  *([91m::ChainRulesCore.AbstractThunk[39m, ::Any) at ~/.julia/packages/ChainRulesCore/Z4Jry/src/tangent_arithmetic.jl:125
[0m  ...

Agh! We got `MethodError: no method matching *(::IRTools.Inner.Variable, ::IRTools.Inner.Variable)
`. So we successfully recursed into `mult` and `add` and transformed the syntax, but then something fell apart when we tried to run it. What happened?

Here's where we run into one of the first real issues with our divergence from Zygote's design.

Our design choice was to build around a function `gradient`, which takes a function (`composite`) and two values (`3` and `5`), and returns the gradients of each argument. But we encounter a problem: internally, we need to be able to compute the gradient without inputting actual values. That means that we need to insert the `derivative_rule` function as a symbol, which is fine. But IRTools doesn't seem to support "deep-resolving" variables in function calls like we'd require to succesfully pass in and then resolve e.g. lists of symbolic variables.

For example, I was hoping for / expecting `%7 = derivative_rule(*, Any[%2, %3], i)` to pass in the values for `%2` and `%3` at runtime, but instead we pass in the literal `::IRTools.Inner.Variable`. And then that throws `MethodError: no method matching *(::IRTools.Inner.Variable, ::IRTools.Inner.Variable)`, as expected.

My current understanding is that this is a somewhat unfortunate implementation shortcoming but not fundamental limitation of `IRTools`. However, it wouldn't surprise me if there's some more fundamental reason this is impossible that I'm not seeing here.

Zygote resolves this problem by building around a function called a pullback `J`, which returns an output and a gradient function. Then it can symbolically call those gradient functions in its newly constructed gradient function without issue.

This is a bummer, because it means we're unlikely to be able to pull off recursion, which fundamentally requires the same trick (defer `derivative_rule` evaluation until runtime).

### Conclusion
This project started out with the intention of cloning `Zygote.jl`. It's no fun to clone things *exactly*, so I figured that instead of returning a value and a gradient (pullback) function, as Zygote does, I'd just compute both the value and the gradient values inline and then return both. To do so, I'd manipulate a simplified syntax (SSA-form generated via `IRTools`), as Zygote does, but inject my gradient calculations between each line instead of all at the end.

This works well up to computing the derivative of 

At the time, I thought that would be a basically inconsequential choice, but it turned out to be more substantial than expected.

When I started this project I was expecting to learn a lot about automatic differentiators. And in some sense I did--I now have a pretty good understanding of the design space and where Zygote sits in it. But really I spent most of my time (a way larger % than expected) learning how to do metaprogramming and debugging a small number of extremely dense snippets (for example, the six lines of code inside `if ex.head == :call` took *forever* to get correct.) Lots of quoting and unquoting issues in all directions.

### Quick note on effort allocation
I figured I'd write a quick note on where exactly my ~3.5 psets worth of time and effort went on this project, since some parts took longer than expected and some shorter than expected.

* Understanding how to implement reverse mode AD was pretty fast. I read up on it and did a few chalkboard examples, and then felt like I understood it. ~1 hour.
* Really understanding Julia's metaprogramming (quoting, `$()`, `eval`, etc) took ~2 hours. It's been a while since I've dealt with a lispy language.
* I spent 2-3 hours trying to inspect Zygote itself with a debugger to understand it, but this was a nightmare. One of the questions I'd like to ask someone in the Julia lab is how the pros inspect libraries in Julia. Generated functions, multiple dispatch, etc., made it extremely hard for me to track down who called whom.
* I spent 2 hours reading and trying to understand the Zygote paper, which was actually pretty useful. In particular, when I started this project, I didn't really understand the design tradeoffs made by Zygote vs e.g. PyTorch.
* I spent 10+ hours learning `IRTools`, and debugging various issues with it. For example, it doesn't correctly parse injected expressions like `Core.tuple(a)`, so instead we're using `IRTools.Inner.Statement(:($(GlobalRef(Core, :tuple))($(outs...),)))`, which took me forever to figure out. This was a much, much larger timesink than I expected or is probably obvious from looking at the project.
* I spent 2-3 total hours writing.

### References

* https://arxiv.org/pdf/1810.07951.pdf. This is the original Zygote paper, and provides useful context on the tradeoffs Zygote makes in design space.
* https://fluxml.ai/Zygote.jl/latest/internals/. This page gives an overview of how Zygote works, and includes an example of a loop being unrolled.
* https://rufflewind.com/2016-12-30/reverse-mode-automatic-differentiation. This post helped me finally understand how reverse-mode autodiff works.