In [None]:
!pip install matplotlib tensorflow_model_optimization
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.datasets import cifar100
import tensorflow_model_optimization as tfmot
from tqdm import tqdm

Using TensorFlow backend


In [None]:
n_epoch = 1000 # Il y a le early stopping
batch_size = 100
taux_validation = 0.1
num_classes = 100
n_images = 50000 # Pour l'entrainement, et 10000 pour le test

In [None]:
!mc cp s3/afeldmann/modele_enseignant.keras ~/modele_enseignant.keras
modele_enseignant = tf.keras.models.load_model("~/modele_enseignant.keras")
!mc cp s3/afeldmann/modele_eleve_2.keras ~/modele_eleve_2.keras
modele_eleve_2 = tf.keras.models.load_model("~/modele_eleve_2.keras")

In [None]:
def format_image(image, label):
    image = tf.image.resize(image, (224, 224)) / 255.0
    label = tf.squeeze(tf.one_hot(label, depth = num_classes), axis = 0)
    return  image, label

def distillation_hors_ligne(image, label):
    label = modele_enseignant.predict(image)
    return  image, label

train_dataset, test_dataset = cifar100.load_data()

validation_size = int(n_images * taux_validation)
raw_train_dataset = tf.data.Dataset.from_tensor_slices(train_dataset)
train_dataset = tf.data.Dataset.from_tensor_slices(train_dataset).map(distillation_hors_ligne).shuffle(n_images)
test_dataset = tf.data.Dataset.from_tensor_slices(test_dataset).map(format_image)

validation_dataset = train_dataset.take(validation_size).batch(batch_size).prefetch(tf.data.AUTOTUNE)
train_dataset = train_dataset.skip(validation_size).batch(batch_size).prefetch(tf.data.AUTOTUNE)

In [None]:
def apply_pruning_to_dense(layer):
  if isinstance(layer, tf.keras.layers.Dense):
    return tfmot.sparsity.keras.prune_low_magnitude(layer)
  return layer
    
modele_eleve_2_pruning = tf.keras.models.clone_model(
    modele_eleve_2,
    clone_function=apply_pruning_to_dense
)



In [None]:
callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep()
]

modele_eleve_2_pruning.compile(optimizer='adam', loss='kl_divergence', metrics=['accuracy'])

modele_eleve_2_pruning.fit(train_dataset, callbacks=callbacks, epochs=2)

Epoch 1/2
Epoch 2/2


<keras.src.callbacks.History at 0x7a408c100cd0>

In [None]:
modele_eleve_2_pruning.evaluate(X_test,y_test)



[2.223886489868164, 0.41589999198913574]

In [None]:
modele_eleve_2.evaluate(test)



[2.223886489868164, 0.41589999198913574]

In [None]:
model_for_export = tfmot.sparsity.keras.strip_pruning(modele_eleve_2_pruning)
_, pruned_keras_file = tempfile.mkstemp('.h5')
tf.keras.models.save_model(model_for_export, pruned_keras_file, include_optimizer=False)

  tf.keras.models.save_model(model_for_export, pruned_keras_file, include_optimizer=False)


Saved pruned Keras model to: /tmp/tmpfu09hwmk.h5


In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
pruned_tflite_model = converter.convert()
pruned_tflite_file = "~/modele_eleve_2_pruning.tflite"
with open(pruned_tflite_file, 'wb') as f:
  f.write(pruned_tflite_model)

print('Saved pruned TFLite model to:', pruned_tflite_file)


Saved pruned TFLite model to: /content/drive/MyDrive/projet_cnam/modele_prune.tflite


In [None]:
def format_image_raw(image, label):
    image = tf.image.resize(image, (224, 224)) / 255.0
    return  image, label

test_batches = raw_test.map(format_image_raw).batch(1)

interpreter = tf.lite.Interpreter(model_path=pruned_tflite_file)
interpreter.allocate_tensors()

input_index = interpreter.get_input_details()[0]["index"]
output_index = interpreter.get_output_details()[0]["index"]

predictions = []

test_labels, test_imgs = [], []
for img, label in tqdm(test_batches.take(1000)):
    interpreter.set_tensor(input_index, img)
    interpreter.invoke()
    predictions.append(interpreter.get_tensor(output_index))

    test_labels.append(label.numpy()[0])
    test_imgs.append(img)


score = 0
for item in range(0,len(predictions)):
  prediction=np.argmax(predictions[item])
  label = test_labels[item]
  if prediction==label:
    score=score+1

print("Out of 1000 predictions I got " + str(score) + " correct")

100%|██████████| 1000/1000 [00:08<00:00, 115.75it/s]

Out of 1000 predictions I got 421 correct



