In [1]:
!pip install tensorflow-model-optimization
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_model_optimization as tfmot

Collecting tensorflow-model-optimization
  Downloading tensorflow_model_optimization-0.8.0-py2.py3-none-any.whl.metadata (904 bytes)
Downloading tensorflow_model_optimization-0.8.0-py2.py3-none-any.whl (242 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m242.5/242.5 kB[0m [31m994.7 kB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tensorflow-model-optimization
Successfully installed tensorflow-model-optimization-0.8.0


In [2]:
# Load the Oxford Flowers 102 dataset
dataset, info = tfds.load("oxford_flowers102", as_supervised=True, with_info=True)
num_classes = info.features['label'].num_classes

# Preprocess function
def preprocess(image, label):
    image = tf.image.resize(image, (224, 224)) / 255.0  # Normalize
    return image, label

# Prepare train and validation datasets
train_data = dataset['train'].map(preprocess).batch(32).shuffle(1000).prefetch(tf.data.AUTOTUNE)
val_data = dataset['validation'].map(preprocess).batch(32).prefetch(tf.data.AUTOTUNE)

Downloading and preparing dataset 328.90 MiB (download: 328.90 MiB, generated: 331.34 MiB, total: 660.25 MiB) to /root/tensorflow_datasets/oxford_flowers102/2.1.1...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]

Generating train examples...:   0%|          | 0/1020 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/oxford_flowers102/incomplete.CM0ZSI_2.1.1/oxford_flowers102-train.tfrecord…

Generating test examples...:   0%|          | 0/6149 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/oxford_flowers102/incomplete.CM0ZSI_2.1.1/oxford_flowers102-test.tfrecord*…

Generating validation examples...:   0%|          | 0/1020 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/oxford_flowers102/incomplete.CM0ZSI_2.1.1/oxford_flowers102-validation.tfr…

Dataset oxford_flowers102 downloaded and prepared to /root/tensorflow_datasets/oxford_flowers102/2.1.1. Subsequent calls will reuse this data.


In [3]:
# Define a MobileNetV2-based CNN model
base_model = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3), include_top=False, weights='imagenet')
base_model.trainable = False

model = tf.keras.Sequential([
    base_model,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(num_classes, activation='softmax')
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Train the baseline model
model.fit(train_data, validation_data=val_data, epochs=5)


Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top.h5
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tf_keras.src.callbacks.History at 0x7ea138e54dd0>

In [4]:
# Apply random pruning
def apply_random_pruning(model):
    prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

    pruning_params = {
        'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
            initial_sparsity=0.2,
            final_sparsity=0.5,
            begin_step=0,
            end_step=1000
        )
    }

    pruned_model = prune_low_magnitude(model, **pruning_params)
    return pruned_model

# Apply pruning and recompile
pruned_model = apply_random_pruning(model)
pruned_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Add the UpdatePruningStep callback
callbacks = [tfmot.sparsity.keras.UpdatePruningStep()]

# Fine-tune pruned model
pruned_model.fit(train_data, validation_data=val_data, epochs=5, callbacks=callbacks)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tf_keras.src.callbacks.History at 0x7ea13957dd90>

In [5]:
# Evaluate pruned model
_, pruned_acc = pruned_model.evaluate(val_data)
print(f"Pruned Model Accuracy: {pruned_acc:.4f}")

Pruned Model Accuracy: 0.6775
