### What

In this notebook, we build a simple prototype of source-to-source reverse-mode automatic differentiation. This is the same strategy used by `Zygote.jl`.

We opt to implement **reverse mode** automatic differentiation because it efficiently computes the gradient of functions with a (very) large number of inputs and a single output. This kind of problem shows up in all kinds of optimization, most notably machine learning.

We also choose to do **source-to-source** transformation. This is substantially more complex to implement than dynamic systems, which build up a computation graph and then differentiate through it on every program run. The main benefit is that our gradient computation can be compiled ahead of time and optimized down by LLVM, instead of (expensively) re-built and (poorly) optimized every time.

This is also better than trace-based systems, which usually end up with a huge number of traced operations. Since there are so many operations, it's infeasible to analyze with optimization algorithms requiring `O(N^2)` time, and we're unable to optimize effectively anyway. (source: Zygote paper)

### Version 1: Basics

We begin by building a program to automatically differentiate `simple_math`.

In [1]:
function simple_math(x, y)
    x * y + sin(x)
end

simple_math (generic function with 1 method)

Our automatic differentiator operates at the SSA-level: it deals with a simplified structure, called static single-assignment form, where variables can only be assigned once, and control flow is simplified into only simple branch/goto statements.

Julia does this pass [internally](https://docs.julialang.org/en/v1/devdocs/ssair/), but doesn't expose the details, so we use the `IRTools.jl` library instead. This is the same one used by Zygote.

Converting `simple_math` to SSA form using `IRTools.jl`, we have:

In [2]:
using IRTools

In [3]:
IRTools.@code_ir simple_math(3, 5)

1: (%1, %2, %3)
  %4 = %2 * %3
  %5 = Main.sin(%2)
  %6 = %4 + %5
  return %6

If you look closely, this is functionally identical to `simple_math` above.

So how do we differentiate it? Our algorithm is as follows:
* Find the return value (`%6` here), and set its adjoint to 1.
* Iterate through the program backwards and compute the adjoints `gN` for each `%N`.

For example, given `%6 = %4 + %5`, we want to convert it (using the sum rule for derivatives) into
```
g4 += g6
g5 += g6
```
which possible involves also creating `g4` and `g5`.

As we iterate through computations, we also keep a map from variables to the single static assignment which contains their most recently adjoint (`x_to_gx` in the code).

A quick note: Zygote constructs a gradient function and returns that:

```
return a, function(a) 
    ... compute gradient ...
    end
```

whereas we just compute the gradient inline while we compute the value and return them both.

```
return a, ga
```

There's no particular reason for this difference, other than that it seemed harmless and it's easiest to learn if you don't clone things _exactly_!

In [4]:
include("autodiff.jl")

gradient_1 (generic function with 1 method)

Anyway, what does this transformation look like in practice?

The output of this transformation on `simple_math` is
```
1: (%1, %2, %3)
  %4 = %2 * %3
  %5 = Main.sin(%2)
  %6 = %4 + %5
  %7 = 1
  %8 = 0
  %9 = 0
  %10 = %8 + %7
  %11 = %9 + %7
  %12 = 0
  %13 = cos(%2)
  %14 = %11 * %13
  %15 = %12 + %14
  %16 = 0
  %17 = 0
  %18 = %10 * %3
  %19 = %15 + %18
  %20 = %10 * %2
  %21 = %17 + %20
  %22 = Core.tuple(%6, %19, %21)
  return %22
```

Running that SSA code, we get:

In [5]:
gradient_1(simple_math, 3, 5)

(4.010007503399555, 3)

We check the same result with Zygote:

In [6]:
using Zygote

In [7]:
Zygote.gradient(simple_math, 3, 5)

(4.010007503399555, 3.0)

Hooray! This is as the same as what Zygote reports for our function, so we've computed the gradient of `simple_math` correctly.

### Improvement 2: Extensibility

Here's what our addition rule currently looks like:

In [8]:
fn = "" #ignore me!

""

In [9]:
if fn == "Main.:+"
    # where is the adjoint of the assignee?
    g_assignee = x_to_gx[var(assignee)]

    # either get the variable that we've accumulated adjoints into so far, 
    # or if it doesn't exist, initialize it to zero.
    g_arg1_old = get(x_to_gx, ex.args[2], IRTools.push!(ir, :(0)))
    g_arg2_old = get(x_to_gx, ex.args[3], IRTools.push!(ir, :(0)))

    # finally, add the new bit (g_assignee) and accumulate it into the new variable.
    g_arg1_new = IRTools.push!(ir, :($(g_arg1_old) + $(g_assignee)))
    g_arg2_new = IRTools.push!(ir, :($(g_arg2_old) + $(g_assignee)))

    x_to_gx[ex.args[2]] = g_arg1_new
    x_to_gx[ex.args[3]] = g_arg2_new
end

And we've repeated nearly the exact same thing for our multiplication rule:

In [10]:
if fn == "Main.:*"
    g_assignee = x_to_gx[var(assignee)]

    g_arg1_old = get(x_to_gx, ex.args[2], IRTools.push!(ir, :(0)))
    g_arg2_old = get(x_to_gx, ex.args[3], IRTools.push!(ir, :(0)))

    g_arg1_new = IRTools.push!(ir, :($(g_arg1_old) + $(g_assignee) * $(ex.args[3])))
    g_arg2_new = IRTools.push!(ir, :($(g_arg2_old) + $(g_assignee) * $(ex.args[2])))

    x_to_gx[ex.args[2]] = g_arg1_new
    x_to_gx[ex.args[3]] = g_arg2_new
end

Notice the similarity in the `g_arg1_new` line. They're almost exactly the same for `+` and `*`, with the only exception being the factor that `g_assignee` is multiplied by.

We take advantage of this (general, it turns out) property to express all these rules more succintly with the following snippet.

In [11]:
if ex.head == :call
    g_assignee = x_to_gx[var(assignee)]
    for (i, a) in Iterators.enumerate(ex.args[2:length(ex.args)])
        g_old = get(x_to_gx, a, IRTools.push!(ir, :(0)))
        factor = derivative_rule(eval(ex.args[1]), ex.args[2: length(ex.args)], i)
        g_new = IRTools.push!(ir, :($(g_old) + $(g_assignee) * $(factor)))
        x_to_gx[a] = g_new
    end
end

LoadError: UndefVarError: ex not defined

The magic happens in `derivative_rule` via multiple dispatch. For each operation, we implement the derivative rule with respect to the `i`th argument. Here are the three functions we implemented in version 1. Notice how much more succinct and self-contained they are this time.

In [12]:
function derivative_rule(::typeof(sin), args, i)
    return :(cos($(args[1])))
end

function derivative_rule(::typeof(+), args, i)
    return 1
end

# not quite the * rule in full generality, but good enough for our two-argument purposes.
function derivative_rule(::typeof(*), args, i)
    if i == 1
        return args[2]
    elseif i == 2
        return args[1]
    end
end

derivative_rule (generic function with 3 methods)

And just to make sure our differentiator still works: 

In [19]:
include("autodiff_v2.jl")
gradient_2(simple_math, 3, 5)

(4.010007503399555, 3)

Great! Now we have a basic reverse mode automatic differentiator, capable of handling long sequences of manipulations of primitive operations. And if the user wants to use a new function, then can define their own derivative_rule for it without reaching into the internals of our library.

### Improvement 3: Recursive decent

Consider the following series of functions.

In [20]:
function mult(a, b)
    a * b
end

mult (generic function with 1 method)

In [21]:
function add(a, b)
    a + b
end

add (generic function with 1 method)

In [22]:
function composite(a, b)
    mult(a, add(a, b))
end

composite (generic function with 1 method)

In [23]:
#function derivative_rule(::typeof(mult), args, i)
#    return gradient_2(mult, args...)[i]
#end

Our current implementation will fail to differentiate `composite` (with `MethodError: no method matching derivative_rule(::typeof(mult), ::Vector{Any}, ::Int64)
`) since `add` and `mult` aren't functions it knows how to differentiate. 

In [24]:
gradient_2(composite, 3, 5)

LoadError: MethodError: no method matching *(::IRTools.Inner.Variable, ::IRTools.Inner.Variable)
[0mClosest candidates are:
[0m  *(::Any, ::Any, [91m::Any[39m, [91m::Any...[39m) at operators.jl:591
[0m  *([91m::SpecialFunctions.SimplePoly[39m, ::Any) at ~/.julia/packages/SpecialFunctions/hefUc/src/expint.jl:8
[0m  *([91m::ChainRulesCore.AbstractThunk[39m, ::Any) at ~/.julia/packages/ChainRulesCore/Z4Jry/src/tangent_arithmetic.jl:125
[0m  ...

But if we can teach it to recursively descend into `add` and `mult`, it'll discover that they're made of functions we recognize, so we should be able to differentiate the entire thing.

That seems simple: we'll just introduce a simple fallback function, like this, which catches derivatives of previously unseen functions:

In [25]:
function derivative_rule(unknown_function, args, i)
    return gradient_2(unknown_function, args...)[i]
end

derivative_rule (generic function with 4 methods)

Then, instead of computing `derivative_rule` at compile time, we'll insert the `derivative_rule` symbolically and compute it at runtime. This should allow us to do things like infinite recursion without issue. Here's some SSA code generated by `gradient_3`.

```
1: (%1, %2, %3)
  %4 = %2 * %3
  %5 = 1
  %6 = 0
  %7 = derivative_rule(*, Any[%2, %3], i)
  %8 = %5 * %7
  %9 = %6 + %8
  %10 = 0
  %11 = derivative_rule(*, Any[%2, %3], i)
  %12 = %5 * %11
  %13 = %10 + %12
  %14 = Core.tuple(%4, %9, %13)
  return %14
```

In [26]:
include("autodiff_v3.jl")

derivative_rule (generic function with 4 methods)

Now we go to call our new SSA code:

In [27]:
gradient_3(composite, 3, 5)

LoadError: MethodError: no method matching *(::IRTools.Inner.Variable, ::IRTools.Inner.Variable)
[0mClosest candidates are:
[0m  *(::Any, ::Any, [91m::Any[39m, [91m::Any...[39m) at operators.jl:591
[0m  *([91m::SpecialFunctions.SimplePoly[39m, ::Any) at ~/.julia/packages/SpecialFunctions/hefUc/src/expint.jl:8
[0m  *([91m::ChainRulesCore.AbstractThunk[39m, ::Any) at ~/.julia/packages/ChainRulesCore/Z4Jry/src/tangent_arithmetic.jl:125
[0m  ...

Agh! We got `MethodError: no method matching *(::IRTools.Inner.Variable, ::IRTools.Inner.Variable)
`. So we successfully recursed into `mult` and `add` and transformed the syntax, but then something fell apart when we tried to run it. What happened?

Here's where we run into one of the first real issues with our divergence from Zygote's design.

Our design choice was to build around a function `gradient`, which takes a function (`composite`) and two values (`3` and `5`), and returns the gradients of each argument. But we encounter a problem: internally, we need to be able to compute the gradient without inputting actual values. That means that we need to insert the `derivative_rule` function as a symbol, which is fine. But IRTools doesn't seem to support "deep-resolving" variables in function calls like we'd require to succesfully pass in and then resolve e.g. lists of symbolic variables.

For example, I was hoping for / expecting `%7 = derivative_rule(*, Any[%2, %3], i)` to pass in the values for `%2` and `%3` at runtime, but instead we pass in the literal `::IRTools.Inner.Variable`. And then that throws `MethodError: no method matching *(::IRTools.Inner.Variable, ::IRTools.Inner.Variable)`, as expected.

My current understanding is that this is a somewhat unfortunate implementation shortcoming but not fundamental limitation of `IRTools`. However, it wouldn't surprise me if there's some more fundamental reason this is impossible that I'm not seeing here.

Zygote resolves this problem by building around a function called a pullback `J`, which returns an output and a gradient function. Then it can symbolically call those gradient functions in its newly constructed gradient function without issue.

This is a bummer, because it means we're unlikely to be able to pull off recursion, which fundamentally requires the same trick (defer `derivative_rule` evaluation until runtime).

### Improvement 4: Conditional

Ok, we got unexpectedly bitten by our cute design choice in recursive descent, but can we pull off conditionals?

Let's look at a program that includes a conditional return, like this one:

In [28]:
function conditional(a, b)
    c = a * b
    if a >= b
        return a * a
    else
        return c * sin(c)
    end    
end

conditional (generic function with 1 method)

Here's its SSA-form IR, which now has basic control flow.

In [29]:
IRTools.@code_ir conditional(1, 2)

1: (%1, %2, %3)
  %4 = %2 * %3
  %5 = %2 >= %3
  br 3 unless %5
2:
  %6 = %2 * %2
  return %6
3:
  %7 = Main.sin(%4)
  %8 = %4 * %7
  return %8

Recall that our earlier algorithm starts at the return value, and then propagates the adjoints up from there. That's immediately a problem for us, since by symbolically analyzing this program, we can't tell whether `return %6` or `return %8` will ultimately be run!

To deal with `N` conditional returns, we could just compute N separate gradients, and then return the one we need. This would be wasteful, since reverse mode autodiff needs to do a pass for every output. 

Instead, we trace control flow (and only control flow! No expensive per-operator traces allowed) take notice of which branches actually run, and then only propagate adjoints through those variables. This is also Zygote's solution.

Since Zygote constructs the gradient function in the function itself, they can collect an inline record of what control flow ran. Since our implementation is slightly different, we first generate an instrumented function, run it to extract the control flow, and then use that to determine which adjoints should contribute to the final output.


But oh no! We get bitten hard by our design decision again. We need to know the final path of control flow to compute our gradients, but we compute our gradients inline. That means that we need to run the function twice (once to extract control flow, and again to calculate gradients), and we incur any side effects twice. If I had realized this at the beginning, I would have just implemented `J` Zygote's way; it's not particularly easier or harder, and it's now clear to me that it has very substantial benefits!

Enough talk. Let's see some tracing!

We handle the tracing by writing `trace_executed_blocks`, which annotates the IR like this:

```
1: (%1, %2, %3)
  %10 = push!(Main.gradient_trace, 1)
  %4 = %2 * %3
  %5 = %2 >= %3
  br 3 unless %5
2:
  %11 = push!(Main.gradient_trace, 2)
  %6 = %2 * %2
  return %6
3:
  %12 = push!(Main.gradient_trace, 3)
  %7 = %2 * %3
  %8 = Main.sin(%7)
  %9 = %7 * %8
  return %9
```

`trace_executed_blocks` then returns the blocks that executed.

In [32]:
include("autodiff_v4.jl")

trace_executed_blocks (generic function with 1 method)

In [36]:
trace_executed_blocks(conditional, 3, 5) # blocks 1 and 3 run when a >= b

2-element Vector{Any}:
 1
 3

In [34]:
trace_executed_blocks(conditional, 5, 3) # blocks 1 and 2 run when a < b

2-element Vector{Any}:
 1
 2

Now that we know which blocks ran, we need to parse the IR to determine which statements to propagate the adjoint through. We find the block enclosed by each statement (and each return value), and only compute gradients for those variables.

In [27]:
statements_to_surrounding_block(conditional, 5, 3)

(Dict{Any, Any}(5 => 1, 4 => 1, 6 => 2, 7 => 3, 9 => 3, 8 => 3), Dict{Any, Any}(2 => 6, 3 => 9))

Then we also need to mangle the block/goto structure, since I don't think IRTools supports inserting SSA code in arbitrary blocks (or if it does, I couldn't figure out how to do it!). So we re-arrange the block structure to put the "true" return statement last, and then inject all our suffix code (dealing with returning the gradient) right at the end of the program.

Side note: IRTools doesn't seem to support a native way to do these branch/block analyses, so we end up using regexes (!) on the stringification. Yuck! :p

Here's what the traced function looks like:
```

Dict{Any, Any}(5 => 1, 4 => 1, 6 => 2, 7 => 3, 9 => 3, 8 => 3) # statement => block
Any[1, 2] # blocks that were run
Dict{Any, Any}(2 => 6, 3 => 9) # block => return variable

skipped 9 # these variabels don't contribute, so we skip them
skipped 8
skipped 7
1: (%1, %2, %3)
  %4 = %2 * %3
  %5 = %2 >= %3
  br 1 unless %5 # notice that this has been mangled to avoid a dangling reference to block 3
2:
  %6 = %2 * %2
  %10 = 1
  %11 = 0
  %12 = %10 * %2
  %13 = %11 + %12
  %14 = 0
  %15 = %10 * %2
  %16 = %13 + %15
  %17 = nothing
  %18 = Core.tuple(%6, %16, %17)
  return %18
  
  # notice that block 3 is now gone
```

That's all great, but does it work?

In [38]:
gradient_4(conditional, 5, 3)

Dict{Any, Any}(5 => 1, 4 => 1, 6 => 2, 7 => 3, 8 => 3)Any[1, 2]Dict{Any, Any}(2 => 6, 3 => 8)
skipped 8
skipped 7
1: (%1, %2, %3)
  %4 = %2 * %3
  %5 = %2 >= %3
  br 1 unless %5
2:
  %6 = %2 * %2
  %9 = 1
  %10 = 0
  %11 = %9 * %2
  %12 = %10 + %11
  %13 = 0
  %14 = %9 * %2
  %15 = %12 + %14
  %16 = nothing
  %17 = Core.tuple(%6, %15, %16)
  return %17


(10, nothing)

In [39]:
Zygote.gradient(conditional, 5, 3)

(10.0, nothing)

Amazing! It does! But that conditional was kind of easy; it only had two cases, and its conditions and computations were simple. What if we push a bit harder?

In [40]:
function conditional_hard(a, b)
    c = a * b * sin(b) * sin(a)
    
    if a >= b * 3
        return a * a
        
    elseif a >= b
        c = a * b
        return c * sin(c)
        
    elseif a > sin(b)
        return sin(b)
        
    else
        e = sin(sin(sin(a)) * b) * a
        
    end    
end     

conditional_hard (generic function with 1 method)

In [43]:
Zygote.gradient(conditional_hard, 5, 1), 
Zygote.gradient(conditional_hard, 5, 3), 
Zygote.gradient(conditional_hard, 5, 10π), 
Zygote.gradient(conditional_hard, -3, 3)

((10.0, nothing), (-32.23509255817561, -53.725154263626024), (nothing, 1.0), (7.638088029360477, 0.38494624699166546))

In [45]:
gradient_4(conditional_hard, 5, 1), 
gradient_4(conditional_hard, 5, 3), 
gradient_4(conditional_hard, 5, 10π), 
gradient_4(conditional_hard, -3, 3)

Dict{Any, Any}(5 => 1, 16 => 6, 20 => 7, 12 => 4, 8 => 1, 17 => 7, 19 => 7, 6 => 1, 11 => 4, 9 => 2, 14 => 5, 7 => 1, 4 => 1, 13 => 4, 15 => 5, 21 => 7, 10 => 3, 18 => 7)Any[1, 2]Dict{Any, Any}(4 => 13, 6 => 16, 7 => 21, 2 => 9)
skipped 21
skipped 20
skipped 19
skipped 18
skipped 17
skipped 16
skipped 15
skipped 14
skipped 13
skipped 12
skipped 11
skipped 10
1: (%1, %2, %3)
  %4 = Main.sin(%3)
  %5 = Main.sin(%2)
  %6 = %2 * %3 * %4 * %5
  %7 = %3 * 3
  %8 = %2 >= %7
  br 1 unless %8
2:
  %9 = %2 * %2
  %22 = 1
  %23 = 0
  %24 = %22 * %2
  %25 = %23 + %24
  %26 = 0
  %27 = %22 * %2
  %28 = %25 + %27
  %29 = nothing
  %30 = Core.tuple(%9, %28, %29)
  return %30
Dict{Any, Any}(5 => 1, 16 => 6, 20 => 7, 12 => 4, 8 => 1, 17 => 7, 19 => 7, 6 => 1, 11 => 4, 9 => 2, 14 => 5, 7 => 1, 4 => 1, 13 => 4, 15 => 5, 21 => 7, 10 => 3, 18 => 7)Any[1, 3, 4]Dict{Any, Any}(4 => 13, 6 => 16, 7 => 21, 2 => 9)
skipped 21
skipped 20
skipped 19
skipped 18
skipped 17
skipped 16
skipped 15
skipped 14
skipped 9
1

((10, nothing), (-32.23509255817561, -53.725154263626024), (nothing, 1.0), (7.638088029360477, 0.38494624699166546))

The output, `((10, nothing), (-32.23509255817561, -53.725154263626024), (nothing, 1.0), (7.638088029360477, 0.38494624699166546))`, is exactly the same in all four cases! Hooray! (This also confirms that our original reverse AD code above is reasonably robust.)

### Conclusion
This project started out with the intention of writing an AD system like `Zygote.jl`. It's no fun to clone things *exactly*, so I figured that instead of returning a value and a gradient (pullback) function, as Zygote does, I'd just compute both the value and the gradient values inline and then return both. To do so, I'd manipulate a simplified syntax (SSA-form generated via `IRTools`), as Zygote does, but inject my gradient calculations between each line instead of all at the end.

This works well up to computing the derivative of functions composed of functions whose derivatives we know. But it falls apart when we try to compose functions with unknown derivatives, due to what I suspect is a limitation of `IRTools`.

Nevertheless, we continue to successfully implement extensibility and support for conditional returns.

When I started this project I was expecting to learn a lot about automatic differentiators. And in some sense I did--I now have a pretty good understanding of the design space tradeoffs and where Zygote (and several other popular AD systems) sit in it. 

But really I spent most of my time (a way larger % than expected) learning how to do metaprogramming and debugging a small number of extremely dense snippets (for example, the six lines of code inside `if ex.head == :call` took a veritable eternity to get correct.) Lots of quoting and unquoting issues in all directions.

I'm pretty happy with this, since metaprogramming is increasingly something that seems important for me to master, but probably a bit less actual AD meat involved than I originally expected. So it goes!

If you want to see my scratch work for any of this, there are several other Jupyter notebooks included in this repo for your browswing :).

### Quick note on effort allocation
I figured I'd write a quick note on where exactly my ~3 psets worth of time and effort went on this project, since some parts took longer than expected and some shorter than expected.

* Understanding how to implement reverse mode AD was pretty fast. I read up on it and did a few chalkboard examples, and then felt like I understood it. ~2 hours.
* Really understanding Julia's metaprogramming (quoting, `$()`, `eval`, etc) took ~2 hours. It's been a while since I've dealt with a lispy language.
* I spent 2-3 hours trying to inspect Zygote itself with a debugger to understand it, but this was a nightmare. One of the questions I'd like to ask someone in the Julia lab is how the pros inspect libraries in Julia. Generated functions, multiple dispatch, etc., made it extremely hard for me to track down who called whom. After lots of trying I gave up on this approach.
* I spent 2 hours reading and trying to understand the Zygote paper, which was actually pretty useful. In particular, when I started this project, I didn't really understand the design tradeoffs made by Zygote vs e.g. PyTorch.
* I spent ~8 hours learning `IRTools` inside out, and debugging various issues with it. For example, it doesn't correctly parse injected expressions like `Core.tuple(outs)`, since it thinks the dot is a function call, so instead we're using `IRTools.Inner.Statement(:($(GlobalRef(Core, :tuple))($(outs...),)))`, which took me forever to figure out. This was a much, much larger timesink than I expected or is probably obvious from looking at the project.
* There's probably an hour or two of learning how to use Julia at all (this is my first Julia project!).
* The real AD coding (most of what's in this notebook) probably took about 15 hours, exluding IRTools and Julia language time.
* I spent 1-2 total hours writing and another hour arranging code to make the history of the project clear to you, dear grader! :)

This seems roughly equivalent to three heavier psets of work, but I'm very happy with what I've practiced here and generally how I've spent this time, if you're happy with the output. Thanks for letting me branch out and do my own project!

### References

* https://arxiv.org/pdf/1810.07951.pdf. This is the original Zygote paper, and provides useful context on the tradeoffs Zygote makes in design space.
* https://fluxml.ai/Zygote.jl/latest/internals/. This page gives an overview of how Zygote works, and includes an example of a loop being unrolled.
* https://rufflewind.com/2016-12-30/reverse-mode-automatic-differentiation. This post helped me deeply understand how reverse-mode autodiff works.