<a href="https://colab.research.google.com/github/amelft81/EmbeddedAI/blob/main/Python_Code_for_Model_Pruning_(Revised_with_Specific_Architecture_Load).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import tensorflow as tf
import tensorflow_model_optimization as tfmot
import os
import numpy as np

# --- Configuration ---
# IMPORTANT: Ensure this path correctly points to your 'simple_embedded_model.h5' file.
# If the file is in the same directory as this script, just the filename is fine.
# Otherwise, provide the full path, e.g., '/path/to/your/simple_embedded_model.h5'
# Based on your latest output, it seems the file is now at /content/simple_embedded_model.h5
# If you are running this in a Colab-like environment where files are uploaded to /content/,
# then this path might be correct for that environment.
SIMPLE_MODEL_PATH = '/content/simple_embedded_model.h5' # Updated based on your latest output

# Output directory for optimized models
OUTPUT_DIR = 'optimized_models'
os.makedirs(OUTPUT_DIR, exist_ok=True)

def prune_and_save_model(model_path, target_sparsity=0.75, epochs=10, batch_size=32):
    """
    Loads a Keras model, applies pruning, retrains it, and saves the pruned TFLite model.

    Args:
        model_path (str): Path to the original .h5 model file.
        target_sparsity (float): The final target sparsity (e.g., 0.75 for 75% sparse weights).
                                 Higher values aim for smaller models.
        epochs (int): Number of epochs to retrain the pruned model. More epochs might be needed
                      for higher sparsity to recover accuracy.
        batch_size (int): Batch size for retraining.
    """
    print(f"\n--- Starting Pruning for {os.path.basename(model_path)} ---")

    # 1. Load the original model
    model = None
    try:
        # Try loading the model directly
        model = tf.keras.models.load_model(model_path)
        print(f"Original model '{os.path.basename(model_path)}' loaded successfully.")
        model.summary()
    except Exception as e:
        print(f"ERROR: Could not load model from '{model_path}'.")
        print(f"Details: {e}")
        print("Attempting to load with a specific architecture for simple_embedded_model.h5...")
        try:
            # Based on the error "Model expects 2 layers but the loaded weights have 4 layers."
            # and common simple models, let's define a 4-layer Sequential model (Input + 3 Dense/Conv)
            # This is a strong assumption about your model's structure.
            # If this still fails, you MUST provide the exact model architecture.

            # We need to infer the input shape first. From previous errors, it was (None, 10).
            input_shape_for_dummy = (10,) # Default, adjust if your model's input is different

            # Attempt to infer input shape from the error message again, if it's there
            import re
            match = re.search(r"'batch_shape': \[None, (\d+)\]", str(e))
            if match:
                input_shape_for_dummy = (int(match.group(1)),)
                print(f"Inferred input shape from error: {input_shape_for_dummy}")
            else:
                print(f"Could not infer input shape from error. Using default: {input_shape_for_dummy}")

            # Define a 4-layer sequential model (Input + 3 trainable layers)
            # This is a common structure for simple_embedded_model.h5 based on uTensorEdgeImpulse.ipynb
            # from your uploaded files, which shows:
            # InputLayer(input_shape=(10,)), Dense(8, activation='relu'), Dense(1, activation='sigmoid')
            # This is 3 layers, plus the InputLayer makes 4.
            model = tf.keras.Sequential([
                tf.keras.layers.InputLayer(input_shape=input_shape_for_dummy),
                tf.keras.layers.Dense(8, activation='relu'), # Assuming 8 units based on uTensorEdgeImpulse.ipynb
                tf.keras.layers.Dense(1, activation='sigmoid') # Assuming 1 output unit, sigmoid for binary classification
            ])

            # Load weights into this newly defined model
            model.load_weights(model_path)
            print(f"Model architecture defined and weights loaded from '{os.path.basename(model_path)}'. Model summary:")
            model.summary()

        except Exception as load_weights_e:
            print(f"CRITICAL ERROR: Failed to load model even with specific architecture attempt: {load_weights_e}")
            print("This indicates the assumed architecture might be incorrect or another issue.")
            print("Please provide the exact Keras architecture of your 'simple_embedded_model.h5' if this persists.")
            print("Exiting pruning process for this model.")
            return

    # 2. Prepare Dummy Data for Demonstration (REPLACE THIS WITH YOUR ACTUAL TRAINING DATA)
    # IMPORTANT: Determine the correct input shape for your model.
    # You can inspect `model.input_shape` after loading the model to confirm.
    input_shape_for_dummy = model.input_shape[1:] # Get input shape excluding batch dimension
    num_samples = 1000

    # Determine number of classes/output shape for y_train and loss function
    if len(model.output_shape) > 1:
        num_classes = model.output_shape[-1]
    else:
        num_classes = 1

    X_train = np.random.rand(num_samples, *input_shape_for_dummy).astype(np.float32)

    if num_classes > 1:
        y_train = np.random.randint(0, num_classes, num_samples)
        loss_function = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    elif num_classes == 1:
        if hasattr(model.layers[-1], 'activation') and model.layers[-1].activation == tf.keras.activations.sigmoid:
            y_train = np.random.randint(0, 2, num_samples).astype(np.float32)
            loss_function = tf.keras.losses.BinaryCrossentropy(from_logits=False)
        else:
            y_train = np.random.rand(num_samples, num_classes).astype(np.float32)
            loss_function = tf.keras.losses.MeanSquaredError()
    else:
        y_train = np.random.rand(num_samples, num_classes).astype(np.float32)
        loss_function = tf.keras.losses.MeanSquaredError()


    print(f"Dummy training data created with shape X_train: {X_train.shape}, y_train: {y_train.shape}")


    # 3. Define the pruning schedule
    end_step = len(X_train) // batch_size * epochs

    pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(
        initial_sparsity=0.0,
        final_sparsity=target_sparsity,
        begin_step=0,
        end_step=end_step
    )

    # 4. Apply pruning to the model
    pruned_model = tfmot.sparsity.keras.prune_low_magnitude(model, pruning_schedule=pruning_schedule)

    # 5. Recompile the pruned model (important!)
    pruned_model.compile(
        optimizer='adam',
        loss=loss_function,
        metrics=['accuracy'] if num_classes > 1 or (num_classes == 1 and hasattr(model.layers[-1], 'activation') and model.layers[-1].activation == tf.keras.activations.sigmoid) else ['mse']
    )
    print("Pruned model compiled.")
    pruned_model.summary()

    # 6. Train the pruned model (fine-tuning)
    print(f"\nTraining pruned model with target sparsity {target_sparsity*100:.0f}% over {epochs} epochs...")
    pruned_model.fit(
        X_train, y_train,
        epochs=epochs,
        batch_size=batch_size,
        callbacks=[tfmot.sparsity.keras.UpdatePruningStep()],
        verbose=1
    )
    print("Pruned model training complete.")

    # 7. Strip the pruning wrappers
    model_for_export = tfmot.sparsity.keras.strip_pruning(pruned_model)
    print("Pruning wrappers stripped.")

    # 8. Convert to TFLite model
    converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
    tflite_model = converter.convert()

    # 9. Save the pruned TFLite model
    pruned_tflite_path = os.path.join(OUTPUT_DIR, f'pruned_model_sparsity_{int(target_sparsity*100)}.tflite')
    with open(pruned_tflite_path, 'wb') as f:
        f.write(tflite_model)

    pruned_tflite_size_kb = os.path.getsize(pruned_tflite_path) / 1024
    print(f"Pruned TFLite model saved to: {pruned_tflite_path}")
    print(f"Pruned TFLite model size: {pruned_tflite_size_kb:.2f} KB")

    # Compare with the target size
    target_size_to_beat_kb = 15407.12
    if pruned_tflite_size_kb < target_size_to_beat_kb:
        print(f"SUCCESS: Pruned model size ({pruned_tflite_size_kb:.2f} KB) is LESS than the previous pruned size ({target_size_to_beat_kb:.2f} KB).")
    else:
        print(f"NOTE: Pruned model size ({pruned_tflite_size_kb:.2f} KB) is NOT less than the previous pruned size ({target_size_to_beat_kb:.2f} KB).")
        print("Consider increasing 'target_sparsity' or 'epochs' for more aggressive pruning.")

if __name__ == "__main__":
    # Call the function to prune your simple_embedded_model.h5
    # Start with a higher target_sparsity (e.g., 0.8 or 0.9) to aim for a smaller model.
    # You might also need to increase the number of epochs.
    prune_and_save_model(SIMPLE_MODEL_PATH, target_sparsity=0.85, epochs=20)

    # If you want to try pruning MobileNetV2 (this will take much longer and require more resources):
    # MOBILENET_V2_MODEL_PATH = 'mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224 (1).h5'
    # prune_and_save_model(MOBILENET_V2_MODEL_PATH, target_sparsity=0.5, epochs=5) # Start lower sparsity for large models


--- Starting Pruning for simple_embedded_model.h5 ---
ERROR: Could not load model from '/content/simple_embedded_model.h5'.
Details: Error when deserializing class 'InputLayer' using config={'batch_shape': [None, 10], 'dtype': 'float32', 'sparse': False, 'name': 'input_layer'}.

Exception encountered: Unrecognized keyword arguments: ['batch_shape']
Attempting to load with a specific architecture for simple_embedded_model.h5...
Inferred input shape from error: (10,)
Model architecture defined and weights loaded from 'simple_embedded_model.h5'. Model summary:
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense (Dense)               (None, 8)                 88        
                                                                 
 dense_1 (Dense)             (None, 1)                 9         
                                                                 
Total params: 97 (3

In [2]:
!pip install tensorflow-model-optimization

Collecting tensorflow-model-optimization
  Downloading tensorflow_model_optimization-0.8.0-py2.py3-none-any.whl.metadata (904 bytes)
Collecting numpy~=1.23 (from tensorflow-model-optimization)
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
Downloading tensorflow_model_optimization-0.8.0-py2.py3-none-any.whl (242 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m242.5/242.5 kB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.3/18.3 MB[0m [31m80.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy, tensorflow-model-optimization
  Attempting uninstall: numpy
    Found existing installation: numpy 2.0.2
    Uninstalling 