In [1]:
import tensorflow as tf


class CenteredLayer(tf.keras.Model):
    def __init__(self):
        super().__init__()
        
    def call(self, inputs):
        return inputs - tf.reduce_mean(inputs)

In [2]:
layer = CenteredLayer()
layer(tf.constant([1, 2, 3, 4, 5]))

<tf.Tensor: shape=(5,), dtype=int32, numpy=array([-2, -1,  0,  1,  2])>

In [3]:
net = tf.keras.Sequential([tf.keras.layers.Dense(128), CenteredLayer()])

In [4]:
y = net(tf.random.uniform((4, 8)))
tf.reduce_mean(y)

<tf.Tensor: shape=(), dtype=float32, numpy=3.0267984e-09>

In [5]:
class MyDense(tf.keras.Model):
    def __init__(self, units):
        super().__init__()
        self.units = units
        
    def build(self, x_shape):
        self.weight = self.add_weight(name='weight', 
            shape=[x_shape[-1], self.units], 
            initializer=tf.random_normal_initializer())
        self.bias = self.add_weight(name='bias', 
            shape=[self.units],
            initializer=tf.zeros_initializer())
        
    def call(self, x):
        linear = tf.matmul(x, self.weight) + self.bias
        return tf.nn.relu(linear)

In [6]:
dense = MyDense(3)
dense(tf.random.uniform((2, 5)))
dense.get_weights()

[array([[ 0.08048212, -0.00303731, -0.04358453],
        [ 0.02270886,  0.04479399, -0.0508699 ],
        [ 0.08357178,  0.04450287,  0.01045294],
        [ 0.04281571,  0.01404852, -0.10602387],
        [ 0.05098827, -0.01931676, -0.00984185]], dtype=float32),
 array([0., 0., 0.], dtype=float32)]

In [7]:
dense(tf.random.uniform((2, 5)))

<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[0.10378342, 0.02193992, 0.        ],
       [0.13338718, 0.02882155, 0.        ]], dtype=float32)>

In [8]:
net = tf.keras.models.Sequential([MyDense(8), MyDense(1)])
net(tf.random.uniform((2, 64)))

<tf.Tensor: shape=(2, 1), dtype=float32, numpy=
array([[0.],
       [0.]], dtype=float32)>