In [7]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import model_from_json
import tensorflow_model_optimization as tfmot
import os
import tempfile

In [8]:
json_file = open('saved_model/model_CsiNet_indoor_dim64.json', 'r')
base_model_json = json_file.read()
json_file.close()
base_model = model_from_json(base_model_json)
# load weights into new model
base_model.load_weights("saved_model/model_CsiNet_indoor_dim64.h5")
print("Loaded base model from disk")
base_model.summary()

Loaded base model from disk
Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 2, 32, 32)]  0           []                               
                                                                                                  
 conv2d_1 (Conv2D)              (None, 2, 32, 32)    38          ['input_1[0][0]']                
                                                                                                  
 batch_normalization_1 (BatchNo  (None, 2, 32, 32)   128         ['conv2d_1[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 leaky_re_lu_1 (LeakyReLU)      (None, 2, 32, 32)    0          

In [9]:
# Convert the base model to a tflite model
converter = tf.lite.TFLiteConverter.from_keras_model(base_model)
tflite_base_model = converter.convert()




INFO:tensorflow:Assets written to: C:\Users\Omar\AppData\Local\Temp\tmp55ihgdrh\assets


INFO:tensorflow:Assets written to: C:\Users\Omar\AppData\Local\Temp\tmp55ihgdrh\assets


In [10]:
# Convert base model to a tflite model and apply dynamic range quantization which quantizes the weights ONLY
# In Dynamic Range Quantization, weights are converted to 8-bit precision values
converter = tf.lite.TFLiteConverter.from_keras_model(base_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

tflite_base_model_dynamic_range_quant = converter.convert()



INFO:tensorflow:Assets written to: C:\Users\Omar\AppData\Local\Temp\tmpka1l7dk0\assets


INFO:tensorflow:Assets written to: C:\Users\Omar\AppData\Local\Temp\tmpka1l7dk0\assets


In [11]:
# Passing the Keras model to the TF Lite Converter.
converter = tf.lite.TFLiteConverter.from_keras_model(base_model)
# Using float-16 quantization.
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
# Converting the model.
tflite_fp16_model = converter.convert()




INFO:tensorflow:Assets written to: C:\Users\Omar\AppData\Local\Temp\tmp580b7m0e\assets


INFO:tensorflow:Assets written to: C:\Users\Omar\AppData\Local\Temp\tmp580b7m0e\assets


In [12]:
# Saving the model.
with open('omar_saved_model/base_model.tflite', 'wb') as f:
  f.write(tflite_base_model)
with open('omar_saved_model/dynamic_range_quant_model.tflite', 'wb') as f:
  f.write(tflite_base_model_dynamic_range_quant)
with open('omar_saved_model/fp16_quant_model.tflite', 'wb') as f:
  f.write(tflite_fp16_model)

# Get size of quantized model and base model
tflite_base_model_size = os.path.getsize('omar_saved_model/base_model.tflite') 
tflite_base_model_dynamic_range_quant_size = os.path.getsize('omar_saved_model/dynamic_range_quant_model.tflite')
tflite_fp16_model_size = os.path.getsize('omar_saved_model/fp16_quant_model.tflite')

print("Base model in Mb:", tflite_base_model_size / float(2**20))
print("Dynamic Range Quantized model in Mb:", tflite_base_model_dynamic_range_quant_size / float(2**20))
print("FP16 Quantized model in Mb:", tflite_fp16_model_size / float(2**20))
print("Number of times larger", tflite_base_model_size/tflite_base_model_dynamic_range_quant_size)

Base model in Mb: 1.0368881225585938
Dynamic Range Quantized model in Mb: 0.28081512451171875
FP16 Quantized model in Mb: 0.5305137634277344
Number of times larger 3.6924226369983972
