In [1]:
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
import utils
import torch.fx as fx
from torch.fx import Proxy, Graph, GraphModule


In [2]:
class TestModel(torch.nn.Module):
  
  def __init__(self, input_dim, output_dim):
    super().__init__()
    self.mlp1 = torch.nn.Linear(input_dim, output_dim)
    
    
  def forward(self,x):
    return self.mlp1(x)

In [3]:
model = TestModel(100,300)

In [4]:
import torch.fx as fx
traced = fx.symbolic_trace(model)

In [5]:
traced.register_module("modifiy",torch.nn.Linear(300,100))

In [6]:
traced.graph.print_tabular()


opcode       name    target    args     kwargs
-----------  ------  --------  -------  --------
placeholder  x       x         ()       {}
call_module  mlp1    mlp1      (x,)     {}
output       output  output    (mlp1,)  {}


In [7]:
env = utils.get_env(traced)

In [16]:
node = env['mlp1']
node.users

{modifiy: None}

torch.fx.node.Node

In [9]:
with traced.graph.inserting_after(node):
  new_node = traced.graph.call_module("modifiy",(env['mlp1'],))
  utils.replace_use_with(node,new_node)

{}

In [10]:
graph = traced.graph
# graph.call_module("modifiy",(env['mlp1'],))

In [11]:
graph.print_tabular()

opcode       name     target    args        kwargs
-----------  -------  --------  ----------  --------
placeholder  x        x         ()          {}
call_module  mlp1     mlp1      (x,)        {}
call_module  modifiy  modifiy   (mlp1,)     {}
output       output   output    (modifiy,)  {}


In [12]:
graph.lint()

In [13]:
graph.print_tabular()

opcode       name     target    args        kwargs
-----------  -------  --------  ----------  --------
placeholder  x        x         ()          {}
call_module  mlp1     mlp1      (x,)        {}
call_module  modifiy  modifiy   (mlp1,)     {}
output       output   output    (modifiy,)  {}


In [14]:
traced.recompile()

PythonCode(src='\n\n\ndef forward(self, x):\n    mlp1 = self.mlp1(x);  x = None\n    modifiy = self.modifiy(mlp1);  mlp1 = None\n    return modifiy\n    ', globals={'inf': inf, 'nan': nan, 'NoneType': <class 'NoneType'>, 'torch': <module 'torch' from '/home/yssun/miniconda3/envs/deepctr-torch/lib/python3.9/site-packages/torch/__init__.py'>, 'device': <class 'torch.device'>, 'fx_pytree': <module 'torch.fx._pytree' from '/home/yssun/miniconda3/envs/deepctr-torch/lib/python3.9/site-packages/torch/fx/_pytree.py'>, 'pytree': <module 'torch.utils._pytree' from '/home/yssun/miniconda3/envs/deepctr-torch/lib/python3.9/site-packages/torch/utils/_pytree.py'>}, _lineno_map={1: 1, 2: 2, 3: 3, 4: 3})

In [15]:
print(traced.code)




def forward(self, x):
    mlp1 = self.mlp1(x);  x = None
    modifiy = self.modifiy(mlp1);  mlp1 = None
    return modifiy
    


In [47]:
interp = utils.ProfilingInterpreter(traced)
interp.run(torch.ones((4096,100)))
print(interp.summary(True))

total time: 2.330780029296875 ms
Op type      Op         Average runtime (s)    Pct total runtime
-----------  -------  ---------------------  -------------------
call_module  mlp1               0.00124717              53.5086
call_module  modifiy            0.000793695             34.0528
placeholder  x                  4.43459e-05              1.90262
output       output             3.38554e-05              1.45254


In [49]:
modify_mod = GraphModule(traced, traced.graph)
torch.onnx.export(modify_mod,torch.ones((4096,100)),f'modify_mod.onnx')