# Train a model using a custom training loop to tackle the Fashion MNIST dataset
1. Display the epoch, iteration, mean training loss, and mean accuracy over each epoch (updated at each iteration), as well as the validation loss and accuracy at the end of each epoch.
2. Try using a different optimizer with a different learning rate for the upper layers and the lower layers.

In [2]:
from collections import OrderedDict
import numpy as np
import tensorflow as tf
from tqdm.auto import trange

In [3]:
(X_train_full, y_train_full), (X_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
X_train_full = X_train_full.astype('float32') / 255.
X_valid, X_train = X_train_full[:5000], X_train_full[5000:]
y_valid, y_train = y_train_full[:5000], y_train_full[5000:]
X_test = X_test.astype('float32') / 255.

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
[1m29515/29515[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
[1m26421880/26421880[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
[1m5148/5148[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
[1m4422102/4422102[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


In [4]:
tf.keras.utils.set_random_seed(42)

In [5]:
lower_layers = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape= [28, 28]),
    tf.keras.layers.Dense(100, activation= 'relu'),
])
upper_layers = tf.keras.Sequential([
    tf.keras.layers.Dense(10, activation= 'softmax'),
])
model = tf.keras.Sequential([
    lower_layers, upper_layers
])

  super().__init__(**kwargs)
I0000 00:00:1768810593.539808      55 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 15513 MB memory:  -> device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0


In [6]:
lower_optimizer = tf.keras.optimizers.SGD(learning_rate= 1e-4)
upper_optimizer = tf.keras.optimizers.Nadam(learning_rate= 1e-3)

In [7]:
n_epochs = 5
batch_size = 32
n_steps = len(X_train) // batch_size
loss_fn = tf.keras.losses.sparse_categorical_crossentropy
mean_loss = tf.keras.metrics.Mean()
metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]

In [8]:
def random_batch(X, y, *, batch_size: int= 32) -> tuple:
    idx = np.random.randint(len(X), size= batch_size)
    return X[idx], y[idx]

In [9]:
with trange(1, n_epochs + 1, desc= 'All epochs') as epochs:
    for epoch in epochs:
        # Progress bar for steps within the current epoch
        with trange(1, n_steps + 1, desc= f'Epoch {epoch}/{n_epochs}', leave= False) as steps:
            for step in steps:
                # --------------------------------------------------
                # 1. Sample a random mini-batch
                # --------------------------------------------------
                X_batch, y_batch = random_batch(X_train, y_train)

                # --------------------------------------------------
                # 2. Forward pass + loss computation
                # --------------------------------------------------
                with tf.GradientTape(persistent= True) as tape:
                    y_pred = model(X_batch)
                    # Primary task loss
                    main_loss = tf.reduce_mean(loss_fn(y_batch, y_pred))
                    # Add regularization losses (e.g., L2)
                    loss = tf.add_n([main_loss] + model.losses)

                # --------------------------------------------------
                # 3. Backpropagation with different optimizers
                # --------------------------------------------------
                for layers, optimizer in (
                    (lower_layers, lower_optimizer),
                    (upper_layers, upper_optimizer),
                ):
                    grads = tape.gradient(loss, layers.trainable_variables)
                    # Filter out None gradients (safety)
                    grads_vars = [
                        (g, v) for g, v in zip(grads, layers.trainable_variables)
                        if g is not None
                    ]
                    optimizer.apply_gradients(grads_vars)

                # Explicitly release resources
                del tape

                # --------------------------------------------------
                # 4. Apply variable constraints (if any)
                # --------------------------------------------------
                for var in model.variables:
                    if var.constraint is not None:
                        var.assign(var.constraint(var))

                # --------------------------------------------------
                # 5. Update training metrics
                # --------------------------------------------------
                status = OrderedDict()

                mean_loss.update_state(loss)
                status['loss'] = mean_loss.result().numpy()

                for metric in metrics:
                    metric.update_state(y_batch, y_pred)
                    status[metric.name] = metric.result().numpy()

                # Update tqdm display
                steps.set_postfix(status)

        # ------------------------------------------------------
        # 6. Validation phase (outside step loop)
        # ------------------------------------------------------
        y_val_pred = model(X_valid, training= False)

        val_loss = tf.reduce_mean(loss_fn(y_valid, y_val_pred))
        val_acc = tf.reduce_mean(
            tf.keras.metrics.sparse_categorical_accuracy(y_valid, y_val_pred)
        )

        epochs.set_postfix(
            {
                'loss': status['loss'],
                'val_loss': val_loss.numpy(),
                'val_accuracy': val_acc.numpy(),
            }
        )

        # ------------------------------------------------------
        # 7. Reset metrics at end of epoch
        # ------------------------------------------------------
        for metric in [mean_loss] + metrics:
            metric.reset_state()

All epochs:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 1/5:   0%|          | 0/1718 [00:00<?, ?it/s]

Epoch 2/5:   0%|          | 0/1718 [00:00<?, ?it/s]

Epoch 3/5:   0%|          | 0/1718 [00:00<?, ?it/s]

Epoch 4/5:   0%|          | 0/1718 [00:00<?, ?it/s]

Epoch 5/5:   0%|          | 0/1718 [00:00<?, ?it/s]