<a href="https://colab.research.google.com/github/awf/awf-misc/blob/main/FX_Experiments.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch, torch.fx
import copy

def aux(p,q):
    return torch.relu(1.234*p*q).neg()

def my_func(x, b):
    y = 2 * x
    for _ in range(2): # Loops will be unrolled
        x = aux(x,y) # Function calls will be inlined
    return torch.atan2(b,x)

my_func_trace = torch.fx.symbolic_trace(my_func)
print(my_func_trace.code)




def forward(self, x, b):
    mul = 2 * x
    mul_1 = 1.234 * x;  x = None
    mul_2 = mul_1 * mul;  mul_1 = None
    relu = torch.relu(mul_2);  mul_2 = None
    neg = relu.neg();  relu = None
    mul_3 = 1.234 * neg;  neg = None
    mul_4 = mul_3 * mul;  mul_3 = mul = None
    relu_1 = torch.relu(mul_4);  mul_4 = None
    neg_1 = relu_1.neg();  relu_1 = None
    atan2 = torch.atan2(b, neg_1);  b = neg_1 = None
    return atan2
    


## Example: convert relu to gelu

In [3]:
def relu_to_gelu(mod: torch.fx.GraphModule):
    g = mod.graph
    for n in g.nodes:
        if n.op == 'call_function' and n.target == torch.relu:
            n.target = torch.nn.functional.gelu

    mod.recompile()
    return None # in-place modification of the graph

my_func_trace = torch.fx.symbolic_trace(my_func)
relu_to_gelu(my_func_trace)
print(my_func_trace.code)




def forward(self, x, b):
    mul = 2 * x
    mul_1 = 1.234 * x;  x = None
    mul_2 = mul_1 * mul;  mul_1 = None
    relu = torch._C._nn.gelu(mul_2);  mul_2 = None
    neg = relu.neg();  relu = None
    mul_3 = 1.234 * neg;  neg = None
    mul_4 = mul_3 * mul;  mul_3 = mul = None
    relu_1 = torch._C._nn.gelu(mul_4);  mul_4 = None
    neg_1 = relu_1.neg();  relu_1 = None
    atan2 = torch.atan2(b, neg_1);  b = neg_1 = None
    return atan2
    


## Methods to functions

Why _do_ we distinguish methods and functions?

In [4]:
# map from method name to function name
fn = {
    'neg': torch.neg
}

def method_to_function(mod: torch.fx.GraphModule):
    g = mod.graph
    for n in g.nodes:
        if n.op == 'call_function':
            pass
        elif n.op == 'call_method':
            # create IR to call new activate
            with g.inserting_after(n):
                new_n = g.call_function(fn[n.target], n.args)
                n.replace_all_uses_with(new_n)
                g.erase_node(n)
        else:
            print('doing nothing to', n)

    mod.recompile()
    return None # in-place modification of the graph

my_func_trace = torch.fx.symbolic_trace(my_func)
method_to_function(my_func_trace)
print(my_func_trace.code)

# Look for:
#   neg_2 = torch.neg(relu);

doing nothing to x
doing nothing to b
doing nothing to output



def forward(self, x, b):
    mul = 2 * x
    mul_1 = 1.234 * x;  x = None
    mul_2 = mul_1 * mul;  mul_1 = None
    relu = torch.relu(mul_2);  mul_2 = None
    neg_2 = torch.neg(relu);  relu = None
    mul_3 = 1.234 * neg_2;  neg_2 = None
    mul_4 = mul_3 * mul;  mul_3 = mul = None
    relu_1 = torch.relu(mul_4);  mul_4 = None
    neg_3 = torch.neg(relu_1);  relu_1 = None
    atan2 = torch.atan2(b, neg_3);  b = neg_3 = None
    return atan2
    


In [6]:
from vjp_rules import *

# Function to vjp
def foo(x):
    a = scale(1.1, x)
    b = torch.relu(a)
    c = torch.neg(b)
    return c

# Desired vjp code
def foo_grad(x, dret):
    scale,scale_aux = scale_fwd(1.1, x)
    relu,relu_awx = relu_fwd(scale)
    neg,neg_aux = neg_fwd(relu)
    
    drelu = neg_bwd(neg_aux, dret)
    dscale = relu_bwd(relu_awx, drelu)
    _,dx = scale_bwd(scale_aux, dscale)
    return dx


my_func_trace = torch.fx.symbolic_trace(foo)
print(my_func_trace)


foo()



def forward(self, x):
    mul = 1.1 * x;  x = None
    relu = torch.relu(mul);  mul = None
    neg = torch.neg(relu);  relu = None
    return neg
    


In [7]:
from ast import Add
import operator
import torch.fx as fx

def add_vjp(aux,d_add):
    return (d_add,d_add)

# https://github.com/pytorch/pytorch/blob/master/torch/fx/OVERVIEW.md#the-fx-ir

ad_map = {
    operator.neg: (neg_fwd, neg_bwd),
    torch.neg: (neg_fwd, neg_bwd),
    operator.add: (add_fwd, add_bwd),
    operator.mul: (mul_fwd, mul_bwd),
    torch.relu: (relu_fwd, relu_bwd)
}

class ADInterpreter(torch.fx.Interpreter):
    
    def __init__(self, f):
        self.stack = []
        super().__init__(f)

    def call_function(self, target, args, kwargs):
        assert kwargs == None or len(kwargs) == 0
        print('FWD', target, args)
        val,aux = ad_map[target][0](*args)
        self.stack.append((aux, val, ad_map[target][1], args))
        return val

def ensure_tuple(x):
    if isinstance(x, tuple):
        return x
    return (x,)

def ad_wrapper(f_trace):
    ad = ADInterpreter(f_trace)
    def run(x, dret):
        ret = ad.run(x)
        ad.stack.reverse()
        d =  {ret: dret}
        for (aux, val, f_rev, args) in ad.stack:
            dargs = f_rev(aux, d[val])
            print(val, aux, f_rev, args, ":", dargs)
            for (a,da) in zip(args, ensure_tuple(dargs)):
                if a in d:
                    d[a] += da
                else:
                    d[a] = da
        return d[x]
        

    return run
     


foo_trace = torch.fx.symbolic_trace(foo)
new_fn = torch.fx.symbolic_trace(ad_wrapper(foo_trace))
print(new_fn)

my_func_trace = torch.fx.symbolic_trace(foo_grad)
print(my_func_trace)

x = torch.randn(3,4)
dret = 0.001 * torch.randn(3, 4)
new_fn(x, dret)


FWD <built-in function mul> (1.1, Proxy(x))
FWD <built-in method relu of type object at 0x7f3835408d60> (Proxy(mul),)
FWD <built-in method neg of type object at 0x7f3835408d60> (Proxy(relu),)
Proxy(neg) None <function neg_bwd at 0x7f380d4f5ca0> (Proxy(relu),) : Proxy(neg_1)
Proxy(relu) Proxy(mul) <function relu_bwd at 0x7f380d557d30> (Proxy(mul),) : Proxy(mul_1)
Proxy(mul) (1.1, Proxy(x)) <function mul_bwd at 0x7f380d56a310> (1.1, Proxy(x)) : (Ellipsis, Proxy(mul_2))
run()



def forward(self, x, dret):
    mul = 1.1 * x;  x = None
    relu = torch.relu(mul)
    neg = -relu;  relu = None
    neg_1 = -dret;  dret = None
    gt = mul > 0;  mul = None
    mul_1 = gt * neg_1;  gt = neg_1 = None
    mul_2 = 1.1 * mul_1;  mul_1 = None
    return mul_2
    
foo_grad()



def forward(self, x, dret):
    mul = 1.1 * x;  x = None
    relu = torch.relu(mul)
    neg = -relu;  relu = None
    neg_1 = -dret;  dret = None
    gt = mul > 0;  mul = None
    mul_1 = gt * neg_1;  gt = neg_1 = None
    mu

tensor([[-0.0000,  0.0000,  0.0000,  0.0015],
        [-0.0000,  0.0000,  0.0023,  0.0000],
        [-0.0003,  0.0000, -0.0000,  0.0003]])