In [123]:
import os 
import random

import tensorflow as tf
import numpy as np
from PIL import Image

In [135]:
model_version = '7seg2912'

MODEL_FILE_PATH = f'../models/{model_version}.keras'
TFLITE_FILE_PATH = f'../models/{model_version}.tflite'
TFLITE_Q_FILE_PATH = f'../models/{model_version}q.tflite'

In [136]:
def get_image() -> tuple:
    digits_dir = '../training/digits_resized'
    
    image_path = random.choice(os.listdir(digits_dir))
    image_in = Image.open(os.path.join(digits_dir, image_path))
    test_image = np.array(image_in, dtype="float32")
    img = np.reshape(test_image, [1, 32, 20, 3])
    
    correct_digit = int(image_path.split('_')[0])
    
    return correct_digit, img

# Standard model

In [137]:
model = tf.keras.models.load_model(MODEL_FILE_PATH)

In [138]:
_, img = get_image()
standard_times = %timeit -r 10 -n 50 -o model.predict(img, verbose=0)

41.9 ms ± 1.84 ms per loop (mean ± std. dev. of 10 runs, 50 loops each)


In [139]:
standard_times

<TimeitResult : 41.9 ms ± 1.84 ms per loop (mean ± std. dev. of 10 runs, 50 loops each)>

# TFLite model

In [140]:
# Load the TFLite model in TFLite Interpreter
interpreter = tf.lite.Interpreter(TFLITE_FILE_PATH)
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

In [141]:
correct_digit, img = get_image()
interpreter.set_tensor(input_details[0]['index'], img)

tflite_times = %timeit -r 10 -n 1000 -o interpreter.invoke()

86.2 µs ± 11.6 µs per loop (mean ± std. dev. of 10 runs, 1,000 loops each)


In [142]:
# The function `get_tensor()` returns a copy of the tensor data.
# Use `tensor()` in order to get a pointer to the tensor.
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
print(output_data.argmax())
print(correct_digit)

[[1.0000000e+00 1.6525056e-11 1.9402249e-22 1.4331142e-19 4.5395783e-27
  5.8753799e-23 1.1936453e-18 2.3520623e-09 5.3731721e-13 5.8168747e-19
  1.5427888e-19]]
0
0


# TLITE quantized model

In [143]:
# Load the TFLite model in TFLite Interpreter
interpreter = tf.lite.Interpreter(TFLITE_Q_FILE_PATH)
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

In [144]:
correct_digit, img = get_image()
interpreter.set_tensor(input_details[0]['index'], img)

tflite_q_times = %timeit -r 10 -n 1000 interpreter.invoke()

79.8 µs ± 21.8 µs per loop (mean ± std. dev. of 10 runs, 1,000 loops each)


In [145]:
# The function `get_tensor()` returns a copy of the tensor data.
# Use `tensor()` in order to get a pointer to the tensor.
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
print(output_data.argmax())
print(correct_digit)

[[0.         0.         0.99609375 0.         0.         0.
  0.         0.         0.         0.         0.        ]]
2
2
