# Self Distillation rebuilt

## A new from the ground attempt at self distillation to try to sort out some of the issues that we faced in previous builds and testing.

## A new notebook for a new year.
---
## Primary issues
The main problems faced previously were to do with establishing a baseline of performance of a model to compare the self distilation work against. 


---


In [1]:
import os
import random
# os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import tensorflow as tf
from tensorflow import keras
from keras import layers, models
os.environ['TF_DETERMINISTIC_OPS'] = '1'
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sys
import os
sys.path.append("..") # Adds higher directory to python modules path.
import branchingdnn as branching
# dataset = branching.dataset.prepare.dataset(tf.keras.datasets.cifar10.load_data(),64,5000,22500,(227,227),include_targets=False)

In [2]:
print(tf.__version__)

2.8.0


In [3]:
class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super(Distiller, self).__init__()
        self.teacher = teacher
        self.student = tf.keras.models.clone_model(student)
    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=1,        
    ):
        """ Configure the distiller.
        Args:
            optimizer: Keras optimizer for the student weights
            metrics: Keras metrics for evaluation
            student_loss_fn: Loss function of difference between student
                predictions and ground-truth
            distillation_loss_fn: Loss function of difference between soft
                student predictions and soft teacher predictions
            alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
            temperature: Temperature for softening probability distributions.
                Larger temperature gives softer distributions.
        """
        super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
        # Unpack data
        x, y = data
        # Forward pass of teacher
        teacher_predictions = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
            # Forward pass of student
            student_predictions = self.student(x, training=True)

            # Compute losses
            student_loss = self.student_loss_fn(y, student_predictions)
            distillation_loss = self.distillation_loss_fn(
                (teacher_predictions / self.temperature),
                (student_predictions / self.temperature),
            )
            student_loss = student_loss * self.alpha
            distillation_loss = (distillation_loss) * (1 - self.alpha)
#             loss = distillation_loss
            loss=student_loss +distillation_loss
        # Compute gradients
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update the metrics configured in `compile()`.
        self.compiled_metrics.update_state(y, student_predictions)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"loss":loss,"student_loss": student_loss, "distillation_loss": distillation_loss}
        )
        return results

    def test_step(self, data):
        # Unpack the data
        x, y = data
        # Compute predictions
        y_prediction = self.student(x, training=False)

        # Calculate the loss
        student_loss = self.student_loss_fn(y, y_prediction)

        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results

In [4]:
from tensorflow.keras import layers, models

def summarize_keras_trainable_variables(model, message):
    s = sum(map(lambda x: x.sum(), model.get_weights()))
    print("summary of trainable variables %s: %.13f" % (message, s))
    return s


In [5]:
seed = 66
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)

# tf.debugging.experimental.enable_dump_debug_info(logdir, tensor_debug_mode="FULL_HEALTH", circular_buffer_size=-1)
(train_images, train_labels), (test_images, test_labels) = keras.datasets.cifar10.load_data()

CLASS_NAMES= ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# import csv
# with open('results/altTrain_labels.csv', newline='') as f:
    # reader = csv.reader(f,quoting=csv.QUOTE_NONNUMERIC)
    # alt_trainLabels = list(reader)
# with open('results/altTest_labels.csv', newline='') as f:
    # reader = csv.reader(f,quoting=csv.QUOTE_NONNUMERIC)
    # alt_testLabels = list(reader)

# altTraining = tf.data.Dataset.from_tensor_slices((train_images,alt_trainLabels))

# validation_images, validation_labels = train_images[:5000], alt_trainLabels[:5000]
# train_ds = tf.data.Dataset.from_tensor_slices((train_images, alt_trainLabels))
# test_ds = tf.data.Dataset.from_tensor_slices((test_images, alt_testLabels))
train_labels = tf.keras.utils.to_categorical(train_labels,10)
test_labels = tf.keras.utils.to_categorical(test_labels,10)

###normal method
validation_images, validation_labels = train_images[:5000], train_labels[:5000] #get the first 5k training samples as validation set
train_images, train_labels = train_images[5000:], train_labels[5000:] # now remove the validation set from the training set.
train_ds = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
test_ds = tf.data.Dataset.from_tensor_slices((test_images, test_labels))
validation_ds = tf.data.Dataset.from_tensor_slices((validation_images, validation_labels))

def augment_images(image, label):
    # Normalize images to have a mean of 0 and standard deviation of 1
    # image = tf.image.per_image_standardization(image)
    # Resize images from 32x32 to 277x277
    image = tf.image.resize(image, (227,227))
    return image, label

train_ds_size = len(list(train_ds))
test_ds_size = len(list(test_ds))
validation_ds_size = len(list(validation_ds))

train_ds = (train_ds
                  .map(augment_images)
                  .shuffle(buffer_size=train_ds_size,seed=42,reshuffle_each_iteration=False)
                  .batch(batch_size=32, drop_remainder=True))

test_ds = (test_ds
                  .map(augment_images)
                #   .shuffle(buffer_size=train_ds_size)
                  .batch(batch_size=32, drop_remainder=True))

validation_ds = (validation_ds
                  .map(augment_images)
                #   .shuffle(buffer_size=validation_ds_size)
                  .batch(batch_size=32, drop_remainder=True))


In [6]:
model_teacher = tf.keras.models.load_model("models/alexNetv6_logits_teacher.hdf5")


In [None]:
model_teacher.evaluate(train_ds)

In [60]:
seed = 42
# random.seed(seed)
tf.random.set_seed(seed)
np.random.seed(seed)
inputs = keras.Input(shape=(227,227,3))
# targets = keras.Input(shape=(10,))
x = keras.layers.Conv2D(filters=96, kernel_size=(11,11), strides=(4,4), activation='relu', input_shape=(227,227,3))(inputs)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2))(x)
x = keras.layers.Conv2D(filters=256, kernel_size=(5,5), strides=(1,1), activation='relu', padding="same")(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2))(x)
# x = keras.layers.Conv2D(filters=384, kernel_size=(3,3), strides=(1,1), activation='relu', padding="same")(x)
# x = keras.layers.BatchNormalization()(x)
# x = keras.layers.Conv2D(filters=384, kernel_size=(1,1), strides=(1,1), activation='relu', padding="same")(x)
# x = keras.layers.BatchNormalization()(x)
# x = keras.layers.Conv2D(filters=256, kernel_size=(1,1), strides=(1,1), activation='relu', padding="same")(x)
# x = keras.layers.BatchNormalization()(x)
# x = keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2))(x)
# x = keras.layers.Flatten()(x)
# x = keras.layers.Dense(4096, activation='relu')(x)
# x = keras.layers.Dropout(0.5)(x)

# ### first branch
branchLayer = keras.layers.Flatten(name=tf.compat.v1.get_default_graph().unique_name("branch_flatten"))(x)
branchLayer = keras.layers.Dense(124, activation="relu",name=tf.compat.v1.get_default_graph().unique_name("branch124"))(branchLayer)
branchLayer = keras.layers.Dense(64, activation="relu",name=tf.compat.v1.get_default_graph().unique_name("branch64"))(branchLayer)
x = keras.layers.Dense(10, activation="softmax", name=tf.compat.v1.get_default_graph().unique_name("branch_output"))(branchLayer)


student_model = keras.Model(inputs=(inputs), outputs=[x], name="alexnet")

student_model.compile(loss='categorical_crossentropy', optimizer=tf.optimizers.SGD(lr=0.001,momentum=0.9), metrics=['accuracy'])
student_model.save("models/alexNetv6_second_Exit.hdf5")



[0.10008857399225235, 0.9731730222702026]

## First, train the student model without the teacher input to get a baseline

In [31]:
seed = 66
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
## keep setting the seed so that it doesn't matter what order you complete cells in. 
student_model = tf.keras.models.load_model("models/alexNetv6_logits_student.hdf5")
summarize_keras_trainable_variables(student_model,"before training")
student_model.fit(train_ds, validation_data = validation_ds, epochs=3)
summarize_keras_trainable_variables(student_model,"after training")


summary of trainable variables before training: 2688.9488805532455
Epoch 1/3
Epoch 2/3
Epoch 3/3
summary of trainable variables after training: 12822187.2433519028127


12822187.243351903

In [12]:
student_model.evaluate(test_ds)



[0.7575096487998962, 0.7360777258872986]

In [8]:
class TestSetCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        results = self.model.evaluate(test_ds)
        print(results)
seed = 66
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
## keep setting the seed so that it doesn't matter what order you complete cells in. 
student_model = tf.keras.models.load_model("models/alexNetv6_first_Exit.hdf5")
summarize_keras_trainable_variables(student_model,"before training")
student_model.fit(train_ds, validation_data = validation_ds, epochs=9,callbacks=[TestSetCallback()])
summarize_keras_trainable_variables(student_model,"after training")
student_model.evaluate(test_ds)

summary of trainable variables before training: 192.5101304054260
Epoch 1/9
[1.456679344177246, 0.48747995495796204]
Epoch 2/9
[1.7721526622772217, 0.44370993971824646]
Epoch 3/9
[1.4880932569503784, 0.5155248641967773]
Epoch 4/9
[1.266702651977539, 0.5895432829856873]
Epoch 5/9
[1.3018970489501953, 0.6027644276618958]
Epoch 6/9
[1.5510960817337036, 0.5689102411270142]
Epoch 7/9
[1.694735050201416, 0.5645031929016113]
Epoch 8/9
[1.8729580640792847, 0.5557892918586731]
Epoch 9/9
[2.0702269077301025, 0.5476762652397156]
summary of trainable variables after training: 533177.2514293249696


[2.0702269077301025, 0.5476762652397156]

In [8]:
class TestSetCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        results = self.model.evaluate(test_ds)
        print(results)
seed = 66
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
## keep setting the seed so that it doesn't matter what order you complete cells in. 
student_model = tf.keras.models.load_model("models/alexNetv6_first_Exit.hdf5")
summarize_keras_trainable_variables(student_model,"before training")
student_model.fit(train_ds, validation_data = validation_ds, epochs=3,callbacks=[TestSetCallback()])
summarize_keras_trainable_variables(student_model,"after training")
student_model.evaluate(test_ds)

summary of trainable variables before training: 192.5101304054260
Epoch 1/3
[1.456679344177246, 0.48747995495796204]
Epoch 2/3
[1.7721526622772217, 0.44370993971824646]
Epoch 3/3
[1.4880932569503784, 0.5155248641967773]
summary of trainable variables after training: 479021.9750222191215


[1.4880932569503784, 0.5155248641967773]

In [59]:
seed = 66
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
## keep setting the seed so that it doesn't matter what order you complete cells in. 
student_model = tf.keras.models.load_model("models/alexNetv6_first_Exit.hdf5")
summarize_keras_trainable_variables(student_model,"before training")
student_model.fit(train_ds, validation_data = validation_ds, epochs=9,callbacks=[TestSetCallback()])
summarize_keras_trainable_variables(student_model,"after training")
student_model.evaluate(test_ds)


summary of trainable variables before training: 192.5101304054260
Epoch 1/9
[1.4765691757202148, 0.47255608439445496]
Epoch 2/9
[1.4188727140426636, 0.51171875]
Epoch 3/9
[1.2944990396499634, 0.5599960088729858]
Epoch 4/9
[1.2946268320083618, 0.5782251358032227]
Epoch 5/9
[1.3364394903182983, 0.5783253312110901]
Epoch 6/9
[1.4542863368988037, 0.578125]
Epoch 7/9
[1.5658166408538818, 0.573317289352417]
Epoch 8/9
[1.7475775480270386, 0.5616987347602844]
Epoch 9/9
[1.8955588340759277, 0.5756210088729858]
summary of trainable variables after training: 544908.3939273101278


[1.8955588340759277, 0.5756210088729858]

In [61]:
seed = 66
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
## keep setting the seed so that it doesn't matter what order you complete cells in. 
student_model = tf.keras.models.load_model("models/alexNetv6_second_Exit.hdf5")
summarize_keras_trainable_variables(student_model,"before training")
student_model.fit(train_ds, validation_data = validation_ds, epochs=9,callbacks=[TestSetCallback()])
summarize_keras_trainable_variables(student_model,"after training")
student_model.evaluate(test_ds)


summary of trainable variables before training: 722.7782071828842
Epoch 1/9
[1.3598397970199585, 0.5102163553237915]
Epoch 2/9
[1.155240774154663, 0.5949519276618958]
Epoch 3/9
[1.1066480875015259, 0.62109375]
Epoch 4/9
[1.0901410579681396, 0.6359174847602844]
Epoch 5/9
[1.3062540292739868, 0.6081730723381042]
Epoch 6/9
[1.3261829614639282, 0.6244992017745972]
Epoch 7/9
[1.6140880584716797, 0.5974559187889099]
Epoch 8/9
[1.8387786149978638, 0.5842347741127014]
Epoch 9/9
[1.4702694416046143, 0.6544471383094788]
summary of trainable variables after training: 2151495.0204285085201


[1.4702694416046143, 0.6544471383094788]

## Next, train the student with the teacher model input as well to see the difference the teacher made. 

In [7]:
seed = 66
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
def createDistiller(alpha,student,teacher):
    loaded_student = student
    model_teacher = teacher
    summarize_keras_trainable_variables(loaded_student,"before compiling in distiller")
    distiller = Distiller(student=loaded_student, teacher=model_teacher)
    distiller.compile(
        optimizer=tf.optimizers.SGD(lr=0.001,momentum=0.9),
        metrics=[keras.metrics.CategoricalAccuracy()],
        student_loss_fn=keras.losses.CategoricalCrossentropy(from_logits=False),

        distillation_loss_fn=keras.losses.KLDivergence(),
        alpha=alpha,
        temperature=1,
    )
    return distiller

class TestSetCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        results = self.model.evaluate(test_ds)
        print(results)
        
class alphaCooldownCallback(tf.keras.callbacks.Callback):
    """Reduce impact of the teacher input to a minimum of zero. 
    

  Arguments:
      schedule: a function that takes an epoch index
          (integer, indexed from 0) and current learning rate
          as inputs and returns a new learning rate as output (float).
  """
    def __init__(self, cooldownRate, cooldownPoint=1,method="sub"):
        super(alphaCooldownCallback, self).__init__()
        self.cooldownRate = cooldownRate
        self.cooldownPoint = cooldownPoint
        self.cooldownMethod = "sub"

    def on_epoch_begin(self, epoch,logs=None):
        print("cdP",self.cooldownPoint, " cdR",self.cooldownRate, " alpha", self.model.alpha, "epoch", epoch)
        if epoch+1 >= self.cooldownPoint: #-1 because epoch internally start at 0, but are displayed as starting from 1. 
#             if self.cooldownMethod == "sub":
            self.model.alpha = min(self.model.alpha + self.cooldownRate, 1)
#             else: 
#                 self.model.alpha = max(self.model.alpha + (self.model.alpha * self.cooldownRate), 1)
#                 self.model.alpha = 1
        
        tf.print("alpha set to: ",self.model.alpha)



In [12]:
            
seed = 66
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
# student_model = tf.keras.models.load_model("models/alexNetv6_second_Exit.hdf5")
distiller = createDistiller(1,student_model,model_teacher)
print(distiller.alpha)
summarize_keras_trainable_variables(distiller.student,"before training")
# Distill teacher to student
distiller.fit(train_ds, validation_data = validation_ds,epochs=5,verbose=1,
              callbacks=[TestSetCallback()])#alphaCooldownCallback(cooldownRate=.2,cooldownPoint=1)])
# Evaluate student on test dataset
res = distiller.evaluate(test_ds)

summary of trainable variables before compiling in distiller: 722.7782071828842
1
summary of trainable variables before training: 703.6863317489624
Epoch 1/5
[0.4925881326198578, 1.582078218460083]
Epoch 2/5
[0.62890625, 1.2472343444824219]
Epoch 3/5
[0.6438301205635071, 1.4574906826019287]
Epoch 4/5
[0.6608573794364929, 1.2307944297790527]
Epoch 5/5
[0.6412259340286255, 1.0740798711776733]


In [62]:
class TestSetCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        results = self.model.evaluate(test_ds)
        print(results)
            

seed = 66
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
student_model = tf.keras.models.load_model("models/alexNetv6_second_Exit.hdf5")
distiller = createDistiller(1,student_model,model_teacher)
summarize_keras_trainable_variables(distiller.student,"before training")
# Distill teacher to student
distiller.fit(train_ds, validation_data = validation_ds,epochs=9,verbose=1,callbacks=[TestSetCallback()])
# Evaluate student on test dataset
res = distiller.evaluate(test_ds)

summary of trainable variables before compiling in distiller: 722.7782071828842
summary of trainable variables before training: 723.0987946689129
Epoch 1/9
Tensor("alexnet/branch_output_7/Softmax:0", shape=(32, 10), dtype=float32)
Tensor("alexnet/branch_output_7/Softmax:0", shape=(32, 10), dtype=float32)
[0.5541867017745972, 1.3407378196716309]
Epoch 2/9
[0.5690104365348816, 1.4730888605117798]
Epoch 3/9
[0.6148838400840759, 1.2863210439682007]
Epoch 4/9
[0.5687099099159241, 1.3612890243530273]
Epoch 5/9
[0.5966546535491943, 1.627554178237915]
Epoch 6/9
[0.5796273946762085, 1.4880387783050537]
Epoch 7/9
[0.6228966116905212, 1.051383376121521]
Epoch 8/9
[0.6391226053237915, 0.9578129649162292]
Epoch 9/9
[0.6806890964508057, 1.1696693897247314]


In [13]:
class TestSetCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        results = self.model.evaluate(test_ds)
        print(results)
            

seed = 66
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
student_model = tf.keras.models.load_model("models/alexNetv6_second_Exit.hdf5")
distiller = createDistiller(.8,student_model,model_teacher)
summarize_keras_trainable_variables(distiller.student,"before training")
# Distill teacher to student
distiller.fit(train_ds, validation_data = validation_ds,epochs=3,verbose=1,callbacks=[TestSetCallback()])
# Evaluate student on test dataset
res = distiller.evaluate(test_ds)

summary of trainable variables before compiling in distiller: 722.7782071828842
summary of trainable variables before training: 723.0987946689129


  "The `lr` argument is deprecated, use `learning_rate` instead.")


Epoch 1/3
[0.5038061141967773, 1.4889311790466309]
Epoch 2/3
[0.5271434187889099, 1.2526040077209473]
Epoch 3/3
[0.6022636294364929, 1.262915015220642]


In [14]:
seed = 66
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)

student_model = tf.keras.models.load_model("models/alexNetv6_second_Exit.hdf5")
distiller = createDistiller(0,student_model,model_teacher)
summarize_keras_trainable_variables(distiller.student,"before training")
# Distill teacher to student
distiller.fit(train_ds, validation_data = validation_ds,epochs=3,verbose=1,callbacks=[TestSetCallback()])
# Evaluate student on test dataset
res = distiller.evaluate(test_ds)

summary of trainable variables before compiling in distiller: 722.7782071828842
summary of trainable variables before training: 723.0987946689129
Epoch 1/3
[0.5505809187889099, 1.0943448543548584]
Epoch 2/3
[0.6091746687889099, 1.1122685670852661]
Epoch 3/3
[0.6431289911270142, 1.0406861305236816]
