**FX is a toolkit for developers to use to transform nn.Module instances. FX consists of three main components: a symbolic tracer, an intermediate representation, and Python code generation. A demonstration of these components in action:**

In [6]:
import torch
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return self.linear(x + self.param).clamp(min=0.0, max=1.0)

module = MyModule()

from torch.fx import symbolic_trace
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)

print(symbolic_traced.graph)

graph():
    %x : [num_users=1] = placeholder[target=x]
    %param : [num_users=1] = get_attr[target=param]
    %add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp


In [7]:
print(symbolic_traced.code)




def forward(self, x):
    param = self.param
    add = x + param;  x = param = None
    linear = self.linear(add);  add = None
    clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None
    return clamp
    


**Writing Transformations**

**Build a New Graph**



In [8]:
import torch
import torch.fx

def transform(m: torch.nn.Module,
              tracer_class: type = torch.fx.Tracer) -> torch.nn.Module:
    graph: torch.fx.Graph = tracer_class().trace(m)

    for node in graph.nodes:
        if node.op == 'call_module' and 'linear' in node.target:
            with graph.inserting_before(node):
                relu_node = graph.create_node('call_function', torch.relu, args=(node.args[0],))

            node.args = (relu_node,)

    return torch.fx.GraphModule(m, graph)

**Modify the Existing Graph in Place**

In [10]:
import torch
import torch.fx

def transform(m: torch.nn.Module) -> torch.nn.Module:
    gm: torch.fx.GraphModule = torch.fx.symbolic_trace(m)

    for node in gm.graph.nodes:
        if node.op == 'call_module' and 'linear' in node.target:
            with gm.graph.inserting_after(node):
                pass

    gm.recompile()
    return gm

**GRAPH**

In [11]:
import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return torch.topk(torch.sum(
            self.linear(x + self.linear.weight).relu(), dim=-1), 3)

m = MyModule()
gm = torch.fx.symbolic_trace(m)

gm.graph.print_tabular()

opcode         name           target                                                   args                kwargs
-------------  -------------  -------------------------------------------------------  ------------------  -----------
placeholder    x              x                                                        ()                  {}
get_attr       linear_weight  linear.weight                                            ()                  {}
call_function  add            <built-in function add>                                  (x, linear_weight)  {}
call_module    linear         linear                                                   (add,)              {}
call_method    relu           relu                                                     (linear,)           {}
call_function  sum_1          <built-in method sum of type object at 0x7cdddfe647e0>   (relu,)             {'dim': -1}
call_function  topk           <built-in method topk of type object at 0x7cdddfe647e0>  (sum_1, 3) 

**Modifying the graph**

In [12]:
for node in gm.graph.nodes:
    if node.op == 'call_module' and node.target == 'linear':
        with gm.graph.inserting_before(node):
            const_node = gm.graph.create_node('call_function', torch.add, args=(node.args[0], torch.tensor(1.0)))
        node.args = (const_node,)

gm.recompile()
gm.graph.print_tabular()

opcode         name           target                                                   args                kwargs
-------------  -------------  -------------------------------------------------------  ------------------  -----------
placeholder    x              x                                                        ()                  {}
get_attr       linear_weight  linear.weight                                            ()                  {}
call_function  add            <built-in function add>                                  (x, linear_weight)  {}
call_function  add_1          <built-in method add of type object at 0x7cdddfe647e0>   (add, tensor(1.))   {}
call_module    linear         linear                                                   (add_1,)            {}
call_method    relu           relu                                                     (linear,)           {}
call_function  sum_1          <built-in method sum of type object at 0x7cdddfe647e0>   (relu,)             

**Direct Graph Manipulation**

**Example: Replacing torch.add with torch.mul**

* sample module M with a forward method that calls torch.add().

* The transform() function traces the module, iterates over the graph nodes, and replaces any call to torch.add() with torch.mul().

*   After transformation, running the modified module will multiply instead of adding the inputs.




In [13]:
import torch
import torch.fx as fx

class M(torch.nn.Module):
    def forward(self, x, y):
        return torch.add(x, y)

def transform(m: torch.nn.Module, tracer_class: type = fx.Tracer) -> torch.nn.Module:
    graph: fx.Graph = tracer_class().trace(m)

    for node in graph.nodes:
        if node.op == 'call_function' and node.target == torch.add:
            node.target = torch.mul

    graph.lint()
    return fx.GraphModule(m, graph)

module = M()
transformed_module = transform(module)

x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])

output = transformed_module(x, y)
print(output)

tensor([ 4, 10, 18])


**Subgraph Rewriting With replace_pattern()**

*   Define a ConvBNPattern module with a Conv2d followed by a BatchNorm2d.

*   Define ConvBNFused, where the batch normalization is fused with the convolution.
*   The replace_pattern() API finds instances of the Conv2d -> BatchNorm2d pattern and replaces them with the fused version.



In [15]:
import torch
import torch.nn as nn
import torch.fx as fx

class ConvBNPattern(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 16, 3)
        self.bn = nn.BatchNorm2d(16)

    def forward(self, x):
        return self.bn(self.conv(x))

class ConvBNFused(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 16, 3)

    def forward(self, x):
        return self.conv(x)

model = ConvBNPattern()
traced = fx.symbolic_trace(model)

fx.subgraph_rewriter.replace_pattern(traced, ConvBNPattern(), ConvBNFused())

[Match(anchor=bn, nodes_map={bn: bn, conv: conv, x: x})]

**Proxy/Retracing**

*   relu_decomposition(x): Replaces F.relu(x) with the element-wise operation (x > 0) * x.
*   decompose(model): Iterates through the traced graph of the model, applying the relu_decomposition transformation to any F.
*   Proxy objects: Used to wrap nodes when tracing. These Proxies capture the operations performed on them, which helps to build the new computation graph automatically.   



In [16]:
import torch
import torch.fx as fx
import torch.nn.functional as F

def relu_decomposition(x):
    return (x > 0) * x

decomposition_rules = {F.relu: relu_decomposition}

def decompose(model: torch.nn.Module, tracer_class: type = fx.Tracer) -> torch.nn.Module:
    graph: fx.Graph = tracer_class().trace(model)
    new_graph = fx.Graph()
    env = {}
    tracer = torch.fx.proxy.GraphAppendingTracer(new_graph)

    for node in graph.nodes:
        if node.op == 'call_function' and node.target in decomposition_rules:
            proxy_args = [fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x for x in node.args]
            output_proxy = decomposition_rules[node.target](*proxy_args)
            new_node = output_proxy.node
            env[node.name] = new_node
        else:
            new_node = new_graph.node_copy(node, lambda x: env[x.name])
            env[node.name] = new_node

    return fx.GraphModule(model, new_graph)

class M(torch.nn.Module):
    def forward(self, x):
        return F.relu(x)

model = M()
transformed_model = decompose(model)

x = torch.tensor([-1.0, 0.0, 1.0])

output = transformed_model(x)
print(output)

tensor([-0., 0., 1.])


**Custom Transformer**

In [18]:
import torch.fx as fx

class CustomTransformer(fx.Transformer):
    def call_function(self, target, args, kwargs):
        if target == torch.add:
            return self.call_function(torch.mul, args, kwargs)
        return super().call_function(target, args, kwargs)

class M(torch.nn.Module):
    def forward(self, x, y):
        return torch.add(x, y)

model = M()
tracer = fx.Tracer()
graph = tracer.trace(model)
gm = fx.GraphModule(model, graph)
transformed_gm = CustomTransformer(gm).transform()

x, y = torch.tensor(2.0), torch.tensor(3.0)
print(transformed_gm(x, y))

tensor(6.)
