In [1]:
import numpy as np
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from tensorflow.keras.models import load_model
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras.utils import to_categorical

# 1. Load Fashion MNIST dataset
(X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()

# Normalize and reshape
X_train = X_train.astype('float32') / 255.0
X_test = X_test.astype('float32') / 255.0
X_train = np.expand_dims(X_train, -1)  # shape: (60000, 28, 28, 1)
X_test = np.expand_dims(X_test, -1)

# One-hot encode labels (because pruning loss expects categorical)
y_train = to_categorical(y_train, num_classes=10)
y_test = to_categorical(y_test, num_classes=10)

# 2. Load baseline model (make sure it's compiled if needed)
baseline_model = load_model('baseline_model1.h5')
print("✅ Baseline model loaded")

# 3. Setup pruning parameters
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
batch_size = 128
epochs = 2
validation_split = 0.1

num_samples = int(X_train.shape[0] * (1 - validation_split))
end_step = np.ceil(num_samples / batch_size).astype(np.int32) * epochs

pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
        initial_sparsity=0.40,
        final_sparsity=0.75,
        begin_step=0,
        end_step=end_step
    )
}

# 4. Apply pruning wrapper to baseline model
model_for_pruning = prune_low_magnitude(baseline_model, **pruning_params)
print("✅ Pruning wrapper applied")

# 5. Compile the model
model_for_pruning.compile(
    optimizer='adam',
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)
print("✅ Model compiled with pruning")

# Now ready to call model_for_pruning.fit(...) to train



✅ Baseline model loaded
✅ Pruning wrapper applied
✅ Model compiled with pruning


In [2]:
print(model_for_pruning.summary())

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 prune_low_magnitude_conv2d  (None, 26, 26, 32)        610       
  (PruneLowMagnitude)                                            
                                                                 
 prune_low_magnitude_max_po  (None, 13, 13, 32)        1         
 oling2d (PruneLowMagnitude                                      
 )                                                               
                                                                 
 prune_low_magnitude_conv2d  (None, 11, 11, 64)        36930     
 _1 (PruneLowMagnitude)                                          
                                                                 
 prune_low_magnitude_max_po  (None, 5, 5, 64)          1         
 oling2d_1 (PruneLowMagnitu                                      
 de)                                                    

In [3]:
import tempfile
import pathlib

# 6. Setup pruning callbacks
log_dir = tempfile.mkdtemp()
callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep(),
    tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir),
]


In [4]:
# 7. Train the model (2 epochs with 10% validation split)
model_for_pruning.fit(
    X_train,
    y_train,
    batch_size=batch_size,
    epochs=epochs,
    validation_split=validation_split,
    callbacks=callbacks
)
print("✅ Model trained with pruning")

Epoch 1/2


  output, from_logits = _get_logits(


Epoch 2/2
✅ Model trained with pruning


In [5]:
# 1. Strip pruning wrappers to get final sparse model
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

# 2. Convert the stripped model to TFLite format
converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
tflite_model_pruned = converter.convert()

# 3. Save the TFLite model file
tflite_models_dir = pathlib.Path('tflite_models/')  # Change to '/content/tflite_models/' if running on Colab
tflite_models_dir.mkdir(exist_ok=True, parents=True)
tflite_model_file = tflite_models_dir / 'model_pruned.tflite'
tflite_model_file.write_bytes(tflite_model_pruned)

print(f"✅ Pruned TFLite model saved at: {tflite_model_file}")

INFO:tensorflow:Assets written to: /var/folders/bs/x0lj933d1hv0py0d4w2ypdp40000gn/T/tmp9sy2ziyq/assets


INFO:tensorflow:Assets written to: /var/folders/bs/x0lj933d1hv0py0d4w2ypdp40000gn/T/tmp9sy2ziyq/assets


✅ Pruned TFLite model saved at: tflite_models/model_pruned.tflite


2025-06-05 18:34:52.648102: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:364] Ignored output_format.
2025-06-05 18:34:52.648319: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:367] Ignored drop_control_dependency.
2025-06-05 18:34:52.649257: I tensorflow/cc/saved_model/reader.cc:45] Reading SavedModel from: /var/folders/bs/x0lj933d1hv0py0d4w2ypdp40000gn/T/tmp9sy2ziyq
2025-06-05 18:34:52.650168: I tensorflow/cc/saved_model/reader.cc:91] Reading meta graph with tags { serve }
2025-06-05 18:34:52.650176: I tensorflow/cc/saved_model/reader.cc:132] Reading SavedModel debug info (if present) from: /var/folders/bs/x0lj933d1hv0py0d4w2ypdp40000gn/T/tmp9sy2ziyq
2025-06-05 18:34:52.653092: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:375] MLIR V1 optimization pass is not enabled
2025-06-05 18:34:52.653974: I tensorflow/cc/saved_model/loader.cc:231] Restoring SavedModel bundle.
2025-06-05 18:34:52.679308: I tensorflow/cc/saved_model/loader.

In [6]:
tflite_model_file = 'tflite_models/model_pruned.tflite'
interpreter = tf.lite.Interpreter(model_path=tflite_model_file)
interpreter.allocate_tensors()

input_index = interpreter.get_input_details()[0]['index']
output_index = interpreter.get_output_details()[0]['index']

pred_list = []

for images in X_test:
    input_data = np.array(images, dtype=np.float32)
    input_data = input_data.reshape(1, input_data.shape[0], input_data.shape[1], 1)  # Add batch & channel dims
    interpreter.set_tensor(input_index, input_data)
    interpreter.invoke()
    prediction = interpreter.get_tensor(output_index)
    predicted_label = np.argmax(prediction)
    pred_list.append(predicted_label)

accurate_count = 0
for index in range(len(pred_list)):
    if pred_list[index] == np.argmax(y_test[index]):
        accurate_count += 1

accuracy = accurate_count / len(pred_list)
print('Accuracy =', accuracy)

INFO: Created TensorFlow Lite XNNPACK delegate for CPU.


Accuracy = 0.9022


In [8]:
import os
import zipfile
import tempfile

def get_gzipped_model(file_path):
    # Create a temporary zip file
    _, zipped_file = tempfile.mkstemp('.zip')
    with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
        f.write(file_path, arcname=os.path.basename(file_path))
    size = os.path.getsize(zipped_file)
    os.remove(zipped_file)  # Clean up temp zip file
    return size

# Example usage:
#baseline_model_path = 'tflite_models/model.tflite'
pruned_model_path = 'tflite_models/model_pruned.tflite'

#print('Size of compressed baseline model: %.2f bytes' % get_gzipped_model(baseline_model_path))
print('Size of zipped pruned TFlite model: %.2f bytes' % get_gzipped_model(pruned_model_path))


Size of zipped pruned TFlite model: 245906.00 bytes


In [9]:
import time
import numpy as np
import tensorflow as tf
import os
import zipfile
import tempfile
import pathlib

# Paths
tflite_model_file = 'tflite_models/model_pruned.tflite'

# Load TFLite model
interpreter = tf.lite.Interpreter(model_path=tflite_model_file)
interpreter.allocate_tensors()
input_index = interpreter.get_input_details()[0]['index']
output_index = interpreter.get_output_details()[0]['index']

# Function to get zipped model size in KB
def get_gzipped_model_size(file_path):
    _, zipped_file = tempfile.mkstemp('.zip')
    with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
        f.write(file_path, arcname=os.path.basename(file_path))
    size = os.path.getsize(zipped_file)
    os.remove(zipped_file)
    return size / 1024  # KB

# Run inference on test data and compute accuracy + timing
pred_list = []
total_inference_time = 0.0

for image in X_test:
    input_data = np.array(image, dtype=np.float32)
    input_data = input_data.reshape(1, input_data.shape[0], input_data.shape[1], 1)  # Adjust if needed

    start_time = time.time()
    interpreter.set_tensor(input_index, input_data)
    interpreter.invoke()
    inference_time = time.time() - start_time
    total_inference_time += inference_time

    prediction = interpreter.get_tensor(output_index)
    pred_label = np.argmax(prediction)
    pred_list.append(pred_label)

# Calculate accuracy
accurate_count = sum([pred_list[i] == np.argmax(y_test[i]) for i in range(len(pred_list))])
accuracy = accurate_count / len(pred_list)

# Calculate average inference time per sample (in milliseconds)
avg_inference_time_ms = (total_inference_time / len(pred_list)) * 1000

# Get model sizes
model_size_kb = os.path.getsize(tflite_model_file) / 1024
compressed_model_size_kb = get_gzipped_model_size(tflite_model_file)

print(f"Accuracy on test set: {accuracy*100:.2f}%")
print(f"Average inference time per sample: {avg_inference_time_ms:.3f} ms")
print(f"Model size (uncompressed): {model_size_kb:.2f} KB")
print(f"Model size (compressed ZIP): {compressed_model_size_kb:.2f} KB")


Accuracy on test set: 90.22%
Average inference time per sample: 0.154 ms
Model size (uncompressed): 706.16 KB
Model size (compressed ZIP): 240.14 KB
