In [2]:
import tensorflow as tf
import numpy as np
from tensorflow import keras

In [3]:
x = np.array([-1.0, 0.0, 1.0, 2.0, 3.0, 4.0], dtype=float)
y = np.array([-2.0, -1.0, 1.0, 3.0, 5.0, 7.0], dtype=float)

# Loss as a Function 

In [6]:
def huber_loss_threshold(threshold):
    '''
    Custom loss function with hyperparameter
    '''
    def huber_loss(y_true, y_pred):
        error = y_true - y_pred
        is_small_error = tf.abs(error) < threshold # condition to return true or false
        small_error_loss = tf.square(error) / 2 
        big_error_loss = threshold * (tf.abs(error) - (0.5 * threshold))
        return tf.where(is_small_error, small_error_loss, big_error_loss)
    return huber_loss

In [7]:
model = tf.keras.Sequential([keras.layers.Dense(units=1, input_shape=[1])])
model.compile(optimizer='sgd', loss=huber_loss_threshold(threshold=1)) # Use CUSTOM Loss function
model.fit(x, y, epochs=10, verbose = 0)

<tensorflow.python.keras.callbacks.History at 0x13adc96d0>

In [8]:
print(model.predict([10.0]))

[[15.710443]]


# Loss as a Class

In [9]:
from tensorflow.keras.losses import Loss

In [12]:
class huber_loss_threshold_class(Loss):
    '''
    Custom loss function CLASS with hyperparameter
    '''
    threshold = 1 
    def __init__(self, threshold):
        super().__init__()
        self.threshold = threshold
    
    def call(self, y_true, y_pred):
        '''
        Note that calling threshold is now self.threshold 
        '''
        error = y_true - y_pred 
        is_small_error = tf.abs(error) < self.threshold # condition to return true or false
        small_error_loss = tf.square(error) / 2 
        big_error_loss = self.threshold * (tf.abs(error) - (0.5 * self.threshold))
        return tf.where(is_small_error, small_error_loss, big_error_loss)

In [13]:
model2 = tf.keras.Sequential([keras.layers.Dense(units=1, input_shape=[1])])
model2.compile(optimizer='sgd', loss=huber_loss_threshold_class(threshold=1)) # Use CUSTOM Loss function
model2.fit(x, y, epochs=10, verbose = 0)

<tensorflow.python.keras.callbacks.History at 0x13b57b370>

In [14]:
print(model2.predict([10.0]))

[[-3.0951614]]
