In [15]:
import tensorflow as tf
from tensorflow import keras

import numpy as np
import matplotlib.pyplot as plt

from keras.layers import Conv2D, MaxPool2D, Flatten, Dense

from keras.preprocessing.image import ImageDataGenerator
from keras.utils import load_img
from keras.utils import img_to_array
from keras.applications import imagenet_utils
import os

import pandas as pd
import tensorflow_datasets as tfds
import tensorflow_model_optimization as tfmot

In [16]:
model = tf.keras.applications.MobileNet(weights='imagenet', input_shape=(224, 224, 3),
                                             )  #include_preprocessing=False)

In [17]:
model.summary()

Model: "mobilenet_1.00_224"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_3 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 conv1 (Conv2D)              (None, 112, 112, 32)      864       
                                                                 
 conv1_bn (BatchNormalizatio  (None, 112, 112, 32)     128       
 n)                                                              
                                                                 
 conv1_relu (ReLU)           (None, 112, 112, 32)      0         
                                                                 
 conv_dw_1 (DepthwiseConv2D)  (None, 112, 112, 32)     288       
                                                                 
 conv_dw_1_bn (BatchNormaliz  (None, 112, 112, 32)     128       
 ation)                                         

In [18]:
def process_image(data):
    data['image'] = (tf.image.resize(data['image'], (224, 224)) * 2.0 / 255.0) - 1.0
    return data

Compile model so it can be evaluated

In [19]:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(),
              metrics=['accuracy'])

Prepare training and testing datasets

In [21]:
tr_ds = tfds.load('imagenet_v2', split='test[:90%]')
tr_ds = tr_ds.map(process_image)

train_ds = tr_ds\
    .map(lambda data: (data['image'], data['label']))\
    .batch(32)

t_ds = tfds.load('imagenet_v2', split='test[90%:]')
t_ds = t_ds.map(process_image)

test_ds = t_ds.map(lambda data: (data['image'], data['label'])).batch(64)

Evaluate the model on imagenet dataset (note: it needs to be tested on bigger subset of imagenet for accurate results)

In [22]:
loss, acc = model.evaluate(test_ds)
print(f'Top-1 accuracy (float): {acc * 100:.2f}%')

Top-1 accuracy (float): 54.50%


In [23]:
quantize_model = tfmot.quantization.keras.quantize_model

# q_aware stands for for quantization aware.
q_aware_model = quantize_model(model)

# `quantize_model` requires a recompile.
q_aware_model.compile(optimizer='adam',
                      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                      metrics=['accuracy'])

q_aware_model.summary()

Model: "mobilenet_1.00_224"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_3 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 quantize_layer_1 (QuantizeL  (None, 224, 224, 3)      3         
 ayer)                                                           
                                                                 
 quant_conv1 (QuantizeWrappe  (None, 112, 112, 32)     929       
 rV2)                                                            
                                                                 
 quant_conv1_bn (QuantizeWra  (None, 112, 112, 32)     129       
 pperV2)                                                         
                                                                 
 quant_conv1_relu (QuantizeW  (None, 112, 112, 32)     3         
 rapperV2)                                      

Fine-tune Quantization Aware Model

In [25]:
q_aware_model.fit(train_ds, epochs=5, validation_data=test_ds)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x7f7f8f071c90>

In [26]:
qa_loss, qa_acc = q_aware_model.evaluate(test_ds)
print(f'Top-1 accuracy (quantize aware float): {qa_acc * 100:.2f}%')

Top-1 accuracy (quantize aware float): 1.50%


In [42]:
for qlayer in q_aware_model.layers:
    print(qlayer)
    print(qlayer.weights)
    print(qlayer.get_weights())

<keras.engine.input_layer.InputLayer object at 0x7f8005d9f250>
[]
[]
<tensorflow_model_optimization.python.core.quantization.keras.quantize_layer.QuantizeLayer object at 0x7f8003b933a0>
[<tf.Variable 'quantize_layer_1/quantize_layer_1_min:0' shape=() dtype=float32, numpy=-1.0>, <tf.Variable 'quantize_layer_1/quantize_layer_1_max:0' shape=() dtype=float32, numpy=1.0>, <tf.Variable 'quantize_layer_1/optimizer_step:0' shape=() dtype=int32, numpy=-1>]
[-1.0, 1.0, -1]
<tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper.QuantizeWrapperV2 object at 0x7f801fc55ba0>
[<tf.Variable 'conv1/kernel:0' shape=(3, 3, 3, 32) dtype=float32, numpy=
array([[[[ 8.87084053e-14,  3.06156009e-01,  1.92597583e-02,
           1.65614850e-13,  1.67507152e-14,  2.29323477e-01,
          -2.17220068e-01, -2.97841370e-01, -4.32953903e-15,
           7.74574950e-02,  3.76683533e-01,  9.35078702e-14,
           1.49377331e-01,  1.52193442e-01, -1.13821197e-02,
           4.56319394e-05, -5.1