# ONNX implementation of FITS

In [2]:
from onnx import TensorProto, numpy_helper
from onnx.helper import make_node, make_graph, make_model, make_tensor_value_info
from onnx.checker import check_model
import numpy as np
import torch
from onnx.reference import ReferenceEvaluator

In [3]:
weights_real=np.ones((2,4)).astype(np.float32)
weights_imag=np.ones((2,4)).astype(np.float32)



WR=numpy_helper.from_array(weights_real, name = 'WR')
WI=numpy_helper.from_array(weights_imag, name = 'WI')

In [4]:
weights_imag

array([[1., 1., 1., 1.],
       [1., 1., 1., 1.]], dtype=float32)

## prepare variables

In [5]:
X = make_tensor_value_info('X', TensorProto.FLOAT, [None, None])
# DSR = make_tensor_value_info('DSR', TensorProto.FLOAT, [None, None])
# nomean= make_tensor_value_info('x_nomean', TensorProto.FLOAT, [None, None])
# var= make_tensor_value_info('x_var', TensorProto.FLOAT, [None, None])
x_normed = make_tensor_value_info('x_normed', TensorProto.FLOAT, [None, None])
x_rfft = make_tensor_value_info('x_rfft', TensorProto.FLOAT, [None, None])
x_rfft_imag = make_tensor_value_info('x_rfft_lpf_imag_perm', TensorProto.FLOAT, [None, None])
x_rfft_real = make_tensor_value_info('x_rfft_lpf_real_perm', TensorProto.FLOAT, [None, None])

out_R = make_tensor_value_info('R_perm', TensorProto.FLOAT, [None, None])
out_I = make_tensor_value_info('I_perm', TensorProto.FLOAT, [None, None])
RR = make_tensor_value_info('RR', TensorProto.FLOAT, [None, None])
WWR = make_tensor_value_info('WR', TensorProto.FLOAT, [None, None])
II = make_tensor_value_info('II', TensorProto.FLOAT, [None, None])
WWI = make_tensor_value_info('WI', TensorProto.FLOAT, [None, None])

out_rfft = make_tensor_value_info('out_rfft', TensorProto.FLOAT, [None, None])
out = make_tensor_value_info('out', TensorProto.FLOAT, [None, None])


L_axis = numpy_helper.from_array(np.array([1]), name='L_axis')
LPF_start = numpy_helper.from_array(np.array([2]), name='LPF_start')
LPF_end = numpy_helper.from_array(np.array([4]), name='LPF_end')
# Perm_index = numpy_helper.from_array(np.array([0,2,1]), name='Perm_index')
zero = numpy_helper.from_array(np.array([0]), name='zero')
pad_Setting = numpy_helper.from_array(np.array([0,4,0,0,3,0]), name='pad_Setting')

neg_one = numpy_helper.from_array(np.array([-1]), name='neg_one')
neg_two = numpy_helper.from_array(np.array([-2]), name='neg_two')
one = numpy_helper.from_array(np.array([1]), name='one')

DSR = numpy_helper.from_array(np.array([3]).astype(np.float32), name='DSR')


## Define the calculation graph

In [6]:
# RIN
# calculate the mean of the input tensor and remove the mean from the input with onnx
Get_mean=make_node('ReduceMean', ['X', 'L_axis'], ['xmean'], keepdims=1)
Get_nomean=make_node('Sub', ['X', 'xmean'], ['x_nomean']) # X_nomean = X - Xmean
# calculate the sqrt(E(X-EX)^2) with onnx
Get_nomean_square=make_node('Mul', ['x_nomean', 'x_nomean'], ['x_nomean_square'])
Get_nomean_square_mean=make_node('ReduceMean', ['x_nomean_square', 'L_axis'], ['x_nomean_square_sum'], keepdims=1)
Get_sqrt=make_node('Sqrt', ['x_nomean_square_sum'], ['x_var'])
# calculate X_nomean/Xsqrt with onnx
Get_x_normed=make_node('Div', ['x_nomean', 'x_var'], ['x_normed'])

# LPF
Get_rfft = make_node('DFT', ['x_normed'], ['x_rfft'], axis=1, onesided=1, inverse=0)
Get_LPF = make_node('Slice', ['x_rfft', 'LPF_start', 'LPF_end', 'L_axis'], ['x_rfft_lpf'])
Get_Split_Complex = make_node('Split', ['x_rfft_lpf'], ['x_rfft_lpf_real', 'x_rfft_lpf_imag'], axis=-1, num_outputs=2)

# Complex Linear
Get_PermuteI = make_node('Transpose', ['x_rfft_lpf_imag'], ['x_rfft_lpf_imag_perm'], perm=[0,2,1])
Get_PermuteR = make_node('Transpose', ['x_rfft_lpf_real'], ['x_rfft_lpf_real_perm'], perm=[0,2,1])

Get_RR = make_node('MatMul', ['x_rfft_lpf_real_perm', 'WR'], ['RR'])
Get_RI = make_node('MatMul', ['x_rfft_lpf_real_perm', 'WI'], ['RI'])
Get_IR = make_node('MatMul', ['x_rfft_lpf_imag_perm', 'WR'], ['IR'])
Get_II = make_node('MatMul', ['x_rfft_lpf_imag_perm', 'WI'], ['II'])

Get_R = make_node('Sub', ['RR', 'II'], ['R'])
Get_I = make_node('Add', ['RI', 'IR'], ['I'])

Get_inverse_PermuteI = make_node('Transpose', ['I'], ['I_perm'], perm=[0,2,1])
Get_inverse_PermuteR = make_node('Transpose', ['R'], ['R_perm'], perm=[0,2,1])

# Amplitude compensate
Get_compensate_R = make_node('Mul', ['R_perm', 'DSR'], ['R_compensate'])
Get_compensate_I = make_node('Mul', ['I_perm', 'DSR'], ['I_compensate'])

# Zero padding
Get_R_pad = make_node('Pad', ['R_compensate', 'pad_Setting', 'zero'], ['R_pad'], mode='constant')
Get_I_pad = make_node('Pad', ['I_compensate', 'pad_Setting', 'zero'], ['I_pad'], mode='constant')

# Prepare rfft
Get_reverse_real = make_node('Slice', ['R_pad', 'neg_two', 'zero', 'one', 'neg_one'], ['R_rev'])
Get_reverse_imag = make_node('Slice', ['I_pad', 'neg_two', 'zero', 'one', 'neg_one'], ['I_rev'])

Get_neg = make_node('Neg', ['I_rev'], ['I_rev_neg']) # self conjugate
Get_rev_Concat = make_node('Concat', ['R_rev', 'I_rev_neg'], ['x_rfft_rev'], axis=-1)
Get_Concat = make_node('Concat', ['R_pad', 'I_pad'], ['x_rfft_forward'], axis=-1)

Get_fin_Concat = make_node('Concat', ['x_rfft_forward', 'x_rfft_rev'], ['out_rfft'], axis=1) # 0-freq, forward spec, reverse spec

# Inverse rfft, only accept full complex spectrum
Get_inverse_rfft = make_node('DFT', ['out_rfft'], ['out_complex'], axis=1, inverse=1) 
Get_Temporal = make_node('Slice', ['out_complex', 'zero', 'one', 'neg_one'], ['out_temporal']) # only keep temporal domain

# Rev RIN
Get_out_nomean=make_node('Mul', ['out_temporal', 'x_var'], ['out_nomean']) 
Get_out=make_node('Add', ['out_nomean', 'xmean'], ['out'])


In [7]:
# make the graph (GraphProto)
graph = make_graph(
    [Get_mean, Get_nomean, Get_nomean_square, Get_nomean_square_mean, Get_sqrt, Get_x_normed, Get_rfft, Get_LPF, Get_Split_Complex, Get_PermuteI, Get_PermuteR,
     Get_RR, Get_RI, Get_IR, Get_II, Get_R, Get_I, Get_inverse_PermuteI, Get_inverse_PermuteR, Get_compensate_I, Get_compensate_R ,Get_R_pad, Get_I_pad,
     Get_reverse_real, Get_reverse_imag, Get_neg, Get_rev_Concat, Get_Concat, Get_fin_Concat, Get_inverse_rfft, Get_Temporal,
     Get_out_nomean, Get_out],
    'test-model',
    [X],
    [x_rfft, x_rfft_imag, x_rfft_real, out_R, out_I, out_rfft, out],
    [L_axis,LPF_start,LPF_end,WR,WI, pad_Setting, zero, DSR, neg_one, neg_two, one]
)

onnx_model = make_model(graph, producer_name='onnx-example')

check_model(onnx_model)

## Run the model

In [8]:
sees = ReferenceEvaluator(onnx_model)
x=np.random.rand(2,10,1).astype(np.float32)
xx=sees.run(None,{'X':x})
print(xx[0].shape)

(2, 6, 2)


## Serialize and save the model as onnx file

In [9]:
# save the onnx model with example input
with open("test.onnx", "wb") as f:
    f.write(onnx_model.SerializeToString())
    

## check the input and output

In [10]:
realin=xx[2]
realin # real part of the projection layer input

array([[[-2.113672  ,  3.2053206 ]],

       [[-0.00395393,  3.129188  ]]], dtype=float32)

In [11]:
imagin=xx[1]
imagin # imag part of the projection layer input

array([[[-4.205752 ,  2.6030076]],

       [[-2.4222276, -4.660607 ]]], dtype=float32)

In [12]:
x # input

array([[[0.3036429 ],
        [0.60120684],
        [0.48510543],
        [0.75284505],
        [0.23437363],
        [0.26324   ],
        [0.96830004],
        [0.81264067],
        [0.17516446],
        [0.42030957]],

       [[0.7057335 ],
        [0.96225375],
        [0.03172521],
        [0.28581664],
        [0.5284767 ],
        [0.15667306],
        [0.5521558 ],
        [0.9609049 ],
        [0.6574765 ],
        [0.2306728 ]]], dtype=float32)

In [13]:
xx[-1] # output

array([[[ 1.337628  ],
        [ 0.52596164],
        [-0.14773268],
        [ 0.6264461 ],
        [ 0.6940575 ],
        [ 0.5016829 ],
        [ 0.6142318 ],
        [ 0.32213527],
        [ 0.39940047],
        [ 0.6208314 ],
        [ 0.5016829 ],
        [ 0.62821245],
        [ 0.47480467],
        [ 0.28371495],
        [ 0.5676295 ],
        [ 0.5016829 ],
        [ 0.64745516],
        [ 0.77443516],
        [-0.07232857],
        [ 0.23172596]],

       [[ 4.2864046 ],
        [ 1.3112125 ],
        [-2.606667  ],
        [ 0.72204316],
        [ 1.4868536 ],
        [ 0.5071889 ],
        [ 1.1259707 ],
        [-0.21387142],
        [-0.13313806],
        [ 1.0284302 ],
        [ 0.5071889 ],
        [ 1.0966313 ],
        [ 0.56359494],
        [-0.5688746 ],
        [ 0.695366  ],
        [ 0.5071889 ],
        [ 1.056249  ],
        [ 2.0894585 ],
        [-1.9099339 ],
        [-1.4075183 ]]], dtype=float32)

In [14]:
xx[0] # rfft of the input

array([[[-7.4505806e-07,  0.0000000e+00],
        [-1.4925214e+00,  3.3711782e-01],
        [-2.1136720e+00, -4.2057519e+00],
        [ 3.2053206e+00,  2.6030076e+00],
        [-2.1069021e+00, -7.1971887e-01],
        [-2.6442459e+00,  0.0000000e+00]],

       [[-1.0281801e-06,  0.0000000e+00],
        [ 1.5157268e+00,  2.6613832e+00],
        [-3.9539281e-03, -2.4222276e+00],
        [ 3.1291881e+00, -4.6606069e+00],
        [-1.2275178e+00, -1.2868351e+00],
        [-3.9140186e-01,  4.4408921e-16]]], dtype=float32)

In [16]:
xx[-3] # RR RI

array([[[-0.51109576],
        [-0.51109576],
        [-0.51109576],
        [-0.51109576]],

       [[-3.9576    ],
        [-3.9576    ],
        [-3.9576    ],
        [-3.9576    ]]], dtype=float32)

In [17]:
xx[-1]

array([[[ 1.337628  ],
        [ 0.52596164],
        [-0.14773268],
        [ 0.6264461 ],
        [ 0.6940575 ],
        [ 0.5016829 ],
        [ 0.6142318 ],
        [ 0.32213527],
        [ 0.39940047],
        [ 0.6208314 ],
        [ 0.5016829 ],
        [ 0.62821245],
        [ 0.47480467],
        [ 0.28371495],
        [ 0.5676295 ],
        [ 0.5016829 ],
        [ 0.64745516],
        [ 0.77443516],
        [-0.07232857],
        [ 0.23172596]],

       [[ 4.2864046 ],
        [ 1.3112125 ],
        [-2.606667  ],
        [ 0.72204316],
        [ 1.4868536 ],
        [ 0.5071889 ],
        [ 1.1259707 ],
        [-0.21387142],
        [-0.13313806],
        [ 1.0284302 ],
        [ 0.5071889 ],
        [ 1.0966313 ],
        [ 0.56359494],
        [-0.5688746 ],
        [ 0.695366  ],
        [ 0.5071889 ],
        [ 1.056249  ],
        [ 2.0894585 ],
        [-1.9099339 ],
        [-1.4075183 ]]], dtype=float32)

In [18]:
xx[4]

array([[[-0.51109576],
        [-0.51109576],
        [-0.51109576],
        [-0.51109576]],

       [[-3.9576    ],
        [-3.9576    ],
        [-3.9576    ],
        [-3.9576    ]]], dtype=float32)

In [19]:
# permute the 1 and 2 axis of xx[1]
outr=xx[3].transpose(0,2,1)
outi=xx[2].transpose(0,2,1) # imag
outR= np.matmul(outr,weights_real)-np.matmul(outi,weights_imag)

ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 2 is different from 4)

In [None]:
outR

array([[[-2.1290371, -2.1290371, -2.1290371, -2.1290371]],

       [[-4.4575706, -4.4575706, -4.4575706, -4.4575706]]], dtype=float32)

In [None]:
outI = np.matmul(outr,weights_imag)+ np.matmul(outi,weights_real)
outI

array([[[-4.0712576, -4.0712576, -4.0712576, -4.0712576]],

       [[-2.6118875, -2.6118875, -2.6118875, -2.6118875]]], dtype=float32)

In [None]:
xx[1]

array([[[ 8.5867941e-07,  0.0000000e+00],
        [-5.2885100e-02,  4.9447557e-01],
        [ 1.1048088e+00,  1.6531996e+00],
        [-2.1126856e-01, -1.6675388e+00],
        [-2.1087587e+00,  5.7152481e+00],
        [-3.4242604e+00,  1.1102230e-16]],

       [[ 8.5681677e-08,  0.0000000e+00],
        [-3.2358274e+00, -9.8743612e-01],
        [ 4.3768554e+00,  2.5047356e-01],
        [ 3.4072566e+00, -2.0510821e+00],
        [ 1.6312933e+00, -4.0278584e-01],
        [ 1.1790700e+00,  2.2204460e-16]]], dtype=float32)