In [1]:
%load_ext autoreload
%autoreload 2

# Some experiments with DiffFX

Some explorations of difffx

In [2]:
import torch
from fx_print import fx_print
from difffx import vjp as fx_vjp, fx_add_shapes
import vjp_rules # TODO: don't require this
from awfutils.pytree_utils import PyTree

# Function to vjp
def foo(x):
    w = torch.trace(x)
    w = torch.sin(w)
    a = w * x
    return a

torch.manual_seed(42)

x = torch.randn(3,3)
foo_vjp = fx_vjp(foo, x)

dret = torch.randn_like(foo(x))
foo_vjp_pt = lambda x,dret: torch.autograd.functional.vjp(foo, x, dret)

PyTree.assert_close(foo_vjp_pt(x,dret), foo_vjp(x, dret))
print('VJPs match OK')

fx_add_shapes(foo_vjp, (x, dret))
fx_print(foo_vjp)

VJPs match OK
def vjp_template(x,dret):
  v10 = x[32m # Tensor[3x3, torch.float32][0m
  v11 = dret[32m # Tensor[3x3, torch.float32][0m
  v12 = torch.trace(v10)[32m # Tensor[(), torch.float32][0m
  v13 = torch.sin(v12)[32m # Tensor[(), torch.float32][0m
  v14 = mul(v13,v10)[32m # Tensor[3x3, torch.float32][0m
  v15 = mul(v10,v11)[32m # Tensor[3x3, torch.float32][0m
  v16 = mul(v13,v11)[32m # Tensor[3x3, torch.float32][0m
  v17 = v15.sum((0,1))[32m # Tensor[(), torch.float32][0m
  v18 = v17.reshape(())[32m # Tensor[(), torch.float32][0m
  v19 = v16.reshape((3,3))[32m # Tensor[3x3, torch.float32][0m
  v20 = add(0,v18)[32m # Tensor[(), torch.float32][0m
  v21 = add(0,v19)[32m # Tensor[3x3, torch.float32][0m
  v22 = torch.cos(v12)[32m # Tensor[(), torch.float32][0m
  v23 = mul(v22,v20)[32m # Tensor[(), torch.float32][0m
  v24 = add(0,v23)[32m # Tensor[(), torch.float32][0m
  v25 = self._tensor_constant0[32m # Tensor[3x3, torch.float32]f32[3x3] [[1.000 0.000 0.

In [3]:
# Manual VJP to compare to
def foo_vjp_manual(x, dret):
    w = torch.trace(x)
    w1 = torch.sin(w)
    ret = w1 * x
    
    dw1 = torch.sum(dret * x)
    dx = w1 * dret

    dw = torch.cos(w) * dw1

    dx += dw * torch.eye(*x.shape)
    
    return ret, dx

PyTree.assert_close(foo_vjp_manual(x, dret), foo_vjp_pt(x,dret))
print('VJPs match')

VJPs match


In [4]:
foo_vjp.print_readable();

class vjp_template(torch.nn.Module):
    def forward(self, x: "f32[3, 3]", dret: "f32[3, 3]"):
        # No stacktrace found for following nodes
        trace: "f32[]" = torch.trace(x)
        sin: "f32[]" = torch.sin(trace)
        mul: "f32[3, 3]" = sin * x
        mul_1: "f32[3, 3]" = x * dret;  x = None
        mul_2: "f32[3, 3]" = sin * dret;  sin = dret = None
        sum_1: "f32[]" = mul_1.sum((0, 1));  mul_1 = None
        reshape: "f32[]" = sum_1.reshape(());  sum_1 = None
        reshape_1: "f32[3, 3]" = mul_2.reshape((3, 3));  mul_2 = None
        add: "f32[]" = 0 + reshape;  reshape = None
        add_1: "f32[3, 3]" = 0 + reshape_1;  reshape_1 = None
        cos: "f32[]" = torch.cos(trace);  trace = None
        mul_3: "f32[]" = cos * add;  cos = add = None
        add_2: "f32[]" = 0 + mul_3;  mul_3 = None
        _tensor_constant0: "f32[3, 3]" = self._tensor_constant0
        mul_4: "f32[3, 3]" = add_2 * _tensor_constant0;  add_2 = _tensor_constant0 = None
        add_3: "

## Using FX IR to print source code of functorch.jacrev(f)

This is rather more low-level than the AD above, as it reflects the operations that hit the torch dispatcher.
This also means it is size-specialized. 

In [9]:
from functorch import make_fx
def f(x):
    return torch.sin(x) + x
x = torch.randn(13)
grad_f = make_fx(torch.func.jacrev(f))(x)
fx_print(grad_f)

def f(x_1):
  v10 = x_1[32m # Tensor[13, torch.float32],val:FakeTensor(..., size=(13,))[0m
  v11 = aten.sin.default(v10)[32m # Tensor[13, torch.float32],val:FakeTensor(..., size=(13,))[0m
  v12 = aten.add.Tensor(v11,v10)[32m # Tensor[13, torch.float32],val:FakeTensor(..., size=(13,))[0m
  v13 = self._tensor_constant0[32m # i64[1] [13][0m
  v14 = aten.lift_fresh_copy.default(v13)[32m # Tensor[1, torch.int64],val:FakeTensor(..., size=(1,), dtype=torch.int64)[0m
  v15 = aten.cumsum.default(v14,0)[32m # Tensor[1, torch.int64],val:FakeTensor(..., size=(1,), dtype=torch.int64)[0m
  v16 = aten.slice.Tensor(v15,0,0,-1)[32m # Tensor[0, torch.int64],val:FakeTensor(..., size=(0,), dtype=torch.int64)[0m
  v17 = aten.neg.default(v16)[32m # Tensor[0, torch.int64],val:FakeTensor(..., size=(0,), dtype=torch.int64)[0m
  v18 = aten.unbind.int(v17)[32m # val:[][0m
  v19 = aten.new_zeros.default(v12,[13,13])[32m # Tensor[13x13, torch.float32],val:FakeTensor(..., size=(13, 13))[0m
  v20

For comparison, the FX AD version is closer to what one might write by hand:

In [10]:
fx_print(fx_vjp(f, x))

def vjp_template(x,dret):
  v10 = x
  v11 = dret
  v12 = torch.sin(v10)
  v13 = add(v12,v10)
  v14 = v11.reshape((13))
  v15 = v11.reshape((13))
  v16 = add(0,v14)
  v17 = add(0,v15)
  v18 = torch.cos(v10)
  v19 = mul(v18,v16)
  v20 = add(v17,v19)
  return (v13,v20)
