In [4]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf

## Create a Basic Model of the Form y = mx + c

In [25]:
# Create a simple Keras model.
x = np.arange(1,11)*2 - 2*5
y = x*2 - 1 + np.random.randn(10)*0.1

print(x, y)

[-8 -6 -4 -2  0  2  4  6  8 10] [-17.04088579 -12.95622548  -8.9484593   -4.85814225  -0.8697331
   2.84736949   7.06092187  11.06872737  14.99017446  19.16439948]


In [26]:
dataset = tf.data.Dataset.from_tensor_slices(([[val] for val in x], y))

In [27]:
dataset = dataset.batch(2)

In [28]:
[*dataset.take(1)]

[(<tf.Tensor: shape=(2, 1), dtype=int32, numpy=
  array([[-8],
         [-6]], dtype=int32)>,
  <tf.Tensor: shape=(2,), dtype=float64, numpy=array([-17.04088579, -12.95622548])>)]

In [31]:
model = tf.keras.models.Sequential([
        tf.keras.layers.Dense(units=1, input_shape=[1]),
])

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01),
              loss='mean_absolute_error')

model.fit(dataset, epochs=500, callbacks=[tf.keras.callbacks.EarlyStopping(patience=5, monitor='loss')])

Epoch 1/500
Epoch 2/500
Epoch 3/500
Epoch 4/500
Epoch 5/500
Epoch 6/500
Epoch 7/500
Epoch 8/500
Epoch 9/500
Epoch 10/500
Epoch 11/500
Epoch 12/500
Epoch 13/500
Epoch 14/500
Epoch 15/500
Epoch 16/500
Epoch 17/500
Epoch 18/500
Epoch 19/500
Epoch 20/500
Epoch 21/500
Epoch 22/500
Epoch 23/500
Epoch 24/500
Epoch 25/500
Epoch 26/500
Epoch 27/500
Epoch 28/500
Epoch 29/500
Epoch 30/500
Epoch 31/500
Epoch 32/500
Epoch 33/500
Epoch 34/500
Epoch 35/500
Epoch 36/500
Epoch 37/500
Epoch 38/500
Epoch 39/500
Epoch 40/500
Epoch 41/500
Epoch 42/500
Epoch 43/500
Epoch 44/500
Epoch 45/500
Epoch 46/500
Epoch 47/500
Epoch 48/500
Epoch 49/500
Epoch 50/500
Epoch 51/500
Epoch 52/500
Epoch 53/500
Epoch 54/500
Epoch 55/500


<keras.callbacks.History at 0x7ff08028f160>

## Generate a SavedModel

In [33]:
output_path = './model/model_v1'
Path(output_path).mkdir(exist_ok=True, parents=True)

In [34]:
model.save(output_path)

INFO:tensorflow:Assets written to: ./model/model_v1/assets


## Convert the SavedModel to TFLite

In [35]:
# Convert the model.
converter = tf.lite.TFLiteConverter.from_saved_model(output_path)
tflite_model = converter.convert()

In [36]:
tflite_model_file = Path('/model/model_v1.tflite')
tflite_model_file.write_bytes(tflite_model)

916

## Initialize the TFLite Interpreter To Try It Out

[{'name': 'serving_default_dense_8_input:0',
  'index': 0,
  'shape': array([1, 1], dtype=int32),
  'shape_signature': array([-1,  1], dtype=int32),
  'dtype': numpy.float32,
  'quantization': (0.0, 0),
  'quantization_parameters': {'scales': array([], dtype=float32),
   'zero_points': array([], dtype=int32),
   'quantized_dimension': 0},
  'sparsity_parameters': {}}]

In [37]:
# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

In [44]:
input_details

[{'name': 'serving_default_dense_8_input:0',
  'index': 0,
  'shape': array([1, 1], dtype=int32),
  'shape_signature': array([-1,  1], dtype=int32),
  'dtype': numpy.float32,
  'quantization': (0.0, 0),
  'quantization_parameters': {'scales': array([], dtype=float32),
   'zero_points': array([], dtype=int32),
   'quantized_dimension': 0},
  'sparsity_parameters': {}}]

In [45]:
output_details

[{'name': 'StatefulPartitionedCall:0',
  'index': 3,
  'shape': array([1, 1], dtype=int32),
  'shape_signature': array([-1,  1], dtype=int32),
  'dtype': numpy.float32,
  'quantization': (0.0, 0),
  'quantization_parameters': {'scales': array([], dtype=float32),
   'zero_points': array([], dtype=int32),
   'quantized_dimension': 0},
  'sparsity_parameters': {}}]

In [48]:
test_in = np.array([[5.]], dtype=np.float32)
in_index = input_details[0]['index']
out_index = output_details[0]['index']


interpreter.set_tensor(in_index, test_in)
interpreter.invoke()
test_out = interpreter.get_tensor(out_index)
print(test_in, test_out)


[[5.]] [[9.057872]]
