# Self Distillation for Keras

---


---


In [3]:
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import tensorflow as tf
os.environ['TF_DETERMINISTIC_OPS'] = '1'
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sys
import os
import keras

sys.path.append("..") # Adds higher directory to python modules path.
# import brevis
# from brevis import branches
# from brevis import evaluate
# dataset = branching.dataset.prepare.dataset(tf.keras.datasets.cifar10.load_data(),64,5000,22500,(227,227),include_targets=False)

In [5]:
class BranchModel(tf.keras.Model):
    '''
    Branched model sub-class. 
    Acts as a drop in replacement keras model class, with the additional functionality of adding branches to the model.
            
    '''
    def __init__(self, inputs=None, outputs=None, name="", model=None, transfer=True,custom_objects={}):
        ## add default custom objects to the custom objects dictionary, this saves having to define them everytime.
        custom_objects = {**branching.default_custom_objects,**custom_objects} 
        if inputs  is None and model is None and name is not "":
            model = tf.keras.models.load_model(name,custom_objects=custom_objects)
            self.saveLocation = name
            super(BranchModel, self).__init__(inputs = model.inputs, outputs=model.outputs,name=model.name)            
        elif model is None:
            super(BranchModel, self).__init__(inputs = inputs, outputs=outputs,name=name)
        elif model is not None:
            super(BranchModel, self).__init__(inputs = model.inputs, outputs=model.outputs,name=name)
        self.transfer = transfer
        self.custom_objects = custom_objects
        self.branch_active = False
 
    
class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super(Distiller, self).__init__()
        self.teacher = teacher
        self.student = tf.keras.models.clone_model(student)
    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=1,
    ):
        """ Configure the distiller.
        Args:
            optimizer: Keras optimizer for the student weights
            metrics: Keras metrics for evaluation
            student_loss_fn: Loss function of difference between student
                predictions and ground-truth
            distillation_loss_fn: Loss function of difference between soft
                student predictions and soft teacher predictions
            alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
            temperature: Temperature for softening probability distributions.
                Larger temperature gives softer distributions.
        """
        super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
        # Unpack data
        x, y = data
        # Forward pass of teacher
        teacher_predictions = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
            # Forward pass of student
            student_predictions = self.student(x, training=True)

            # Compute losses
            student_loss = self.student_loss_fn(y, student_predictions)
            distillation_loss = self.distillation_loss_fn(
                (teacher_predictions / self.temperature),
                (student_predictions / self.temperature),
            )
#             student_loss = student_loss #* self.alpha
#             distillation_loss = (distillation_loss) #* (1 - self.alpha)
            loss = distillation_loss
            #loss=student_loss +distillation_loss
        # Compute gradients
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update the metrics configured in `compile()`.
        self.compiled_metrics.update_state(y, student_predictions)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"loss":loss,"student_loss": student_loss, "distillation_loss": distillation_loss}
        )
        return results

    def test_step(self, data):
        # Unpack the data
        x, y = data
        # Compute predictions
        y_prediction = self.student(x, training=False)

        # Calculate the loss
        student_loss = self.student_loss_fn(y, y_prediction)

        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results

In [6]:
from tensorflow.keras import layers, models

def summarize_keras_trainable_variables(model, message):
    s = sum(map(lambda x: x.sum(), model.get_weights()))
    print("summary of trainable variables %s: %.13f" % (message, s))
    return s


In [12]:

# tf.debugging.experimental.enable_dump_debug_info(logdir, tensor_debug_mode="FULL_HEALTH", circular_buffer_size=-1)
(train_images, train_labels), (test_images, test_labels) = keras.datasets.cifar10.load_data()

CLASS_NAMES= ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# import csv
# with open('results/altTrain_labels.csv', newline='') as f:
    # reader = csv.reader(f,quoting=csv.QUOTE_NONNUMERIC)
    # alt_trainLabels = list(reader)
# with open('results/altTest_labels.csv', newline='') as f:
    # reader = csv.reader(f,quoting=csv.QUOTE_NONNUMERIC)
    # alt_testLabels = list(reader)

# altTraining = tf.data.Dataset.from_tensor_slices((train_images,alt_trainLabels))

# validation_images, validation_labels = train_images[:5000], alt_trainLabels[:5000]
# train_ds = tf.data.Dataset.from_tensor_slices((train_images, alt_trainLabels))
# test_ds = tf.data.Dataset.from_tensor_slices((test_images, alt_testLabels))
train_labels = tf.keras.utils.to_categorical(train_labels,10)
test_labels = tf.keras.utils.to_categorical(test_labels,10)

###normal method
validation_images, validation_labels = train_images[:5000], train_labels[:5000] #get the first 5k training samples as validation set
train_images, train_labels = train_images[5000:], train_labels[5000:] # now remove the validation set from the training set.
train_ds = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
test_ds = tf.data.Dataset.from_tensor_slices((test_images, test_labels))
validation_ds = tf.data.Dataset.from_tensor_slices((validation_images, validation_labels))

def augment_images(image, label):
    # Normalize images to have a mean of 0 and standard deviation of 1
    # image = tf.image.per_image_standardization(image)
    # Resize images from 32x32 to 277x277
    image = tf.image.resize(image, (227,227))
    return image, label

train_ds_size = len(list(train_ds))
test_ds_size = len(list(test_ds))
validation_ds_size = len(list(validation_ds))

train_ds = (train_ds
                  .map(augment_images)
                  .shuffle(buffer_size=train_ds_size,seed=42,reshuffle_each_iteration=False)
                  .batch(batch_size=32, drop_remainder=True))

test_ds = (test_ds
                  .map(augment_images)
                #   .shuffle(buffer_size=train_ds_size)
                  .batch(batch_size=32, drop_remainder=True))

validation_ds = (validation_ds
                  .map(augment_images)
                #   .shuffle(buffer_size=validation_ds_size)
                  .batch(batch_size=32, drop_remainder=True))
print("train_ds batch count", len(train_ds))
print("validation_ds batch count", len(validation_ds))
print("test_ds batch count", len(test_ds))

train_ds batch count 1406
validation_ds batch count 156
test_ds batch count 312


In [15]:
model_teacher = tf.keras.models.load_model("models/alexNetv6_logits_teacher.hdf5")
model_teacher.summary()

Model: "alexnet"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 227, 227, 3)]     0         
                                                                 
 conv2d_1 (Conv2D)           (None, 55, 55, 96)        34944     
                                                                 
 batch_normalization (BatchN  (None, 55, 55, 96)       384       
 ormalization)                                                   
                                                                 
 max_pooling2d (MaxPooling2D  (None, 27, 27, 96)       0         
 )                                                               
                                                                 
 conv2d_2 (Conv2D)           (None, 27, 27, 256)       614656    
                                                                 
 batch_normalization_1 (Batc  (None, 27, 27, 256)      1024

In [6]:
model_teacher.evaluate(test_ds)



[0.6905280947685242, 0.7939703464508057]

## First, train the student model without the teacher input to get a baseline

In [14]:
seed = 66
# random.seed(seed)
tf.random.set_seed(seed)
np.random.seed(seed)
## keep setting the seed so that it doesn't matter what order you complete cells in. 
student_model = tf.keras.models.load_model("models/alexNetv6_logits_student.hdf5")
student_model.summary()
# summarize_keras_trainable_variables(student_model,"before training")
# student_model.fit(train_ds, validation_data = validation_ds, epochs=9)
# summarize_keras_trainable_variables(student_model,"after training")


Model: "alexnet"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_3 (InputLayer)        [(None, 227, 227, 3)]     0         
                                                                 
 conv2d_5 (Conv2D)           (None, 55, 55, 96)        34944     
                                                                 
 batch_normalization_5 (Batc  (None, 55, 55, 96)       384       
 hNormalization)                                                 
                                                                 
 max_pooling2d_3 (MaxPooling  (None, 27, 27, 96)       0         
 2D)                                                             
                                                                 
 conv2d_6 (Conv2D)           (None, 27, 27, 256)       614656    
                                                                 
 batch_normalization_6 (Batc  (None, 27, 27, 256)      1024

In [12]:
student_model.evaluate(test_ds)



[0.7575096487998962, 0.7360777258872986]

In [7]:
seed = 66
# random.seed(seed)
tf.random.set_seed(seed)
np.random.seed(seed)
## keep setting the seed so that it doesn't matter what order you complete cells in. 
student_model = tf.keras.models.load_model("models/alexNetv6_logits_student.hdf5")
summarize_keras_trainable_variables(student_model,"before training")
student_model.fit(train_ds, validation_data = validation_ds, epochs=3)
summarize_keras_trainable_variables(student_model,"after training")

summary of trainable variables before training: 2688.9488805532455
Epoch 1/3
Epoch 2/3
Epoch 3/3
summary of trainable variables after training: 9552084.5724323093891


9552084.57243231

In [8]:
student_model.evaluate(test_ds)



[1.0212757587432861, 0.643629789352417]

In [8]:
seed = 66
# random.seed(seed)
tf.random.set_seed(seed)
np.random.seed(seed)
## keep setting the seed so that it doesn't matter what order you complete cells in. 
student_model = tf.keras.models.load_model("models/alexNetv6_logits_student.hdf5")
summarize_keras_trainable_variables(student_model,"before training")
student_model.fit(train_ds, validation_data = validation_ds, epochs=3)
summarize_keras_trainable_variables(student_model,"after training")


summary of trainable variables before training: 2688.9488805532455
Epoch 1/3
Epoch 2/3
Epoch 3/3
summary of trainable variables after training: 9552084.5724323093891


9552084.57243231

In [9]:
student_model.evaluate(test_ds)



[1.0212757587432861, 0.643629789352417]

## Next, train the student with the teacher model input as well to see the difference the teacher made. 

In [10]:
seed = 66
# random.seed(seed)
tf.random.set_seed(seed)
np.random.seed(seed)
def createDistiller(alpha=1):
    loaded_student = tf.keras.models.load_model("models/alexNetv6_logits_student.hdf5")
    summarize_keras_trainable_variables(loaded_student,"before compiling in distiller")
    distiller = Distiller(student=loaded_student, teacher=model_teacher)
    distiller.compile(
        optimizer=tf.optimizers.SGD(lr=0.001,momentum=0.9),
        metrics=[keras.metrics.CategoricalAccuracy()],
        student_loss_fn=keras.losses.CategoricalCrossentropy(from_logits=False),

        distillation_loss_fn=keras.losses.KLDivergence(),
        alpha=alpha,
        temperature=1,
    )
    return distiller


In [11]:
seed = 66
# random.seed(seed)
tf.random.set_seed(seed)
np.random.seed(seed)
distiller = createDistiller(1)
summarize_keras_trainable_variables(distiller.student,"before training")
# Distill teacher to student
distiller.fit(train_ds, validation_data = validation_ds,epochs=3,verbose=1)
# Evaluate student on test dataset
res = distiller.evaluate(test_ds)

summary of trainable variables before compiling in distiller: 2688.9488805532455


  "The `lr` argument is deprecated, use `learning_rate` instead.")


summary of trainable variables before training: 2684.4909174442291
Epoch 1/3
Epoch 2/3
Epoch 3/3


In [9]:
seed = 66
# random.seed(seed)
tf.random.set_seed(seed)
np.random.seed(seed)
distiller = createDistiller(1)
summarize_keras_trainable_variables(distiller.student,"before training")
# Distill teacher to student
distiller.fit(train_ds, validation_data = validation_ds,epochs=3,verbose=1)
# Evaluate student on test dataset
res = distiller.evaluate(test_ds)

summary of trainable variables before compiling in distiller: 2688.9488805532455
summary of trainable variables before training: 2684.4909174442291
Epoch 1/3
Epoch 2/3
Epoch 3/3


In [8]:
results = {}
results[0.1] = "test"
# for i in range(10):
seed = 66
# random.seed(seed)
tf.random.set_seed(seed)
np.random.seed(seed)
distiller = createDistiller(0.1)
summarize_keras_trainable_variables(distiller.student,"before training")
# Distill teacher to student
distiller.fit(train_ds, validation_data = validation_ds,epochs=9,verbose=1)
# Evaluate student on test dataset
res = distiller.evaluate(test_ds)
print("results: ",res)


summary of trainable variables before compiling in distiller: 2688.9488805532455


  "The `lr` argument is deprecated, use `learning_rate` instead.")


summary of trainable variables before training: 2684.4909174442291
Epoch 1/9
Epoch 2/9
Epoch 3/9
Epoch 4/9
Epoch 5/9
Epoch 6/9
Epoch 7/9
Epoch 8/9
Epoch 9/9
results:  [0.6296073794364929, 1.137315273284912]


NameError: name 'i' is not defined

In [15]:

tf.random.set_seed(seed)
np.random.seed(seed)
distiller = createDistiller(0.5)
summarize_keras_trainable_variables(distiller.student,"before training")
# Distill teacher to student
distiller.fit(train_ds, validation_data = validation_ds,epochs=9,verbose=1)
# Evaluate student on test dataset
res = distiller.evaluate(test_ds)
print("results: ",res)

tf.random.set_seed(seed)
np.random.seed(seed)
distiller = createDistiller(0.6)
summarize_keras_trainable_variables(distiller.student,"before training")
# Distill teacher to student
distiller.fit(train_ds, validation_data = validation_ds,epochs=9,verbose=1)
# Evaluate student on test dataset
res = distiller.evaluate(test_ds)
print("results: ",res)

tf.random.set_seed(seed)
np.random.seed(seed)
distiller = createDistiller(0.7)
summarize_keras_trainable_variables(distiller.student,"before training")
# Distill teacher to student
distiller.fit(train_ds, validation_data = validation_ds,epochs=9,verbose=1)
# Evaluate student on test dataset
res = distiller.evaluate(test_ds)
print("results: ",res)

tf.random.set_seed(seed)
np.random.seed(seed)
distiller = createDistiller(0.8)
summarize_keras_trainable_variables(distiller.student,"before training")
# Distill teacher to student
distiller.fit(train_ds, validation_data = validation_ds,epochs=9,verbose=1)
# Evaluate student on test dataset
res = distiller.evaluate(test_ds)
print("results: ",res)

tf.random.set_seed(seed)
np.random.seed(seed)
distiller = createDistiller(1)
summarize_keras_trainable_variables(distiller.student,"before training")
# Distill teacher to student
distiller.fit(train_ds, validation_data = validation_ds,epochs=9,verbose=1)
# Evaluate student on test dataset
res = distiller.evaluate(test_ds)
print("results: ",res)
    

summary of trainable variables before compiling in distiller: 2688.9488805532455


  "The `lr` argument is deprecated, use `learning_rate` instead.")


summary of trainable variables before training: 2684.4909174442291
Epoch 1/9
Epoch 2/9
Epoch 3/9
Epoch 4/9
Epoch 5/9
Epoch 6/9
Epoch 7/9
Epoch 8/9
Epoch 9/9
results:  [0.7475961446762085, 1.0419151782989502]
summary of trainable variables before compiling in distiller: 2688.9488805532455
summary of trainable variables before training: 2684.4909174442291
Epoch 1/9
Epoch 2/9
Epoch 3/9
Epoch 4/9
Epoch 5/9
Epoch 6/9
Epoch 7/9
Epoch 8/9
Epoch 9/9
results:  [0.7394831776618958, 1.1406097412109375]
summary of trainable variables before compiling in distiller: 2688.9488805532455
summary of trainable variables before training: 2684.4909174442291
Epoch 1/9
Epoch 2/9
Epoch 3/9
Epoch 4/9
Epoch 5/9
Epoch 6/9
Epoch 7/9
Epoch 8/9
Epoch 9/9
results:  [0.7182492017745972, 1.1012463569641113]
summary of trainable variables before compiling in distiller: 2688.9488805532455
summary of trainable variables before training: 2684.4909174442291
Epoch 1/9
Epoch 2/9
Epoch 3/9

KeyboardInterrupt: 

In [19]:
for i in range(3):
    seed = 66
    # random.seed(seed)
    tf.random.set_seed(seed)
    np.random.seed(seed)
    distiller = createDistiller(0.1)
#     summarize_keras_trainable_variables(distiller.student,"before training")
    # Distill teacher to student
    tf.random.set_seed(seed)
    np.random.seed(seed)
    model = tf.keras.models.load_model("models/alexNetv6_logits_student.hdf5")
    model.compile(loss='categorical_crossentropy', optimizer=tf.optimizers.SGD(lr=0.001,momentum=0.9), metrics=['accuracy'])
    model.fit(train_ds, validation_data=validation_ds, epochs=1,verbose=1)
    model.evaluate(test_ds)
#     distiller.fit(train_ds, validation_data = validation_ds,epochs=1,verbose=1)
    # Evaluate student on test dataset
#     res = distiller.evaluate(test_ds)
#     print("results: ",res)


summary of trainable variables before compiling in distiller: 2688.9488805532455
summary of trainable variables before compiling in distiller: 2688.9488805532455
summary of trainable variables before compiling in distiller: 2688.9488805532455


In [21]:

# tf.debugging.experimental.enable_dump_debug_info(logdir, tensor_debug_mode="FULL_HEALTH", circular_buffer_size=-1)
(train_images, train_labels), (test_images, test_labels) = keras.datasets.cifar10.load_data()

CLASS_NAMES= ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# import csv
# with open('results/altTrain_labels.csv', newline='') as f:
    # reader = csv.reader(f,quoting=csv.QUOTE_NONNUMERIC)
    # alt_trainLabels = list(reader)
# with open('results/altTest_labels.csv', newline='') as f:
    # reader = csv.reader(f,quoting=csv.QUOTE_NONNUMERIC)
    # alt_testLabels = list(reader)

# altTraining = tf.data.Dataset.from_tensor_slices((train_images,alt_trainLabels))

# validation_images, validation_labels = train_images[:5000], alt_trainLabels[:5000]
# train_ds = tf.data.Dataset.from_tensor_slices((train_images, alt_trainLabels))
# test_ds = tf.data.Dataset.from_tensor_slices((test_images, alt_testLabels))
train_labels = tf.keras.utils.to_categorical(train_labels,10)
test_labels = tf.keras.utils.to_categorical(test_labels,10)

###normal method
validation_images, validation_labels = train_images[:5000], train_labels[:5000] #get the first 5k training samples as validation set
train_images, train_labels = train_images[5000:], train_labels[5000:] # now remove the validation set from the training set.
train_ds = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
test_ds = tf.data.Dataset.from_tensor_slices((test_images, test_labels))
validation_ds = tf.data.Dataset.from_tensor_slices((validation_images, validation_labels))

def augment_images(image, label):
    # Normalize images to have a mean of 0 and standard deviation of 1
    # image = tf.image.per_image_standardization(image)
    # Resize images from 32x32 to 277x277
    image = tf.image.resize(image, (227,227))
    return image, label

train_ds_size = len(list(train_ds))
test_ds_size = len(list(test_ds))
validation_ds_size = len(list(validation_ds))

train_ds = (train_ds
                  .map(augment_images)
                  .shuffle(buffer_size=train_ds_size,seed=42,reshuffle_each_iteration=False)
                  .batch(batch_size=128, drop_remainder=True))

test_ds = (test_ds
                  .map(augment_images)
                #   .shuffle(buffer_size=train_ds_size)
                  .batch(batch_size=128, drop_remainder=True))

validation_ds = (validation_ds
                  .map(augment_images)
                #   .shuffle(buffer_size=validation_ds_size)
                  .batch(batch_size=128, drop_remainder=True))

for i in range(3):
    tf.random.set_seed(seed)
    np.random.seed(seed)
    model = tf.keras.models.load_model("distil_test.hdf5")
    summarize_keras_trainable_variables(model,"before training")
    
    model.fit(train_ds, shuffle=False,validation_data = validation_ds,epochs=1)
    summarize_keras_trainable_variables(model,"after training")
    results = model.evaluate(test_ds)
    print(results)

summary of trainable variables before training: 2931.7305606603622
summary of trainable variables after training: 2272256.2266048272140
[1.5863159894943237, 0.43609777092933655]
summary of trainable variables before training: 2931.7305606603622
summary of trainable variables after training: 2272256.2266048272140
[1.5863159894943237, 0.43609777092933655]
summary of trainable variables before training: 2931.7305606603622
summary of trainable variables after training: 2272256.2266048272140
[1.5863159894943237, 0.43609777092933655]


In [14]:
for i in range(10):
    print(i/100)


0.0
0.01
0.02
0.03
0.04
0.05
0.06
0.07
0.08
0.09


In [None]:
seed = 42
# random.seed(seed)
tf.random.set_seed(seed)
np.random.seed(seed)
## keep setting the seed so that it doesn't matter what order you complete cells in. 
student_model = tf.keras.models.load_model("models/alexNetv6_logits_student.hdf5")
summarize_keras_trainable_variables(student_model,"before training")
student_model.fit(train_ds, validation_data = validation_ds, epochs=3)
summarize_keras_trainable_variables(student_model,"after training")

In [15]:
outputs =[]
inputs = keras.Input(shape=(28,28,))
x = layers.Flatten(input_shape=(28,28))(inputs)
x = layers.Dense(512, activation="relu")(x)
#exit 1 The main branch exit is refered to as "exit 1" or "main exit" to avoid confusion when adding addtional exits
output1 = layers.Dense(10, name="output1")(x)
softmax = layers.Softmax()(output1)

outputs.append(softmax)
print(len(outputs))
model_student = keras.Model(inputs=inputs, outputs=outputs, name="mnist_model_normal")
model_student.summary()
#visualize_model(model,"mnist_normal")
print(len(model_student.outputs))


1
Model: "mnist_model_normal"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         [(None, 28, 28)]          0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 512)               401920    
_________________________________________________________________
output1 (Dense)              (None, 10)                5130      
_________________________________________________________________
softmax (Softmax)            (None, 10)                0         
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________
1


In [None]:
inputs = keras.Input(shape=(227,227,3))
# targets = keras.Input(shape=(10,))
x = keras.layers.Conv2D(filters=96, kernel_size=(11,11), strides=(4,4), activation='relu', input_shape=(227,227,3))(inputs)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2))(x)
x = keras.layers.Conv2D(filters=256, kernel_size=(5,5), strides=(1,1), activation='relu', padding="same")(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2))(x)
x = keras.layers.Conv2D(filters=384, kernel_size=(3,3), strides=(1,1), activation='relu', padding="same")(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Conv2D(filters=384, kernel_size=(1,1), strides=(1,1), activation='relu', padding="same")(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Conv2D(filters=256, kernel_size=(1,1), strides=(1,1), activation='relu', padding="same")(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2))(x)
x = keras.layers.Flatten()(x)
x = keras.layers.Dense(4096, activation='relu')(x)
x = keras.layers.Dropout(0.5)(x)

# ### first branch
# branchLayer = keras.layers.Flatten(name=tf.compat.v1.get_default_graph().unique_name("branch_flatten"))(x)
# branchLayer = keras.layers.Dense(124, activation="relu",name=tf.compat.v1.get_default_graph().unique_name("branch124"))(branchLayer)
# branchLayer = keras.layers.Dense(64, activation="relu",name=tf.compat.v1.get_default_graph().unique_name("branch64"))(branchLayer)
# branchLayer = keras.layers.Dense(10, name=tf.compat.v1.get_default_graph().unique_name("branch_output"))(branchLayer)

x = keras.layers.Dense(4096, activation='relu')(x)
x = keras.layers.Dropout(0.5)(x)

# ### second Branch
# branchLayer2 = keras.layers.Flatten(name=tf.compat.v1.get_default_graph().unique_name("branch_flatten"))(x)
# branchLayer2 = keras.layers.Dense(10, name=tf.compat.v1.get_default_graph().unique_name("branch_output"))(branchLayer2)

x = keras.layers.Dense(10, activation='softmax')(x)

# model = keras.Model(inputs=inputs, outputs=[x,branchLayer,branchLayer2], name="alexnet")
teacher = keras.Model(inputs=(inputs), outputs=[x], name="alexnet")

teacher.compile(loss='categorical_crossentropy', optimizer=tf.optimizers.SGD(lr=0.001,momentum=0.9), metrics=['accuracy'])
teacher.save("models/alexNetv6_new_teacher.hdf5")

In [None]:
summarize_keras_trainable_variables(student_model,"before training")
teacher.fit(train_ds, validation_data = validation_ds, epochs=30)
summarize_keras_trainable_variables(student_model,"after training")
teacher.save("models/alexNetv6_new_teacher.hdf5")
