In [66]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
real_type = tf.float32

In [67]:
class ForwardLayer(keras.layers.Layer):
    def __init__(self, units, activation='softplus', **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.activation = keras.activations.get(activation)

    def build(self, batch_input_shape):
        self.w = self.add_weight(
            name = 'weights',
            shape=(batch_input_shape[-1], self.units),
            initializer="glorot_normal",
            trainable=True
        )
        self.b = self.add_weight(
            name = 'bias',
            shape=(self.units,), 
            initializer="zeros", 
            trainable=True
        )
        super().build(batch_input_shape)

    def call(self, inputs):
        return self.activation(tf.matmul(inputs, self.w) + self.b)

class BackpropLayer(keras.layers.Layer):
    def __init__(self, units, twin: ForwardLayer, activation='sigmoid', **kwargs):
        super().__init__(**kwargs)
        self.twin = twin
        self.units = self.twin.units
        self.activation = keras.activations.get(activation)
        self.built = twin.built

    def build(self, batch_input_shape):
        if not self.built:
            self.twin.build(batch_input_shape)
        self.built = True
        super().build(batch_input_shape)

    def call(self, inputs, zbar, last=False):
        if last:
            return tf.matmul(zbar, tf.transpose(self.twin.w))
        else:
            return tf.matmul(zbar, tf.transpose(self.twin.w)) * self.activation(inputs)

class DeepNeuralNet(keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.forward1 = ForwardLayer(units=5)
        self.forward2 = ForwardLayer(units=3)
        self.forward3 = ForwardLayer(units=1)
        self.backprop1 = BackpropLayer(twin=self.forward3)
        self.backprop2 = BackpropLayer(twin=self.forward2)
        self.backprop3 = BackpropLayer(twin=self.forward1)

    def call(self, inputs):
        Z1 = self.forward1(inputs)
        Z2 = self.forward2(Z1)
        Y = self.forward3(Z2)
        Z3_bar = tf.ones_like(Y, dtype=real_type)
        Z2_bar = self.backprop1(inputs=Z2, zbar=Z3_bar)
        Z1_bar = self.backprop2(inputs=Z1, zbar=Z2_bar)
        Y_bar = self.backprop3(inputs=Z1, zbar=Z1_bar, last=True)
        return Y, Y_bar

In [81]:
model = DeepNeuralNet()
model.build(input_shape=(3,3))
model.summary()

Model: "deep_neural_net_18"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 forward_layer_58 (ForwardLa  multiple                 30        
 yer)                                                            
                                                                 
 forward_layer_59 (ForwardLa  multiple                 18        
 yer)                                                            
                                                                 
 forward_layer_60 (ForwardLa  multiple                 4         
 yer)                                                            
                                                                 
 backprop_layer_58 (Backprop  multiple                 4         
 Layer)                                                          
                                                                 
 backprop_layer_59 (Backprop  multiple          

In [69]:
testf = ForwardLayer(units=3)
testb = BackpropLayer(twin=testf)

In [70]:
x = tf.ones((2,5))
testf(x)

<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[0.15079384, 0.6294772 , 1.6144779 ],
       [0.15079384, 0.6294772 , 1.6144779 ]], dtype=float32)>

In [71]:
testf.weights

[<tf.Variable 'forward_layer_33/weights:0' shape=(5, 3) dtype=float32, numpy=
 array([[-0.3795917 ,  0.47555545,  0.4316404 ],
        [-0.35289612,  0.18576394, -0.05235577],
        [-0.9895086 , -0.12197417, -0.7805339 ],
        [-0.07881267, -0.15414277,  1.1322336 ],
        [-0.01468835, -0.5168736 ,  0.6616059 ]], dtype=float32)>,
 <tf.Variable 'forward_layer_33/bias:0' shape=(3,) dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>]

In [72]:
testb.weights

[<tf.Variable 'forward_layer_33/weights:0' shape=(5, 3) dtype=float32, numpy=
 array([[-0.3795917 ,  0.47555545,  0.4316404 ],
        [-0.35289612,  0.18576394, -0.05235577],
        [-0.9895086 , -0.12197417, -0.7805339 ],
        [-0.07881267, -0.15414277,  1.1322336 ],
        [-0.01468835, -0.5168736 ,  0.6616059 ]], dtype=float32)>,
 <tf.Variable 'forward_layer_33/bias:0' shape=(3,) dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>]