In [1]:
import tensorflow as tf

In [2]:
class MyDense(tf.keras.layers.Layer):
  def __init__(self, units, activation=None, **kwargs):
    super().__init__(**kwargs)
    self.units = units
    self.activation = tf.keras.activations.get(activation)

  def build(self, batch_input_shape):
    self.kernel = self.add_weight(
        name="kernel", shape=[batch_input_shape[-1], self.units], initializer="glorot_normal"
    )

    self.bias = self.add_weight(
        name="bias", shape=[self.units], initializer="zeros"
    )

  def call(self, x):
    return self.activation(x @ self.kernel + self.bias)

  def base_config(self):
    base_config = super().get_config()
    return {
        **base_config, "units": self.units, "activation": tf.keras.activations.serialize(self.activation)
    }