<a href="https://colab.research.google.com/github/awf/awf-misc/blob/main/FX_Experiments.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [85]:
import torch

In [87]:
from fx_shnty import shnty_trace, ShapeAndType
from fx_print import fx_print

def aux(p, q):
  return torch.relu(1.234 * p * q).neg()

def my_func(x, b):
  y = 2 * x
  for _ in range(2):  # Loops will be unrolled
    x = aux(x, y)  # Function calls will be inlined
  return torch.atan2(b, x)

gm = shnty_trace(my_func, arg_shntys=(ShapeAndType((3, 5), torch.Tensor),))

fx_print(gm)


TypeError: Union[arg, ...]: each arg must be a type. Got Ellipsis.

In [84]:
def fx_print(gm):
  def commajoin(vs): return ",".join(vs)
  name2ord = {}
  ord = 10
  
  
  args = [n.name for n in gm.graph.nodes if n.op == 'placeholder']
  name = torch.nn.Module._get_name(gm)
  print(f'def {name}({commajoin(args)}):')
  for n in gm.graph.nodes:
      assert n.name not in name2ord
      name2ord[n.name] = f'v{ord}'
      ord += 1
      target = n.target.__name__ if n.op == 'call_function' else n.target
      def argstr(a):
        if isinstance(a, tfx.Node):
          return name2ord[a.name]
        return str(a)
      args = [argstr(a) for a in n.args]
      pr = lambda x: print('  ' + x)
      if n.op == 'placeholder':
        pr(f'{argstr(n)} = {target}')
      elif n.op == 'call_function':
        pr(f'{argstr(n)} = {target}({commajoin(args)})')
      elif n.op == 'call_method':
        pr(f'{argstr(n)} = {argstr(args[0])}.{target}({",".join(args[1:])})')
      elif n.op == 'output':
        pr(f'return {argstr(args[0])}')
      else:
        pr(f'# unhandled {n.op} {argstr(n)} = {target}({",".join(args)})')
      
fx_print(gm)

def my_func(x,b):
  v10 = x
  v11 = b
  v12 = mul(2,v10)
  v13 = mul(1.234,v10)
  v14 = mul(v13,v12)
  v15 = relu(v14)
  v16 = v15.neg()
  v17 = mul(1.234,v16)
  v18 = mul(v17,v12)
  v19 = relu(v18)
  v20 = v19.neg()
  v21 = atan2(v11,v20)
  return v21


In [4]:
import torch, torch.fx
import copy

def aux(p,q):
    return torch.relu(1.234*p*q).neg()

def my_func(x, b):
    y = 2 * x
    for _ in range(2): # Loops will be unrolled
        x = aux(x,y) # Function calls will be inlined
    return torch.atan2(b,x)

my_func_trace = torch.fx.symbolic_trace(my_func)
print(my_func_trace.code)




def forward(self, x, b):
    mul = 2 * x
    mul_1 = 1.234 * x;  x = None
    mul_2 = mul_1 * mul;  mul_1 = None
    relu = torch.relu(mul_2);  mul_2 = None
    neg = relu.neg();  relu = None
    mul_3 = 1.234 * neg;  neg = None
    mul_4 = mul_3 * mul;  mul_3 = mul = None
    relu_1 = torch.relu(mul_4);  mul_4 = None
    neg_1 = relu_1.neg();  relu_1 = None
    atan2 = torch.atan2(b, neg_1);  b = neg_1 = None
    return atan2
    


## Example: convert relu to gelu

In [3]:
def relu_to_gelu(mod: torch.fx.GraphModule):
    g = mod.graph
    for n in g.nodes:
        if n.op == 'call_function' and n.target == torch.relu:
            n.target = torch.nn.functional.gelu

    mod.recompile()
    return None # in-place modification of the graph

my_func_trace = torch.fx.symbolic_trace(my_func)
relu_to_gelu(my_func_trace)
print(my_func_trace.code)




def forward(self, x, b):
    mul = 2 * x
    mul_1 = 1.234 * x;  x = None
    mul_2 = mul_1 * mul;  mul_1 = None
    relu = torch._C._nn.gelu(mul_2);  mul_2 = None
    neg = relu.neg();  relu = None
    mul_3 = 1.234 * neg;  neg = None
    mul_4 = mul_3 * mul;  mul_3 = mul = None
    relu_1 = torch._C._nn.gelu(mul_4);  mul_4 = None
    neg_1 = relu_1.neg();  relu_1 = None
    atan2 = torch.atan2(b, neg_1);  b = neg_1 = None
    return atan2
    


## Methods to functions

Why _do_ we distinguish methods and functions?

In [6]:
# map from method name to function name
def method_to_function(mod: torch.fx.GraphModule):
    g = mod.graph
    for n in g.nodes:
        if n.op == 'call_method':
            # create IR to call new activate
            with g.inserting_after(n):
                new_n = g.call_function(fn[n.target], n.args)
                n.replace_all_uses_with(new_n)
                g.erase_node(n)
        else:
            print('doing nothing to', n)

    mod.recompile()
    return None # in-place modification of the graph

my_func_trace = torch.fx.symbolic_trace(my_func)
method_to_function(my_func_trace)
print(my_func_trace.code)

# Look for:
#   neg_2 = torch.neg(relu);

doing nothing to x
doing nothing to b
doing nothing to mul
doing nothing to mul_1
doing nothing to mul_2
doing nothing to relu
doing nothing to neg_2
doing nothing to mul_3
doing nothing to mul_4
doing nothing to relu_1
doing nothing to neg_3
doing nothing to atan2
doing nothing to output



def forward(self, x, b):
    mul = 2 * x
    mul_1 = 1.234 * x;  x = None
    mul_2 = mul_1 * mul;  mul_1 = None
    relu = torch.relu(mul_2);  mul_2 = None
    neg_2 = torch.neg(relu);  relu = None
    mul_3 = 1.234 * neg_2;  neg_2 = None
    mul_4 = mul_3 * mul;  mul_3 = mul = None
    relu_1 = torch.relu(mul_4);  mul_4 = None
    neg_3 = torch.neg(relu_1);  relu_1 = None
    atan2 = torch.atan2(b, neg_3);  b = neg_3 = None
    return atan2
    
