In [None]:
import numpy as np
import tensorflow as tf
from keras_quant_utils import gen_anchors, relu6, ssd_focal_loss
import cv2
import tqdm
import os

In [None]:
DATASET_PATH = 'path/to/dataset'
MODEL_PATH = 'path/to/keras/model'
INPUT_SHAPE = (192,192,3)

# default values for mobilenetv1 with alpha=0.25
SCALE = 127.5
OFFSET = -1

model = tf.keras.models.load_model('MODEL_PATH',
                                    custom_objects={
                                       'gen_anchors': gen_anchors, 
                                       'relu6': relu6, 
                                       '_loss': ssd_focal_loss
                                       })
input_type = 'uint8'
output_type = 'float32'

In [None]:
def representative_data_gen():
    for image_file in tqdm.tqdm(os.listdir(DATASET_PATH)):
        if image_file.endswith(".jpg"):
            image = cv2.imread(os.path.join(DATASET_PATH, image_file))
            if len(image.shape) != 3:
                image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            resized_image = cv2.resize(image, (INPUT_SHAPE[0], INPUT_SHAPE[1]), interpolation=cv2.INTER_LINEAR)
            image_data = resized_image/SCALE + OFFSET
            img = image_data.astype(np.float32)
            image_processed = np.expand_dims(img, 0)
            yield [image_processed]

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(model)

if input_type == 'int8':
    converter.inference_input_type = tf.int8
elif input_type == 'uint8':
    converter.inference_input_type = tf.uint8
else:
    pass

if output_type == 'int8':
    converter.inference_output_type = tf.int8
elif output_type == 'uint8':
    converter.inference_input_type = tf.uint8
else:
    pass

converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen

tflite_model_quantized = converter.convert()
output_filename = "quantized_model.tflite"
with open(output_filename, 'wb') as f:
    f.write(tflite_model_quantized)