In [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras

In [2]:
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

housing = fetch_california_housing()
x_train_full, x_test, y_train_full, y_test = train_test_split(
    housing.data, housing.target.reshape(-1, 1), random_state=42)
x_train, x_valid, y_train, y_valid = train_test_split(
    x_train_full, y_train_full, random_state=42)

scaler = StandardScaler()
x_train_scaled = scaler.fit_transform(x_train)
x_valid_scaled = scaler.transform(x_valid)
x_test_scaled = scaler.transform(x_test)

input_shape = x_train.shape[1:]

Using TF Functions with tf.keras (or Not)

In [3]:
def my_mse(y_true, y_pred):
    print('Tracing loss my_mse()')
    return tf.reduce_mean(tf.square(y_pred - y_true))

In [4]:
def my_mae(y_true, y_pred):
    print('Tracing metric my_mae()')
    return tf.reduce_mean(tf.abs(y_pred -  y_true))

In [5]:
class MyDense(keras.layers.Layer):
    def __init__(self, units, activation=None, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.activation = keras.activations.get(activation)
        
    def build(self, input_shape):
        self.kernel = self.add_weight(name='kernel', 
                                      shape=(input_shape[1], self.units), 
                                      initializer='uniform', 
                                      trainable=True)
        self.biases = self.add_weight(name='bias', 
                                      shape=(self.units,), 
                                      initializer='zeros', 
                                      trainable=True)
        super().build(input_shape)
        
    def call(self, x):
        print('Tracing MyDense.call()')
        return self.activation(x @ self.kernel + self.biases)

In [6]:
class MyModel(keras.models.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.hidden1 = MyDense(30, activation='relu')
        self.hidden2 = MyDense(30, activation='relu')
        self.output_ = MyDense(1)
        
    def call(self, input):
        print('Tracing MyModel.call()')
        hidden1 = self.hidden1(input)
        hidden2 = self.hidden2(hidden1)
        concat = keras.layers.concatenate([input, hidden2])
        output = self.output_(concat)
        return output

In [7]:
model = MyModel()

model.compile(loss=my_mse, optimizer='nadam', metrics=[my_mae])

model.fit(x_train_scaled, y_train, epochs=2, 
          validation_data=(x_valid_scaled, y_valid))

Epoch 1/2
Tracing MyModel.call()
Tracing MyDense.call()
Tracing MyDense.call()
Tracing MyDense.call()
Tracing loss my_mse()
Tracing metric my_mae()
Tracing MyModel.call()
Tracing MyDense.call()
Tracing MyDense.call()
Tracing MyDense.call()
Tracing loss my_mse()
Tracing metric my_mae()
Tracing MyDense.call()
Tracing MyDense.call()
Tracing MyDense.call()
Tracing loss my_mse()
Tracing metric my_mae()
Epoch 2/2


<tensorflow.python.keras.callbacks.History at 0x1dd4d126c48>

In [8]:
model.evaluate(x_test_scaled, y_test)



[0.43518805503845215, 0.46883660554885864]

In [9]:
model = MyModel(dynamic=True)

model.compile(loss=my_mse, optimizer='nadam', metrics=[my_mae])

model.fit(x_train_scaled[:64], y_train[:64], epochs=1, 
          validation_data=(x_valid_scaled[:64], y_valid[:64]), verbose=0)

Tracing MyModel.call()
Tracing MyDense.call()
Tracing MyDense.call()
Tracing MyDense.call()
Tracing loss my_mse()
Tracing metric my_mae()
Tracing MyModel.call()
Tracing MyDense.call()
Tracing MyDense.call()
Tracing MyDense.call()
Tracing loss my_mse()
Tracing metric my_mae()
Tracing MyModel.call()
Tracing MyDense.call()
Tracing MyDense.call()
Tracing MyDense.call()
Tracing loss my_mse()
Tracing metric my_mae()
Tracing MyModel.call()
Tracing MyDense.call()
Tracing MyDense.call()
Tracing MyDense.call()
Tracing loss my_mse()
Tracing metric my_mae()


<tensorflow.python.keras.callbacks.History at 0x1dd4f6c1848>

In [10]:
model.evaluate(x_test_scaled[:64], y_test[:64], verbose=0)

Tracing MyModel.call()
Tracing MyDense.call()
Tracing MyDense.call()
Tracing MyDense.call()
Tracing loss my_mse()
Tracing metric my_mae()
Tracing MyModel.call()
Tracing MyDense.call()
Tracing MyDense.call()
Tracing MyDense.call()
Tracing loss my_mse()
Tracing metric my_mae()


[5.525728702545166, 2.065046787261963]

In [11]:
model = MyModel()

model.compile(loss=my_mse, optimizer='nadam', 
              metrics=[my_mae], run_eagerly=True)

model.fit(x_train_scaled[:64], y_train[:64], epochs=1, 
          validation_data=(x_valid_scaled[:64], y_valid[:64]), verbose=0)

Tracing MyModel.call()
Tracing MyDense.call()
Tracing MyDense.call()
Tracing MyDense.call()
Tracing loss my_mse()
Tracing metric my_mae()
Tracing MyModel.call()
Tracing MyDense.call()
Tracing MyDense.call()
Tracing MyDense.call()
Tracing loss my_mse()
Tracing metric my_mae()
Tracing MyModel.call()
Tracing MyDense.call()
Tracing MyDense.call()
Tracing MyDense.call()
Tracing loss my_mse()
Tracing metric my_mae()
Tracing MyModel.call()
Tracing MyDense.call()
Tracing MyDense.call()
Tracing MyDense.call()
Tracing loss my_mse()
Tracing metric my_mae()


<tensorflow.python.keras.callbacks.History at 0x1dd4f757c88>

In [12]:
model.evaluate(x_test_scaled[:64], y_test[:64], verbose=0)

Tracing MyModel.call()
Tracing MyDense.call()
Tracing MyDense.call()
Tracing MyDense.call()
Tracing loss my_mse()
Tracing metric my_mae()
Tracing MyModel.call()
Tracing MyDense.call()
Tracing MyDense.call()
Tracing MyDense.call()
Tracing loss my_mse()
Tracing metric my_mae()


[5.651527404785156, 2.0681631565093994]

Custom Optimizers

In [13]:
class MyMomentumOptimizer(keras.optimizers.Optimizer):
    def __init__(self, learning_rate=0.001, momentum=0.9, name='MyMomentumOptimizer', **kwargs):
        super().__init__(name, **kwargs)
        self._set_hyper('learning_rate', kwargs.get('learning_rate', learning_rate))
        self._set_hyper('decay', self._initial_decay)
        self._set_hyper('momentum', momentum)
        
    def _create_slots(self, var_list):
        for var in var_list:
            self.add_slot(var, 'momentum')
            
    @tf.function
    def _resource_apply_dense(self, grad, var):
        var_dtype = var.dtype.base_dtype
        lr_t = self._decayed_lr(var_dtype)
        momentum_var = self.get_slot(var, 'momentum')
        momentum_hyper = self._get_hyper('momentum', var_dtype)
        momentum_var.assign(momentum_var * momentum_hyper - (1. - momentum_hyper) * grad)
        var.assign_add(momentum_var * lr_t)
    
    def _resource_apply_sparse(self, grad, var):
        raise NotImplementedError
        
    def get_config(self):
        base_config = super().get_config()
        return {
            **base_config,
            'learning_rate': self._serialize_hyperparameter('learning_rate'),
            'decay': self._serialize_hyperparameter('decay'),
            'momentum': self._serialize_hyperparameter('momentum'),
        }

In [14]:
model = keras.models.Sequential([keras.layers.Dense(1, input_shape=[8])])
model.compile(loss='mse', optimizer=MyMomentumOptimizer())
model.fit(x_train_scaled, y_train, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x1dd4f634a48>