In [2]:
import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return torch.topk(torch.sum(
            self.linear(x + self.linear.weight).relu(), dim=-1), 3)

m = MyModule()
gm = torch.fx.symbolic_trace(m)

gm.graph.print_tabular()

opcode         name           target                                                args                kwargs
-------------  -------------  ----------------------------------------------------  ------------------  -----------
placeholder    x              x                                                     ()                  {}
get_attr       linear_weight  linear.weight                                         ()                  {}
call_function  add            <built-in function add>                               (x, linear_weight)  {}
call_module    linear         linear                                                (add,)              {}
call_method    relu           relu                                                  (linear,)           {}
call_function  sum_1          <built-in method sum of type object at 0x10bd77788>   (relu,)             {'dim': -1}
call_function  topk           <built-in method topk of type object at 0x10bd77788>  (sum_1, 3)          {}
output         

In [5]:
img = torch.tensor(torch.arange(3 * 9 * 9).reshape(3,9,9),requires_grad=True,dtype=torch.float32)
ker = torch.tensor(torch.arange(3 * 3 * 3).reshape(1,3,3,3),requires_grad=True,dtype=torch.float32)
fet = torch.conv2d(img,ker,stride=2,dilation=1)
x = fet.sum(0).sum(0).sum(0)
# gm = torch.fx.symbolic_trace(x)
# gm

  img = torch.tensor(torch.arange(3 * 9 * 9).reshape(3,9,9),requires_grad=True,dtype=torch.float32)
  ker = torch.tensor(torch.arange(3 * 3 * 3).reshape(1,3,3,3),requires_grad=True,dtype=torch.float32)


In [7]:
class MySingleConvolution(torch.nn.Module):
    def __init__(self):
        super().__init__()
        ker = torch.tensor(torch.arange(3 * 3 * 3).reshape(1,3,3,3),requires_grad=True,dtype=torch.float32)
        self.ker = torch.nn.Parameter(ker)

    def forward(self, img):
        # img = torch.tensor(torch.arange(3 * 9 * 9).reshape(3,9,9),requires_grad=True,dtype=torch.float32)
        fet = torch.conv2d(img,self.ker,stride=2,dilation=1)
        return fet

my_conv = MySingleConvolution()

from torch.fx import symbolic_trace
trace = symbolic_trace(my_conv)

  ker = torch.tensor(torch.arange(3 * 3 * 3).reshape(1,3,3,3),requires_grad=True,dtype=torch.float32)


In [10]:
print(trace.graph)

graph():
    %img : [num_users=1] = placeholder[target=img]
    %ker : [num_users=1] = get_attr[target=ker]
    %conv2d : [num_users=1] = call_function[target=torch.conv2d](args = (%img, %ker), kwargs = {stride: 2, dilation: 1})
    return conv2d


In [44]:
import torch
class MySingleConvolution(torch.nn.Module):
    def __init__(self):
        super().__init__()
        ker = torch.tensor(torch.arange(4 * 3 * 3 * 3).reshape(4,3,3,3),dtype=torch.float32)
        self.ker = torch.nn.Parameter(ker)

    def forward(self, img):
        # img = torch.tensor(torch.arange(3 * 9 * 9).reshape(3,9,9),requires_grad=True,dtype=torch.float32)
        fet = torch.conv2d(img,self.ker,stride=1,dilation=1)
        return fet

my_conv = MySingleConvolution()

dummy_input = torch.tensor(torch.arange(3 * 9 * 9).reshape(3,9,9),requires_grad=False,dtype=torch.float32)
my_conv(dummy_input)

input_names = ['image_batch', 'kernel_weights']
output_names = ['features_batch']
torch.onnx.export(my_conv, dummy_input, 'MySingleConv.onnx', verbose=False, input_names=input_names, output_names=output_names)

  ker = torch.tensor(torch.arange(4 * 3 * 3 * 3).reshape(4,3,3,3),dtype=torch.float32)
  dummy_input = torch.tensor(torch.arange(3 * 9 * 9).reshape(3,9,9),requires_grad=False,dtype=torch.float32)


In [45]:
dummy_input = torch.tensor(torch.arange(3 * 9 * 9).reshape(3,9,9),requires_grad=True,dtype=torch.float32)
my_conv = MySingleConvolution()
my_conv(dummy_input)

s = torch.onnx.export_to_pretty_string(my_conv, dummy_input,keep_initializers_as_inputs=True,export_params=True, verbose=True, input_names=input_names, output_names=output_names)
print(s)

ModelProto {
  producer_name: "pytorch"
  domain: ""
  doc_string: ""
  graph:
    GraphProto {
      name: "main_graph"
      inputs: [{name: "image_batch", type:Tensor dtype: 1, Tensor dims: 3 9 9},{name: "kernel_weights", type:Tensor dtype: 1, Tensor dims: 4 3 3 3}]
      outputs: [{name: "features_batch", type:Tensor dtype: 1, Tensor dims: 4 7 7}]
      value_infos: []
      initializers: [TensorProto shape: [4 3 3 3]]
      nodes: [
        Node {type: "Constant", inputs: [], outputs: [onnx::Unsqueeze_2], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: [1]}]},
        Node {type: "Unsqueeze", inputs: [image_batch,onnx::Unsqueeze_2], outputs: [onnx::Conv_3], attributes: []},
        Node {type: "Conv", inputs: [onnx::Conv_3,kernel_weights], outputs: [onnx::Squeeze_4], attributes: [{ name: 'dilations', type: ints, values: [1 1]},{ name: 'group', type: int, value: 1},{ name: 'kernel_shape', type: ints, values: [3 3]},{ name: 'pads', type: ints, values: [0 0 0 0]}

  dummy_input = torch.tensor(torch.arange(3 * 9 * 9).reshape(3,9,9),requires_grad=True,dtype=torch.float32)
  ker = torch.tensor(torch.arange(4 * 3 * 3 * 3).reshape(4,3,3,3),dtype=torch.float32)


In [46]:
import onnx

# Load the ONNX model
model = onnx.load('MySingleConv.onnx')

# Check that the model is well formed
onnx.checker.check_model(model)

# Print a human readable representation of the graph
print(onnx.helper.printable_graph(model.graph))

graph main_graph (
  %image_batch[FLOAT, 3x9x9]
) initializers (
  %kernel_weights[FLOAT, 4x3x3x3]
) {
  %/Constant_output_0 = Constant[value = <Tensor>]()
  %/Unsqueeze_output_0 = Unsqueeze(%image_batch, %/Constant_output_0)
  %/Conv_output_0 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [0, 0, 0, 0], strides = [1, 1]](%/Unsqueeze_output_0, %kernel_weights)
  %/Constant_1_output_0 = Constant[value = <Tensor>]()
  %features_batch = Squeeze(%/Conv_output_0, %/Constant_1_output_0)
  return %features_batch
}
