In [None]:
import tensorflow as tf

# Load the existing model
model = tf.keras.models.load_model('/mnt/data/emotion_detection_model.h5')

In [None]:
import tensorflow_model_optimization as tfmot

# Define the pruning parameters
pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
        initial_sparsity=0.50,
        final_sparsity=0.80,
        begin_step=0,
        end_step=2000
    )
}

# Apply pruning to the model
pruned_model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)

# Compile the pruned model
pruned_model.compile(optimizer='adam',
                     loss='sparse_categorical_crossentropy',
                     metrics=['accuracy'])

In [None]:
# Train the pruned model
pruned_model.fit(train_dataset, epochs=2, validation_data=validation_dataset)

# Strip the pruning wrappers
model_for_export = tfmot.sparsity.keras.strip_pruning(pruned_model)

# Save the pruned model
model_for_export.save('/mnt/data/pruned_model.h5')

In [None]:
# Convert the pruned model to a TensorFlow Lite model with quantization
converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_tflite_model = converter.convert()

# Save the quantized model
with open('/mnt/data/quantized_model.tflite', 'wb') as f:
    f.write(quantized_tflite_model)

In [None]:
# Load the quantized model and evaluate its accuracy
interpreter = tf.lite.Interpreter(model_path='/mnt/data/quantized_model.tflite')
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Function to evaluate the quantized model
def evaluate_tflite_model(interpreter, test_dataset):
    total, correct = 0, 0
    for images, labels in test_dataset:
        interpreter.set_tensor(input_details[0]['index'], images)
        interpreter.invoke()
        predictions = interpreter.get_tensor(output_details[0]['index'])
        total += labels.shape[0]
        correct += (tf.argmax(predictions, axis=1) == labels).numpy().sum()
    return correct / total

accuracy = evaluate_tflite_model(interpreter, test_dataset)
print(f'Accuracy of the pruned and quantized model: {accuracy * 100:.2f}%')