In [1]:
!pip install tensorflow-model-optimization numpy tensorflow



# Train base model on CIFAR-10

In [2]:
import tensorflow as tf

(train_imgs, train_lbls), (val_imgs, val_lbls) = tf.keras.datasets.cifar10.load_data()

train_imgs, val_imgs = train_imgs/255.0, val_imgs/255.0

test_imgs, test_lbls = val_imgs[-2000:], val_lbls[-2000:]
val_imgs, val_lbls = val_imgs[:8000], val_lbls[:8000]

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 0us/step


In [3]:
from tensorflow_model_optimization.python.core.keras.compat import keras

def separable_conv(i, ch):
  x = keras.layers.DepthwiseConv2D((3,3), padding='same')(i)
  x = keras.layers.BatchNormalization()(x)
  x = keras.layers.Activation('relu')(x)
  x = keras.layers.Conv2D(ch, (1,1), padding='same')(x)
  x = keras.layers.BatchNormalization()(x)
  return keras.layers.Activation('relu')(x)

In [4]:
# def dwsepcnn_block(ch):
#   return keras.Sequential([
#     keras.layers.DepthwiseConv2D((3,3), padding='same'),
#     keras.layers.BatchNormalization(),
#     keras.layers.Activation('relu'),
#     keras.layers.Conv2D(ch, (1,1), padding='same'),
#     keras.layers.BatchNormalization(),
#     keras.layers.Activation('relu')
#   ])

Convolution base

In [5]:
input = keras.layers.Input((32,32,3))
x = keras.layers.Conv2D(16, (3, 3), padding='same')(input)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Activation('relu')(x)
x = separable_conv(x, 16)
x = keras.layers.MaxPooling2D((2,2))(x)
x = separable_conv(x, 48)
x = keras.layers.MaxPooling2D((2,2))(x)
x = separable_conv(x, 96)
x = separable_conv(x, 192)
x = keras.layers.MaxPooling2D((2,2))(x)

In [6]:
# model = keras.Sequential([
#     keras.layers.Input((32,32,3)),
#     keras.layers.Conv2D(16, (3, 3), padding='same'),
#     keras.layers.BatchNormalization(),
#     keras.layers.Activation('relu'),
#     dwsepcnn_block(16),
#     keras.layers.MaxPooling2D((2,2)),
#     dwsepcnn_block(48),
#     keras.layers.MaxPooling2D((2,2)),
#     dwsepcnn_block(96),
#     dwsepcnn_block(192),
#     keras.layers.MaxPooling2D((2,2))])

Classification head

In [7]:
x = keras.layers.Flatten()(x)
x = keras.layers.Dropout(0.2)(x)
x = keras.layers.Dense(10)(x)

In [8]:
# model.add(keras.layers.Flatten())
# model.add(keras.layers.Dropout(0.2))
# model.add(keras.layers.Dense(10))

In [9]:
#from keras.models import Model
model = keras.models.Model(input, x)
model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 conv2d (Conv2D)             (None, 32, 32, 16)        448       
                                                                 
 batch_normalization (Batch  (None, 32, 32, 16)        64        
 Normalization)                                                  
                                                                 
 activation (Activation)     (None, 32, 32, 16)        0         
                                                                 
 depthwise_conv2d (Depthwis  (None, 32, 32, 16)        160       
 eConv2D)                                                        
                                                                 
 batch_normalization_1 (Bat  (None, 32, 32, 16)        64    

In [10]:
loss_f = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer='adam', loss = loss_f, metrics=['accuracy'])

history = model.fit(train_imgs, train_lbls, epochs=10, batch_size=32,
                    validation_data=(val_imgs, val_lbls))

model.export('cifar10')

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Saved artifact at 'cifar10'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 32, 32, 3), dtype=tf.float32, name='input_1')
Output Type:
  TensorSpec(shape=(None, 10), dtype=tf.float32, name=None)
Captures:
  135954260309968: TensorSpec(shape=(), dtype=tf.resource, name=None)
  135954260310544: TensorSpec(shape=(), dtype=tf.resource, name=None)
  135954278586064: TensorSpec(shape=(), dtype=tf.resource, name=None)
  135954260309776: TensorSpec(shape=(), dtype=tf.resource, name=None)
  135954260307280: TensorSpec(shape=(), dtype=tf.resource, name=None)
  135954260310352: TensorSpec(shape=(), dtype=tf.resource, name=None)
  135954260310736: TensorSpec(shape=(), dtype=tf.resource, name=None)
  135954260311696: TensorSpec(shape=(), dtype=tf.resource, name=None)
  135954260311312: TensorSpec(shape=(), dtype=tf.resource, n

# Quantize the model
We need a sample from the training dataset to run inference on for full integer
quantization, because the zero point and scale need to be calculated for the activations.

In [11]:
cifar_ds = tf.data.Dataset.from_tensor_slices(train_imgs)
def representative_data_gen():
  for i_value in cifar_ds.batch(1).take(1000):
    i_value_f32 = tf.dtypes.cast(i_value, tf.float32)
    yield [i_value_f32]

tfl_conv = tf.lite.TFLiteConverter.from_saved_model('cifar10')
tfl_conv.representative_dataset = \
  tf.lite.RepresentativeDataset(representative_data_gen)
tfl_conv.optimizations = [tf.lite.Optimize.DEFAULT]
tfl_conv.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
tfl_conv.inference_input_type = tf.int8
tfl_conv.inference_output_type = tf.int8

In [12]:
tfl_model = tfl_conv.convert()
print(len(tfl_model))

79672


Evaluate the quantized model using the validation dataset

In [13]:
tfl_interp = tf.lite.Interpreter(model_content=tfl_model)
tfl_interp.allocate_tensors()

i_details = tfl_interp.get_input_details()[0]
o_details = tfl_interp.get_output_details()[0]

# print(len(tfl_interp.get_input_details()))
# print(len(tfl_interp.get_output_details()))

i_quant = i_details['quantization_parameters']
o_quant = o_details['quantization_parameters']
i_scale = i_quant['scales'][0]
i_zero_point = i_quant['zero_points'][0]
o_scale = o_quant['scales'][0]
o_zero_point = o_quant['zero_points'][0]

    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    


In [14]:
def classify(i_data):
  input_data = i_data.reshape((1, 32, 32, 3))
  i_value_f32 = tf.dtypes.cast(input_data, tf.float32)
  i_value_f32 = i_value_f32 / i_scale + i_zero_point
  i_value_s8 = tf.cast(i_value_f32, dtype=tf.int8)

  tfl_interp.set_tensor(i_details['index'], i_value_s8)
  tfl_interp.invoke()
  o_pred = tfl_interp.get_tensor(o_details['index'])[0]

  return (o_pred - o_zero_point) * o_scale

In [15]:
num_correct_samples = 0

import numpy as np

for i_value, o_value in zip(val_imgs, val_lbls):
  o_pred_f32 = classify(i_value)
  if np.argmax(o_pred_f32) == o_value:
    num_correct_samples += 1

total_samples = len(list(val_imgs))
print('Accuracy: ', num_correct_samples/total_samples)

Accuracy:  0.731


In [16]:
with open('cifar10.tflite', 'wb') as file:
  file.write(tfl_model)

!apt-get update && apt-get -qq install xxd
!xxd -i cifar10.tflite > model.h
!sed -i 's/unsigned char/const unsigned char/g' model.h
!sed -i 's/const/alignas(8) const/g' model.h

0% [Working]            Get:1 https://cli.github.com/packages stable InRelease [3,917 B]
0% [Waiting for headers] [Connecting to security.ubuntu.com (185.125.190.36)] [0% [Waiting for headers] [Connecting to security.ubuntu.com (185.125.190.36)] [                                                                               Get:2 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease [1,581 B]
0% [Waiting for headers] [Connecting to security.ubuntu.com (185.125.190.36)] [                                                                               Hit:3 http://archive.ubuntu.com/ubuntu jammy InRelease
Get:4 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,632 B]
Get:5 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [128 kB]
Get:6 https://cli.github.com/packages stable/main amd64 Packages [346 B]
Get:7 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]
Get:8 https://r2u.stat.illinois.edu/ubuntu ja

# Quantization aware training
https://www.tensorflow.org/model_optimization/guide/quantization/training_example

In [17]:
import tensorflow_model_optimization as tfmot

quantize_model = tfmot.quantization.keras.quantize_model

q_aware_model = quantize_model(model)

# quantize_model requires a recompile
q_aware_model.compile(optimizer='adam',
                      loss=loss_f,
                      metrics=['accuracy'])

q_aware_model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 quantize_layer (QuantizeLa  (None, 32, 32, 3)         3         
 yer)                                                            
                                                                 
 quant_conv2d (QuantizeWrap  (None, 32, 32, 16)        481       
 perV2)                                                          
                                                                 
 quant_batch_normalization   (None, 32, 32, 16)        65        
 (QuantizeWrapperV2)                                             
                                                                 
 quant_activation (Quantize  (None, 32, 32, 16)        3         
 WrapperV2)                                                  

All layers are now prefixed by "quant". Note that the resulting model is quantization aware but not quantized (e.g. the weights are float32 instead of int8).

In [18]:
train_images_subset = train_imgs[:1000]
train_labels_subset = train_lbls[:1000]

q_aware_model.fit(train_images_subset, train_labels_subset,
                  batch_size=64, epochs=2, validation_data=(val_imgs, val_lbls))

_, baseline_model_accuracy = model.evaluate(test_imgs, test_lbls, verbose=0)
#_, tfl_model_accuracy = tfl_model.evaluate(test_imgs, test_lbls, verbose=0)
_, q_aware_model_accuracy = q_aware_model.evaluate(test_imgs, test_lbls, verbose=0)


Epoch 1/2
Epoch 2/2


In [19]:
print('Baseline test accuracy: ', baseline_model_accuracy)
#print('Quant test accuracy: ', tfl_model_accuracy)
print('Quant aware test accuracy: ', q_aware_model_accuracy)

Baseline test accuracy:  0.7174999713897705
Quant aware test accuracy:  0.7404999732971191


## Quantize q-aware model

In [20]:
tfl_conv = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
tfl_conv.representative_dataset = \
  tf.lite.RepresentativeDataset(representative_data_gen)
tfl_conv.optimizations = [tf.lite.Optimize.DEFAULT]
tfl_conv.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
tfl_conv.inference_input_type = tf.int8
tfl_conv.inference_output_type = tf.int8

quantized_q_aware_model = tfl_conv.convert()



In [21]:
with open('cifar10_q_aware.tflite', 'wb') as file:
  file.write(quantized_q_aware_model)

!xxd -i 'cifar10_q_aware.tflite' > q_aware_model.h
!sed -i 's/unsigned char/const unsigned char/g' q_aware_model.h
!sed -i 's/const/alignas(8) const/g' q_aware_model.h