In [None]:
import tensorflow as tf
from tensorflow import keras

class DumbLayer(keras.layers.Layer):
  def __init__(self):
    super(DumbLayer, self).__init__()

  def call(self, inputs):
    return tf.concat([inputs, inputs], axis=-1)

# see what the layer does
inputs = keras.Input(shape=(14,))
x = DumbLayer()(inputs)
model = keras.Model(inputs=inputs, outputs=x)

model.summary()

In [None]:
class CustomDenseLayer(keras.layers.Layer):
  def __init__(self, units=32):
    super(CustomDenseLayer, self).__init__()
    self.units = units

  # create the weights of the model based on the expected shape
  # only invoked once the input shape is known
  def build(self, input_shape):
    self.w = self.add_weight(
            shape=(input_shape[-1], self.units), #visual for weights: https://claude.site/artifacts/49da6081-3ce1-4b3e-910f-2c4157af3451
            initializer='random_normal',
            trainable=True
        )
    self.b = self.add_weight(
            shape=(self.units,),
            initializer='zeros',
            trainable=True
        )
  # perform the actual operation of the layer.
  # in this case it's just the weighted sum
  def call(self, inputs):
    return tf.matmul(inputs, self.w) + self.b

# using the custom layer in a model
inputs = keras.Input(shape=(784,))
x = CustomDenseLayer(units=32)(inputs)
x = keras.layers.Activation('relu')(x)
outputs = CustomDenseLayer(units=10)(x)
model = keras.Model(inputs=inputs, outputs=outputs)

model.summary()

## Subclassing Models
Useful when you need even greater control over the model architecture. Can implement any python code in the model's code for things like conditional logic and other advanced use cases.

In [None]:
class SimpleFeedForwardModel(keras.Model):
  # used to define layers and variable of the model
  def __init__(self):
    super(SimpleFeedForwardModel, self).__init__()
    self.dense1 = keras.layers.Dense(32, activation='relu')
    self.dense2 = keras.layers.Dense(32, activation='relu')
    self.output_layer = keras.layers.Dense(1)
  # the forward pass of the model
  def call(self, inputs):
    x = self.dense1(inputs)
    x = self.dense2(x)
    return self.output_layer(x)

model = SimpleFeedForwardModel()
model.compile(optimizer='adam', loss='mse', metrics=['mae'])