# DiffFX Derivation

This notebook is a walkthrough of the initial design and implementation of difffx,
a simple source-to-source autodiff tool using Torch FX.
See https://github.com/pytorch/pytorch/blob/master/torch/fx/OVERVIEW.md#the-fx-ir for
an overview of the FX IR.

In [1]:
## Import prerequisites, define utility functions
from collections import defaultdict

import operator

import torch
import torch.fx as fx

from icecream import ic 

from awfutils.pytree_utils import PyTree

from vjp_check import vjp_check

def ensure_tuple(x):
    if isinstance(x, tuple):
        return x
    return (x,)

## FX_VJP: Source-to-source reverse mode algorithmic differentiation via FX 

An FX Interpreter that implements reverse-mode automatic differentiation.

It may help to recall the basic laws of AD.
Given function `f` which takes arbitrary pytree type `S`, and returns arbitary pytree `T`,
we define the vector-Jacobian product `vjp{f}` which takes types `(S, dT)` and returns type `dS` with `vjp{f}(s, dt) = dt * J{f}(s)`.
When `T` is a scalar or has only one element, then the VJP is the gradient,
so `grad{f}(s) = vjp{f}(s, 1.0)`.
```py
def f(s : S) -> T: 
  ...
def vjp{f}(s : S, dt : dT) -> dS: 
  ...
```
and as we generally divide the computation into
'forward' and 'backward' passes, returning the result of `f` in the forward pass,
as well as some auxiliary information of type `Aux{f}` 
```py
def fwd{f}(s : S) -> (T, Aux{f}): 
  ...
def bwd{f}(aux : Aux{f}, dt : dT) -> dS:
  ...
```
in terms of which we could just write `vjp{f}` as
```py
def vjp{f}(s : S, dt : dT) -> dS:
  _t, Aux = fwd{f}(s)
  return bwd{f}(Aux, dt)
```

Here's an example for `sin`, where we provide two implementations, 
one likely to be more memory efficient, the other likely to be faster. 

In [2]:
# Example: sin
if True:
    # sin saves x in Aux - may save memory as x is likely preserved for backward pass
    def sin_fwd(x):
        return (torch.sin(x), x)
    def sin_bwd(aux_is_x,dret): 
        return torch.cos(aux_is_x) * dret
else:
    # sin saves cos(x) in Aux - may save compute if `sincos`` is faster than `sin` and `cos`
    def sin_fwd(x):
        ret, aux = torch.sincos(x)
        return ret, aux
    def sin_bwd(aux_is_cosx,dret): 
        return aux_is_cosx * dret

### Chain rule

An FX traced function is of the general form
```py
def foo(a1..an):
  t1 = f1(a1..an)
  t2 =  f2(a1..an,t1) # wlog, fk uses all in-scope variables
  ...
  tm =     fm(a1..an,t1..t{m-1}) 
  return tm
```
Then the VJP (vector-jacobian product) is of the form
```py
def foo_vjp(a1..an, dtm):
  # Forward pass
  t1, aux1               = f1_fwd(a1..an)
  t2, aux2               =   f2_fwd(a1..an,t1)
  ...
  tm, auxm               =      fm_fwd(a1..an,t1..t{m-1})

  # Backward pass
  da1..dan,dt1..dt{m-1} +=      fm_bwd(auxm, dtm)
  ...
  da{1..n},dt1          +=   f2_bwd(aux2, dt3)
  da{1..n}              += f1_bwd(aux1, dt1)

  return da{1..n}
```

So let's make a transformer `fx_vjp` that does that.



In [3]:

# A mapping from python function to (forward, backward)
ad_map = {}

# Register the sin functions as above
ad_map[torch.sin] = sin_fwd, sin_bwd

# Get shnty tracers
from fx_shnty import shnty_trace, abstractify, get_return_abstract_value

def fx_vjp(f, sample_input):
    """
    An FX transform that implements reverse-mode automatic differentiation.

    If the traced function is of the form
    ```py
    def foo(a1..an):
      t1 = f1(a1..an)
      t2 =  f2(a1..an,t1) # wlog, fk uses all in-scope variables
      ...
      tm =   fm(a1..an,t1..t{m-1}) 
      return tm
    ```
    Then the VJP (vector-jacobian product is of the form)
    ```py
    def foo_vjp(a1..an, dtm):
      t1, aux1 = f1_fwd(a1..an)
      t2, aux2 = f2_fwd(a1..an,t1)
      ...
      tm, auxm = fm_fwd(a1..an,t1..{m-1})

      da{1..n},dt{1..m-1} += fm_bwd(auxm, dtm)
      ...
      da{1..n},dt1 += f2_bwd(aux2, dt3)
      da{1..n} += f1_bwd(aux1, dt1)

      return da{1..n}
    ```
    """

    class ADInterpreter(torch.fx.Interpreter):
        """
        This interpreter runs through the forward transformation, 
        replacing calls to `fk` with `fk_fwd = ad_map[fk][0]`,
        and recording the operations on a stack.
        """
        def __init__(self, f):
            super().__init__(f)
            self.stack = []

        def call_function(self, target, args, kwargs):
            assert kwargs == None or len(kwargs) == 0

            if target not in ad_map:
                raise NotImplementedError(f"Need VJP rule for {target}")
            # Look up forward/backward functions in `ad_map`
            fwd,bwd = ad_map[target]
            # Call the fwd function, getting proxies for returns
            val,aux = fwd(*args)
            # In the backward pass, we will compute:
            #  d[args[0]],...,d[args[-1]] = bwd(aux, d{val})
            # So remember: (args, bwd, aux, val)
            # Note that all of these are symbolic, so it's cheap to remember them
            self.stack.append((args, bwd, aux, val))
            # And return the return value (a proxy)
            return val

        def call_method(self, target, args, kwargs):
            raise NotImplementedError # use method_to_function

        def get_attr(self, target, args, kwargs):
            raise NotImplementedError # TODO

    # Grab the FX graph, using shnty_trace to get shapes
    f_trace = shnty_trace(f, sample_input)
    
    # This is the "template" function for the VJP
    def vjp_template(x, dret):
        # Run the forward computations, and collect them in ad.stack
        ad = ADInterpreter(f_trace)
        ret = ad.run(x)
        # Build a dict to hold derivatives
        d =  defaultdict(lambda: 0)
        # Add dret to derivatives dict
        d[ret] = dret
        # And run down the stack...
        for (args, bwd, aux, val) in reversed(ad.stack):
            dargs = bwd(aux, d[val])
            for (a,da) in zip(args, ensure_tuple(dargs)):
                d[a] += da
        # And return ret and J'*dret
        return ret, d[x]

    # Trace through vjp_template and return.
    return shnty_trace(vjp_template, sample_input + (get_return_abstract_value(f_trace),))

## And now define the AD rules

These are "just python", nothing is built in. We define VJPs for just a couple of primitives here, and with minimal error checking.
See `vjp_rules.py` for more primitives (e.g. `mul` rather than the special case of `scale` here.)

In [4]:


# Trace
def trace_fwd(x):
    return torch.trace(x), x.shape


def trace_bwd(x_shape, dret):
    return dret * torch.eye(*x_shape)

ad_map[torch.trace] = (trace_fwd, trace_bwd)


# Add
def add_fwd(A, B):
    assert A.shape == B.shape
    return A + B, None


def add_bwd(aux, dret):
    return dret, dret

ad_map[operator.add] = (add_fwd, add_bwd)

# Special case of mul (scalar * Tensor)
def mul_fwd(a, T):
    return a * T, (a, T)

def mul_bwd(aux, dret):
    # T: mxn
    # dret: mxn
    a, T = aux
    da = torch.sum(T * dret)
    dT = a * dret
    return da, dT

ad_map[operator.mul] = (mul_fwd, mul_bwd)


## ... and let's try it out

We define a function `foo`, compute its vjp using DiffFX, and compare to `autograd.functional.vjp`.

In [8]:
# Function to vjp
def foo(x):
    h = torch.trace(x)
    w = torch.sin(h)
    y = w * x
    z = w * torch.sin(x)
    return y + z

import fx_shnty
fx_shnty._log = lambda x:...
torch.manual_seed(42)

x = torch.randn(3,3)
foo_vjp = fx_vjp(foo, (abstractify(x),))

dret = torch.randn_like(foo(x))
foo_vjp_pt = lambda x,dret: torch.autograd.functional.vjp(foo, x, dret)

PyTree.assert_close(foo_vjp(x,dret), foo_vjp_pt(x, dret))
print('VJP matches PT')

VJP matches PT


And now let's show that it's a real source-to-source transformation, by printing the code for the VJP

In [9]:
from fx_print import fx_print
fx_print(foo_vjp)

def vjp_template(x,dret):
  v10 = x
  v11 = dret
  v12 = torch.trace(v10)
  v13 = torch.sin(v12)
  v14 = mul(v13,v10)
  v15 = torch.sin(v10)
  v16 = mul(v13,v15)
  v17 = add(v14,v16)
  v18 = add(int(0),v11)
  v19 = add(int(0),v11)
  v20 = mul(v15,v19)
  v21 = torch.sum(v20)
  v22 = mul(v13,v19)
  v23 = add(int(0),v21)
  v24 = add(int(0),v22)
  v25 = torch.cos(v10)
  v26 = mul(v25,v24)
  v27 = add(int(0),v26)
  v28 = mul(v10,v18)
  v29 = torch.sum(v28)
  v30 = mul(v13,v18)
  v31 = add(v23,v29)
  v32 = add(v27,v30)
  v33 = torch.cos(v12)
  v34 = mul(v33,v31)
  v35 = add(int(0),v34)
  v36 = _tensor_constant0 # tensor([[1., 0., 0.],\n        [0., 1., 
  v37 = mul(v35,v36)
  v38 = add(v32,v37)
  return (v17,v38)


## Using FX IR to print source code of functorch.jacrev(f)

This is rather more low-level than the AD above, as it reflects the operations that hit the torch dispatcher.
Let's first print the DiffFX gradient for a simple function

In [10]:
def f(x):
    z = x + x
    y = torch.trace(z)
    return torch.sin(y)

x = torch.randn(3,5)
f_vjp = fx_vjp(f, (abstractify(x),))

dret = torch.rand_like(f(x))
print(f_vjp(x, dret))

fx_print(f_vjp)

(tensor(-0.8336), tensor([[0.2229, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.2229, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.2229, 0.0000, 0.0000]]))
def vjp_template(x,dret):
  v10 = x
  v11 = dret
  v12 = add(v10,v10)
  v13 = torch.trace(v12)
  v14 = torch.sin(v13)
  v15 = torch.cos(v13)
  v16 = mul(v15,v11)
  v17 = add(int(0),v16)
  v18 = _tensor_constant0 # tensor([[1., 0., 0., 0., 0.],\n        [
  v19 = mul(v17,v18)
  v20 = add(int(0),v19)
  v21 = add(int(0),v20)
  v22 = add(v21,v20)
  return (v14,v22)


And then the gradient from jacrev:

In [11]:
import functorch

grad_f = functorch.make_fx(functorch.jacrev(f))(x)
fx_print(grad_f)

def f(x_1):
  v10 = x_1
  v11 = torch._ops.aten.add.Tensor(v10,v10)
  v12 = torch._ops.aten.trace.default(v11)
  v13 = torch._ops.aten.sin.default(v12)
  v14 = _tensor_constant0 # tensor([1])
  v15 = torch._ops.aten.lift_fresh_copy.default(v14)
  v16 = torch._ops.aten.cumsum.default(v15,int(0))
  v17 = torch._ops.aten.slice.Tensor(v16,int(0),int(0),int(-1))
  v18 = torch._ops.aten.neg.default(v17)
  v19 = torch._ops.aten.unbind.int(v18)
  v20 = torch._ops.aten.new_zeros.default(v13,[1, 1][<class 'torch.fx.immutable_collections.immutable_list'>])
  v21 = torch._ops.aten.diagonal.default(v20)
  v22 = torch._ops.aten.fill_.Scalar(v21,int(1))
  v23 = torch._ops.aten.view.default(v20,[1][<class 'torch.fx.immutable_collections.immutable_list'>])
  v24 = torch._ops.aten.cos.default(v12)
  v25 = torch._ops.aten.select.int(v23,int(0),int(0))
  v26 = torch._ops.aten.mul.Tensor(v23,v24)
  v27 = torch._ops.aten.zeros.default([15][<class 'torch.fx.immutable_collections.immutable_list'>])
  v28 = to