In [None]:
!pip install tensorflow-model-optimization

In [None]:
import tensorflow as tf
import tensorflow_model_optimization as tfmot
import tensorflow_datasets as tfds
from tqdm import tqdm
import keras

from google.colab import drive
drive.mount('/content/drive')

### Observe Pruning Metrics

Note: run the following commands __after__ the pruning execution

In [None]:
%load_ext tensorboard
%tensorboard --logdir /content/drive/MyDrive/Computer-Vision/Models/FineTuning/Pruning/logs

## Parameters Configuration

In [None]:
SHUFFLE_BUFFER_SIZE = 1000
BATCH_SIZE = 32
IMG_SIZE = 224
EPOCHS = 5

#Adjust the paths as needed
BASE_PATH = '/content/drive/MyDrive/Computer-Vision'
DATASET_PATH = BASE_PATH + '/Dataset/'
MODEL_LOAD_PATH = BASE_PATH + '/Models/FineTuning/V1-Batch32.h5'
PRUNED_MODEL_SAVE_PATH = BASE_PATH + '/Models/Pruning/pruned-model.h5'
COMPRESSED_MODEL_SAVE_PATH = BASE_PATH + '/TFlite-Models/tfLite-model.tflite'

## Model Loading

In [None]:
trained_model = tf.keras.models.load_model(MODEL_LOAD_PATH)

## Data Loading

In [None]:
def format_image(image , label):
    image = tf.cast(image, tf.float32)
    image = (image/127.5) - 1
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
    return image, label

# The script will assume to find three subfolders in the DATASET_PATH :
#  - train 
#  - val (for validation data)
#  - test
builder = tfds.folder_dataset.ImageFolder(DATASET_PATH)
raw_train = builder.as_dataset(split='train', as_supervised=True, shuffle_files=True)
raw_validation = builder.as_dataset(split='val', as_supervised=True)

info = builder.info
label_names = info.features['label'].names

print("Total training images: {}  ".format(len(raw_train)) )
print("Total validation images: {} ".format(len(raw_validation)))
print("Label names: {}".format(info.features['label'].names))

#Pre-process all the images in order to match the expected pre-requisites for the input of MobileNetV2
train = raw_train.map(format_image)
validation = raw_validation.map(format_image)

train_batches = train.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
validation_batches = validation.batch(BATCH_SIZE)

## Pruning Setup

In [None]:
end_step = len(train_batches) * EPOCHS

pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
        initial_sparsity=0.1,
        final_sparsity=0.40,
        begin_step=0,
        end_step=end_step
    )
}

model_to_prune = tfmot.sparsity.keras.prune_low_magnitude(trained_model, **pruning_params)
model_to_prune.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.SparseCategoricalCrossentropy(), metrics=['accuracy'])
model_to_prune.summary()

## Pruning Execution

In [None]:
callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep(),
    tfmot.sparsity.keras.PruningSummaries(log_dir= '/content/drive/MyDrive/Computer-Vision/Models/FineTuning/Pruning/logs-mattia2'),
    tf.keras.callbacks.ModelCheckpoint(
        filepath='/content/drive/MyDrive/Computer-Vision/Models/FineTuning/Pruning/mattiaTest2/pruned-model.h5',
        monitor='val_loss',
        save_best_only=False,
        save_weights_only=False,
        mode='min',
        verbose=1
    )
]

history = model_to_prune.fit(train_batches, epochs=EPOCHS, validation_data=validation_batches, verbose=1, callbacks=callbacks)
acc = history.history['accuracy']
print(acc)


model_for_export = tfmot.sparsity.keras.strip_pruning(model_to_prune)
keras.models.save_model(model_for_export, PRUNED_MODEL_SAVE_PATH, include_optimizer=False)
print('Saved pruned Keras model to: ', PRUNED_MODEL_SAVE_PATH)

converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
tflite_model = converter.convert()

with open(COMPRESSED_MODEL_SAVE_PATH, 'wb') as f:
  f.write(tflite_model)