In [3]:
!pip install tensorflow-model-optimization
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_model_optimization as tfmot



In [4]:
# Load the TF Flowers dataset
dataset, info = tfds.load("tf_flowers", as_supervised=True, with_info=True)
num_classes = info.features['label'].num_classes

def preprocess(image, label):
    image = tf.image.resize(image, (224, 224)) / 255.0  # Normalize
    return image, label

# Prepare training and validation datasets
train_data = dataset['train'].map(preprocess).batch(32).shuffle(1000).prefetch(tf.data.AUTOTUNE)
val_data = dataset['train'].map(preprocess).batch(32).prefetch(tf.data.AUTOTUNE)

In [5]:
# Define a simple CNN model
base_model = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3), include_top=False, weights='imagenet')
base_model.trainable = False

model = tf.keras.Sequential([
    base_model,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(num_classes, activation='softmax')
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Train the baseline model
model.fit(train_data, validation_data=val_data, epochs=5)

# Apply random pruning
def apply_random_pruning(model):
    prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

    # Define pruning parameters
    pruning_params = {
        'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.2,
                                                                  final_sparsity=0.5,
                                                                  begin_step=0,
                                                                  end_step=1000)
    }

    pruned_model = prune_low_magnitude(model, **pruning_params)
    return pruned_model

# Apply pruning and recompile
pruned_model = apply_random_pruning(model)
pruned_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Add the UpdatePruningStep callback
callbacks = [tfmot.sparsity.keras.UpdatePruningStep()]

# Fine-tune pruned model with the callback
pruned_model.fit(train_data,
                  validation_data=val_data,
                  epochs=5,
                  callbacks=callbacks)  # Pass the callbacks here

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tf_keras.src.callbacks.History at 0x7859b6723f90>

In [6]:
# Evaluate pruned model
_, pruned_acc = pruned_model.evaluate(val_data)
print(f"Pruned Model Accuracy: {pruned_acc:.4f}")

Pruned Model Accuracy: 0.7357
