In [15]:
import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
from tensorflow.keras import layers
num_classes = 10
input_shape = (28, 28, 1)

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")


   # Model / data parameters
num_classes = 10
input_shape = (28, 28, 1)
  

x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples


In [187]:
 # convert class vectors to binary class matrices

dataset = tf.keras.datasets.mnist.load_data()
(train_images, train_labels), (test_images, test_labels) = dataset
train_images = train_images.reshape(60000, 784).astype("float32") / 255
test_images = test_images.reshape(10000, 784).astype("float32") / 255

validation_images, validation_labels = train_images[:12000], train_labels[:12000]
train_images, train_labels = train_images[12000:], train_labels[12000:] # now remove the validation set from the training set.
train_ds = tf.data.Dataset.from_tensor_slices(((train_images,train_labels)))
test_ds = tf.data.Dataset.from_tensor_slices((test_images, test_labels))
validation_ds = tf.data.Dataset.from_tensor_slices((validation_images, validation_labels))

train_ds_size = len(list(train_ds))
test_ds_size = len(list(test_ds))
validation_ds_size = len(list(validation_ds))

target = tf.data.Dataset.from_tensor_slices((train_labels))
train_ds = tf.data.Dataset.zip((train_ds,target))

v_target = tf.data.Dataset.from_tensor_slices((validation_labels))
validation_ds = tf.data.Dataset.zip((validation_ds,v_target))


train_ds = (train_ds
    # .map(prepare.augment_images)
    .shuffle(buffer_size=int(train_ds_size),reshuffle_each_iteration=True)
    .batch(batch_size=batch_size, drop_remainder=True))
test_ds = (test_ds
    # .map(prepare.augment_images)
    .shuffle(buffer_size=int(test_ds_size)) ##why would you shuffle the test set?
    .batch(batch_size=batch_size, drop_remainder=True))

validation_ds = (validation_ds
    # .map(prepare.augment_images)
    .shuffle(buffer_size=int(validation_ds_size))
    .batch(batch_size=batch_size, drop_remainder=True))


In [237]:
class LogisticEndpoint(keras.layers.Layer):
    def __init__(self, name=None):
        super(LogisticEndpoint, self).__init__(name=name)
        self.loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
        self.accuracy_fn = keras.metrics.BinaryAccuracy()

    def call(self, targets, logits, sample_weights=None):
        # Compute the training-time loss value and add it
        # to the layer using `self.add_loss()`.
        loss = self.loss_fn(targets, logits, sample_weights)
        self.add_loss(loss)

        # Log accuracy as a metric and add it
        # to the layer using `self.add_metric()`.
        acc = self.accuracy_fn(targets, logits, sample_weights)
        self.add_metric(acc, name="accuracy")

        # Return the inference-time prediction tensor (for `.predict()`).
        return tf.nn.softmax(logits)
    
    
class FeatureDistillation(keras.layers.Layer):
    def __init__(self, name=None):
        super(FeatureDistillation, self).__init__(name=name)
        self.loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
        self.loss_coefficient = 1
        self.feature_loss_coefficient = 1
#         self.loss_fn = keras.losses.sparse_categorical_crossentropy()

    def call(self, prediction, teaching_features=None, sample_weights=None):
        # Compute the training-time loss value and add it
        # to the layer using `self.add_loss()`.
#         print(prediction)
        #loss functions are (True, Prediction)
        loss = self.loss_fn(targets, prediction, sample_weights)
        
        #if loss is a list of additional loss objects
        if isinstance(additional_loss,list):
            for i in range(len(additional_loss)):
                loss += self.loss_fn(targets, additional_loss[i], sample_weights) * self.loss_coefficient
        elif additional_loss is not None:
            loss += self.loss_fn(targets, additional_loss, sample_weights) * self.loss_coefficient
            
        #feature distillation
        if teaching_features is not None:
            diff = tf.math.abs(prediction - teaching_features) * self.feature_loss_coefficient
            loss += self.loss_fn(targets, additional_loss, sample_weights)
            
        
        #TODO might be faster to concatenate all elements together and then perform the loss once on all the elements.
        
        self.add_loss(loss)

        return prediction
    
    
    
class BranchEndpoint(keras.layers.Layer):
    def __init__(self, name=None):
        super(BranchEndpoint, self).__init__(name=name)
        self.loss_fn = keras.losses.SparseCategoricalCrossentropy()
        self.loss_coefficient = 1
        self.feature_loss_coefficient = 1
#         self.loss_fn = keras.losses.sparse_categorical_crossentropy()

    def call(self, prediction, targets, additional_loss=None, student_features=None, teaching_features=None, sample_weights=None):
        # Compute the training-time loss value and add it
        # to the layer using `self.add_loss()`.
        print(prediction, targets, additional_loss)
        #loss functions are (True, Prediction)
        loss = self.loss_fn(targets, prediction, sample_weights)
        print(loss)
        #if loss is a list of additional loss objects
        if isinstance(additional_loss,list):
            for i in range(len(additional_loss)):
                loss += self.loss_fn(targets, additional_loss[i], sample_weights) * self.loss_coefficient
        elif additional_loss is not None:
            loss += self.loss_fn(targets, additional_loss, sample_weights) * self.loss_coefficient
            
        #feature distillation
        if teaching_features is not None and student_features is not None:
            diff = tf.norm(tf.math.abs(student_features - teaching_features)) * self.feature_loss_coefficient
            loss += self.loss_fn(targets, additional_loss, sample_weights)
            
        
        #TODO might be faster to concatenate all elements together and then perform the loss once on all the elements.
        
        self.add_loss(loss)

        return tf.nn.softmax(prediction)
    
    

In [247]:
tf.keras.backend.clear_session()
inputLayer = keras.Input(shape=(784,),name="input")
targets = keras.Input(shape=(1,), name="targets")
x = layers.Flatten(input_shape=(28,28))(inputLayer)
x = layers.Dense(512, activation="relu")(x)
x= layers.Dropout(0.2)(x)
#exit 2
x = layers.Dense(512, activation="relu")(x)
x= layers.Dropout(0.2)(x)
branch1_256 = keras.layers.Dense(256,activation="relu")(x)
branch1_dense = keras.layers.Dense(10)(branch1_256)

#exit 3
x = layers.Dense(512, activation="relu")(x)
x= layers.Dropout(0.2)(x)

branch2_256 = keras.layers.Dense(256,activation="relu")(x)
branch2_dense = keras.layers.Dense(10)(branch2_256)

#exit 4
x = layers.Dense(512, activation="relu")(x)
x= layers.Dropout(0.2)(x)
#exit 5
x = layers.Dense(512, activation="relu")(x)
x= layers.Dropout(0.2)(x)

x_teacher = layers.Dense(256, activation="relu")(x)
x= layers.Dropout(0.2)(x_teacher)
#exit 1 The main branch exit is refered to as "exit 1" or "main exit" to avoid confusion when adding addtional exits
output = layers.Dense(10, name="output")(x)
softmax = layers.Softmax()(output)

#branch ends
branch1_predictions = BranchEndpoint(name="branch1_predictions")(branch1_dense, targets)
branch2_predictions = BranchEndpoint(name="branch2_predictions")(branch2_dense, targets)

outputs =[softmax, branch1_predictions,branch2_predictions]
model = keras.Model(inputs=[inputLayer,targets], outputs=outputs)


Tensor("Placeholder:0", shape=(None, 10), dtype=float32) Tensor("Placeholder_1:0", shape=(None, 1), dtype=float32) None
Tensor("branch1_predictions/sparse_categorical_crossentropy/weighted_loss/value:0", shape=(), dtype=float32)
Tensor("Placeholder:0", shape=(None, 10), dtype=float32) Tensor("Placeholder_1:0", shape=(None, 1), dtype=float32) None
Tensor("branch2_predictions/sparse_categorical_crossentropy/weighted_loss/value:0", shape=(), dtype=float32)


In [248]:
# model.summary()
batch_size = 128
epochs = 2

# for i in train_ds.take(1):
#     print(i)

In [249]:

model.compile(loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
data = {"input":train_images, "targets":train_labels}
model.fit(train_ds, validation_data = validation_ds, batch_size=32, epochs=epochs )

Epoch 1/2
Tensor("model/dense_6/BiasAdd:0", shape=(128, 10), dtype=float32) Tensor("model/Cast:0", shape=(128, 1), dtype=float32) None
Tensor("model/branch2_predictions/sparse_categorical_crossentropy/weighted_loss/value:0", shape=(), dtype=float32)
Tensor("model/dense_3/BiasAdd:0", shape=(128, 10), dtype=float32) Tensor("model/Cast:0", shape=(128, 1), dtype=float32) None
Tensor("model/branch1_predictions/sparse_categorical_crossentropy/weighted_loss/value:0", shape=(), dtype=float32)
Tensor("model/dense_6/BiasAdd:0", shape=(128, 10), dtype=float32) Tensor("model/Cast:0", shape=(128, 1), dtype=float32) None
Tensor("model/branch2_predictions/sparse_categorical_crossentropy/weighted_loss/value:0", shape=(), dtype=float32)
Tensor("model/dense_3/BiasAdd:0", shape=(128, 10), dtype=float32) Tensor("model/Cast:0", shape=(128, 1), dtype=float32) None
Tensor("model/branch1_predictions/sparse_categorical_crossentropy/weighted_loss/value:0", shape=(), dtype=float32)
Tensor("model/branch2_predicti

<tensorflow.python.keras.callbacks.History at 0x1d32d638f28>

In [None]:
score = model.evaluate(x_test, y_test, verbose=0)
print("Test loss:", score[0])
print("Test accuracy:", score[1])

375/375 [==============================] - 10s 27ms/step - loss: 1.7662 - softmax_loss: 0.1891 - branch1_predictions_loss: 0.1604 - softmax_accuracy: 0.9477 - branch1_predictions_accuracy: 0.9510 - val_loss: 1.2677 - val_softmax_loss: 0.1185 - val_branch1_predictions_loss: 0.1081 - val_softmax_accuracy: 0.9671 - val_branch1_predictions_accuracy: 0.9670

With Feature Training:
375/375 [==============================] - 9s 24ms/step - loss: 3.3208 - softmax_loss: 0.2790 - branch1_predictions_loss: 0.4166 - softmax_accuracy: 0.9220 - branch1_predictions_accuracy: 0.9117 - val_loss: 2.7620 - val_softmax_loss: 0.1886 - val_branch1_predictions_loss: 0.2611 - val_softmax_accuracy: 0.9469 - val_branch1_predictions_accuracy: 0.9367