# Some experiments with PyTorch FX.

In which we make a simple source-to-source autodiff tool.
See https://github.com/pytorch/pytorch/blob/master/torch/fx/OVERVIEW.md#the-fx-ir for
an overview of the FX IR.

In [11]:
## Import prerequisites, define utility functions
from collections import defaultdict

import operator

import torch
import torch.fx as fx

from icecream import ic 

from awfutils.pytree_utils import pt_rand_like, PyTree

from vjp_check import vjp_check

def ensure_tuple(x):
    if isinstance(x, tuple):
        return x
    return (x,)

## Shape propagation

This is an aside -- we need shapes for e.g. the gradient of 'trace', so let's
quickly assemble a solution, where we can call `fx_shape` on proxies.

Quick hack here, as we expect more thorough handling upstream:
 - https://discuss.pytorch.org/t/fx-proxy-and-shapes/113861/4
 - https://github.com/pytorch/pytorch/issues/54982
 - https://www.youtube.com/watch?v=pLni96jtcjY


In [12]:
from typing import Any
import torch.fx.passes

class AnnotatingInterpreter(torch.fx.Interpreter):
    """
    An FX Interpreter that attaches the original FX node to proxies.

    This allows annotations left by previous passes to be picked up, for example shapes
    """
    def run_node(self, n):
        val = super().run_node(n)
        val.fxi_node = n # Attach node to val
        return val

def fx_add_shapes(f_trace : torch.fx.GraphModule, sample_input : Any):
    """
    Run shape propagation on graph `f_trace`, which will add shape metadata in place.
    """
    torch.fx.passes.graph_manipulation.ShapeProp(f_trace).run(sample_input)

def fx_shape(x):
    """
    Return the shape of FX Proxy x.

    Assumes that ShapeProp has been run on the graph, so that x.fsi_node is set
    """
    return x.fxi_node.meta['tensor_meta'].shape



## FX_VJP: Source-to-source reverse mode algorithmic differentiation via FX 

An FX Interpreter that implements reverse-mode automatic differentiation.

It may help to recall the basic laws of AD.
Given function `f` which takes arbitrary pytree type `S`, and returns arbitary pytree `T`,
we define the vector-Jacobian product `vjp{f}` which takes types `(S, dT)` and returns type `dS` with `vjp{f}(s, dt) = dt * J{f}(s)`.
When `T` is a scalar or has only one element, then the VJP is the gradient,
so `grad{f}(s) = vjp{f}(s, 1.0)`.
```py
def f(s : S) -> T: ...
def vjp{f}(s : S, dt : dT) -> dS: ...
```
and as we generally divide the computation into
'forward' and 'backward' passes, returning the result of `f` in the forward pass,
as well as some auxiliary information of type `Aux{f}` 
```py
def fwd{f}(s : S) -> (T, Aux{f}): ...
def bwd{f}(aux : Aux{f}, dt : dT) -> dS: ...
```
in terms of which we could just write `vjp{f}` as
```py
def vjp{f}(s : S, dt : dT) -> dS:
  _t, Aux = fwd{f}(s)
  return bwd{f}(Aux, dt)
```

Here's an example for `sin`, where we provide two implementations, 
one likely to be more memory efficient, the other likely to be faster. 

In [23]:
# Example: sin
if True:
    # sin saves x in Aux - may save memory as x is likely preserved for backward pass
    def sin_fwd(x):
        return (torch.sin(x), x)
    def sin_bwd(aux_is_x,dret): 
        return torch.cos(aux_is_x) * dret
else:
    # sin saves cos(x) in Aux - may save compute if `sincos`` is faster than `sin` and `cos`
    def sin_fwd(x):
        ret, aux = torch.sincos(x)
        return ret, aux
    def sin_bwd(aux_is_cosx,dret): 
        return aux_is_cosx * dret

### Chain rule

An FX traced function is of the general form
```py
def foo(a1..an):
  t1 = f1(a1..an)
  t2 =  f2(a1..an,t1) # wlog, fk uses all in-scope variables
  ...
  tm =     fm(a1..an,t1..t{m-1}) 
  return tm
```
Then the VJP (vector-jacobian product) is of the form
```py
def foo_vjp(a1..an, dtm):
  # Forward pass
  t1, aux1               = f1_fwd(a1..an)
  t2, aux2               =   f2_fwd(a1..an,t1)
  ...
  tm, auxm               =      fm_fwd(a1..an,t1..t{m-1})

  # Backward pass
  da1..dan,dt1..dt{m-1} +=      fm_bwd(auxm, dtm)
  ...
  da{1..n},dt1          +=   f2_bwd(aux2, dt3)
  da{1..n}              += f1_bwd(aux1, dt1)

  return da{1..n}
```

So let's make a transformer `fx_vjp` that does that.



In [13]:

# A mapping from python function to (forward, backward)
ad_map = {}
ad_map[torch.sin] = sin_fwd, sin_bwd

def fx_vjp(f, sample_input):
    """
    An FX transform that implements reverse-mode automatic differentiation.

    If the traced function is of the form
    ```py
    def foo(a1..an):
      t1 = f1(a1..an)
      t2 =  f2(a1..an,t1) # wlog, fk uses all in-scope variables
      ...
      tm =   fm(a1..an,t1..t{m-1}) 
      return tm
    ```
    Then the VJP (vector-jacobian product is of the form)
    ```py
    def foo_vjp(a1..an, dtm):
      t1, aux1 = f1_fwd(a1..an)
      t2, aux2 = f2_fwd(a1..an,t1)
      ...
      tm, auxm = fm_fwd(a1..an,t1..{m-1})

      da{1..n},dt{1..m-1} += fm_bwd(auxm, dtm)
      ...
      da{1..n},dt1 += f2_bwd(aux2, dt3)
      da{1..n} += f1_bwd(aux1, dt1)

      return da{1..n}
    ```
    """

    class ADInterpreter(AnnotatingInterpreter):
        """
        This interpreter runs through the forward transformation, 
        replacing calls to `fk` with `fk_fwd = ad_map[fk][0]`,
        and recording the operations on a stack.
        """
        def __init__(self, f):
            super().__init__(f)
            self.stack = []

        def call_function(self, target, args, kwargs):
            assert kwargs == None or len(kwargs) == 0

            if target not in ad_map:
                raise NotImplementedError(f"Need VJP rule for {target}")
            # Look up forward/backward functions in `ad_map`
            fwd,bwd = ad_map[target]
            # Call the fwd function, getting proxies for returns
            val,aux = fwd(*args)
            # In the backward pass, we will compute:
            #  d[args[0]],...,d[args[-1]] = bwd(aux, d{val})
            # So remember: (args, bwd, aux, val)
            # Note that all of these are symbolic, so it's cheap to remember them
            self.stack.append((args, bwd, aux, val))
            # And return the return value (a proxy)
            return val

        def call_method(self, target, args, kwargs):
            raise NotImplementedError # use method_to_function

        def get_attr(self, target, args, kwargs):
            raise NotImplementedError # TODO

    # Grab the FX graph
    f_trace = torch.fx.symbolic_trace(f)
    
    # Run shape analysis, record answers in the graph
    fx_add_shapes(f_trace, sample_input)

    # This is the "template" function for the VJP
    def vjp_template(x, dret):
        # Run the forward computations, and collect them in ad.stack
        ad = ADInterpreter(f_trace)
        ret = ad.run(x)
        # Build a dict to hold derivatives
        d =  defaultdict(lambda: 0)
        # Add dret to derivatives dict
        d[ret] = dret
        # And run down the stack...
        for (args, bwd, aux, val) in reversed(ad.stack):
            dargs = bwd(aux, d[val])
            for (a,da) in zip(args, ensure_tuple(dargs)):
                d[a] += da
        # And return ret and J'*dret
        return ret, d[x]

    # Trace through vjp_template and return.
    return torch.fx.symbolic_trace(vjp_template)

## And now define the AD rules

These are "just python", nothing is built in.

In [16]:

import vjp_rules
from vjp_check import vjp_check

def vjp_linear(f):
    """
    Construct fwd and bwd for f a linear function of x
    """
    def fwd(*args): return f(*args), None
    def bwd(_, dret): return f(dret)
    return fwd, bwd


ad_map_aux = {
    operator.neg: vjp_linear(operator.neg),
    operator.add: (vjp_rules.add_fwd, vjp_rules.add_bwd),
    operator.mul: (vjp_rules.mul_fwd, vjp_rules.mul_bwd),
    operator.matmul: (vjp_rules.matmul_fwd, vjp_rules.matmul_bwd),
    torch.neg: vjp_linear(torch.neg),
    torch.sum: vjp_linear(torch.sum),
    torch.relu: (vjp_rules.relu_fwd, vjp_rules.relu_bwd),
    torch.transpose: (lambda x,m,n: (torch.transpose(x,m,n), (m,n)),
                      lambda aux,dret: torch.transpose(dret,*aux)),
    torch.diag: (vjp_rules.diag_fwd, vjp_rules.diag_bwd),
    vjp_rules.scale: (vjp_rules.scale_fwd, vjp_rules.scale_bwd),
}

ad_map = {**ad_map, **ad_map_aux}

def fx_trace_fwd(x):
    assert len(fx_shape(x)) == 2
    return torch.trace(x), fx_shape(x)

# Register custom passes for trace (as we needed to use fx_shape)
ad_map[torch.trace] = (fx_trace_fwd, vjp_rules.trace_bwd)

# And let's add some shape checking to add and mul, as we have it...
def fx_add_fwd(A,B):
    assert fx_shape(A) == fx_shape(B)
    return vjp_rules.add_fwd(A,B)
ad_map[operator.add] = (fx_add_fwd, vjp_rules.add_bwd)

def fx_mul_fwd(A,B):
    assert fx_shape(A) == fx_shape(B)
    return vjp_rules.mul_fwd(A,B)
ad_map[operator.mul] = (fx_mul_fwd, vjp_rules.mul_bwd)



In [17]:
# Function to vjp
def foo(x):
    w = torch.trace(x)
    w = torch.sin(w)
    a = vjp_rules.scale(w, x)
    return a

torch.manual_seed(42)

x = torch.randn(3,3)
foo_vjp = fx_vjp(foo, x)

dret = torch.randn_like(foo(x))
foo_vjp_pt = lambda x,dret: torch.autograd.functional.vjp(foo, x, dret)

PyTree.assert_close(foo_vjp_pt(x,dret), foo_vjp(x, dret))
print('VJPs match OK')

print(foo_vjp.code)

VJPs match OK

torch.fx._symbolic_trace.wrap("vjp_rules_scale")

def forward(self, x, dret):
    trace = torch.trace(x)
    sin = torch.sin(trace)
    scale = vjp_rules_scale(sin, x)
    mul = x * dret;  x = None
    sum_1 = torch.sum(mul);  mul = None
    scale_1 = vjp_rules_scale(sin, dret);  sin = dret = None
    add = 0 + sum_1;  sum_1 = None
    add_1 = 0 + scale_1;  scale_1 = None
    cos = torch.cos(trace);  trace = None
    mul_1 = cos * add;  cos = add = None
    add_2 = 0 + mul_1;  mul_1 = None
    _tensor_constant0 = self._tensor_constant0
    mul_2 = add_2 * _tensor_constant0;  add_2 = _tensor_constant0 = None
    add_3 = add_1 + mul_2;  add_1 = mul_2 = None
    return (scale, add_3)
    


In [18]:
# Manual VJP to compare to
def foo_vjp_manual(x, dret):
    w = torch.trace(x)
    w1 = torch.sin(w)
    ret = w1 * x
    
    dw1 = torch.sum(dret * x)
    dx = w1 * dret

    dw = torch.cos(w) * dw1

    dx += dw * torch.eye(*x.shape)
    
    return ret, dx

PyTree.assert_close(foo_vjp_manual(x, dret), foo_vjp_pt(x,dret))
print('VJPs match')

VJPs match



---------------
Misc below here

In [19]:

# Function to vjp
def bar(x):
    t = torch.transpose(x, 1, 0)
    v = operator.matmul(t, x)

    # w = torch.trace(v)
    # ret = vjp_rules.scale(w, v)
    
    return v

def bar_vjp_manual(x, dret):
    x1, x2 = x,x
    t = torch.transpose(x1, 1, 0)
    v,aux = vjp_rules.matmul_fwd(t, x2)

    v1,v2 = v,v
    w = torch.trace(v1)
    ret = w * v2
    
    dw = (dret * v2).sum()
    dv2 = w * dret
    dv1 = dw * torch.eye(*v1.shape)
    
    dv = dv1 + dv2
    dt,dx2 = vjp_rules.matmul_bwd(aux, dv)

    dx1 = torch.transpose(dt, 1, 0)
    dx = dx1 + dx2
    return ret, dx
dret = torch.randn_like(foo(x))

x = torch.randn(5,2)

bar_vjp = fx_vjp(bar, x)

dret = torch.randn_like(bar(x))
ic(bar_vjp_manual(x, dret))
ic(bar_vjp(x, dret))

vjp_check(bar, bar_vjp, x)

ic| bar_vjp_manual(x, dret): (tensor([[ 8.3484, -3.1498],
                                     [-3.1498,  3.7344]]),
                              tensor([[ 0.5019, -0.6084],
                                     [-1.3575,  1.0871],
                                     [-1.5760,  1.8188],
                                     [-4.3323, -0.3791],
                                     [ 2.5767, -1.5431]]))
ic| bar_vjp(x, dret): (tensor([[ 2.4017, -0.9061],
                              [-0.9061,  1.0743]]),
                       tensor([[ 0.0694, -0.1907],
                              [-0.2250,  0.3101],
                              [-0.2241,  0.5651],
                              [-0.9746, -0.4284],
                              [ 0.4618, -0.3968]]))


VJP OK: <function bar at 0x7f085d69ff40>


In [20]:
foo_vjp.print_readable()

class vjp_template(torch.nn.Module):
    torch.fx._symbolic_trace.wrap("vjp_rules_scale")
    
    def forward(self, x, dret):
        
        # No stacktrace found for following nodes 
        trace = torch.trace(x)
        sin = torch.sin(trace)
        scale = vjp_rules_scale(sin, x)
        mul = x * dret;  x = None
        sum_1 = torch.sum(mul);  mul = None
        scale_1 = vjp_rules_scale(sin, dret);  sin = dret = None
        add = 0 + sum_1;  sum_1 = None
        add_1 = 0 + scale_1;  scale_1 = None
        cos = torch.cos(trace);  trace = None
        mul_1 = cos * add;  cos = add = None
        add_2 = 0 + mul_1;  mul_1 = None
        _tensor_constant0 = self._tensor_constant0
        mul_2 = add_2 * _tensor_constant0;  add_2 = _tensor_constant0 = None
        add_3 = add_1 + mul_2;  add_1 = mul_2 = None
        return (scale, add_3)
        


## Using FX IR to print source code of functorch.jacrev(f)

This is rather more low-level than the AD above, as it reflects the operations that hit the torch dispatcher.
This also means it is size-specialized. 

In [21]:
from functorch import make_fx, grad, vjp, jacrev
def f(x):
    return torch.sin(x)
x = torch.randn(13)
grad_f = make_fx(jacrev(f))(x)
print(grad_f.code)




def forward(self, x_1):
    sin = torch.ops.aten.sin.default(x_1)
    _tensor_constant0 = self._tensor_constant0
    lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0);  _tensor_constant0 = None
    cumsum = torch.ops.aten.cumsum.default(lift_fresh_copy, 0);  lift_fresh_copy = None
    slice_1 = torch.ops.aten.slice.Tensor(cumsum, 0, 0, -1);  cumsum = None
    neg = torch.ops.aten.neg.default(slice_1);  slice_1 = None
    unbind = torch.ops.aten.unbind.int(neg);  neg = None
    new_zeros = torch.ops.aten.new_zeros.default(sin, [13, 13], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False);  sin = None
    diagonal = torch.ops.aten.diagonal.default(new_zeros)
    fill_ = torch.ops.aten.fill_.Scalar(diagonal, 1);  diagonal = None
    view = torch.ops.aten.view.default(new_zeros, [13, 13]);  new_zeros = None
    cos = torch.ops.aten.cos.default(x_1);  x_1 = None
    mul = torch.ops.aten.mul.Tensor(view, cos);  view = c

For comparison, the FX AD version is closer to what one might write by hand:

In [41]:
print(fx_vjp(f, x).code)




def forward(self, x, dret):
    sin = torch.sin(x)
    cos = torch.cos(x);  x = None
    mul = cos * dret;  cos = dret = None
    add = 0 + mul;  mul = None
    return (sin, add)
    


# Trying to get functorch vjp working..

In [31]:
import functorch
from functorch import make_fx, grad, vjp, jacrev
def f(x):
    return torch.sin(x)
x = torch.randn(10)
f_vjp = functorch.vjp(f,x)[1]
grad_f = make_fx(f_vjp)(x, x, 1)
print(grad_f(x))
#print(torch.fx.symbolic_trace(grad_f)(x).code)


TypeError: only integer tensors of a single element can be converted to an index