![alt text](./TrainingANN.png "Image produced by callback")


## Keras Training Visualization

This notebook uses a Keras callback and Matplotlib to display an animated graph of a model being trained.<br />

The model trained is a multi-layer neural network. The callback function in the cell titled "Function that draws and updates the graph" is generic and can be used for any neural network. The key limitations are that it's only for a single input and single output.<br />

The basic strategy is to create a hook after each mini-batch training. The callback runs the model on the dataset and plots the results and (if desired) the current mean squared error.<br />

### Import Requirements
Note that this uses qt5 for display, not inline. It is possible to do animation inline, but it's a bit more limiting.

In [1]:
import numpy as np
import pandas as pd
%matplotlib qt5
import matplotlib.pyplot as plt

### Create Data

This is a toy dataset made to show how a neural network can fit complex functions.

In [2]:
# evenly spaced input values
x = np.linspace(0, 12, 10000).reshape(-1,1)

# function that maps features to labels
def f(x):
    # zero, then sin for a bit, then zero again
    if x < np.pi:
        return 0
    elif x < 3 * np.pi:
        return np.sin(x)
    else:
        return 0

# map the labels onto the features
y = np.array([f(v) for v in x]).reshape(-1)

# create labels that are normally distributed around the curve
noise = np.random.randn(len(y)) * 0.5

features = x
labels = y + noise

## Take a quick look at the function and dataset created

In [3]:
plt.figure(figsize=(15, 10))
plt.ylim(-2.5, 2.5)
plt.plot(x, y, lw=3)
plt.scatter(features, labels, s=1, alpha=0.5, c='gray')

<matplotlib.collections.PathCollection at 0x1182f0a20>

### Split into test and training datasets

In [4]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(features.reshape(-1, 1), labels, random_state=42)

## Function that draws and updates the graph
The callback function should be run *on_batch_end*. It determines whether an update is necessary and redraws the graph, displaying related data, as well.<br />

The enclosing function does some set up and creates a closure with the data that will be retained between calls or that needs to be known in advance, because keras passes very little information to the callback.<br />

This function should be reusable for graphing any single-input single-output model in Keras. I have used it, e.g., for multi-layer neural networks.<br />

Expect this to be *super slow*. Especially for simple models, the cost of running and graphing the model will be significantly higher than the cost of training on a single batch. You will get warnings from Keras about slowness.<br />

This is a toy for learning / investigating. So, performance is not a primary concern, but you do want it to be usable. Several things can have a big impact on performance.<br />
<ol>
<li>The **sparsity** options reduce the number datapoints used in the scatter plot and in running the model on each pass.</li>
<li>The **frequency** option determines how often to update the graph. I've used the function option to update only when the change to the model parameters is big enough to justify an update. But, it's specific to the model being trained.</li>
<li>Turning off the **error display** slightly reduces the number of computations, but significantly reduces the amount of drawing.</li>
</ol>

There is also an option to write the updates to files as individual images, which can then be used to create an animation.

In [5]:
from keras.callbacks import LambdaCallback

import matplotlib as mpl
from matplotlib import gridspec

# prevents overflow when displaying a lot of data points
mpl.rcParams['agg.path.chunksize'] = 100000

# do some set up and return a closure around the redraw callback for Keras 
def get_redraw(X_in, y_in, model, batch_size, epochs, **kwargs):

    ## PROCESS COMMAND LINE ARGUMENTS
    
    # plot dimensions
    left = kwargs.get('left', X_in.min())
    right = kwargs.get('right', X_in.max())
    bottom = kwargs.get('bottom', y_in.min())
    top = kwargs.get('top', y_in.max())

    # how much data to use in graph
    
    # ... scatter plot sparsity (0 = no scatter plot)
    #     scatter is only drawn once, but it can be a lot of data, both computationally and visually
    scatter_sparsity = kwargs.get('scatter_sparsity', 5)

    # ... graph sparsity
    #     keeping the graph sparse improves performance
    graph_sparsity = kwargs.get('graph_sparsity', 1000)

    # whether to display error
    show_err = kwargs.get('show_err', True)
    
    # .. and level of smoothing to apply to error, if so (needs to be an odd number)
    err_smoothing = kwargs.get('err_smoothing', 101)
    
    # how frequently (in batches) to update the graph
    frequency = kwargs.get('frequency', 10)
    if callable(frequency):
        # if a function is provided, it will be called every batch and asked for a True/False response
        frequency_mode = 'function'
    elif np.isscalar(frequency):
        # if a number is provided, updates will be done every [frequency] batches
        frequency_mode = 'scalar'
    else:
        # for array-like setting, update when frequency[batch number] is True
        frequency_mode = 'array'
        
    # figure size
    figure_size = kwargs.get('figure_size', (15, 10))
    
    # text labels
    title = kwargs.get('title', None)
    x_label = kwargs.get('x_label', None)
    y_label = kwargs.get('y_label', None)
    
    # tick formatters
    x_tick_formatter = kwargs.get('x_tick_formatter', None)
    y_tick_formatter = kwargs.get('y_tick_formatter', None)
    
    # loss scale (depends on loss function)
    loss_scale = kwargs.get('loss_scale', 1.0)
    
    # display legend
    show_legend = kwargs.get('show_legend', True)
    
    # write to screen or file?
    display_mode = kwargs.get('display_mode', 'screen')
    filepath = kwargs.get('filepath', 'images/batch')
    
    
    ## PREP DATA FOR QUICKER DISPLAY
    
    # parallel sort feature and label arrays for plotting
    ix = X_in.argsort(axis=0)[:,0]

    # ... reducing number of points used in graph
    ix = ix[::graph_sparsity]
    
    X = X_in[ix]
    y = y_in[ix]
    
    # keep track of total number of batches seens
    tot_batches = 0
    batches_per_epoch = np.ceil(len(X_in) / batch_size)
    
    # scale for loss plot = total number of batches that will be run
    max_batches = epochs * batches_per_epoch
    
    ## DRAW BACKGROUND COMPONENTS
    
    # set the figure size and layout
    fig = plt.figure(figsize=figure_size)
    grd = gridspec.GridSpec(ncols=3, nrows=2)
    
    # graphs for data/error, loss over time, and model parameters
    ax_main = fig.add_subplot(grd[:2, :2])
    ax_loss = fig.add_subplot(grd[:1, 2:])
    ax_params = fig.add_subplot(grd[1:, 2:])

    # data boundaries on main graph
    ax_main.set_xlim(left, right)
    ax_main.set_ylim(bottom, top)
    
    # titles and labels on main graph
    if title:
        ax_main.set_title(title, size=14, fontweight='bold', y=1.03)
    
    if x_label:
        ax_main.set_xlabel(x_label, size=12, fontweight='bold')
        
    if y_label:
        ax_main.set_ylabel(y_label, size=12, fontweight='bold')
                 
    # tick formatting on main graph
    if x_tick_formatter:
        ax_main.xaxis.set_major_formatter(x_tick_formatter)
        
    if y_tick_formatter:
        ax_main.yaxis.set_major_formatter(y_tick_formatter)
        
    # draw a scatter plot of the training data on main graph
    if scatter_sparsity > 0:
        ax_main.scatter(X_in[::scatter_sparsity], y_in[::scatter_sparsity], marker='.', c='silver', s=1, alpha=0.5, zorder=10)

    # set titles and labels on loss plots
    ax_loss.set_title("Total Loss", size=11, fontweight='bold', y=0.9)
    ax_loss.set_xlabel("Batch", size=9, fontweight='bold')
    ax_loss.set_ylabel("Loss", size=9, fontweight='bold')

    ax_params.set_title("Batch Loss", size=11, fontweight='bold', y=0.9)
    ax_params.set_xlabel("Batch", size=9, fontweight='bold')
    ax_params.set_ylabel("Loss", size=9, fontweight='bold')

    # set scale of loss plots
    # x axes are logarithimic because progress slows over course of training
    ax_loss.set_xscale('log', nonposx='clip')
    ax_loss.set_xlim(1, max_batches)
    ax_loss.set_ylim(0, loss_scale)        

    ax_params.set_xscale('log', nonposx='clip')
    ax_params.set_xlim(1, max_batches)
    ax_params.set_ylim(0, loss_scale)        

    if display_mode == 'file':
        plt.savefig("%s-%05d.png" %(filepath, 0))
    
    # declare components that will be retained between calls
    first_pass = True
    y_pred_line = None
    err_line_u = None
    err_line_d = None
    fill_between = None
    
    # RETURN A CALLBACK FUNCTION usable by keras with closure around fixed arguments
    def redraw(batch, logs):
        
        # let Python know that outside scope variables will be used
        nonlocal first_pass, y_pred_line, err_line_u, err_line_d, fill_between, tot_batches

        # keep track of total number of batches seens
        tot_batches += 1
                
        # update graph at the requested frequency
        
        if frequency_mode == 'scalar':
            if tot_batches % frequency != 0:
                return
        elif frequency_mode == 'array':
            if not frequency[tot_batches]:
                return    
        
        if frequency_mode == 'function':
            if not frequency(model, X, y, tot_batches):
                return
        
        # run the model in its current state of training to get the prediction so far
        y_pred = model.predict(X).reshape(-1)
        
    
        if show_err:
            
            # compute the error relative to each training label
            err = np.square(y - y_pred.reshape(-1))

            # smooth error with a moving average 
            if err_smoothing > 1:
                err = np.convolve(err, np.ones((err_smoothing,))/err_smoothing, mode='same')

        # first time through, draw the dynamic portions
        if first_pass:

            # draw the current prediction of the model
            y_pred_line = ax_main.plot(X, y_pred, '-', color='steelblue', lw=4, label='model', zorder=15)[0]

            if show_err:

                # draw the error around the prediction
                err_line_u = ax_main.plot(X, y_pred + err, '-', alpha=0.6, lw=0.5, color='steelblue', label='err', zorder=3)[0]
                err_line_d = ax_main.plot(X, y_pred - err, '-', alpha=0.6, lw=0.5, color='steelblue', zorder=3)[0]

            if display_mode == 'screen':
                plt.show()

            first_pass = False

        # on subsequent calls, update the dynamic portions
        else:

            # draw the current prediction of the model
            y_pred_line.set_ydata(y_pred)

            # update the error around the prediction
            if show_err:
                err_line_u.set_ydata(y_pred + err)
                err_line_d.set_ydata(y_pred - err)

        if show_err:
            
            # shade in the area between the error lines
            if fill_between:
                fill_between.remove()

            fill_between = ax_main.fill_between(X.reshape(-1), y_pred + err, y_pred - err, color='steelblue', alpha=0.2, zorder=0)

        # add points to loss graphs
        tot_loss = err.sum() / len(y)
        ax_loss.scatter([tot_batches], [tot_loss], s=5, c='steelblue')            
        ax_params.scatter([tot_batches], [logs['loss']], s=5, c='steelblue')            
            
        if show_legend:
            ax_main.legend()        

        if display_mode == 'screen':

            # push changes to screen
            fig.canvas.draw()
            fig.canvas.flush_events()

        elif display_mode == 'file':

            # save changes to image file        
            plt.savefig("%s-%05d.png" % (filepath, tot_batches))
    
    # return the closure around the callback that Keras will use
    return redraw

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


## Create and compile the model

In [6]:
from keras.models import Sequential
from keras.layers import Dense

model = Sequential()
model.add(Dense(200, input_dim=features.shape[1], activation='relu'))
model.add(Dense(200, activation='relu'))
model.add(Dense(200, activation='relu'))
model.add(Dense(1))

model.compile(loss='mean_squared_error', optimizer='adagrad')

epochs = 75
batch_size = 128


## Build the callback function that will be passed to Keras and train the model

In [7]:
from keras.callbacks import LambdaCallback

# get closured redraw callback function
# this will also draw the background for the graph
cb_redraw = get_redraw( X_train, y_train, model, batch_size, epochs,
                        frequency=20, graph_sparsity=1,
                        scatter_sparsity=1, show_err=True, err_smoothing=201,
                        title="Neural Network Fitting Complex Function",
                        x_label="x",
                        y_label="f(x)",
                        loss_scale=0.8, display_mode='screen')

# wrap callback function in Keras structure, to be called after each batch
redraw_callback = LambdaCallback(on_batch_end=cb_redraw)

# train the model, passing the Keras-wrapped callback function
model.fit(X_train, y_train, epochs=epochs, batch_size=batch_size, callbacks=[redraw_callback])

Epoch 1/75
Epoch 2/75
 256/7500 [>.............................] - ETA: 8s - loss: 0.5236

  % delta_t_median)
  % delta_t_median)


Epoch 3/75
 384/7500 [>.............................] - ETA: 5s - loss: 0.4908

  % delta_t_median)


Epoch 4/75
Epoch 5/75
Epoch 6/75
Epoch 7/75
Epoch 8/75
Epoch 9/75
Epoch 10/75
Epoch 11/75
Epoch 12/75
Epoch 13/75
Epoch 14/75
Epoch 15/75
Epoch 16/75
Epoch 17/75
Epoch 18/75
Epoch 19/75
Epoch 20/75
Epoch 21/75
Epoch 22/75
 256/7500 [>.............................] - ETA: 8s - loss: 0.3786

  % delta_t_median)
  % delta_t_median)


Epoch 23/75
 384/7500 [>.............................] - ETA: 5s - loss: 0.3960

  % delta_t_median)


Epoch 24/75
Epoch 25/75
Epoch 26/75
Epoch 27/75
Epoch 28/75
Epoch 29/75
Epoch 30/75
Epoch 31/75
Epoch 32/75
Epoch 33/75
Epoch 34/75
Epoch 35/75
Epoch 36/75
Epoch 37/75
Epoch 38/75
Epoch 39/75
Epoch 40/75
Epoch 41/75
Epoch 42/75
 256/7500 [>.............................] - ETA: 9s - loss: 0.2740

  % delta_t_median)
  % delta_t_median)


Epoch 43/75
 384/7500 [>.............................] - ETA: 6s - loss: 0.2871

  % delta_t_median)


Epoch 44/75
Epoch 45/75
Epoch 46/75
Epoch 47/75
Epoch 48/75
Epoch 49/75
Epoch 50/75
Epoch 51/75
Epoch 52/75
Epoch 53/75
Epoch 54/75
Epoch 55/75
Epoch 56/75
Epoch 57/75
Epoch 58/75
Epoch 59/75
Epoch 60/75
Epoch 61/75
Epoch 62/75
 256/7500 [>.............................] - ETA: 10s - loss: 0.2664

  % delta_t_median)
  % delta_t_median)


Epoch 63/75
 384/7500 [>.............................] - ETA: 6s - loss: 0.2660

  % delta_t_median)


Epoch 64/75
Epoch 65/75
Epoch 66/75
Epoch 67/75
Epoch 68/75
Epoch 69/75
Epoch 70/75
Epoch 71/75
Epoch 72/75
Epoch 73/75
Epoch 74/75
Epoch 75/75


<keras.callbacks.History at 0x1a354d2240>