In [1]:
import numpy as np
import os

In [2]:
import tensorflow_model_optimization as tfmot

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"]="1"
os.environ['TF_DETERMINISTIC_OPS'] = '1'

In [4]:
import PIL
from collections import Counter
import tensorflow as tf
import random
from tensorflow.python.framework.ops import disable_eager_execution
from tensorflow.python.framework.ops import enable_eager_execution
#disable_eager_execution()
enable_eager_execution()
from tensorflow.keras.layers import Input

In [5]:
from tensorflow.keras.applications.resnet50 import ResNet50
import tensorflow_datasets as tfds

In [6]:
# r: resnet
# d: densenet
# m: mobilenet
m = "d"

In [7]:
def preprocess_image_resnet(features):
    """Preprocesses the given image.

      Args:
        image: `Tensor` representing an image of arbitrary size.

  """
    image = features["image"]
    image = tf.image.resize(image,[224,224])
    image = tf.keras.applications.resnet.preprocess_input(image)
    
    features["image"] = image
    return features["image"], features["label"]

In [8]:
def preprocess_image_densenet(features):
    """Preprocesses the given image.

      Args:
        image: `Tensor` representing an image of arbitrary size.

  """
    image = features["image"]
    image = tf.image.resize(image,[224,224])
    image = tf.keras.applications.densenet.preprocess_input(image)
    
    features["image"] = image
    return features["image"], features["label"]

In [9]:
def preprocess_image_mobilenet(features):
    """Preprocesses the given image.

      Args:
        image: `Tensor` representing an image of arbitrary size.

  """
    image = features["image"]
    image = tf.image.resize(image,[224,224])
    image = tf.keras.applications.mobilenet.preprocess_input(image)
    
    features["image"] = image
    return features["image"], features["label"]

In [10]:
tfds_dataset1, tfds_info  = tfds.load(name='imagenet2012_subset', split='train', with_info=True,
                                     data_dir='../datasets/ImageNet/')

In [11]:
tfds_dataset2, tfds_info  = tfds.load(name='imagenet2012_subset', split='validation', with_info=True,
                                     data_dir='../datasets/ImageNet/')

In [12]:
BATCH_SIZE=20

In [14]:
if 'm' == 'r':
    preprocess_image = preprocess_image_resnet
elif 'm' == 'd':
    preprocess_image = preprocess_image_densenet
else:
    preprocess_image = preprocess_image_mobilenet

In [15]:
train_ds = tfds_dataset1.map(preprocess_image).batch(BATCH_SIZE).prefetch(1)
val_ds = tfds_dataset2.map(preprocess_image).batch(BATCH_SIZE).prefetch(1)

In [16]:
num_images = tfds_info.splits['train'].num_examples
num_classes = tfds_info.features['label'].num_classes
print(num_images)

12811


In [17]:
img_rows, img_cols = 224 ,224

# Base

In [18]:
if 'm' == 'r':
    model = tf.keras.applications.ResNet50(input_shape=(img_rows, img_cols,3))
elif 'm' == 'd':
    model = tf.keras.applications.DenseNet121(input_shape=(img_rows, img_cols,3))
else:
    model = tf.keras.applications.MobileNet(input_shape=(img_rows, img_cols,3))

In [19]:
if 'm' == 'r':
    model.load_weights("../weights/fp_model_40_resnet50.h5")
elif 'm' == 'd':
    model.load_weights("../weights/fp_model_40_densenet121.h5")
else:
    model.load_weights("../weights/fp_model_40_mobilenet.h5")

In [20]:
model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=2e-5),
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [21]:
_, baseline_model_accuracy = model.evaluate(val_ds, verbose=0)
print(baseline_model_accuracy)

0.7421600222587585


# Train pruned model

In [22]:
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

In [32]:
epochs = 7

In [33]:
end_step = np.ceil(num_images / BATCH_SIZE).astype(np.int32) * epochs

In [34]:
pruning_params = {
    "pruning_schedule": tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0,
                                                            final_sparsity=0.8,
                                                            begin_step=0,
                                                            end_step=end_step)
}

In [35]:
model_for_pruning = prune_low_magnitude(model, **pruning_params)

In [36]:
model_for_pruning.compile(optimizer=tf.keras.optimizers.RMSprop(lr=2e-5),
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [37]:
model_for_pruning.summary()

Model: "densenet121"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
prune_low_magnitude_zero_paddin (None, 230, 230, 3)  1           input_2[0][0]                    
__________________________________________________________________________________________________
prune_low_magnitude_conv1/conv  (None, 112, 112, 64) 18818       prune_low_magnitude_zero_padding2
__________________________________________________________________________________________________
prune_low_magnitude_conv1/bn (P (None, 112, 112, 64) 257         prune_low_magnitude_conv1/conv[0]
________________________________________________________________________________________

In [38]:
callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep(),
]

In [39]:
model_for_pruning.fit(train_ds,
          epochs=epochs,
          batch_size=BATCH_SIZE,
          validation_data=val_ds,
          callbacks=callbacks,
          verbose=1)

Epoch 1/7
Epoch 2/7
Epoch 3/7
Epoch 4/7
Epoch 5/7
Epoch 6/7
Epoch 7/7


<tensorflow.python.keras.callbacks.History at 0x7effe0741c50>

In [40]:
_, model_for_pruning_accuracy = model_for_pruning.evaluate(val_ds, verbose=0)
print(model_for_pruning_accuracy)

0.6249799728393555


In [41]:
model_for_pruning.get_weights()

[1799,
 array([[[[ 0.        , -0.        ,  0.        , ...,  0.        ,
            0.        ,  0.        ],
          [ 0.16445662, -0.        , -0.        , ...,  0.        ,
           -0.        ,  0.        ],
          [-0.        ,  0.        ,  0.        , ...,  0.        ,
            0.        ,  0.        ]],
 
         [[ 0.14914592, -0.        , -0.        , ...,  0.        ,
            0.        ,  0.        ],
          [ 0.2487163 ,  0.        , -0.        , ...,  0.        ,
           -0.        ,  0.        ],
          [-0.        ,  0.        ,  0.        , ..., -0.        ,
            0.        ,  0.        ]],
 
         [[ 0.16573551, -0.        , -0.        , ...,  0.        ,
            0.        ,  0.        ],
          [ 0.2661216 , -0.        , -0.        , ...,  0.        ,
            0.        , -0.        ],
          [-0.        ,  0.        ,  0.        , ..., -0.        ,
            0.        , -0.        ]],
 
         ...,
 
         [[ 0.

In [42]:
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

In [43]:
if 'm' == 'r':
    model_for_export.save("../weights/p_model_40_resnet50.h5")
elif 'm' == 'd':
    model_for_export.save("../weights/p_model_40_densenet121.h5")
else:
    model_for_export.save("../weights/p_model_40_mobilenet.h5")

# Quantize pruned model + finetune

In [23]:
model_for_export = tf.keras.applications.DenseNet121(input_shape=(img_rows, img_cols,3))
model_for_export.load_weights("../weights/fp_model_40_densenet121.h5")

In [25]:
class DefaultBNQuantizeConfig(tfmot.quantization.keras.QuantizeConfig):

    def get_weights_and_quantizers(self, layer):
        return []

    def get_activations_and_quantizers(self, layer):
        return []

    def set_quantize_weights(self, layer, quantize_weights):
        pass
    def set_quantize_activations(self, layer, quantize_activations):
        pass
    def get_output_quantizers(self, layer):
        return [tfmot.quantization.keras.quantizers.MovingAverageQuantizer(num_bits=8, per_axis=False, symmetric=False, narrow_range=False)]

    def get_config(self):
        return {}
    
    
class NoOpQuantizeConfig(tfmot.quantization.keras.QuantizeConfig):
    """Use this config object if the layer has nothing to be quantized for 
    quantization aware training."""

    def get_weights_and_quantizers(self, layer):
        return []

    def get_activations_and_quantizers(self, layer):
        return []

    def set_quantize_weights(self, layer, quantize_weights):
        pass

    def set_quantize_activations(self, layer, quantize_activations):
        pass

    def get_output_quantizers(self, layer):
        # Does not quantize output, since we return an empty list.
        return []

    def get_config(self):
        return {}
    
    
def apply_quantization(layer):
    if 'bn'  in layer.name:
        return tfmot.quantization.keras.quantize_annotate_layer(layer,DefaultBNQuantizeConfig())
    elif 'concat' in layer.name:
        return tfmot.quantization.keras.quantize_annotate_layer(layer,NoOpQuantizeConfig())
    else:
        return tfmot.quantization.keras.quantize_annotate_layer(layer)

In [26]:
if m == 'd':
    LastValueQuantizer = tfmot.quantization.keras.quantizers.LastValueQuantizer
    MovingAverageQuantizer = tfmot.quantization.keras.quantizers.MovingAverageQuantizer

    # Use `tf.keras.models.clone_model` to apply `apply_quantization_to_dense` 
    # to the layers of the model.
    annotated_model = tf.keras.models.clone_model(
        model_for_export,
        clone_function=apply_quantization,
    )
    with tfmot.quantization.keras.quantize_scope({'DefaultBNQuantizeConfig': DefaultBNQuantizeConfig, 'NoOpQuantizeConfig': NoOpQuantizeConfig}):
            pqat_model = tfmot.quantization.keras.quantize_apply(annotated_model, tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme())
else:
    quant_aware_annotate_model  = tfmot.quantization.keras.quantize_annotate_model(model_for_export)
    pqat_model = tfmot.quantization.keras.quantize_apply(
                quant_aware_annotate_model,
                tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme())

In [27]:
pqat_model.summary()

Model: "densenet121"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
quantize_layer_1 (QuantizeLayer (None, 224, 224, 3)  3           input_2[0][0]                    
__________________________________________________________________________________________________
quant_zero_padding2d (QuantizeW (None, 230, 230, 3)  1           quantize_layer_1[1][0]           
__________________________________________________________________________________________________
quant_conv1/conv (QuantizeWrapp (None, 112, 112, 64) 9539        quant_zero_padding2d[0][0]       
________________________________________________________________________________________

quant_bn (QuantizeWrapper)      (None, 7, 7, 1024)   4099        quant_conv5_block16_concat[0][0] 
__________________________________________________________________________________________________
quant_relu (QuantizeWrapper)    (None, 7, 7, 1024)   3           quant_bn[0][0]                   
__________________________________________________________________________________________________
quant_avg_pool (QuantizeWrapper (None, 1024)         3           quant_relu[0][0]                 
__________________________________________________________________________________________________
quant_predictions (QuantizeWrap (None, 1000)         1025005     quant_avg_pool[0][0]             
Total params: 8,084,151
Trainable params: 7,978,856
Non-trainable params: 105,295
__________________________________________________________________________________________________


In [17]:
pqat_model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=2e-5),
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [None]:
pqat_model.fit(train_ds,
          epochs=5,
          validation_data= val_ds,
          verbose=1)

Epoch 1/5
Epoch 2/5
Epoch 3/5

In [None]:
print(pqat_model.evaluate(val_ds, verbose=0))

In [None]:
if 'm' == 'r':
    pqat_model.save("../weights/pqat_model_40_resnet50.h5")
elif 'm' == 'd':
    pqat_model.save("../weights/pqat_model_40_densenet121.h5")
else:
    pqat_model.save("../weights/pqat_model_40_mobilenet.h5")