In [1]:
!pip install tensorflow-model-optimization

import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_model_optimization as tfmot
import numpy as np

# Load the oxford_flowers102 dataset
dataset_name = "oxford_flowers102"
(ds_train, ds_test), ds_info = tfds.load(dataset_name, split=["train[:80%]", "train[80%:]"], as_supervised=True, with_info=True)

# Preprocessing function
IMG_SIZE = 224

def preprocess(image, label):
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE)) / 255.0
    return image, label

ds_train = ds_train.map(preprocess).batch(32).prefetch(tf.data.AUTOTUNE)
ds_test = ds_test.map(preprocess).batch(32).prefetch(tf.data.AUTOTUNE)

# Define a simple CNN model
base_model = tf.keras.applications.MobileNetV2(input_shape=(IMG_SIZE, IMG_SIZE, 3), include_top=False, weights='imagenet')
base_model.trainable = False  # Freeze base model

model = tf.keras.Sequential([
    base_model,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(128, activation="relu"),
    tf.keras.layers.Dense(ds_info.features["label"].num_classes, activation="softmax")
])

# Compile the model
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])

# Train the model
model.fit(ds_train, epochs=5, validation_data=ds_test)

# Apply Magnitude-Based Pruning
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

pruning_params = {
    "pruning_schedule": tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.2, final_sparsity=0.8, begin_step=0, end_step=1000),
}

pruned_model = prune_low_magnitude(model, **pruning_params)

# Compile the pruned model
pruned_model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])

# Train the pruned model for fine-tuning
pruned_model.fit(ds_train, epochs=3, validation_data=ds_test, callbacks=[tfmot.sparsity.keras.UpdatePruningStep()])

# Strip pruning for deployment
pruned_model = tfmot.sparsity.keras.strip_pruning(pruned_model)

# Compile the stripped model again before evaluation
pruned_model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]) # Re-compile the stripped model

# Evaluate the pruned model
loss, acc = pruned_model.evaluate(ds_test)
print(f"Pruned model accuracy: {acc:.4f}")


Collecting tensorflow-model-optimization
  Downloading tensorflow_model_optimization-0.8.0-py2.py3-none-any.whl.metadata (904 bytes)
Downloading tensorflow_model_optimization-0.8.0-py2.py3-none-any.whl (242 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m242.5/242.5 kB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tensorflow-model-optimization
Successfully installed tensorflow-model-optimization-0.8.0
Downloading and preparing dataset 328.90 MiB (download: 328.90 MiB, generated: 331.34 MiB, total: 660.25 MiB) to /root/tensorflow_datasets/oxford_flowers102/2.1.1...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]

Generating train examples...:   0%|          | 0/1020 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/oxford_flowers102/incomplete.Y1XFS7_2.1.1/oxford_flowers102-train.tfrecord…

Generating test examples...:   0%|          | 0/6149 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/oxford_flowers102/incomplete.Y1XFS7_2.1.1/oxford_flowers102-test.tfrecord*…

Generating validation examples...:   0%|          | 0/1020 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/oxford_flowers102/incomplete.Y1XFS7_2.1.1/oxford_flowers102-validation.tfr…

Dataset oxford_flowers102 downloaded and prepared to /root/tensorflow_datasets/oxford_flowers102/2.1.1. Subsequent calls will reuse this data.
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top.h5
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Epoch 1/3
Epoch 2/3
Epoch 3/3
Pruned model accuracy: 0.6716
