In [1]:
import tensorflow as tf


class CenteredLayer(tf.keras.Model):
    def __init__(self):
        super().__init__()

    def call(self, inputs):
        return inputs - tf.reduce_mean(inputs)

In [2]:
layer = CenteredLayer()
layer(tf.constant([1, 2, 3, 4, 5]))

<tf.Tensor: shape=(5,), dtype=int32, numpy=array([-2, -1,  0,  1,  2])>

In [3]:
net = tf.keras.Sequential([tf.keras.layers.Dense(128), CenteredLayer()])

In [4]:
Y = net(tf.random.uniform((4, 8)))
tf.reduce_mean(Y)

<tf.Tensor: shape=(), dtype=float32, numpy=4.656613e-10>

In [5]:
class MyDense(tf.keras.Model):
    def __init__(self, units):
        super().__init__()
        self.units = units

    def build(self, X_shape):
        self.weight = self.add_weight(name='weight',
            shape=[X_shape[-1], self.units],
            initializer=tf.random_normal_initializer())
        self.bias = self.add_weight(
            name='bias', shape=[self.units],
            initializer=tf.zeros_initializer())

    def call(self, X):
        linear = tf.matmul(X, self.weight) + self.bias
        return tf.nn.relu(linear)

In [6]:
dense = MyDense(3)
dense(tf.random.uniform((2, 5)))
dense.get_weights()

[array([[ 0.00817996,  0.03145441, -0.04453003],
        [ 0.01324688, -0.08035804, -0.05214444],
        [-0.03046067, -0.03616345, -0.00811583],
        [-0.05690337,  0.04095849, -0.05580854],
        [-0.02450579,  0.01157974,  0.03508516]], dtype=float32),
 array([0., 0., 0.], dtype=float32)]

In [7]:
dense(tf.random.uniform((2, 5)))

<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[0.        , 0.04356256, 0.        ],
       [0.        , 0.04003603, 0.        ]], dtype=float32)>

In [8]:
net = tf.keras.models.Sequential([MyDense(8), MyDense(1)])
net(tf.random.uniform((2, 64)))

<tf.Tensor: shape=(2, 1), dtype=float32, numpy=
array([[0.02140534],
       [0.02268417]], dtype=float32)>