In [None]:
from tensorflow import keras
from tensorflow.keras import layers, Model

In [None]:
from tensorflow.keras.datasets import mnist
def get_mnist_model():
 inputs = keras.Input(shape=(28 * 28,))
 features = layers.Dense(512, activation="relu")(inputs)
 features = layers.Dropout(0.5)(features)
 outputs = layers.Dense(10, activation="softmax")(features)
 model = keras.Model(inputs, outputs)
 return model

In [None]:
(images, labels), (test_images, test_labels) = mnist.load_data()
images = images.reshape((60000, 28 * 28)).astype("float32") / 255
test_images = test_images.reshape((10000, 28 * 28)).astype("float32") / 255
train_images, val_images = images[10000:], images[:10000]
train_labels, val_labels = labels[10000:], labels[:10000]

In [None]:
model = get_mnist_model()

In [None]:
model.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
model.fit(train_images, train_labels,
 epochs=3,
 validation_data=(val_images, val_labels))
test_metrics = model.evaluate(test_images, test_labels)
predictions = model.predict(test_images)

Epoch 1/3
Epoch 2/3
Epoch 3/3


A Keras metric is a subclass of the keras.metrics.Metric class. Like layers, a metric has an internal state stored in TensorFlow variables. Unlike layers, these variables
aren’t updated via backpropagation, so you have to write the state-update logic yourself, which happens in the update_state() method.

 For example, here’s a simple custom metric that measures the root mean squared
error (RMSE).

Meanwhile, you also need to expose a way to reset the metric state without having to
reinstantiate it—this enables the same metric objects to be used across different
epochs of training or across both training and evaluation. You do this with the
reset_state() method

In [None]:
import tensorflow as tf
class RootMeanSquaredError(keras.metrics.Metric):
  # Define the state variables in the constructor. Like for layers, you have access to the add_weight() method
  def __init__(self, name="rmse", **kwargs):
    super().__init__(name=name, **kwargs)
    self.mse_sum = self.add_weight(name="mse_sum", initializer="zeros")
    self.total_samples = self.add_weight(name="total_samples", initializer="zeros", dtype="int32")

  def update_state(self, y_true, y_pred, sample_weight=None):
  # Implement the state update logic in update_state(). The y_true argument is the targets (or labels) for one batch, while y_pred represents the corresponding predictions from the model
    y_true = tf.one_hot(y_true, depth=tf.shape(y_pred)[1]) # To match our MNIST model, we expect categorical predictions and integer labels.
    mse = tf.reduce_sum(tf.square(y_true - y_pred))
    self.mse_sum.assign_add(mse)
    num_samples = tf.shape(y_pred)[0]
    self.total_samples.assign_add(num_samples)

  # You use the result() method to return the current value of the metric
  def result(self):
    return tf.sqrt(self.mse_sum / tf.cast(self.total_samples, tf.float32))

  def reset_state(self):
    self.mse_sum.assign(0.)
    self.total_samples.assign(0)


Custom metrics can be used just like built-in ones

In [None]:
model = get_mnist_model()
model.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy", RootMeanSquaredError()])

model.fit(train_images, train_labels, epochs=3, validation_data=(val_images, val_labels))

test_metrics = model.evaluate(test_images, test_labels)

Epoch 1/3
Epoch 2/3
Epoch 3/3
