In [80]:
import tensorflow as tf
from tensorflow import keras
import math

In [408]:
class LogisticEndpoint(keras.layers.Layer):
    def __init__(self, name=None):
        super(LogisticEndpoint, self).__init__(name=name)
        self.loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
        self.accuracy_fn = keras.metrics.BinaryAccuracy()

        
#     get the three losses:
#     1. normal loss, predictions vs labels
#     2. KL divergence loss, aka the difference between the student and teacher's softmax
#     3. L2 loss from hints.
    def call(self, inputs, labels, sample_weights=None):
        # Compute the training-time loss value and add it
        # to the layer using `self.add_loss()`.
        normal_loss = self.loss_fn(inputs, labels, sample_weights)
        self.add_loss(normal_loss)
#         self.add_metric(normal_loss, name=self.name+"_KL")
#         self.add_metric(normal_loss, name=self.name+"loss")

        # Log accuracy as a metric and add it
        # to the layer using `self.add_metric()`.
#         acc = self.accuracy_fn(inputs, labels, sample_weights)
#         self.add_metric(acc, name="accuracy")

        # Return the inference-time prediction tensor (for `.predict()`).
        return tf.nn.softmax(inputs)
    
class BranchEndpoint(keras.layers.Layer):
    def __init__(self, name=None):
        super(BranchEndpoint, self).__init__(name=name)
        self.loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
        self.loss_coefficient = 1
        self.feature_loss_coefficient = 1
        self.kl = tf.keras.losses.KLDivergence()
#         self.loss_fn = keras.losses.sparse_categorical_crossentropy

    def call(self, inputs, labels, teacher_sm=None, sample_weights=None):
        # Compute the training-time loss value and add it
        # to the layer using `self.add_loss()`.
        #loss functions are (True, Prediction)
        softmax = tf.nn.softmax(inputs)
        
        #loss 1. normal loss, predictions vs labels
#         normal_loss = self.loss_fn(inputs, labels, sample_weights)
#         self.add_loss(normal_loss)
        
        #loss 2. KL divergence loss, aka the difference between the student and teacher's softmax
        if teacher_sm is not None:
            kl_loss = self.kl(softmax,teacher_sm)
            self.add_loss(kl_loss)
            self.add_metric(kl_loss, name=self.name+"_KL")
        return softmax
    

class FeatureDistillation(keras.layers.Layer):
    def __init__(self, name=None):
        super(FeatureDistillation, self).__init__(name=name)
        self.loss_coefficient = 1
        self.feature_loss_coefficient = 0.3
    def call(self, inputs, teaching_features, sample_weights=None):
        #loss 3. Feature distillation of the difference between features of teaching layer and student layer.
        l2_loss = self.feature_loss_coefficient * tf.reduce_sum(tf.square(inputs - teaching_features))
        self.add_loss(l2_loss)
        self.add_metric(l2_loss, name=self.name+"_distill") # metric so this loss value can be monitored.
        return inputs

In [409]:
def logit(num, count =1,classes=10):
    output = np.zeros(classes) 
    pos = max(0,num-1)
    output[pos] = 1
    
    return output



targets =[]

for i in range(30):
    target = logit(np.random.randint(0,10))
    targets.append(target)
targets = np.array(targets)
data = {
    "inputs": np.random.random((30, 3)),
    "targets": targets,
}


In [410]:
#process for self distilation
#add y_true as an input for the model, here called 'targets'. targets is not linked to the main model path
#targets is added as input at the model define call
#targets is used as an additional input to the endpoint layers
#in endpoint layers, perform the loss function using the prev_layer input and the 'targets'

#determine if the additional loss is precomputed or computed at the endpoint layer.

from tensorflow.keras import layers

import numpy as np 
inputs = keras.Input(shape=(3,), name="inputs")
targets = keras.Input(shape=(10,), name="targets")
x = layers.Dense(512, activation="relu")(inputs)
x= layers.Dropout(0.2)(x)

x = layers.Dense(512, activation="relu")(x)
x= layers.Dropout(0.2)(x)

branch1_256 = keras.layers.Dense(256,activation="relu")(x)
# print(branch1_256.shape)


x = layers.Dense(512, activation="relu")(x)
x= layers.Dropout(0.2)(x)

# branch2_256 = keras.layers.Dense(256,activation="relu")(x)

x = layers.Dense(512, activation="relu")(x)
x= layers.Dropout(0.2)(x)

teaching_feat = layers.Dense(256, activation="relu")(x)
# teacher_feat = featureDistil(x)
# x= layers.Dropout(0.2)(teaching_feat)

output = layers.Dense(10, name="output")(teaching_feat)

# softmax = layers.Softmax()(output)
endpoint = LogisticEndpoint(name="endpoint")(output,targets)

#rest of branches
branch1_teaching = FeatureDistillation(name="branch1_teaching")(branch1_256,teaching_feat)
branch1_dense = keras.layers.Dense(10,name="branch1_dense")(branch1_teaching)
branch1_predictions = BranchEndpoint(name="branch1_predictions")(branch1_dense, targets, endpoint)


# branch2_teaching = FeatureDistillation(name="branch2_teaching")(branch2_256,teaching_feat)
# branch2_dense = keras.layers.Dense(10)(branch2_teaching)
# branch2_predictions = BranchEndpoint(name="branch2_predictions")(branch2_dense, targets, softmax, [branch2_256], teaching_feat)


model = keras.Model(inputs=[inputs, targets], outputs=[endpoint,branch1_predictions])
model.compile(optimizer="adam", loss =keras.losses.BinaryCrossentropy(from_logits=True))

targets =[]

for i in range(30):
    target = logit(np.random.randint(0,10))
    targets.append(target)
targets = np.array(targets)
data = {
    "inputs": np.random.random((30, 3)),
    "targets": targets,
}

# print(data['targets'])
hist = model.fit(data,targets)


# The total loss is different from parts_loss because it includes the regularization term.
# In other words, loss is computed as loss = parts_loss + k*R, where R is the regularization term 
# (typically the L1 or L2 norm of the model's weights) and k a hyperparameter that controls the 
# contribution of the regularization loss in the total loss.




 1s 1s/step - loss: 5.3907 - endpoint_loss: 0.7344 - branch1_predictions_loss: 0.7345 - binary_accuracy: 0.0000e+00
 1s 1s/step - loss: 2.9807 - endpoint_loss: 0.7344 - branch1_predictions_loss: 0.7345 - binary_accuracy: 0.0000e+00




In [413]:
x = 1
y = 2 
z = 3
data = (x,y,z)
data
a, b, c = data
print(a, b, c)

1 2 3


In [282]:
inputs = keras.Input(shape=(227,227,3))
x = keras.layers.Conv2D(filters=96, kernel_size=(11,11), strides=(4,4), activation='relu', input_shape=(227,227,3))(inputs)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2))(x)
x = keras.layers.Conv2D(filters=256, kernel_size=(5,5), strides=(1,1), activation='relu', padding="same")(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2))(x)
x = keras.layers.Conv2D(filters=384, kernel_size=(3,3), strides=(1,1), activation='relu', padding="same")(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Conv2D(filters=384, kernel_size=(1,1), strides=(1,1), activation='relu', padding="same")(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Conv2D(filters=256, kernel_size=(1,1), strides=(1,1), activation='relu', padding="same")(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2))(x)
x = keras.layers.Flatten()(x)
x = keras.layers.Dense(4096, activation='relu')(x)
x = keras.layers.Dropout(0.5)(x)
x = keras.layers.Dense(4096, activation='relu')(x)
x = keras.layers.Dropout(0.5)(x)
x = keras.layers.Dense(10, activation='softmax')(x)

model = keras.Model(inputs=inputs, outputs=[x], name="alexnet")
model.compile(loss='sparse_categorical_crossentropy', optimizer=tf.optimizers.SGD(lr=0.001,momentum=0.9), metrics=['accuracy'])
    
model.summary()

Model: "alexnet"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         [(None, 227, 227, 3)]     0         
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 55, 55, 96)        34944     
_________________________________________________________________
batch_normalization_10 (Batc (None, 55, 55, 96)        384       
_________________________________________________________________
max_pooling2d_6 (MaxPooling2 (None, 27, 27, 96)        0         
_________________________________________________________________
conv2d_11 (Conv2D)           (None, 27, 27, 256)       614656    
_________________________________________________________________
batch_normalization_11 (Batc (None, 27, 27, 256)       1024      
_________________________________________________________________
max_pooling2d_7 (MaxPooling2 (None, 13, 13, 256)       0   

In [1]:
x = 10 //2
x

5

In [15]:
model.layers[1].output_shape

(None, 55, 55, 96)

In [13]:
model.layers[1].input_shape

(None, 227, 227, 3)