In [2]:
import torch.nn as nn
import torch
from onnx_helper import export_model

In [53]:

def c_weights(P, M):
    from modules.fir_helper import ref_kaiser_weights
    # Initialize FIR filter with predefined weights
    weights = ref_kaiser_weights(P, M, reversed=False)
    # Duplicate weights for complex channels, both channels have the same FIR filter
    # so we can duplicate the weights for both real and imaginary parts.
    # print("Original weights shape:", weights.shape)
    # print(weights)
    weights = weights.reshape(P, 1, M, 1)  # Reshape to (P, 1, M, 1)
    # print("Reshaped weights shape:", weights.shape)
    # print(weights)
    weights = [weights[i // 2] for i in range(0, 2 * P)]
    weights = torch.stack(weights, dim=0)
    
    # print("Weights shape:", weights.shape)
    # print("Weights:", weights.squeeze())
    
    return weights

# old:
# weights = torch.zeros(P * 2, 1, 1, M)
# new:
# weights = torch.zeros(P * 2, 1, M, 1)
# weights = torch.zeros(12, 1, 4, 1)

# c_weights(6, 4)


In [55]:

def export_fir():
    from modules.fir_cnn_module import FIRCNNModule
    B = 1
    P = 256
    M = 16
    # weights = torch.zeros(P * 2, 1, M, 1)
    weights = c_weights(P, M)
    module = FIRCNNModule(P, M, weights)
    example_inputs = torch.zeros(B, 2*P, M, 1)
    # example_inputs[1][-1] = 2
    out = module(example_inputs)
    print(example_inputs.shape)
    print("Output:", out.squeeze())
    print(out.shape)
    

    torch.onnx.export(
        module,
        example_inputs,
        "FIR.onnx",
        input_names=["input"],
        output_names=["output"],
        do_constant_folding=True,
        # opset_version=15,
    )

export_fir()


torch.Size([1, 512, 16, 1])
Output: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 

In [17]:
from modules.dft_cnn_module import DFTCNNModule

def export_dft():
    P = 256
    module = DFTCNNModule(P)
    example_inputs = torch.zeros(5, 2*P, 1, 1)
    out = module(example_inputs)
    print("Output:", out.squeeze())


    torch.onnx.export(
        module,
        example_inputs,
        "DFT.onnx",
        input_names=["input"],
        output_names=["output"],
        do_constant_folding=True,
    )

export_dft()


Output: tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])


In [19]:
from modules.fft_cnn_module import FFTCNNModule

def export_fft():
    P = 8
    module = FFTCNNModule(P)
    example_inputs = torch.zeros(5, 2*P, 1, 1)
    out = module(example_inputs)
    print("Output:", out.squeeze())


    torch.onnx.export(
        module,
        example_inputs,
        "FFT.onnx",
        input_names=["input"],
        output_names=["output"],
        do_constant_folding=True,
    )

export_fft()


Output: tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])


In [12]:
def export_test_module():

    class ExampleModule(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc = nn.Linear(4, 4)

        def forward(self, x, y):
            # input_channel_slices = list(torch.split(x, 1, dim=1))
            x = torch.add(x, y)
            return x

    module = ExampleModule()
    example_inputs = (torch.zeros(5, 1, 1, 1), torch.zeros(1, 1, 1))
    example_inputs[1][-1] = 2
    out = module(*example_inputs)
    print("Output:", out.squeeze())
    

    torch.onnx.export(
        module,
        example_inputs,
        "fc.onnx",
        input_names=["input", "input2"],
        output_names=["output"],
        do_constant_folding=True,
        # opset_version=15,
    )

print("Exporting test module...")
export_test_module()

Exporting test module...
Output: tensor([2., 2., 2., 2., 2.])
