官方示例库：https://github.dev/pytorch/examples/tree/main/fx

## symbolic_trace和graph

In [141]:
import torch
from torch.fx import symbolic_trace

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return self.linear(x).clamp(min=0.0, max=1.0)

model = MyModule()

# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : torch.fx.GraphModule = symbolic_trace(model)

# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)

graph():
    %x : [#users=1] = placeholder[target=x]
    %linear : [#users=1] = call_module[target=linear](args = (%x,), kwargs = {})
    %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp


In [142]:
print(symbolic_traced)

MyModule(
  (linear): Linear(in_features=4, out_features=5, bias=True)
)



def forward(self, x):
    linear = self.linear(x);  x = None
    clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None
    return clamp
    


In [143]:
#[[n.op, n.name, n.target, n.args, n.kwargs] for n in self.nodes]
symbolic_traced.graph.print_tabular()

opcode       name    target    args       kwargs
-----------  ------  --------  ---------  ------------------------
placeholder  x       x         ()         {}
call_module  linear  linear    (x,)       {}
call_method  clamp   clamp     (linear,)  {'min': 0.0, 'max': 1.0}
output       output  output    (clamp,)   {}


## Garph Manipulation

### Direct Graph Manipulation

E.g., Replace `torch.add()` calls with `torch.mul()` calls:

In [144]:
import torch
from torch import fx

class M(torch.nn.Module):
    def forward(self, x, y):
        return torch.add(x, y)

m_add = M()
print(fx.Tracer().trace(m_add))

graph():
    %x : [#users=1] = placeholder[target=x]
    %y : [#users=1] = placeholder[target=y]
    %add : [#users=1] = call_function[target=torch.add](args = (%x, %y), kwargs = {})
    return add


注意判断的是`node.target`而非`node`，这样只需要替换`node.target`：

In [145]:

def transform(m, tracer_class=fx.Tracer) -> torch.nn.Module:
    graph: fx.Graph = tracer_class().trace(m)
    for node in graph.nodes:
        if node.op == 'call_function':
            if node.target == torch.add:
                node.target = torch.mul
    graph.lint() # Does some checks to make sure the Graph is well-formed.
    return fx.GraphModule(m, graph)

m_mul = transform(m_add)
print(m_mul)

GraphModule()



def forward(self, x, y):
    add = torch.mul(x, y);  x = y = None
    return add
    


E.g., Graph rewrites

In [146]:
class M(torch.nn.Module):
    def forward(self, x, y):
        x = torch.add(x, y)
        x1 = torch.mul(x, 3)
        x2 = torch.mul(x, 2)
        return x1 * x2

m = M()

在`torch.add`后面插入`torch.relu`，注意使用`deepcopy(node)`防止被`replace_all_uses_with()`替换：

In [147]:
from copy import deepcopy
m_trace = fx.symbolic_trace(m)
for node in m_trace.graph.nodes:
    if (node.op, node.target) == ("call_function", torch.add):
        with m_trace.graph.inserting_after(node):
            # Insert a new `call_function` node calling `torch.relu`
            new_node = m_trace.graph.call_function(torch.relu, args=(deepcopy(node),))
            node.replace_all_uses_with(new_node)
m_trace.recompile()
print(m_trace)

M()



def forward(self, x, y):
    add = torch.add(x, y);  x = y = None
    relu = torch.relu(add);  add = None
    mul = torch.mul(relu, 3)
    mul_1 = torch.mul(relu, 2);  relu = None
    mul_2 = mul * mul_1;  mul = mul_1 = None
    return mul_2
    


### Subgraph Rewriting With `replace_pattern()`

更多内容和示例参见[官网](https://pytorch.org/docs/stable/fx.html#subgraph-rewriting-with-replace-pattern)，包括：
- Replace one op
- Conv/Batch Norm fusion
- replace_pattern: Basic usage
- Quantization
- Invert Transformation

## Proxy/Retracing

E.g., Create a Graph Using Proxy Objects Instead of Tracing:

In [148]:
import torch
from torch.fx import Proxy, Graph, GraphModule

# Create a graph independently of symbolic tracing
graph = Graph()
tracer = torch.fx.proxy.GraphAppendingTracer(graph)

# Create raw Nodes
raw1 = graph.placeholder('x')
raw2 = graph.placeholder('y')

# Initialize Proxies using the raw Nodes and graph's default tracer
y = Proxy(raw1, tracer)
z = Proxy(raw2, tracer)

# Create other operations using the Proxies `y` and `z`
a = torch.cat([y, z])
b = torch.tanh(a)
c = torch.neg(b)
d = torch.add(b, c)

# Create a new output Node and add it to the Graph.
graph.output(c.node)

# Wrap our created Graph in a GraphModule to get a final, runnable `nn.Module` instance
mod = GraphModule(torch.nn.Module(), graph)

print(mod)

GraphModule()



def forward(self, x, y):
    cat = torch.cat([x, y]);  x = y = None
    tanh = torch.tanh(cat);  cat = None
    neg = torch.neg(tanh)
    add = torch.add(tanh, neg);  tanh = None
    return neg
    


E.g., Decomposing ReLU into its mathematical definition.

Decompose model into smaller constituent operations. 
 
Here, we decompose `ReLU` into its mathematical definition: `(x > 0) * x`:

In [149]:
class M(torch.nn.Module):
    def forward(self, x):
        x = torch.add(x, 2)
        return torch.relu(x)

m = M()

inp = torch.tensor([2])
m(inp)

tensor([4])

In [150]:
from torch.fx import map_arg

def relu_decomposition(x):
    print('my_relu:', x)
    return (x > 0) * x

decomposition_rules = {}
decomposition_rules[torch.relu] = relu_decomposition

def decompose(model: torch.nn.Module,
              tracer_class : type = fx.Tracer) -> torch.nn.Module:
    graph : fx.Graph = tracer_class().trace(model)
    new_graph = fx.Graph()
    env = {}
    tracer = torch.fx.proxy.GraphAppendingTracer(graph)
    for node in graph.nodes:
        if node.op == 'call_function' and node.target in decomposition_rules:
            # By wrapping the arguments with proxies, we can dispatch to the appropriate
            # decomposition rule and implicitly add it to the Graph by symbolically tracing it.
            proxy_args = map_arg(node.args, lambda n: fx.Proxy(env[n.name]))
            output_proxy = decomposition_rules[node.target](*proxy_args)

            new_node = output_proxy.node
            env[node.name] = new_node
        else:
            # Default case: we don't have a decomposition rule for this
            # node, so just copy the node over into the new graph.
            new_node = new_graph.node_copy(node, lambda x: env[x.name])
            env[node.name] = new_node
    return fx.GraphModule(model, new_graph)

m_new = decompose(m)
print(m_new)

my_relu: Proxy(add)
GraphModule()



def forward(self, x):
    add = torch.add(x, 2);  x = None
    gt = add > 0
    mul = gt * add;  gt = add = None
    return mul
    


`print('my_relu:', x)`居然在这里打印了出来，其中参数`x`变为了上一个node也就是add，实际`forward()`时候没有执行`print()`：

In [151]:
m_new(inp)

tensor([4])

## The Interpreter Pattern

### E.g., [Shape Propagation](https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py)

实现一个简单的Shape Propagation：
> 这里代码有点问题，当`result != torch.Tensor`时，需要进一步处理。

In [175]:
import torch
import torch.fx
from torch.fx.node import Node

from typing import Dict

class ShapeProp:
    def __init__(self, mod):
        self.mod = mod
        self.graph = mod.graph
        self.modules = dict(self.mod.named_modules())

    def propagate(self, *args):
        args_iter = iter(args)
        env : Dict[str, Node] = {}

        def load_arg(a):
            return torch.fx.graph.map_arg(a, lambda n: env[n.name])

        def fetch_attr(target : str):
            target_atoms = target.split('.')
            attr_itr = self.mod
            for i, atom in enumerate(target_atoms):
                if not hasattr(attr_itr, atom):
                    raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
                attr_itr = getattr(attr_itr, atom)
            return attr_itr

        for node in self.graph.nodes:
            if node.op == 'placeholder':
                result = next(args_iter)
            elif node.op == 'get_attr':
                result = fetch_attr(node.target)
            elif node.op == 'call_function':
                result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
            elif node.op == 'call_method':
                self_obj, *args = load_arg(node.args)
                kwargs = load_arg(node.kwargs)
                result = getattr(self_obj, node.target)(*args, **kwargs)
            elif node.op == 'call_module':
                result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))

            # This is the only code specific to shape propagation.
            # you can delete this `if` branch and this becomes
            # a generic GraphModule interpreter.
            if isinstance(result, torch.Tensor):
                node.shape = result.shape
                node.dtype = result.dtype
            # else:
            #     ......

            env[node.name] = result

        for node in self.graph.nodes:
            print(node.name, node.shape, node.dtype)

In [181]:
from torch.fx import symbolic_trace

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return self.linear(x).clamp(min=0.0, max=1.0)

m = MyModule()
traced_m = symbolic_trace(m)
s = ShapeProp(traced_m)
s.propagate(torch.randn(4))

x torch.Size([4]) torch.float32
linear torch.Size([5]) torch.float32
clamp torch.Size([5]) torch.float32
output torch.Size([5]) torch.float32


PyTorch提供了[`class ShapeProp(torch.fx.Interpreter)`](https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py)。

In [184]:
from torch.fx.passes.shape_prop import ShapeProp

ShapeProp(traced_m).propagate(torch.randn(4))
for node in traced_m.graph.nodes:
    print(node.name, node.meta['tensor_meta'].dtype,
        node.meta['tensor_meta'].shape)

x torch.float32 torch.Size([4])
linear torch.float32 torch.Size([5])
clamp torch.float32 torch.Size([5])
output torch.float32 torch.Size([5])


### [fx.Interpreter](https://github.com/pytorch/pytorch/blob/master/torch/fx/interpreter.py)

```
run()
    +-- run_node
        +-- placeholder()
        +-- get_attr()
        +-- call_function()
        +-- call_method()
        +-- call_module()
        +-- output()
```

交换`sigmoid()`和`neg()`：

In [189]:
from torch.fx.interpreter import Interpreter
from typing import Tuple, Any
from torch.fx.node import Target

class NegSigmSwapInterpreter(Interpreter):
    def call_function(self, target : Target,
                      args : Tuple, kwargs : Dict) -> Any:
        if target == torch.sigmoid:
            return torch.neg(*args, **kwargs)

    def call_method(self, target : Target,
                    args : Tuple, kwargs : Dict) -> Any:
        if target == 'neg':
            call_self, *args_tail = args
            return call_self.sigmoid(*args_tail, **kwargs)

def fn(x):
    return torch.sigmoid(x).neg()

gm = torch.fx.symbolic_trace(fn)
input = torch.randn(3, 4)
result = NegSigmSwapInterpreter(gm).run(input)
torch.testing.assert_allclose(result, torch.neg(input).sigmoid())

## feature_extraction比如我们要提取ResNet模型的C4和C5特征： 
torch.fx还有一个比较实用的使用场景，那就是对模型进行特征提取，比如我们希望得到模型中间特征用来分析，或者用一些中间特征用于构建其它模型，比如检测和分割模型。比如我们要提取ResNet模型的C4和C5特征：

In [3]:
import torchvision, torch
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names

# 构建模型
model = torchvision.models.resnet50()

# 获取模型的所有的nodes
train_nodes, eval_nodes = get_graph_node_names(model)

# 定义输出node
return_nodes = {'layer3.5.relu_2': 'C4', 'layer4.2.relu_2': 'C5'}

# 进行重建
n_model = create_feature_extractor(model, return_nodes)

out = n_model(torch.rand(1, 3, 224, 224))
print([(k, v.shape) for k, v in out.items()])

[('C4', torch.Size([1, 1024, 14, 14])), ('C5', torch.Size([1, 2048, 7, 7]))]
