In [1]:
import tensorflow as tf
import numpy as np

tf.random.set_seed(42)

(x_train_full, y_train_full), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
x_train, y_train = x_train_full[:-5000].astype(np.float32), y_train_full[:-5000].astype(np.float32)
x_valid, y_valid = x_train_full[-5000:].astype(np.float32), y_train_full[-5000:].astype(np.float32)

In [2]:
class Normalization(tf.keras.layers.Layer):
    def __init__(self, eps = 1e-4, **kwargs):
        super().__init__(**kwargs)
        self.eps = eps
        
    def build(self, batch_input_shape):
        self.alpha = self.add_weight(name = "alpha", shape = batch_input_shape[-1:], initializer = "ones")
        self.beta = self.add_weight(name = "beta", shape = batch_input_shape[-1:], initializer = "zeros")
        super().build(batch_input_shape)
        
    def compute_output_shape(self, batch_input_shape):
        return batch_input_shape
        
    def call(self, x):
        mean, variance = tf.nn.moments(x, axes = -1, keepdims = True)
        return self.alpha * (x - mean) / (tf.sqrt(variance + self.eps)) + self.beta
    
    def get_config(self):
        base_config = super().get_config()
        return {**base_config, "eps":self.eps}

In [3]:
custom_layer_norm = Normalization()
keras_layer_norm = tf.keras.layers.LayerNormalization()
tf.reduce_mean(tf.keras.losses.mean_absolute_error(keras_layer_norm(x_train), custom_layer_norm(x_train)))

<tf.Tensor: shape=(), dtype=float32, numpy=3.650889e-06>

In [4]:
np.random.seed(42)

def data_batch(x, y, batch_size = 32):
    ids = np.random.randint(len(x), size = batch_size)
    return x[ids], y[ids]

def status_bar(step, total, loss, metrics = None):
    metrics = " - ".join([f"{m.name}:{m.result():.4f}" for m in [loss] + (metrics or [])])
    end = "" if step < total else "\n"
    print(f"\r{step}/{total} - " + metrics, end = end)

regularizer = tf.keras.regularizers.l2(0.01)
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape = [28, 28]),
    Normalization(),
    tf.keras.layers.Dense(32, activation = "relu", kernel_initializer = "he_normal", kernel_regularizer = regularizer),
    tf.keras.layers.Dense(1, kernel_regularizer = regularizer, activation = "softmax")
])

In [5]:
total_epochs = 10
batch_size = 32
total_steps = len(x_train) // batch_size
optimizer = tf.keras.optimizers.Nadam(learning_rate = 1e-3)
loss_fn = tf.keras.losses.sparse_categorical_crossentropy
mean_loss = tf.keras.metrics.Mean()
metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]

In [None]:
for epoch in range(1, total_epochs + 1):
    print(f"Epoch {epoch}/{n_epochs}")
    for step in range(1, total_steps + 1):
        x_batch, y_batch = data_batch(x_train, y_train)
        with tf.GradientTape() as tape:
            y_pred = model(x_batch, training = True)
            loss = tf.reduce_mean(loss_fn(y_batch, y_pred))
            total_loss = tf.add_n([loss] + model.losses)
            
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        
        mean_loss(total_loss)
        for metric in metrics:
            metric(y_batch, y_pred)
            
        status_bar(step, total_steps, mean_loss, metrics)
        
    for metric in [mean_loss] + metrics:
        metric.reset_states()

In [6]:
lower_layers = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape = [28, 28]),
    Normalization(),
    tf.keras.layers.Dense(32, activation = "relu", kernel_initializer = "he_normal", kernel_regularizer = regularizer)
])
upper_layers = tf.keras.models.Sequential([
    tf.keras.layers.Dense(32, activation = "relu", kernel_initializer = "he_normal", kernel_regularizer = regularizer),
    tf.keras.layers.Dense(10, activation = "softmax")
])
model = tf.keras.models.Sequential([
    lower_layers, upper_layers
])

upper_optimizer = tf.keras.optimizers.RMSprop(learning_rate = 1e-3)
lower_optimizer = tf.keras.optimizers.Nadam(learning_rate = 1e-3)

In [7]:
from tqdm.notebook import trange
from collections import OrderedDict

with trange(1, total_epochs + 1, desc = "All epochs") as epochs:
    for epoch in epochs:
        with trange(1, total_steps + 1, desc = f"Epoch {epoch}/{total_epochs}") as steps:
            for step in steps:
                x_batch, y_batch = data_batch(x_train, y_train)
                with tf.GradientTape(persistent = True) as tape:
                    y_pred = model(x_batch, training = True)
                    loss = tf.reduce_mean(loss_fn(y_batch, y_pred))
                    total_loss = tf.add_n([loss] + model.losses)
                for layers, optimizer in ((lower_layers, lower_optimizer), (upper_layers, upper_optimizer)):
                    gradients = tape.gradient(loss, layers.trainable_variables)
                    optimizer.apply_gradients(zip(gradients, layers.trainable_variables))
                del tape
                for variable in model.variables:
                    if variable.constraint is not None:
                        variable.assign(variable.constraint(variable))
                status = OrderedDict()
                mean_loss(total_loss)
                status["loss"] = mean_loss.result().numpy()
                for metric in metrics:
                    metric(y_batch, y_pred)
                    status[metric.name] = metric.result().numpy()
                steps.set_postfix(status)
            y_pred = model(x_valid)
            status["val_loss"] = np.mean(loss_fn(y_valid, y_pred))
            status["val_acc"] = np.mean(tf.keras.metrics.sparse_categorical_accuracy(tf.constant(y_valid, dtype = np.float32),
                                                                                    y_pred))
            steps.set_postfix(status)
        for metric in [mean_loss] + metrics:
            metric.reset_states()

All epochs:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1/10:   0%|          | 0/1718 [00:00<?, ?it/s]

Epoch 2/10:   0%|          | 0/1718 [00:00<?, ?it/s]

Epoch 3/10:   0%|          | 0/1718 [00:00<?, ?it/s]

Epoch 4/10:   0%|          | 0/1718 [00:00<?, ?it/s]

Epoch 5/10:   0%|          | 0/1718 [00:00<?, ?it/s]

Epoch 6/10:   0%|          | 0/1718 [00:00<?, ?it/s]

Epoch 7/10:   0%|          | 0/1718 [00:00<?, ?it/s]

Epoch 8/10:   0%|          | 0/1718 [00:00<?, ?it/s]

Epoch 9/10:   0%|          | 0/1718 [00:00<?, ?it/s]

Epoch 10/10:   0%|          | 0/1718 [00:00<?, ?it/s]