### What

In this notebook, we build a simple prototype of source-to-source reverse-mode automatic differentiation. This is the same strategy used by `Zygote.jl`.

We opt to implement **reverse mode** automatic differentiation because it efficiently computes the gradient of functions with a (very) large number of inputs and a single output. This kind of problem shows up in all kinds of optimization, most notably machine learning.

We also choose to do **source-to-source** transformation. This is substantially more complex to implement than dynamic systems, which build up a computation graph and then differentiate through it on every program run. The main benefit is that our gradient computation can be compiled ahead of time and optimized down by LLVM, instead of (expensively) re-built and (poorly) optimized every time.

This is also better than trace-based systems, which usually end up with a huge number of traced operations. Since there are so many operations, it's infeasible to analyze with optimization algorithms requiring `O(N^2)` time, and we're unable to optimize effectively anyway. (source: Zygote paper)

### How

### Version 1

We begin by building a program to automatically differentiate `simple_math`.

In [1]:
function simple_math(x, y)
    x * y + sin(x)
end

simple_math (generic function with 1 method)

Our automatic differentiator operates at the SSA-level: it deals with a simplified structure, called static single-assignment form, where variables can only be assigned once, and control flow is simplified into only simple branch/goto statements.

Julia does this pass [internally](https://docs.julialang.org/en/v1/devdocs/ssair/), but doesn't expose the details, so we use the `IRTools.jl` library instead. This is the same one used by Zygote.

Converting `simple_math` to SSA form using `IRTools.jl`, we have:

In [3]:
using IRTools

In [4]:
IRTools.@code_ir simple_math(3, 5)

1: (%1, %2, %3)
  %4 = %2 * %3
  %5 = Main.sin(%2)
  %6 = %4 + %5
  return %6

For our first version, our algorithm is as follows:
* Find the return value (`%6` here), and set its adjoint to 1.
* Iterate through the program backwards and compute the adjoints `gN` for each `%N`.

### References

* https://arxiv.org/pdf/1810.07951.pdf
* https://fluxml.ai/Zygote.jl/latest/internals/
* https://rufflewind.com/2016-12-30/reverse-mode-automatic-differentiation