In [15]:
import tensorflow as tf
import numpy as np

## Save custom model

In [91]:
DTYPE = tf.float32 # yeah, tf lite supports only 32!

@tf.function
def quick_integrand(xarr):
    """Le page test function"""
    n_dim = 1
    a = tf.constant(0.1, dtype=DTYPE)
    n100 = tf.cast(100*n_dim, dtype=DTYPE)
    pref = tf.pow(1.0/a/np.sqrt(np.pi), n_dim)
    coef = tf.reduce_sum(tf.range(n100+1))
    coef +=  tf.reduce_sum(tf.square( (xarr-1.0/2.0)/a ), axis=1)
    coef -= (n100+1)*n100/2.0
    return pref*tf.exp(-coef)

In [113]:
xx = np.array([[0.5]], dtype=np.float32)
r1 = quick_integrand(np.array(xx, dtype=np.float32)).numpy()
print(r1)

[5.641896]


## TensorFlow lite conversion

In [116]:
concrete_func = quick_integrand.get_concrete_function(tf.TensorSpec([None,1], DTYPE))
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
tflite_model = converter.convert()
with open('model.tflite', 'wb') as f:
    f.write(tflite_model)

## TensorFlow lite testing

In [117]:
import tflite_runtime.interpreter as tflite

In [118]:
interpreter = tf.lite.Interpreter(model_path='model.tflite')

In [119]:
interpreter.allocate_tensors()

In [120]:
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

In [121]:
interpreter.set_tensor(input_details[0]['index'], xx)

In [122]:
interpreter.invoke()

In [123]:
output_data = interpreter.get_tensor(output_details[0]['index'])

In [124]:
output_data

array([5.641896], dtype=float32)

In [125]:
assert r1 == output_data