# DiffFX Derivation

This notebook is a walkthrough of the initial design and implementation of difffx,
a simple source-to-source autodiff tool using Torch FX.
See https://github.com/pytorch/pytorch/blob/master/torch/fx/OVERVIEW.md#the-fx-ir for
an overview of the FX IR.

See `difffx.py` for the actual implementation, which covers cases not shown here,
particularly:
 - Dealing with shapes (e.g. the derivative of `trace` needs to know the argument size)
 - Dealing with method calls (this code just handles function calls)

In [1]:
## Import prerequisites, define utility functions
from collections import defaultdict
import torch
import torch.fx as fx

from awfutils.pytree_utils import PyTree

from vjp_check import vjp_check


## FX_VJP: Source-to-source reverse mode algorithmic differentiation via FX 

An FX Interpreter that implements reverse-mode automatic differentiation.

It may help to recall the basic laws of AD.
Given function `f` which takes arbitrary pytree type `S`, and returns arbitary pytree `T`,
we define the vector-Jacobian product `vjp{f}` which takes types `(S, dT)` and returns type `dS` with `vjp{f}(s, dt) = dt * J{f}(s)`.
When `T` is a scalar or has only one element, then the VJP is the gradient,
so `grad{f}(s) = vjp{f}(s, 1.0)`.
```py
def f(s : S) -> T: 
  ...
def vjp{f}(s : S, dt : dT) -> dS: 
  ...
```
and as we generally divide the computation into
'forward' and 'backward' passes, returning the result of `f` in the forward pass,
as well as some auxiliary information of type `Aux{f}` 
```py
def fwd{f}(s : S) -> (T, Aux{f}): 
  ...
def bwd{f}(aux : Aux{f}, dt : dT) -> dS:
  ...
```
in terms of which we could just write `vjp{f}` as
```py
def vjp{f}(s : S, dt : dT) -> dS:
  _t, Aux = fwd{f}(s)
  return bwd{f}(Aux, dt)
```

Here are examples for `add`, `mul`, and `matmul`:

In [2]:
# If they were scalars: d/da(a+b) = 1, d/db(a+b) = 1
def add_fwd(a, b):
    return a + b, None # Aux is empty

def add_bwd(aux, dret):
    return (dret, dret)

# If they were scalars: d/da(a*b) = b, d/db(a*b) = a
def mul_fwd(a, b):
    return a * b, (a,b) # Aux needs to remember both a and b

def mul_bwd(aux, dret):
    a,b = aux # Unpack the inputs
    return (dret*b, a*dret)

# And for matmul, the sizes need to line up
# MxK @ KxN -> MxN
def matmul_fwd(a, b):
    return a @ b, (a,b) # Aux needs to remember both a and b

def matmul_bwd(aux, dret):
    a,b = aux # Unpack the inputs
    # dret is MxN, a is MxK, b is KxN
    # dA should be MxK, dB should be KxN
    da = dret @ b.t()
    db = a.t() @ dret
    return (da, db)


## Just checking the above are correct
from vjp_check import vjp_check_fwdbwd
vjp_check_fwdbwd(torch.add, add_fwd, add_bwd, (torch.randn(3,2), torch.randn(3,2)))
vjp_check_fwdbwd(torch.mul, mul_fwd, mul_bwd, (torch.randn(3,2), torch.randn(3,2)))
vjp_check_fwdbwd(torch.mm, matmul_fwd, matmul_bwd, (torch.randn(3,2), torch.randn(2,5)))
print('All asserts passed')

All asserts passed


### A choice of VJPs?

Here's a more interesting example for `sin`, where we provide two implementations, 
one likely to be more memory efficient, the other likely to be faster.
How to choose automatically between them is a question for another day. 

In [3]:
# Example: sin
OPTIMIZE_FOR_MEMORY = True
if OPTIMIZE_FOR_MEMORY:
    # sin saves x in Aux - may save memory as x is likely preserved for backward pass
    def sin_fwd(x):
        return (torch.sin(x), x)
    def sin_bwd(aux_is_x,dret): 
        return torch.cos(aux_is_x) * dret
else:
    # sin saves cos(x) in Aux - may save compute if `sincos`` is faster than `sin` and `cos`
    def sin_fwd(x):
        ret, aux = torch.sincos(x)
        return ret, aux
    def sin_bwd(aux_is_cosx,dret): 
        return aux_is_cosx * dret

vjp_check_fwdbwd(torch.sin, sin_fwd, sin_bwd, (torch.randn(3,2),))


### Chain rule

An FX traced function is of the general form
```py
def foo(a1..an):
  t1 = f1(a1..an)
  t2 =  f2(a1..an,t1) # wlog, fk uses all in-scope variables
  ...
  tm =     fm(a1..an,t1..t{m-1}) 
  return tm
```
Then the VJP (vector-jacobian product) is of the form
```py
def foo_vjp(a1..an, dtm):
  # Forward pass
  t1, aux1               = f1_fwd(a1..an)
  t2, aux2               =   f2_fwd(a1..an,t1)
  ...
  tm, auxm               =      fm_fwd(a1..an,t1..t{m-1})

  # Backward pass
  da1..dan,dt1..dt{m-1} +=      fm_bwd(auxm, dtm)
  ...
  da{1..n},dt1          +=   f2_bwd(aux2, dt3)
  da{1..n}              += f1_bwd(aux1, dt1)

  return da{1..n}
```

So let's make a transformer `fx_vjp` that does that.



In [None]:

# A mapping from python function to (forward, backward)
ad_map = {}

# Register the AD functions from above (lots more are done in vjp_rules.py)
import operator
ad_map[operator.add] = add_fwd, add_bwd
ad_map[operator.mul] = mul_fwd, mul_bwd
ad_map[operator.matmul] = matmul_fwd, matmul_bwd
ad_map[torch.sin] = sin_fwd, sin_bwd

def fx_vjp(f, sample_input):
    """
    An FX transform that implements reverse-mode automatic differentiation.
    """

    class ADInterpreter(torch.fx.Interpreter):
        """
        This interpreter runs through the forward transformation, 
        replacing calls to `fk` with `fk_fwd = ad_map[fk][0]`,
        and recording the operations on a stack.
        """
        def __init__(self, f):
            super().__init__(f)
            self.stack = []

        def call_function(self, target, args, kwargs):
            assert kwargs == None or len(kwargs) == 0

            if target not in ad_map:
                raise NotImplementedError(f"Need VJP rule for {target}")
            # Look up forward/backward functions in `ad_map`
            fwd,bwd = ad_map[target]
            # Call the fwd function, getting proxies for returns
            val,aux = fwd(*args)
            # In the backward pass, we will compute:
            #  d[args[0]],...,d[args[-1]] = bwd(aux, d{val})
            # So remember: (args, bwd, aux, val)
            # Note that all of these are symbolic, so it's cheap to remember them
            self.stack.append((args, bwd, aux, val))
            # And return the return value (a proxy)
            return val

        def call_method(self, target, args, kwargs):
            raise NotImplementedError # see difffx.py for how to implement this

        def get_attr(self, target, args, kwargs):
            raise NotImplementedError # see difffx.py for how to implement this

    # Grab the FX graph
    f_trace = torch.fx.symbolic_trace(f)
    
    # Run torch's shape analysis, record answers in the graph
    torch.fx.passes.graph_manipulation.ShapeProp(f_trace).run(sample_input)

    # This is the "template" function for the VJP - symbolically tracing this template
    # will generate the VJP function.
    def vjp_template(x, dret):
        # Run the forward computations, and collect them in ad.stack
        ad = ADInterpreter(f_trace)
        ret = ad.run(x)
        # Build a dict to hold derivatives
        d =  defaultdict(lambda: 0)
        # Add dret to derivatives dict
        d[ret] = dret
        # And run down the stack...
        for (args, bwd, aux, val) in reversed(ad.stack):
            dargs = bwd(aux, d[val])
            for (a,da) in zip(args, dargs if isinstance(dargs, tuple) else (dargs,)):
                d[a] += da
        # And return ret and J'*dret
        return ret, d[x]

    # Trace through vjp_template and return.
    return torch.fx.symbolic_trace(vjp_template)

## ... and let's try it out

We define a function `foo`, compute its vjp using DiffFX, and compare to `autograd.functional.vjp`.

In [8]:
# Function to vjp
def foo(x):
    w = torch.sin(x)
    y = w * x + x
    z = w * torch.sin(x)
    return y + z

torch.manual_seed(42)

x = torch.randn(3,3)
foo_vjp = fx_vjp(foo, x)

dret = torch.randn_like(foo(x))
foo_vjp_pt = lambda x,dret: torch.autograd.functional.vjp(foo, x, dret)

PyTree.assert_close(foo_vjp(x,dret), foo_vjp_pt(x, dret))
print('VJP matches PyTorch')

VJP matches PyTorch


And now let's show that it's a real source-to-source transformation, by printing the code for the VJP.

(Note there is nicer printing in `fx_print`, this is just to make this notebook more colabbable )

In [9]:
foo_vjp.print_readable();

class vjp_template(torch.nn.Module):
    def forward(self, x, dret):
        # No stacktrace found for following nodes
        sin = torch.sin(x)
        mul = sin * x
        add = mul + x;  mul = None
        sin_1 = torch.sin(x)
        mul_1 = sin * sin_1
        add_1 = add + mul_1;  add = mul_1 = None
        add_2 = 0 + dret
        add_3 = 0 + dret;  dret = None
        mul_2 = add_3 * sin_1;  sin_1 = None
        mul_3 = sin * add_3;  add_3 = None
        add_4 = 0 + mul_2;  mul_2 = None
        add_5 = 0 + mul_3;  mul_3 = None
        cos = torch.cos(x)
        mul_4 = cos * add_5;  cos = add_5 = None
        add_6 = 0 + mul_4;  mul_4 = None
        add_7 = 0 + add_2
        add_8 = add_6 + add_2;  add_6 = add_2 = None
        mul_5 = add_7 * x
        mul_6 = sin * add_7;  sin = add_7 = None
        add_9 = add_4 + mul_5;  add_4 = mul_5 = None
        add_10 = add_8 + mul_6;  add_8 = mul_6 = None
        cos_1 = torch.cos(x);  x = None
        mul_7 = cos_1 * add_9;  cos_1 =

## Compare to PyTorch's jacrev

Here's the output from PyTorch's `jacrev`.  Because it is implemented at a 
slightly lower level, the output is arguably less easy to understand and to 
modify if that were one's goal.  On the other hand, PyTorch's implementation
deals with huge numbers of cases that ours does not, emphasizing that this 
is primarily a pedagocical exercise. 

In [13]:
import functorch

grad_f = functorch.make_fx(torch.func.jacrev(foo))(x)
grad_f.print_readable();

class foo(torch.nn.Module):
    def forward(self, x_1: "f32[3, 3]"):
        # No stacktrace found for following nodes
        sin: "f32[3, 3]" = torch.ops.aten.sin.default(x_1)
        mul: "f32[3, 3]" = torch.ops.aten.mul.Tensor(sin, x_1)
        add: "f32[3, 3]" = torch.ops.aten.add.Tensor(mul, x_1);  mul = None
        sin_1: "f32[3, 3]" = torch.ops.aten.sin.default(x_1)
        mul_1: "f32[3, 3]" = torch.ops.aten.mul.Tensor(sin, sin_1)
        add_1: "f32[3, 3]" = torch.ops.aten.add.Tensor(add, mul_1);  add = mul_1 = None
        _tensor_constant0 = self._tensor_constant0
        lift_fresh_copy: "i64[1]" = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0);  _tensor_constant0 = None
        cumsum: "i64[1]" = torch.ops.aten.cumsum.default(lift_fresh_copy, 0);  lift_fresh_copy = None
        slice_1: "i64[0]" = torch.ops.aten.slice.Tensor(cumsum, 0, 0, -1);  cumsum = None
        neg: "i64[0]" = torch.ops.aten.neg.default(slice_1);  slice_1 = None
        unbind = torch.ops

## Conclusion

This notebook shows the from-scratch derivation of an automatic differentiation
transformation in torch FX.  The interested reader should look at `difffx.py` to 
see more details.