In [1]:
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.layers as l
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.utils import plot_model

%matplotlib inline

In [2]:
# Fashion MNIST Dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
x_train = tf.cast(x_train, tf.float32) / 255.
x_test = tf.cast(x_test, tf.float32) / 255.
y_train_cat = keras.utils.to_categorical(y_train, num_classes=10)
y_test_cat =  keras.utils.to_categorical(y_test, num_classes=10)

In [3]:
def create_simple_classifier():
    inputs = l.Input(shape=(28,28,1))
    z = l.Flatten()(inputs)
    z = l.Dense(200, activation='relu')(z)
    z = l.Dense(100, activation='relu')(z)
    output = l.Dense(10, activation='softmax')(z)
    
    return keras.Model(inputs=[inputs], outputs=[output])

In [4]:
tf.random.set_seed(42)
simple_clf = create_simple_classifier()
simple_clf.compile(loss='categorical_crossentropy', optimizer=keras.optimizers.Adam(1e-3), metrics=['accuracy'])
simple_clf.fit(x_train, y_train_cat, epochs=1)



<tensorflow.python.keras.callbacks.History at 0x13da85860>

# Custom Loss

#### Lets make a custom Huber Loss, which is an intermediate between MSE and MAE. We can implement it as a function or as a class


In [5]:
# As a Function
def create_huber(threshold=1.0):
    def huber_fn(y_true, y_pred):
        error = y_true - y_pred
        squared_loss = tf.square(error)/2
        linear_loss = threshold * tf.abs(error) -\
            threshold**2 / 2
        is_small_error = tf.abs(error) < threshold
        
        return tf.where(is_small_error, squared_loss, linear_loss)
    return huber_fn

In [6]:
# As a Class
class HuberLoss(keras.losses.Loss):
    def __init__(self, threshold=1.0, **kwargs):
        super().__init__(**kwargs)
        self.threshold = threshold
    
    def call(self, y_true, y_pred):
        error = y_true - y_pred
        squared_loss = tf.square(error)/2
        linear_loss = self.threshold * tf.abs(error) -\
            self.threshold**2 / 2
        is_small_error = tf.abs(error) < self.threshold
        
        return tf.where(is_small_error, squared_loss, linear_loss)
        
    def get_config(self):
        base_config = super().get_config()
        return {**base_config, 'threshold': self.threshold}

In [7]:
tf.random.set_seed(42)
simple_clf = create_simple_classifier()
simple_clf.compile(loss=HuberLoss(), optimizer=keras.optimizers.Adam(1e-3), metrics=['accuracy'])
simple_clf.fit(x_train, y_train_cat, epochs=1)



<tensorflow.python.keras.callbacks.History at 0x13e190cf8>

In [8]:
tf.random.set_seed(42)
simple_clf = create_simple_classifier()
simple_clf.compile(loss=create_huber(1), optimizer=keras.optimizers.Adam(1e-3), metrics=['accuracy'])
simple_clf.fit(x_train, y_train_cat, epochs=1)



<tensorflow.python.keras.callbacks.History at 0x13e2b5d68>

# Custom Metric

#### Lets make a custom Huber Metric

In [9]:
class HuberMetric(keras.metrics.Metric):
    def __init__(self, threshold=1.0, **kwargs):
        super().__init__(**kwargs)
        self.threshold = threshold
        self.huber_fn = create_huber(threshold)
        self.total = self.add_weight("total", initializer='zeros')
        self.count = self.add_weight('count', initializer='zeros')
    
    def update_state(self, y_true, y_pred, sample_weight=None):
        metric = self.huber_fn(y_true, y_pred)
        self.total.assign_add(tf.reduce_sum(metric))
        self.count.assign_add(tf.cast(tf.size(y_true), tf.float32))
        
    def result(self):
        return self.total / self.count
        
    def get_config(self):
        base_config = super().get_config()
        return {**base_config, 'threshold': self.threshold}

In [10]:
tf.random.set_seed(42)
simple_clf = create_simple_classifier()
simple_clf.compile(loss=create_huber(1), optimizer=keras.optimizers.Adam(1e-3), metrics=[HuberMetric(), 'accuracy'])
simple_clf.fit(x_train, y_train_cat, epochs=1)



<tensorflow.python.keras.callbacks.History at 0x13e3cf710>

In [11]:
tf.random.set_seed(42)
simple_clf = create_simple_classifier()
simple_clf.compile(loss=create_huber(1), optimizer=keras.optimizers.Adam(1e-3), metrics=[HuberMetric(), 'accuracy'])
simple_clf.fit(x_train, y_train_cat, epochs=3)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<tensorflow.python.keras.callbacks.History at 0x13e5009e8>

#### The Base Class has a default implementation of reset_states() which  resets all variables to 0, but you can override that if necessary

Here's the relevant code from the reset_states() function in the base class:  
`K.batch_set_value([(v, 0) for v in self.variables])`  
Reference: https://github.com/tensorflow/tensorflow/blob/v2.4.0/tensorflow/python/keras/metrics.py#L247-L253
