# Building custom training loops

## Importing libraries

In [1]:
import pandas as pd
import numpy as np
import time
from ipynb.fs.full.Useful_funcs import data_pipeline, pre_model, create_huber # Custom funcs for data processing, modelling, compiling and training
import tensorflow as tf
from sklearn.datasets import fetch_california_housing
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.activations import selu, relu, elu
from tensorflow.keras.initializers import lecun_normal, he_normal
from tensorflow.keras.optimizers import Nadam
from tensorflow.keras.losses import mean_squared_error, mse
from tensorflow.keras.regularizers import l2
from tensorflow.keras.metrics import Mean, MeanAbsoluteError

## Loading the dataset

In [2]:
housing = fetch_california_housing()
x_train, x_train_scaled, x_valid, x_valid_scaled, x_test, x_test_scaled, y_train, y_valid, y_test = data_pipeline(housing)

## Building the model

In [3]:
pre_model()

- We will first build a simple model. Need not compile it as we will be handling the training manually.

In [4]:
model = Sequential()
model.add(Dense(30, activation = elu, kernel_initializer = he_normal, kernel_regularizer = l2(0.05)))
model.add(Dense(1, kernel_regularizer = l2(0.05)))

- Define a func that will randomly sample a batch of data from the training set.

In [5]:
def random_batch(x, y, batch_size = 32):
    ids = np.random.randint(len(x), size = batch_size) # Collecting random indexes
    return x[ids], y[ids]

- Define a function that display the training details like the no. of batches, total no. of batches, the mean loss since start of the epoch and other metrics.

In [6]:
def print_status_bar(iteration, total, loss, metrics = None):
    metrics = ' - '.join(['{} : {:.4f}'.format(m.name, m.result()) for m in [loss] + (metrics or [])])
    end = '' if iteration < total else '\n'
    print('\r{}/{} - '.format(iteration, total) + metrics, end = end)

- {:.4f} will format with 4 digits after the decimal point.
- Using \r and end = '' ensures that the status bar always gets printed on the same line.

In [7]:
mean_loss = keras.metrics.Mean(name = 'loss')
mean_square = keras.metrics.Mean(name = 'mean_square')
for i in range(1, 50 + 1):
    loss = 1 / i
    mean_loss(loss)
    mean_square(i ** 2)
    print_status_bar(i, 50, mean_loss, [mean_square])
    time.sleep(0.05)

50/50 - loss : 0.0900 - mean_square : 858.5000


- Lets define a fancier status update with a progress bar

In [8]:
def progress_bar(iteration, total, size = 30): # Size of bar, not batch size.
    running = iteration < total
    c = '>' if running else '='
    p = (size - 1) * iteration // total
    fmt = '{}/{} [{}]' # Format for the status bar
    params = [iteration, total, '=' * p + c + '.' * (size - p - 1)]
    return fmt.format(*params)

In [9]:
progress_bar(3500, 10000, size = 6)

'3500/10000 [=>....]'

In [25]:
def print_status_bar(iteration, total, loss, metrics = None):
    metrics = ' - '.join(['{}: {:.4f}'.format(m.name, m.result()) for m in [loss] + (metrics or [])])
    end = '' if iteration < total else '\n'
    print('\r{} - {} - '.format(progress_bar(iteration, total), metrics), end = end)

In [26]:
mean_loss = keras.metrics.Mean(name = 'loss')
mean_square = keras.metrics.Mean(name = 'mean_square')
for i in range(1, 50 + 1):
    loss = 1 / i
    mean_loss(loss)
    mean_square(i ** 2)
    print_status_bar(i, 50, mean_loss, [mean_square])
    time.sleep(0.05)



In [30]:
pre_model()

- Lets now define the hyperparameters to be used.

In [31]:
n_epochs = 5
batch_size = 32
n_steps = len(x_train) // batch_size
optimizer = Nadam(lr = 0.01)
loss_fn = mean_squared_error
mean_loss = Mean()
metrics = [MeanAbsoluteError()]

In [34]:
for epoch in range(1, n_epochs + 1):
    print('Epoch {}/{}'.format(epoch, n_epochs)) # Printing epoch status
    for step in range(1, n_steps + 1):
        x_batch, y_batch = random_batch(x_train_scaled, y_train) # Randomly picking a batch from training data
        with tf.GradientTape() as tape: # Taping forward prop
            y_pred = model(x_batch) # Predicting the labels
            main_loss = tf.reduce_mean(loss_fn(y_batch, y_pred)) # Calculating MSE 
            loss = tf.add_n([main_loss] + model.losses) # Adding MSE with other model losses
        gradients = tape.gradient(loss, model.trainable_variables) # Computing the gradients
        optimizer.apply_gradients(zip(gradients, model.trainable_variables)) # Using the optimizer for Gradient Descent
        for variable in model.variables:
            if variable.constraint is not None:
                variable.assign(variable.constraint(variable)) # Applying constraints to the variables if any 
        mean_loss(loss) # Calculating the mean loss
        for metric in metrics:
            metric(y_batch, y_pred) # Calculating the metrics
        print_status_bar(step * batch_size, len(y_train), mean_loss, metrics) # Printing the status bar for steps
    print_status_bar(len(y_train), len(y_train), mean_loss, metrics) # Printing the status bar for the epochs
    for metric in [mean_loss] + metrics:
        metric.reset_states() # Resetting the metrics 

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


- We created two nested loops, one for the epochs and the other one for the steps within each epoch.
- Then we sampled a random batch from the training data. These batches are not mutually exclusive as the some instances may be sampled repeatedly and some may not be sampled at all.
- Inside the tf.GradientTape() loop we make the prediction for one batch using the model() as func and we compute the loss. The total loss is equal to the sum of the MSE loss computed and the regularization loss per layer in this case. Since MSE returns one loss value per instance we will calculate the mean across the entire batch.
- We ask the tape to compute the gradients of the loss wrt each variable. Then we apply them to the optimizer to perform a Gradient Descent step.
- We calculate the mean loss and the metrics over the current epoch and we display the status bar.
- In the end after each epoch we display the final status bar and reset the metric values.