In [112]:
import tensorflow as tf
import numpy as np
from tqdm import trange
from collections import OrderedDict
import time

In [113]:
tf.random.set_seed(60)

In [114]:
X = tf.linspace(0.0, 10.0, 1000)
X = tf.reshape(X, (-1, 1)) 
alpha = np.random.rand(X.shape[-1])
beta = np.random.rand(X.shape[-1])

In [115]:
class LayerNormalization(tf.keras.layers.Layer):
    def __init__(self, small = 10**(-5) ,**kwargs):
        super().__init__(**kwargs)
        self.small = small

    def build(self, shape):
        self.alpha = self.add_weight(initializer="ones", shape=shape[-1:], name="alpha")
        self.beta = self.add_weight(initializer="zeros", shape=shape[-1:], name="beta")

    def call(self, x):
        mean, variance = tf.nn.moments(x, axes=-1,keepdims=True)
        return self.alpha * (x - mean)/tf.sqrt(self.alpha + self.small) + self.beta

    def get_config(self):
        base = super().get_config()
        return {**base, 'small': self.small}

In [116]:
customLayerNorm = LayerNormalization()
LayerNorm = tf.keras.layers.LayerNormalization()

In [117]:
customLayerNorm.build(X.shape)
LayerNorm.build(X.shape)
customLayerNorm.set_weights([alpha, beta])
LayerNorm.set_weights([alpha, beta])

In [118]:
mae = tf.keras.losses.MeanAbsoluteError()
loss = mae(customLayerNorm(X), LayerNorm(X))
tf.reduce_mean(loss)

<tf.Tensor: shape=(), dtype=float32, numpy=8.172988827936933e-07>

In [119]:
(X_train_full, y_train_full), (X_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
X_train_full = X_train_full.astype(np.float32) / 255.
X_train, X_val = X_train_full[:50000], X_train_full[50000:]
y_train, y_val = y_train_full[:50000], y_train_full[50000:]
X_test = X_test.astype(np.float32) / 255.

In [120]:
lower_layers = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=[28, 28]),
    tf.keras.layers.Dense(200, activation="relu"),
    tf.keras.layers.Dense(200, activation="relu"),
])
upper_layers = tf.keras.Sequential([
    tf.keras.layers.Dense(10, activation="softmax"),
])
model = tf.keras.Sequential([
    lower_layers, upper_layers
])

In [121]:
lower_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01,  momentum=0.99, nesterov=True)
upper_optimizer = tf.keras.optimizers.Nadam(learning_rate=1e-4)

In [122]:
n_epochs = 1
batch_size = 50
n_steps = len(X_train) // batch_size
loss = tf.keras.losses.sparse_categorical_crossentropy
mean_loss = tf.keras.metrics.Mean()
metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]

In [123]:
def random_batch(X, y, batch_size=32):
    idx = np.random.randint(len(X), size=batch_size)
    return X[idx], y[idx]

In [124]:
with trange(1, n_epochs + 1, desc="All epochs") as epochs:
    for epoch in epochs:
        with trange(1, n_steps + 1, desc=f"Epoch {epoch}/{n_epochs}") as steps:
            for step in steps:
                time.sleep(0.001)
                X_batch, y_batch = random_batch(X_train, y_train)
                with tf.GradientTape(persistent=True) as tape:
                    y_pred = model(X_batch)
                    main_loss = tf.reduce_mean(loss(y_batch, y_pred))
                    total_loss = tf.add_n([main_loss] + model.losses)
                for layers, optimizer in ((lower_layers, lower_optimizer), (upper_layers, upper_optimizer)):
                    gradients = tape.gradient(total_loss, layers.trainable_variables)
                    optimizer.apply_gradients(zip(gradients, layers.trainable_variables))
                del tape
                for variable in model.variables:
                    if variable.constraint is not None:
                        variable.assign(variable.constraint(variable))                    
                status = OrderedDict()
                mean_loss(total_loss)
                status["loss"] = mean_loss.result().numpy()
                for metric in metrics:
                    metric(y_batch, y_pred)
                    status[metric.name] = metric.result().numpy()
                steps.set_postfix(status, refresh=False)
            y_pred = model(X_val)
            status["val_loss"] = np.mean(loss(y_val, y_pred))
            status["val_accuracy"] = np.mean(tf.keras.metrics.sparse_categorical_accuracy(
                tf.constant(y_val, dtype=np.float32), y_pred))
            steps.set_postfix(status)
        for metric in [mean_loss] + metrics:
            metric.reset_state()

All epochs:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 1/1:   0%|          | 0/1000 [00:00<?, ?it/s][A
Epoch 1/1:   0%|          | 1/1000 [00:00<02:46,  6.00it/s, loss=2.42, sparse_categorical_accuracy=0.0625][A
Epoch 1/1:   0%|          | 3/1000 [00:00<01:23, 11.93it/s, loss=2.37, sparse_categorical_accuracy=0.104] [A
Epoch 1/1:   1%|          | 6/1000 [00:00<00:58, 17.01it/s, loss=2.27, sparse_categorical_accuracy=0.177][A
Epoch 1/1:   1%|          | 9/1000 [00:00<00:49, 20.18it/s, loss=2.16, sparse_categorical_accuracy=0.278][A
Epoch 1/1:   1%|          | 12/1000 [00:00<00:50, 19.48it/s, loss=2.08, sparse_categorical_accuracy=0.341][A
Epoch 1/1:   2%|▏         | 15/1000 [00:00<00:46, 21.17it/s, loss=1.98, sparse_categorical_accuracy=0.39] [A
Epoch 1/1:   2%|▏         | 18/1000 [00:00<00:45, 21.38it/s, loss=1.86, sparse_categorical_accuracy=0.434][A
Epoch 1/1:   2%|▏         | 21/1000 [00:01<00:44, 21.88it/s, loss=1.79, sparse_categorical_accuracy=0.449][A
Epoch 1/1:   2%|▏