In [2]:
import tensorflow as tf

In [3]:
from tensorflow import keras

In [47]:
class RootMeanSquarredError(keras.metrics.Metric):
    def __init__(self, name = "rmse", **kwargs ):
        super().__init__(name = name ,**kwargs)
        
        self.mse_sum = self.add_weight(name = "mse_sum", initializer = 'zeros')
        self.total_samples = self.add_weight(name='total_samples', initializer ='zeros', dtype="int32")
        
    def update_state(self, y_true, y_pred, sample_weight = None):
        y_true = tf.one_hot(y_true, depth = tf.shape(y_pred) [1])
        mse = tf.reduce_sum(tf.square(y_true - y_pred))
        self.mse_sum.assign(mse)
        num_samples = tf.shape(y_pred)[0]
        self.total_samples.assign(num_samples)
        
    def result(self):
        return tf.sqrt(self.mse_sum / tf.cast(self.total_samples, tf.float32))
    
    
    def reset_state(self):
        self.mse_sum.assign(0,)
        self.total_samples.assign(0)

In [48]:
RMSE = RootMeanSquarredError()

In [49]:
def get_model():
    inputs = keras.Input(shape=(28*28))
    features = keras.layers.Dense(512, activation='relu')(inputs)
    drop = keras.layers.Dropout(0.5)(features)
    outputs = keras.layers.Dense(10, activation='softmax')(drop)
    
    model = keras.Model(inputs = inputs , outputs = outputs)
    
    return model

In [55]:
model = get_model()

In [56]:
model.compile(optimizer = 'rmsprop',
             loss= 'sparse_categorical_crossentropy',
             metrics = ['accuracy', RMSE])

In [57]:
from keras.datasets import mnist
(train_images , train_labels ) , (test_images , test_labels) = mnist.load_data()

In [58]:
train_images = train_images.reshape((60000, 28*28)).astype('float32') / 255
test_images  = test_images.reshape((10000, 28*28)).astype('float32') /255

In [59]:
model.fit(train_images , train_labels ,epochs = 10 , validation_data =( test_images , test_labels))

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x28e0087e140>