## Custom Model Creation with Keras

This work will present the usage of model `fit()` method with custom-created models. We need to use fit() method for utilizing the training callbacks, loss function compatibility, and more. To make it, we should override the `training step function` of the Model class. This is the function that is called by fit() for `every batch of data`.

The input argument data is what gets passed to fit as training data:

* If you pass `Numpy arrays`, by calling fit(x, y, ...), then data will be the tuple (x, y)
* If you pass a `tf.data.Dataset`, by calling fit(dataset, ...), then data will be what gets yielded by dataset at each batch.

We compute the `loss` via `self.compute_loss()`, which wraps the loss(es) function(s) that were passed to `compile()`.

We call `metric.update_state(y, y_pred)` on metrics from `self.metrics`, to update the state of the metrics that were passed in `compile()`, and we query results from self.metrics at the end to retrieve their current value.

Reference: [Customizing what happens in fit()](https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit)

In [1]:
import tensorflow as tf
import keras

In [None]:
class CustomModel(tf.keras.Model):
    def train_step(self, data):
        x, y = data  # Data structure depends on your model and on what you pass to fit()

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            # Compute Loss value that configured in 'compile()'
            loss = self.compute_loss(y=y, y_pred=y_pred)

        # Compute Gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update metrics
        for metric in self.metrics:
            if metric.name == 'loss':
                metric.update_state(loss)
            else:
                metric.update_state(y, y_pred)

        return {m.name: m.result() for m in self.metrics}

In [None]:
import numpy as np

# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)

model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])

# Just use `fit` as usual
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))

model.fit(x, y, epochs=3)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<keras.src.callbacks.History at 0x78a4648cabf0>

Let's make a lower-level example, that only uses compile() to configure the optimizer.

* We start by creating `Metric` instances to track our loss and a MAE score (in __init__()).
* We implement a `custom train_step()` that `updates` the state of these `metrics` (by calling `pdate_state()` on them), then query them (via result()) to return their current average value, to be displayed by the progress bar and to be pass to any callback.


**Note** that we would need to call `reset_states()` on our metrics between `each epoch`! Otherwise calling result() would return an average since the start of training, whereas we usually work with `per-epoch averages`. Thankfully, the framework can do that for us: just list any metric you want to reset in the metrics property of the model. The model will call reset_states() on any object listed here at the beginning of each fit() epoch or at the beginning of a call to evaluate().

In [None]:
class CustomModel(tf.keras.Model):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_tracker = keras.metrics.Mean(name='loss')
        self.mae_metric = keras.metrics.MeanAbsoluteError()
        # https://keras.io/api/metrics/

    def train_step(self, data):
        x, y = data  # Data structure depends on your model and on what you pass to fit()

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            # Compute custom Loss value
            # loss = self.compute_loss(y=y, y_pred=y_pred)
            loss = keras.losses.mean_squared_error(y, y_pred)
            # https://keras.io/api/losses/

        # Compute Gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Compute your custom metrics
        self.loss_tracker.update_state(loss)
        self.mae_metric.update_state(y, y_pred)
        return {'loss': self.loss_tracker.result(), 'mae': self.mae_metric.result()}

    @property
    def metrics(self):
        # We list our `Metric` objects here so that `reset_states()` can be
        # called automatically at the start of each epoch
        # or at the start of `evaluate()`.
        # If you don't implement this property, you have to call
        # `reset_states()` yourself at the time of your choosing.
        return [self.loss_tracker, self.mae_metric]

In [None]:
# Construct an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)

# We don't passs a loss or metrics here.
model.compile(optimizer="adam")

# Just use `fit` as usual -- you can use callbacks, etc.
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.src.callbacks.History at 0x78a464633dc0>

### Give sample_weight & class_weight to the model

If you want to support the fit() arguments `sample_weight` and `class_weight`, you'd simply do the following:

* `Unpack` sample_weight from the data argument
* Pass it to `compute_loss` & `update_state`.

In [None]:
class CustomModel(tf.keras.Model):
    def train_step(self, data):
        # Unpack data
        if len(data) == 3:
            x, y, sample_weight = data
        else:
            sample_weight = None
            x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            # The loss function is configured in `compile()`.
            loss = self.compute_loss(
                y=y,
                y_pred=y_pred,
                sample_weight=sample_weight,
            )

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update the metrics.
        # Metrics are configured in `compile()`.
        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(y, y_pred, sample_weight=sample_weight)

        # Return a dict mapping metric names to current value.
        # Note that it will include the loss (tracked in self.metrics).
        return {m.name: m.result() for m in self.metrics}

In [None]:
# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)

model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])

# You can now use sample_weight argument
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
sw = np.random.random((1000, 1))  # Sample weights
model.fit(x, y, sample_weight=sw, epochs=3)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<keras.src.callbacks.History at 0x78a455fca080>

### Providing custom Evaluation steps

Apply the same steps for test_step as done for train_step.

In [None]:
class CustomModel(tf.keras.Model):
    def test_step(self, data):
        # Unpack data
        x, y = data

        # Compute predictions
        y_pred = self(x, training=False)

        # Update the metrics tracking the loss
        self.compute_loss(y=y, y_pred=y_pred)

        # Update the metrics
        for metric in self.metrics:
            if metric.name != "loss":
                metric.update_state(y, y_pred)
        # Return a dict mapping metric names to current value.
        # Note that it will include the loss (tracked in self.metrics).
        return {m.name: m.result() for m in self.metrics}

In [None]:
# Construct an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(loss="mse", metrics=["mae"])

# Evaluate with our custom test_step
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.evaluate(x, y)



0.741997480392456

### End-to-end GAN Example

![](https://camo.githubusercontent.com/c2f14b881d82a7ff68054cfc41c0152c7c5e2ba887fd62f0b8afcdfc02b77d1f/68747470733a2f2f7777772e74656e736f72666c6f772e6f72672f7475746f7269616c732f67656e657261746976652f696d616765732f67616e322e706e67)

In [36]:
from tensorflow.keras import layers

class GAN(tf.keras.Model):
    def __init__(self, input_shape, latent_dim):
        super().__init__()
        self.latent_dim = latent_dim
        self._input_shape = input_shape
        self.discriminator = self.discriminator_model(self._input_shape)
        self.generator = self.generator_model(self.latent_dim)
        self.d_loss_tracker = keras.metrics.Mean(name="d_loss")
        self.g_loss_tracker = keras.metrics.Mean(name="g_loss")

    def call(self):
        super().call()

    def generator_model(self, latent_dim):
        return keras.Sequential(
        [
            keras.Input(shape=(latent_dim,)),
            # We want to generate 128 coefficients to reshape into a 7x7x128 map
            layers.Dense(7 * 7 * 128),
            layers.LeakyReLU(alpha=0.2),
            layers.Reshape((7, 7, 128)),
            layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
            layers.LeakyReLU(alpha=0.2),
            layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
            layers.LeakyReLU(alpha=0.2),
            layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
        ],
        name="generator",
        )

    def discriminator_model(self, input_shape):
        # Create the discriminator
        return keras.Sequential(
            [
                keras.Input(shape=input_shape),
                layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
                layers.LeakyReLU(alpha=0.2),
                layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
                layers.LeakyReLU(alpha=0.2),
                layers.GlobalMaxPooling2D(),
                layers.Dense(1),
            ],
            name="discriminator",
        )

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super().compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def train_step(self, real_images):
        if isinstance(real_images, tuple):
            real_images = real_images[0]
        # Sample random points in the latent space
        batch_size = tf.shape(real_images)[0]  # Get batch size
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

        # Decode them to fake images
        generated_images = self.generator(random_latent_vectors)

        # Combine them with real images
        combined_images = tf.concat([generated_images, real_images], axis=0)

        # Assemble labels discriminating real from fake images
        labels = tf.concat(
            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
        )

        # Add random noise to the labels
        labels += 0.05 * tf.random.uniform(tf.shape(labels))

        # Train the discriminator
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)

        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_weights))

        # Sample random points in the latent space
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

        # Assemble labels that say "all real images"
        misleading_labels = tf.zeros((batch_size, 1))

        # Train the generator
        # Do not update the weights of the discriminator
        with tf.GradientTape() as tape:
            predictions = self.discriminator(self.generator(random_latent_vectors))
            g_loss = self.loss_fn(misleading_labels, predictions)

        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))

        # Update metrics
        self.d_loss_tracker.update_state(d_loss)
        self.g_loss_tracker.update_state(g_loss)
        return {"d_loss": self.d_loss_tracker.result(),
                "g_loss": self.g_loss_tracker.result()}

In [37]:
INPUT_SHAPE = (28,28,1)
LATENT_DIM = 128
BATCH_SIZE = 64
LEARNING_RATE = 0.0003

### Load and prepare dataset

In [38]:
import numpy as np

(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))

In [39]:
dataset = tf.data.Dataset.from_tensor_slices(all_digits)
dataset = dataset.shuffle(buffer_size=1024).batch(BATCH_SIZE).cache().prefetch(tf.data.AUTOTUNE)

In [40]:
gan = GAN(input_shape=INPUT_SHAPE, latent_dim=LATENT_DIM)
gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    g_optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss_fn=keras.losses.BinaryCrossentropy(from_logits=True)
)

In [41]:
gan.fit(dataset.take(100), epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.src.callbacks.History at 0x78e283d59660>

### Reference

https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit