<a href="https://colab.research.google.com/github/awf/awf-misc/blob/main/FX_Experiments.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# FX Experiments

Some small experiments with torch.fx

Some of these use the shnty code in this package, which is largely superseded by
`ShapeProp`.

In [1]:
import torch

In [2]:
from fx_shnty import shnty_trace, abstractify, AbstractTensor
from fx_print import fx_print

def aux(p, q):
  return torch.relu(1.234 * p * q).neg()

def foo(x, b, n):
  y = b * x
  for _ in range(n):  # Loops will be unrolled
    x = aux(x, y)  # Function calls will be inlined
  return torch.atan2(y, x)

x = torch.randn(3,5)
b = 8.2
n = 2
foo(x,b,n)

foo_gm = shnty_trace(foo, (abstractify(x), abstractify(b), n))

fx_print(foo_gm)


shnty_trace foo at (AbTensor[(3,5),torch.float32], abval[torch.float32], 2)
shnty_trace -- x = x...AbTensor[(3,5),torch.float32]
shnty_trace -- b = b...abval[torch.float32]
shnty_trace -- n = n...2
shnty_trace -- mul = mul(b,x)...AbTensor[(3,5),torch.float32]
shnty_trace -- mul_1 = mul(float(1.234),x)...AbTensor[(3,5),torch.float32]
shnty_trace -- mul_2 = mul(mul_1,mul)...AbTensor[(3,5),torch.float32]
shnty_trace -- relu = torch.relu(mul_2)...AbTensor[(3,5),torch.float32]
shnty_trace -- neg = relu.neg()...AbTensor[(3,5),torch.float32]
shnty_trace -- mul_3 = mul(float(1.234),neg)...AbTensor[(3,5),torch.float32]
shnty_trace -- mul_4 = mul(mul_3,mul)...AbTensor[(3,5),torch.float32]
shnty_trace -- relu_1 = torch.relu(mul_4)...AbTensor[(3,5),torch.float32]
shnty_trace -- neg_1 = relu_1.neg()...AbTensor[(3,5),torch.float32]
shnty_trace -- atan2 = torch.atan2(mul,neg_1)...AbTensor[(3,5),torch.float32]
def foo(x,b,n):
  v10 = x[32m # $abval:AbTensor[(3,5),torch.float32][0m
  v11 = b[32m # $

In [3]:
# Suppose we didn't even want to build x:
foo_gm = shnty_trace(foo, (AbstractTensor(torch.Size((3,5)), torch.float32), abstractify(b), n))

fx_print(foo_gm)


shnty_trace foo at (AbTensor[(3,5),torch.float32], abval[torch.float32], 2)
shnty_trace -- x = x...AbTensor[(3,5),torch.float32]
shnty_trace -- b = b...abval[torch.float32]
shnty_trace -- n = n...2
shnty_trace -- mul = mul(b,x)...AbTensor[(3,5),torch.float32]
shnty_trace -- mul_1 = mul(float(1.234),x)...AbTensor[(3,5),torch.float32]
shnty_trace -- mul_2 = mul(mul_1,mul)...AbTensor[(3,5),torch.float32]
shnty_trace -- relu = torch.relu(mul_2)...AbTensor[(3,5),torch.float32]
shnty_trace -- neg = relu.neg()...AbTensor[(3,5),torch.float32]
shnty_trace -- mul_3 = mul(float(1.234),neg)...AbTensor[(3,5),torch.float32]
shnty_trace -- mul_4 = mul(mul_3,mul)...AbTensor[(3,5),torch.float32]
shnty_trace -- relu_1 = torch.relu(mul_4)...AbTensor[(3,5),torch.float32]
shnty_trace -- neg_1 = relu_1.neg()...AbTensor[(3,5),torch.float32]
shnty_trace -- atan2 = torch.atan2(mul,neg_1)...AbTensor[(3,5),torch.float32]
def foo(x,b,n):
  v10 = x[32m # $abval:AbTensor[(3,5),torch.float32][0m
  v11 = b[32m # $

In [4]:
from torch import fx
from fx_shnty import shnty_trace, abstractify, shnty_propagator, AbstractTensor, _shnty_propagator_dict
from fx_print import fx_print

def aux(p, q):
  return torch.relu(1.234 * p * q).neg()

aux = fx.wrap(aux)

@shnty_propagator(aux)
def _(p_abval, q_abval):
  return p_abval


from icecream import ic
ic(_shnty_propagator_dict)

def foo(x, b, n):
  y = b * x
  for _ in range(n):  # Loops will be unrolled
    x = aux(x, y)  # Function calls will be inlined
  return torch.atan2(y, x)

x = torch.randn(3,5)
b = 8.2
n = 2
foo(x,b,n)

foo_gm = shnty_trace(foo, (abstractify(x), abstractify(b), n))

fx_print(foo_gm)


ic| _shnty_propagator_dict: {<built-in method cumsum of type object at 0x7ffae3cc5280>: <function fx_shnty_propagators._<shnty_propagator(function `cumsum`)> at 0x7ff9d9bfe7a0>,
                             <built-in method neg of type object at 0x7ffae3cc5280>: <function shnty_propagate_broadcast.<locals>.<lambda> at 0x7ff9d9bfd6c0>,
                             <built-in method transpose of type object at 0x7ffae3cc5280>: <function fx_shnty_propagators._<shnty_propagator(function `transpose`)> at 0x7ff9d9bfe5c0>,
                             <built-in method ones_like of type object at 0x7ffae3cc5280>: <function shnty_propagate_broadcast.<locals>.<

lambda> at 0x7ff9d9bfd580>,
                             <built-in method cos of type object at 0x7ffae3cc5280>: <function shnty_propagate_broadcast.<locals>.<lambda> at 0x7ff9d9bfd800>,
                             <built-in method relu of type object at 0x7ffae3cc5280>: <function shnty_propagate_broadcast.<locals>.<lambda> at 0x7ff9d9bfd620>,
                             <built-in method sin of type object at 0x7ffae3cc5280>: <function shnty_propagate_broadcast.<locals>.<lambda> at 0x7ff9d9bfd760>,
                             <built-in method sum of type object at 0x7ffae3cc5280>: <function fx_shnty_propagators._<shnty_propagator(function `sum`)> at 0x7ff9d9bfe700>,
                             <built-in method trace of type object at 0x7ffae3cc5280>: <function fx_shnty_propagators._<shnty_propagator(function `torch.trace`)> at 0x7ff9d9bfe3e0>,
                             <built-in method atan2 of type object at 0x7ffae3cc5280>: <function shnty_propagate_broadcast.<locals>.<lambda>

shnty_trace foo at (AbTensor[(3,5),torch.float32], abval[torch.float32], 2)
shnty_trace -- x = x...AbTensor[(3,5),torch.float32]
shnty_trace -- b = b...abval[torch.float32]
shnty_trace -- n = n...2
shnty_trace -- mul = mul(b,x)...AbTensor[(3,5),torch.float32]
shnty_trace -- aux = __main__.aux(x,mul)[32m # [/tmp/ipykernel_1580975/1924174572.py:5][0m...AbTensor[(3,5),torch.float32]
shnty_trace -- aux_1 = __main__.aux(aux,mul)[32m # [/tmp/ipykernel_1580975/1924174572.py:5][0m...AbTensor[(3,5),torch.float32]
shnty_trace -- atan2 = torch.atan2(mul,aux_1)...AbTensor[(3,5),torch.float32]
def foo(x,b,n):
  v10 = x[32m # $abval:AbTensor[(3,5),torch.float32][0m
  v11 = b[32m # $abval:abval[torch.float32][0m
  v12 = n
  v13 = mul(v11,v10)[32m # $abval:AbTensor[(3,5),torch.float32][0m
  v14 = __main__.aux(v10,v13)[32m # $abval:AbTensor[(3,5),torch.float32],is_wrapped:True[/tmp/ipykernel_1580975/1924174572.py:5][0m
  v15 = __main__.aux(v14,v13)[32m # $abval:AbTensor[(3,5),torch.float32

## Example: convert relu to gelu

In [5]:
def my_func(x):
    return torch.relu(x) + x

def relu_to_gelu(mod: torch.fx.GraphModule):
    g = mod.graph
    for n in g.nodes:
        if n.op == 'call_function' and n.target == torch.relu:
            n.target = torch.nn.functional.gelu

    mod.recompile()
    return None # in-place modification of the graph

my_func_trace = torch.fx.symbolic_trace(my_func)
relu_to_gelu(my_func_trace)
fx_print(my_func_trace)

def my_func(x):
  v10 = x
  v11 = torch._C._nn.gelu(v10)
  v12 = add(v11,v10)
  return v12


## Methods to functions

Why _do_ we distinguish methods and functions?

In [6]:
from fx_print import fx_print_node
from difffx import fx_type, fx_add_shapes


def my_func(x):
    return torch.relu(x.neg()) + x


print("Original: contains t.neg()")
fx_print(torch.fx.symbolic_trace(my_func))


# map from method to function
def method_to_function(
    mod: torch.fx.GraphModule,
    replacements={torch.Tensor.neg: torch.neg},
    verbose=False,
):
    g = mod.graph
    for n in g.nodes:
        if n.op == "call_method":
            # create IR to call new activate
            with g.inserting_after(n):
                ty = n.meta["type"]
                key = getattr(ty, n.target)
                if key in replacements:
                    print(f"method_to_function: replacing {key} in", fx_print_node(n))
                    new_n = g.call_function(replacements[key], n.args)
                    n.replace_all_uses_with(new_n)
                    g.erase_node(n)
                else:
                    if verbose:
                        print(f"key not in fn [{key}]", fx_print_node(n))

        else:
            if verbose:
                print("doing nothing to", fx_print_node(n))

    mod.recompile()
    return None  # in-place modification of the graph


my_func_trace = torch.fx.symbolic_trace(my_func)
fx_add_shapes(my_func_trace, torch.zeros((2, 3)))
method_to_function(my_func_trace)

print("Modified: contains torch.neg(t)")
fx_print(my_func_trace)

Original: contains t.neg()
def my_func(x):
  v10 = x
  v11 = v10.neg()
  v12 = torch.relu(v11)
  v13 = add(v12,v10)
  return v13
method_to_function: replacing <method 'neg' of 'torch._C.TensorBase' objects> in neg = x.neg()[32m # Tensor[3, torch.float32][0m
Modified: contains torch.neg(t)
def my_func(x):
  v10 = x[32m # Tensor[3, torch.float32][0m
  v11 = torch.neg(v10)
  v12 = torch.relu(v11)[32m # Tensor[3, torch.float32][0m
  v13 = add(v12,v10)[32m # Tensor[3, torch.float32][0m
  return v13[32m # Tensor[3, torch.float32][0m
