In [3]:
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from tensorflow.keras.applications import VGG16
from tensorflow.keras import layers, models
import numpy as np

# Step 1: Load the Pre-trained VGG16 Model
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))

# Step 2: Create a New Model with Classification Layers
model = models.Sequential([
    base_model,
    layers.Flatten(),
    layers.Dense(4096, activation='relu'),
    layers.Dropout(0.5),
    layers.Dense(4096, activation='relu'),
    layers.Dropout(0.5),
    layers.Dense(1000, activation='softmax')  # Assuming ImageNet with 1000 classes
])

# Step 3: Apply Pruning to the Model with High Sparsity (90%)
pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(
    initial_sparsity=0.30,  # Start pruning at 30% sparsity
    final_sparsity=0.90,    # Increase pruning to 90% sparsity for smaller model
    begin_step=0,
    end_step=1000
)

# Wrap the model with pruning
pruned_model = tfmot.sparsity.keras.prune_low_magnitude(
    model,
    pruning_schedule=pruning_schedule
)

# Compile the Pruned Model
pruned_model.compile(optimizer='adam',
                     loss='categorical_crossentropy',
                     metrics=['accuracy'])

# Step 4: Fine-Tune the Pruned Model (Using Dummy Data for Demo)
dummy_data = np.random.rand(10, 224, 224, 3)
dummy_labels = np.random.randint(0, 1000, size=(10, 1000))

# Define Pruning Callbacks
callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep(),  # Required for pruning
    tfmot.sparsity.keras.PruningSummaries(log_dir='./pruning_logs')  # Optional: Logs pruning stats
]

# Train the Model with Pruning Callbacks
pruned_model.fit(dummy_data, dummy_labels, epochs=2, batch_size=5, callbacks=callbacks)

# Step 5: Strip the Pruning Wrappers for TensorFlow Lite Conversion
stripped_model = tfmot.sparsity.keras.strip_pruning(pruned_model)

# Step 6: Apply Weight Clustering for Further Size Reduction
clustering_params = {
    'number_of_clusters': 8,
    'cluster_centroids_init': tfmot.clustering.keras.CentroidInitialization.KMEANS_PLUS_PLUS
}

# Apply clustering
clustered_model = tfmot.clustering.keras.cluster_weights(stripped_model, **clustering_params)

# Recompile the clustered model before fine-tuning
clustered_model.compile(optimizer='adam',
                        loss='categorical_crossentropy',
                        metrics=['accuracy'])

# Fine-tune the clustered model (Using the same dummy data)
clustered_model.fit(dummy_data, dummy_labels, epochs=2, batch_size=5)

# Step 7: Strip Clustering Wrappers Before Conversion
final_model = tfmot.clustering.keras.strip_clustering(clustered_model)

# Step 8: Define a Representative Dataset for Full Integer Quantization
def representative_data_gen():
    for _ in range(100):
        # Replace with a real sample from the dataset for accurate calibration
        yield [np.random.rand(1, 224, 224, 3).astype(np.float32)]

# Step 9: Convert to a Full Integer Quantized TFLite Model
converter = tf.lite.TFLiteConverter.from_keras_model(final_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8  # Quantize inputs to int8
converter.inference_output_type = tf.int8  # Quantize outputs to int8

# Convert the model to TensorFlow Lite
quantized_tflite_model = converter.convert()

# Step 10: Save the Fully Optimized and Quantized Model as a .tflite File
with open("optimized_vgg16.tflite", "wb") as f:
    f.write(quantized_tflite_model)

print("Fully optimized and quantized VGG16 model has been saved as 'optimized_vgg16.tflite'.")

Epoch 1/2
Epoch 2/2
INFO:tensorflow:Assets written to: /var/folders/hb/fpmv_swd6pb2nt84zqhq62fh0000gq/T/tmpl8m51r1u/assets


INFO:tensorflow:Assets written to: /var/folders/hb/fpmv_swd6pb2nt84zqhq62fh0000gq/T/tmpl8m51r1u/assets
2024-10-09 21:22:50.970916: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:364] Ignored output_format.
2024-10-09 21:22:50.970946: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:367] Ignored drop_control_dependency.
2024-10-09 21:22:50.971591: I tensorflow/cc/saved_model/reader.cc:45] Reading SavedModel from: /var/folders/hb/fpmv_swd6pb2nt84zqhq62fh0000gq/T/tmpl8m51r1u
2024-10-09 21:22:50.974013: I tensorflow/cc/saved_model/reader.cc:91] Reading meta graph with tags { serve }
2024-10-09 21:22:50.974020: I tensorflow/cc/saved_model/reader.cc:132] Reading SavedModel debug info (if present) from: /var/folders/hb/fpmv_swd6pb2nt84zqhq62fh0000gq/T/tmpl8m51r1u
2024-10-09 21:22:50.978793: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:375] MLIR V1 optimization pass is not enabled
2024-10-09 21:22:50.981153: I tensorflow/cc/saved_model/load

Pruned and Quantized VGG16 model has been saved as 'pruned_quantized_vgg16.tflite'.
