In [2]:
import torch.nn as nn
import torch
from onnx_helper import export_model

In [None]:
from modules.fir_cnn_module import FIRCNNModule

def export_fir():
    P = 256
    M = 16
    weights = torch.zeros(P * 2, 1, 1, M)
    module = FIRCNNModule(P, M, weights)
    example_inputs = torch.zeros(5, 2*P, 1, M)
    # example_inputs[1][-1] = 2
    out = module(example_inputs)
    print("Output:", out.squeeze())
    

    torch.onnx.export(
        module,
        example_inputs,
        "FIR.onnx",
        input_names=["input"],
        output_names=["output"],
        do_constant_folding=True,
        # opset_version=15,
    )

export_fir()

Output: tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])


In [17]:
from modules.dft_cnn_module import DFTCNNModule

def export_dft():
    P = 256
    module = DFTCNNModule(P)
    example_inputs = torch.zeros(5, 2*P, 1, 1)
    out = module(example_inputs)
    print("Output:", out.squeeze())


    torch.onnx.export(
        module,
        example_inputs,
        "DFT.onnx",
        input_names=["input"],
        output_names=["output"],
        do_constant_folding=True,
    )

export_dft()


Output: tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])


In [19]:
from modules.fft_cnn_module import FFTCNNModule

def export_fft():
    P = 8
    module = FFTCNNModule(P)
    example_inputs = torch.zeros(5, 2*P, 1, 1)
    out = module(example_inputs)
    print("Output:", out.squeeze())


    torch.onnx.export(
        module,
        example_inputs,
        "FFT.onnx",
        input_names=["input"],
        output_names=["output"],
        do_constant_folding=True,
    )

export_fft()


Output: tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])


In [12]:
def export_test_module():

    class ExampleModule(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc = nn.Linear(4, 4)

        def forward(self, x, y):
            # input_channel_slices = list(torch.split(x, 1, dim=1))
            x = torch.add(x, y)
            return x

    module = ExampleModule()
    example_inputs = (torch.zeros(5, 1, 1, 1), torch.zeros(1, 1, 1))
    example_inputs[1][-1] = 2
    out = module(*example_inputs)
    print("Output:", out.squeeze())
    

    torch.onnx.export(
        module,
        example_inputs,
        "fc.onnx",
        input_names=["input", "input2"],
        output_names=["output"],
        do_constant_folding=True,
        # opset_version=15,
    )

print("Exporting test module...")
export_test_module()

Exporting test module...
Output: tensor([2., 2., 2., 2., 2.])
