# frozen graph转化成.tflite

In [1]:
import tensorflow as tf

graph_def_file = "./model/fp32_frozen_graph.pb"
input_arrays = ["input_1"]
output_arrays = ['Logits/Softmax']

converter = tf.lite.TFLiteConverter.from_frozen_graph(
  graph_def_file, input_arrays, output_arrays)
tflite_model = converter.convert()
open("./model/converted_model.tflite", "wb").write(tflite_model)

13974588

In [2]:
def tfLiteInference(input_details, interpreter, output_details, x):
    interpreter.set_tensor(input_details[0]['index'], x)

    interpreter.invoke()

    # The function `get_tensor()` returns a copy of the tensor data.
    # Use `tensor()` in order to get a pointer to the tensor.
    output_data = interpreter.get_tensor(output_details[0]['index'])
    
    return output_data

# 读取一张图片用于预测

In [3]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input, decode_predictions

# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="./model/converted_model.tflite")
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Optional image to test model prediction.
img_path = './elephant.jpg'

image_size = [224, 224, 3]
img = image.load_img(img_path, target_size=image_size[:2])
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)

# 预测FP32的.tflite

In [4]:
# test FP32 CPU
import time
times = []

output_data = tfLiteInference(input_details, interpreter, output_details, x)

# decode the results into a list of tuples (class, description, probability)
# (one such list for each sample in the batch)
print('Predicted:', decode_predictions(output_data, top=3)[0])

for i in range(2000):
    start_time = time.time()
    output_data = tfLiteInference(input_details, interpreter, output_details, x)
    delta = (time.time() - start_time)
    times.append(delta)
mean_delta = np.array(times).mean()
fps = 1 / mean_delta
print('average(sec):{:.2f},fps:{:.2f}'.format(mean_delta, fps))

Predicted: [('n02504458', 'African_elephant', 0.5241098), ('n01871265', 'tusker', 0.17159039), ('n02504013', 'Indian_elephant', 0.15654095)]
average(sec):0.02,fps:46.53
