In [1]:
import tensorflow as tf
from tensorflow import keras
import math

In [2]:
class LogisticEndpoint(keras.layers.Layer):
    def __init__(self, name=None):
        super(LogisticEndpoint, self).__init__(name=name)
        self.loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
        self.accuracy_fn = keras.metrics.BinaryAccuracy()

    def call(self, targets, logits, sample_weights=None):
        # Compute the training-time loss value and add it
        # to the layer using `self.add_loss()`.
        loss = self.loss_fn(targets, logits, sample_weights)
        self.add_loss(loss)

        # Log accuracy as a metric and add it
        # to the layer using `self.add_metric()`.
        acc = self.accuracy_fn(targets, logits, sample_weights)
        self.add_metric(acc, name="accuracy")

        # Return the inference-time prediction tensor (for `.predict()`).
        return tf.nn.softmax(logits)
    
class BranchEndpoint(keras.layers.Layer):
    def __init__(self, name=None):
        super(BranchEndpoint, self).__init__(name=name)
        self.loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
        self.accuracy_fn = keras.metrics.BinaryAccuracy()

    def call(self, targets, logits, additionalLoss, sample_weights=None):
        # Compute the training-time loss value and add it
        # to the layer using `self.add_loss()`.
        loss = self.loss_fn(targets, logits, sample_weights)
        
        loss += self.loss_fn(additionalLoss, logits, sample_weights)
        self.add_loss(loss)

        # Log accuracy as a metric and add it
        # to the layer using `self.add_metric()`.
        acc = self.accuracy_fn(targets, logits, sample_weights)
        self.add_metric(acc, name="accuracy")

        # Return the inference-time prediction tensor (for `.predict()`).
        return tf.nn.softmax(logits)

In [3]:
layer = LogisticEndpoint()

targets = tf.fill((4, 1),.4)
logits = tf.ones((4, 1))
# print(logits)


y = layer(targets, logits)
print(y)


print("layer.metrics:", layer.metrics)
print("current accuracy value:", float(layer.metrics[0].result()))

tf.Tensor(
[[1.]
 [1.]
 [1.]
 [1.]], shape=(4, 1), dtype=float32)
layer.metrics: [<tensorflow.python.keras.metrics.BinaryAccuracy object at 0x0000029F4E1FD080>]
current accuracy value: 0.0


In [4]:
import numpy as np 
inputs = keras.Input(shape=(3,), name="inputs")
targets = keras.Input(shape=(10,), name="targets")
logits = keras.layers.Dense(10)(inputs)
predictions = LogisticEndpoint(name="predictions")(logits, targets)

branch_logits = keras.layers.Dense(10)(logits)
branch_predictions = BranchEndpoint(name="branch_predictions")(branch_logits, targets, predictions)

model = keras.Model(inputs=[inputs, targets], outputs=[predictions,branch_predictions])
model.compile(optimizer="adam")

targets = np.random.random((3, 10))
data = {
    "inputs": np.random.random((3, 3)),
    "targets": targets,
}
model.fit(data,targets)
print(model.outputs)

[<KerasTensor: shape=(None, 10) dtype=float32 (created by layer 'predictions')>, <KerasTensor: shape=(None, 10) dtype=float32 (created by layer 'branch_predictions')>]


In [28]:
inputs = keras.Input(shape=(227,227,3))
x = keras.layers.Conv2D(filters=96, kernel_size=(11,11), strides=(4,4), activation='relu', input_shape=(227,227,3))(inputs)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2))(x)
x = keras.layers.Conv2D(filters=256, kernel_size=(5,5), strides=(1,1), activation='relu', padding="same")(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2))(x)
x = keras.layers.Conv2D(filters=384, kernel_size=(3,3), strides=(1,1), activation='relu', padding="same")(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Conv2D(filters=384, kernel_size=(1,1), strides=(1,1), activation='relu', padding="same")(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Conv2D(filters=256, kernel_size=(1,1), strides=(1,1), activation='relu', padding="same")(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2))(x)
x = keras.layers.Flatten()(x)
x = keras.layers.Dense(4096, activation='relu')(x)
x = keras.layers.Dropout(0.5)(x)
x = keras.layers.Dense(4096, activation='relu')(x)
x = keras.layers.Dropout(0.5)(x)
x = keras.layers.Dense(10, activation='softmax')(x)

model = keras.Model(inputs=inputs, outputs=[x], name="alexnet")
model.compile(loss='sparse_categorical_crossentropy', optimizer=tf.optimizers.SGD(lr=0.001,momentum=0.9), metrics=['accuracy'])
    
model.summary()

Model: "alexnet"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_4 (InputLayer)         [(None, 227, 227, 3)]     0         
_________________________________________________________________
conv2d_16 (Conv2D)           (None, 55, 55, 96)        34944     
_________________________________________________________________
batch_normalization_15 (Batc (None, 55, 55, 96)        384       
_________________________________________________________________
max_pooling2d_9 (MaxPooling2 (None, 27, 27, 96)        0         
_________________________________________________________________
conv2d_17 (Conv2D)           (None, 27, 27, 256)       614656    
_________________________________________________________________
batch_normalization_16 (Batc (None, 27, 27, 256)       1024      
_________________________________________________________________
max_pooling2d_10 (MaxPooling (None, 13, 13, 256)       0   