In [None]:
import numpy as np
import os

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"]="0"
os.environ['TF_DETERMINISTIC_OPS'] = '1'

In [None]:
import tensorflow as tf
from tensorflow.python.framework.ops import enable_eager_execution
enable_eager_execution()

In [None]:
tf.config.list_physical_devices('GPU')

In [None]:
import tensorflow_datasets as tfds

In [None]:
img_rows, img_cols, num_channel = 224 ,224,3

# Load Dataset - Imagenet2012 validation set

In [None]:
def preprocess_image_mobilenet(features):
    """Preprocesses the given image.

      Args:
        image: `Tensor` representing an image of arbitrary size.
  """
    image = features["image"]
    image = tf.image.resize(image,[224,224])
    image = tf.keras.applications.mobilenet.preprocess_input(image)
    
    features["image"] = image
    return features["image"], features["label"]

In [None]:
BATCH_SIZE = 32

In [None]:
#first 40% of 50000 images to be the train dataset - 20000 in total
tfds_dataset1, tfds_info  = tfds.load(name='imagenet2012_subset', split='validation[:40%]', with_info=True,
                                     data_dir='../../datasets/ImageNet/') 

In [None]:
#last 60% of 50000 images to be the train dataset - 30000 in total
tfds_dataset2, tfds_info  = tfds.load(name='imagenet2012_subset', split='validation[-60%:]', with_info=True,
                                     data_dir='../../datasets/ImageNet/')

In [None]:
train_ds = tfds_dataset1.map(preprocess_image_mobilenet).batch(BATCH_SIZE).prefetch(1)
val_ds = tfds_dataset2.map(preprocess_image_mobilenet).batch(BATCH_SIZE).prefetch(1)

In [None]:
num_images = tfds_info.splits['validation[:40%]'].num_examples
num_classes = tfds_info.features['label'].num_classes
print(num_images)
print(num_classes)

In [None]:
figs = tfds.show_examples(tfds_dataset1, tfds_info)

# Model Instability Evaluation Functions

In [1]:
def Instability(model,q_model, ds, batch_size):# evalute instability
    accurate_pred = set()
    accurate_q_pred = set()
    for n, features in enumerate(ds):
        logits = model.predict(features[0])
        q_logits = q_model.predict(features[0])
        p = [(i + n*batch_size) for i, j in enumerate(zip(list(tf.argmax(logits, axis=-1).numpy()),list(features[1].numpy()))) if all(j[0]==k for k in j[1:])]
        q_p = [(i + n*batch_size) for i, j in enumerate(zip(list(tf.argmax(q_logits, axis=-1).numpy()),list(features[1].numpy()))) if all(j[0]==k for k in j[1:])]
        accurate_pred.update(p)
        accurate_q_pred.update(q_p)
        if (n + 1) % 50 == 0:
            print("Finished %d examples" % ((n + 1) * batch_size))
    q_correct = len(accurate_q_pred)
    orig_correct = len(accurate_pred)
    q_correct_orig_wrong = len(accurate_q_pred.difference(accurate_pred))
    q_wrong_orig_correct = len(accurate_pred.difference(accurate_q_pred))
    return q_correct, orig_correct, q_correct_orig_wrong, q_wrong_orig_correct

# Build and Evaluate Models (original model = ft32 model , q_model = fake int 8 quantization model on mobilenet)

In [None]:
import tensorflow_model_optimization as tfmot

In [None]:
model = tf.keras.applications.MobileNet(input_shape=(img_rows, img_cols,num_channel))

In [None]:
model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=2e-5),
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['accuracy'])

In [None]:
#model.fit(train_ds,epochs=10)

In [None]:
#model.save("../../weights/fp_model_40_mobilenet.h5")

In [None]:
model.load_weights("../../weights/fp_model_40_mobilenet.h5")

In [None]:
model.evaluate(val_ds)

In [None]:
q_model = tfmot.quantization.keras.quantize_model(model)
q_model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=2e-5),
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['accuracy'])

In [None]:
# import datetime
# log_dir = "./logs/fit/mobilenet/q_model" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
# tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

In [None]:
# checkpoint_filepath = './tmp/checkpoint_q_model_40_mobilenet'
# model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
#     filepath=checkpoint_filepath,
#     save_weights_only=True,
#     monitor='val_accuracy',
#     mode='max',
#     save_best_only=True)
# q_model.fit(train_ds,
#           epochs=10,
#           validation_data= val_ds,
#           callbacks=[model_checkpoint_callback, tensorboard_callback])

In [None]:
#q_model.save("../../weights/q_model_40_mobilenet.h5")

In [None]:
#q_model.load_weights(checkpoint_filepath)
q_model.load_weights("../../weights/q_model_40_mobilenet.h5")

In [None]:
q_model.evaluate(val_ds)

In [None]:
Instability(model,q_model, val_ds, BATCH_SIZE)

# Make Surrogate Model

In [None]:
from tensorflow.keras import Model

In [None]:
# 1% of the train dataset of imgenet2012 as the dataset used for making the surrogate fp model
tfds_dataset3, tfds_info  = tfds.load(name='imagenet2012_subset', split='train[:100%]', with_info=True,
                                     data_dir='../../datasets/ImageNet/')
train_ds_ = tfds_dataset3.map(preprocess_image_mobilenet).batch(BATCH_SIZE).prefetch(1)

In [None]:
class Distiller(Model):
    def __init__(self, student, teacher):
        super(Distiller, self).__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        """ Configure the distiller.

        Args:
            optimizer: Keras optimizer for the student weights
            metrics: Keras metrics for evaluation
            student_loss_fn: Loss function of difference between student
                predictions and ground-truth
            distillation_loss_fn: Loss function of difference between soft
                student predictions and soft teacher predictions
            alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
            temperature: Temperature for softening probability distributions.
                Larger temperature gives softer distributions.
        """
        super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
        # Unpack data
        x, y = data

        # Forward pass of teacher
        teacher_predictions = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
            # Forward pass of student
            student_predictions = self.student(x, training=True)

            # Compute losses
            student_loss = self.student_loss_fn(y, student_predictions)
            distillation_loss = self.distillation_loss_fn(
                tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                tf.nn.softmax(student_predictions / self.temperature, axis=1),
            )
            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

        # Compute gradients
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update the metrics configured in `compile()`.
        self.compiled_metrics.update_state(y, student_predictions)
        
        # Collect metrics to return
        return_metrics = {m.name: m.result() for m in self.metrics}
        for metric in self.metrics:
            result = metric.result()
            if isinstance(result, dict):
                return_metrics.update(result)
            else:
                return_metrics[metric.name] = result
        return_metrics.update(
             {"student_loss": student_loss, "distillation_loss": distillation_loss}
         )
        return return_metrics

    def test_step(self, data):
        # Unpack the data
        x, y = data

        # Compute predictions
        y_prediction = self.student(x, training=False)

        # Calculate the loss
        student_loss = self.student_loss_fn(y, y_prediction)

        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)

        # Return a dict of performance
        return_metrics = {m.name: m.result() for m in self.metrics}
        for metric in self.metrics:
            result = metric.result()
            if isinstance(result, dict):
                return_metrics.update(result)
            else:
                return_metrics[metric.name] = result
                
        return_metrics.update({"student_loss": student_loss})
        return return_metrics

In [None]:
d_model = tf.keras.applications.MobileNet(input_shape=(img_rows, img_cols,num_channel))
d_model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=2e-5),
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['accuracy'])

In [None]:
distiller = Distiller(student=d_model, teacher=q_model)
distiller.compile(
    optimizer=tf.keras.optimizers.RMSprop(lr=2e-5),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
    student_loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=tf.keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=10,
)

# Distill teacher to student
distiller.fit(train_ds_, validation_data= val_ds,epochs=5)

In [None]:
#d_model.save("../../weights/distilled_fp_model_40_mobilenet.h5")

In [None]:
d_model.load_weights("../../weights/distilled_fp_model_40_mobilenet.h5")

In [None]:
d_model.evaluate(val_ds)

In [None]:
Instability(model,d_model, val_ds, BATCH_SIZE)