In [1]:
import os;os.environ["TF_USE_LEGACY_KERAS"]="1"

import tensorflow_model_optimization as tfmot
import numpy as np
from tensorflow import keras

from preprocessing import dataset_preprocessing
from utils import get_zipped_model_size, print_model_weights_sparsity




In [2]:
BATCH_SIZE = 16
IMAGE_SIZE = 224
INPUT_SHAPE = (IMAGE_SIZE, IMAGE_SIZE, 3)
EPOCHS = 70
DROPOUT = 0.2
SCALE = 127.5
OFFSET = -1

PRUNING_EPOCHS = 3
INITIAL_SPARSITY = 0.20
FINAL_SPARSITY = 0.60

PATH = "..\coffe_dataset"

LOADED_MODEL = "01_mobilenet_v2_0.35_224_distilled"
MODEL_NAME = "01_pruned_mobilenet_v2_0.35_224_distilled"

In [3]:
dataset = keras.utils.image_dataset_from_directory(PATH,
                                                   shuffle = True,
                                                   batch_size = BATCH_SIZE,
                                                   image_size = (IMAGE_SIZE, IMAGE_SIZE))

class_names = dataset.class_names
number_classes = len(class_names)

training_dataset, validation_dataset, testing_dataset = dataset_preprocessing(dataset,
                                                                              train_size=0.60,
                                                                              validation_size=0.17, 
                                                                              augmentation_flag = True, 
                                                                              rescaling_flag = True, 
                                                                              prefetch_flag = True, 
                                                                              scale = SCALE, 
                                                                              offset = OFFSET)

Found 1379 files belonging to 9 classes.



In [4]:
model = keras.models.load_model('saved_models/' + LOADED_MODEL + '.keras')






In [5]:
#_, baseline_accuracy = model.evaluate(testing_dataset, verbose = 0)
#print('Baseline accuracy: ', round(baseline_accuracy*100, 3), '%')

In [6]:
#model.summary()

In [7]:
from custom_mobilenet_v2 import MobileNet_v2

custom_model = MobileNet_v2(INPUT_SHAPE, 0.35, number_classes, dropout=DROPOUT)

for i, layer in enumerate(model.layers):
        custom_model.layers[i].set_weights(layer.get_weights())

In [8]:
custom_model.compile(optimizer = keras.optimizers.Adam(1e-3),
              loss= keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['accuracy'])

_, baseline_accuracy = custom_model.evaluate(testing_dataset, verbose = 0)

print('Baseline accuracy: ' , round(100* baseline_accuracy,3),  '%')













Baseline accuracy:  2.381 %


In [9]:
#custom_model.summary()

In [10]:
custom_model.trainable= True

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

num_images = (len(training_dataset)) *BATCH_SIZE
end_step = np.ceil(num_images / BATCH_SIZE).astype(np.int32) * PRUNING_EPOCHS


pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=INITIAL_SPARSITY, 
                                                             final_sparsity=FINAL_SPARSITY,
                                                             begin_step=0, 
                                                             end_step=end_step,
                                                             frequency = 30)}
''' 
pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(0.5,
                                                              begin_step=0,
                                                              frequency=100)}
'''

pruned_model = prune_low_magnitude(custom_model, **pruning_params)

callbacks = [tfmot.sparsity.keras.UpdatePruningStep()]

pruned_model.compile(optimizer= keras.optimizers.Adam(learning_rate=1e-5),
                     loss= keras.losses.SparseCategoricalCrossentropy(from_logits=False),
                     metrics=['accuracy'])


# Fine tune the model
pruned_model.fit(training_dataset,
                 validation_data=validation_dataset,
                 epochs= PRUNING_EPOCHS,
                 verbose=1,
                 callbacks=callbacks)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<tf_keras.src.callbacks.History at 0x1bf1d87ae20>

In [11]:
pruned_model.compile(optimizer = keras.optimizers.Adam(1e-3),
              loss= keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['accuracy'])

early_stopping = keras.callbacks.EarlyStopping(patience=5, monitor='val_accuracy', restore_best_weights=True )

pruned_model.fit(training_dataset,
                 validation_data=validation_dataset,
                 epochs=EPOCHS,
                 callbacks=[early_stopping])

Epoch 1/70
Epoch 2/70
Epoch 3/70
Epoch 4/70
Epoch 5/70
Epoch 6/70


<tf_keras.src.callbacks.History at 0x1bf15a53d00>

In [12]:
_, pruned_accuracy = pruned_model.evaluate(testing_dataset, verbose = 0)
print('Baseline accuracy: ' , round(100* baseline_accuracy,3),  '%')
print('Pruned accuracy: '   , round(100* pruned_accuracy,3) ,   '%')

Baseline accuracy:  2.381 %
Pruned accuracy:  18.75 %


In [13]:
stripped_pruned_model = tfmot.sparsity.keras.strip_pruning(pruned_model)
print_model_weights_sparsity(stripped_pruned_model)

conv2d/kernel:0: 59.95% sparsity  (259/432)
batch_normalization/gamma:0: 0.00% sparsity  (0/16)
batch_normalization/beta:0: 0.00% sparsity  (0/16)
batch_normalization/moving_mean:0: 0.00% sparsity  (0/16)
batch_normalization/moving_variance:0: 0.00% sparsity  (0/16)
depthwise_conv2d/depthwise_kernel:0: 0.00% sparsity  (0/144)
batch_normalization_1/gamma:0: 0.00% sparsity  (0/16)
batch_normalization_1/beta:0: 0.00% sparsity  (0/16)
batch_normalization_1/moving_mean:0: 0.00% sparsity  (0/16)
batch_normalization_1/moving_variance:0: 0.00% sparsity  (0/16)
conv2d_1/kernel:0: 60.16% sparsity  (77/128)
batch_normalization_2/gamma:0: 0.00% sparsity  (0/8)
batch_normalization_2/beta:0: 0.00% sparsity  (0/8)
batch_normalization_2/moving_mean:0: 0.00% sparsity  (0/8)
batch_normalization_2/moving_variance:0: 0.00% sparsity  (0/8)
conv2d_2/kernel:0: 59.90% sparsity  (230/384)
batch_normalization_3/gamma:0: 0.00% sparsity  (0/48)
batch_normalization_3/beta:0: 0.00% sparsity  (0/48)
batch_normalizat

In [14]:
stripped_pruned_model.save('saved_models/' + MODEL_NAME + '.keras')

In [15]:
print("Pruned model size: ", get_zipped_model_size('saved_models/' + MODEL_NAME + '.keras')/10**6, ' MB')

Pruned model size:  0.529102  MB
