In [1]:
import tempfile
import os

import tensorflow as tf
import numpy as np

from tensorflow_model_optimization.python.core.keras.compat import keras

import warnings
warnings.filterwarnings("ignore")

# Create a Model for MNIST Dataset

In [2]:
# Load MNIST dataset
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

# Normalize the input image so that each pixel value is between 0 and 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the model architecture.
original_model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
original_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

original_model.fit(
  train_images,
  train_labels,
  epochs=4,
  validation_split=0.1,
  verbose=False
)

<tf_keras.src.callbacks.History at 0x286956bd0>

In [3]:
loss, original_model_accuracy = original_model.evaluate(
    test_images, test_labels, verbose=0)

print('Original model test accuracy:', original_model_accuracy)

Original model test accuracy: 0.9779999852180481


In [5]:
original_model_path = 'original_model.h5'
original_model.save(original_model_path) # save the model for comparison

# Pruning

In [6]:
import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

batch_size = 128
epochs = 2
validation_split = 0.1 

num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

# Define model for pruning.
pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                               final_sparsity=0.80,
                                                               begin_step=0,
                                                               end_step=end_step)
}

model_for_pruning = prune_low_magnitude(original_model, **pruning_params)

# `prune_low_magnitude` requires a recompile.
model_for_pruning.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [7]:
callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
]

model_for_pruning.fit(train_images, train_labels,
                  batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                  callbacks=callbacks, verbose=False)

<tf_keras.src.callbacks.History at 0x286951a90>

In [9]:
loss, model_for_pruning_accuracy = model_for_pruning.evaluate(
   test_images, test_labels, verbose=0)

print('Baseline test accuracy:', original_model_accuracy) 
print('Pruned test accuracy:', model_for_pruning_accuracy)

Baseline test accuracy: 0.9779999852180481
Pruned test accuracy: 0.9710000157356262


In [10]:
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
pruned_model_path = 'pruned_model.h5'
model_for_export.save(pruned_model_path)



# Quantization

In [11]:
converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_and_pruned_tflite_model = converter.convert()

quantized_model_path = 'quantized_model.tflite'
with open(quantized_model_path, 'wb') as f:
    f.write(quantized_and_pruned_tflite_model)

INFO:tensorflow:Assets written to: /var/folders/9p/ygcng_xn489bsl49504j5l780000gn/T/tmpu2gogy17/assets


INFO:tensorflow:Assets written to: /var/folders/9p/ygcng_xn489bsl49504j5l780000gn/T/tmpu2gogy17/assets
2024-03-03 15:34:14.348148: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format.
2024-03-03 15:34:14.348160: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency.
2024-03-03 15:34:14.348338: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /var/folders/9p/ygcng_xn489bsl49504j5l780000gn/T/tmpu2gogy17
2024-03-03 15:34:14.348702: I tensorflow/cc/saved_model/reader.cc:51] Reading meta graph with tags { serve }
2024-03-03 15:34:14.348706: I tensorflow/cc/saved_model/reader.cc:146] Reading SavedModel debug info (if present) from: /var/folders/9p/ygcng_xn489bsl49504j5l780000gn/T/tmpu2gogy17
2024-03-03 15:34:14.349406: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:388] MLIR V1 optimization pass is not enabled
2024-03-03 15:34:14.349605: I tensorflow/cc/saved_model/load

In [12]:
# Function to print file size
def print_file_size(file_path):
    size = os.path.getsize(file_path) / 1024  # size in KB
    print(f"Size of {file_path}: {size:.2f} KB")

# Compare file sizes
print_file_size('original_model.h5')
print_file_size('pruned_model.h5')
print_file_size('quantized_model.tflite')

Size of original_model.h5: 265.77 KB
Size of pruned_model.h5: 96.65 KB
Size of quantized_model.tflite: 23.50 KB
