In [1]:
!pip install tensorflow-model-optimization
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_model_optimization as tfmot

Collecting tensorflow-model-optimization
  Downloading tensorflow_model_optimization-0.8.0-py2.py3-none-any.whl.metadata (904 bytes)
Downloading tensorflow_model_optimization-0.8.0-py2.py3-none-any.whl (242 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/242.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m242.5/242.5 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tensorflow-model-optimization
Successfully installed tensorflow-model-optimization-0.8.0


In [2]:
# Load the TF Flowers dataset
dataset, info = tfds.load("tf_flowers", as_supervised=True, with_info=True)
num_classes = info.features['label'].num_classes

def preprocess(image, label):
    image = tf.image.resize(image, (224, 224)) / 255.0  # Normalize
    return image, label

# Prepare training and validation datasets
train_data = dataset['train'].map(preprocess).batch(32).shuffle(1000).prefetch(tf.data.AUTOTUNE)
val_data = dataset['train'].map(preprocess).batch(32).prefetch(tf.data.AUTOTUNE)

Downloading and preparing dataset 218.21 MiB (download: 218.21 MiB, generated: 221.83 MiB, total: 440.05 MiB) to /root/tensorflow_datasets/tf_flowers/3.0.1...


Dl Completed...:   0%|          | 0/5 [00:00<?, ? file/s]

Dataset tf_flowers downloaded and prepared to /root/tensorflow_datasets/tf_flowers/3.0.1. Subsequent calls will reuse this data.


In [3]:
# Define a simple CNN model
base_model = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3), include_top=False, weights='imagenet')
base_model.trainable = False

model = tf.keras.Sequential([
    base_model,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(num_classes, activation='softmax')
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Train the baseline model
model.fit(train_data, validation_data=val_data, epochs=5)

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top.h5
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tf_keras.src.callbacks.History at 0x78e63bcef090>

In [5]:
# Apply gradient-based pruning
def apply_gradient_based_pruning(model):
    prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

    # Define pruning parameters with gradient-based selection
    pruning_params = {
        'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.2,
                                                                  final_sparsity=0.5,
                                                                  begin_step=0,
                                                                  end_step=1000)
    }

    pruned_model = prune_low_magnitude(model, **pruning_params)
    return pruned_model

# Apply pruning and recompile
pruned_model = apply_gradient_based_pruning(model)
pruned_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Fine-tune pruned model
# Include the UpdatePruningStep callback during training
callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep()
]
pruned_model.fit(train_data, validation_data=val_data, epochs=5, callbacks=callbacks)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tf_keras.src.callbacks.History at 0x78e5adacce50>

In [6]:
# Evaluate pruned model
_, pruned_acc = pruned_model.evaluate(val_data)
print(f"Pruned Model Accuracy: {pruned_acc:.4f}")

Pruned Model Accuracy: 0.7330
