In [1]:
import tensorflow as tf
from tensorflow import keras
import tensorboard

### Defining the Model and Processing the Dataset

In [2]:
def get_mnist_model():
    inputs = keras.Input(shape=(28 * 28,))
    features = keras.layers.Dense(512, activation="relu")(inputs)
    features = keras.layers.Dropout(0.5)(features)
    outputs = keras.layers.Dense(10, activation="softmax")(features)
    model = keras.Model(inputs, outputs)
    return model

(images, labels), (test_images, test_labels) = keras.datasets.mnist.load_data()
images = images.reshape((60000, 28 * 28)).astype("float32") / 255
test_images = test_images.reshape((10000, 28 * 28)).astype("float32") / 255
train_images, val_images = images[10000:], images[:10000]
train_labels, val_labels = labels[10000:], labels[:10000]

### Defining a Custom Metric

Similar to model subclassing, we can also implement custom metrics as a subclass of the core keras Metric class.

In [6]:
"""
Python Basic: Inhertitance
Here our custom metric class inherits the methods add_weight(), assign_add() from the Metric class. 
"""

class RootMeanSquaredError(keras.metrics.Metric):

    def __init__(self, name="rmse", **kwargs):
        super().__init__(name=name, **kwargs)
        self.mse_sum = self.add_weight(name="mse_sum", initializer="zeros")
        self.total_samples = self.add_weight(name="total_samples", initializer="zeros", dtype="int32")

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.one_hot(y_true, depth=tf.shape(y_pred)[1])
        mse = tf.reduce_sum(tf.square(y_true - y_pred))
        self.mse_sum.assign_add(mse)
        num_samples = tf.shape(y_pred)[0]
        self.total_samples.assign_add(num_samples)

    def result(self):
        return tf.sqrt(self.mse_sum/tf.cast(self.total_samples, tf.float32))

    def reset_state(self):
        self.mse_sum.assign(0.0)
        self.total_samples.assign(0)

In [12]:
import os
import datetime
os.makedirs('logs/fit', exist_ok=True)

In [14]:
model = get_mnist_model()
model.compile(optimizer="rmsprop",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy", RootMeanSquaredError()])

log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
model.fit(train_images, train_labels,
          epochs=3,
          validation_data=(val_images, val_labels),
          callbacks=[tensorboard_callback])
test_metrics = model.evaluate(test_images, test_labels)

Epoch 1/3
Epoch 2/3
Epoch 3/3


In [None]:
%tensorboard