![alt text](./images/TrainedModel.png "Image produced by callback")


## Keras Training Visualization

This notebook uses a Keras callback and Matplotlib to display an animated graph of a model being trained.<br />

The model trained is a linear regression (modeled as a single node neural network). The callback function in the cell titled "Function that draws and updates the graph" is generic and can be used for any neural network. The key limitations are that it's only for a single input and single output.<br />

The basic strategy is to create a hook after each mini-batch training. The callback runs the model on the dataset and plots the results and (if desired) the current mean squared error.<br />

### Import Requirements
Note that this uses qt5 for display, not inline. It is possible to do animation inline, but it's a bit more limiting.

In [1]:
import numpy as np
import pandas as pd
%matplotlib qt5
import matplotlib.pyplot as plt

### Load Data
This is publicly available data from the Lending Club (http://bit.ly/2LC6wth) on the performance of loans that they issued from 2017 to present. Interest rate was provided. I computed the total losses from fields in their file, did a bit of cleanup, and pickled it in this file. My data set is restricted to loans that have defaulted (aka "charged off").

In [9]:
data = pd.read_pickle('data/lend_club_ir_v_losses.pkl')
data.head()

Unnamed: 0,int_rate,loss_pct
1,0.1527,0.768256
8,0.2128,0.937043
9,0.1269,0.823038
12,0.1349,0.810327
14,0.1065,0.392143


### Select Features and Labels

In [3]:
features = data['int_rate'].values
labels = data['loss_pct'].values

### Split into test and training datasets

In [4]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(features.reshape(-1, 1), labels, random_state=42)

### Function that determines how often to update the graph
Designed spefically for the one-parameter linear model below, this function asks for an update whenever the weight or bias has changed by more than a specified threshold (defaults to 1%). A minimum frequency (defaults to once every 500 batches) can also be specified.

In [5]:
from keras.layers import Dense

# returns closured callback that will determine how often the graph is redrawn
def get_frequency_callback(**kwargs):

    # parameters
    weight_threshold = kwargs.get('weight_threshold', 0.01)
    bias_threshold = kwargs.get('bias_threshold', 0.01)
    min_frequency = kwargs.get('min_frequency', 500)

    # declared variables that will be retained between invocations of the callback
    layer = None
    w_prev = 0
    b_prev = 0
    batch_prev = 0
    
    # the callback that will actually make the decision to update or not
    def frequency_callback(model, X, y, tot_batches):

        nonlocal layer, w_prev, b_prev, batch_prev
        
        # get the model layer containing the weight and bias
        if layer == None:
            layer = model.get_layer('output')
        
        # get the current value of the weight and bias
        w = layer.get_weights()[0][0][0]
        b = layer.get_weights()[1][0]

        # assume change was too small for an update
        display = False
        
        # if change in weight or bias exceeds relevant threshold, or it's been too long since the last update 
        if (np.abs(w - w_prev) > weight_threshold or np.abs(b - b_prev) > bias_threshold) \
            or tot_batches - batch_prev > min_frequency:
            
            # update on this iteration
            display = True
            
            # keep track of weight and bias from last update
            w_prev = w
            b_prev = b
            
            # keep track of how long since the last update
            batch_prev = tot_batches
        
        return display
    
    # return closure
    return frequency_callback

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


## Function that draws and updates the graph
The callback function should be run *on_batch_end*. It determines whether an update is necessary and redraws the graph, displaying related data, as well.<br />

The enclosing function does some set up and creates a closure with the data that will be retained between calls or that needs to be known in advance, because keras passes very little information to the callback.<br />

This function should be reusable for graphing any single-input single-output model in Keras. I have used it, e.g., for multi-layer neural networks.<br />

Expect this to be *super slow*. Especially for simple models, the cost of running and graphing the model will be significantly higher than the cost of training on a single batch. You will get warnings from Keras about slowness.<br />

This is a toy for learning / investigating. So, performance is not a primary concern, but you do want it to be usable. Several things can have a big impact on performance.<br />
<ol>
<li>The **sparsity** options reduce the number datapoints used in the scatter plot and in running the model on each pass.</li>
<li>The **frequency** option determines how often to update the graph. I've used the function option to update only when the change to the model parameters is big enough to justify an update. But, it's specific to the model being trained.</li>
<li>Turning off the **error display** slightly reduces the number of computations, but significantly reduces the amount of drawing.</li>
</ol>

There is also an option to write the updates to files as individual images, which can then be used to create an animation.

In [6]:
from keras.callbacks import LambdaCallback

import matplotlib as mpl
from matplotlib import gridspec

# prevents overflow when displaying a lot of data points
mpl.rcParams['agg.path.chunksize'] = 100000

# do some set up and return a closure around the redraw callback for Keras 
def get_redraw(X_in, y_in, model, batch_size, epochs, **kwargs):

    ## PROCESS COMMAND LINE ARGUMENTS
    
    # plot dimensions
    left = kwargs.get('left', X_in.min())
    right = kwargs.get('right', X_in.max())
    bottom = kwargs.get('bottom', y_in.min())
    top = kwargs.get('top', y_in.max())

    # how much data to use in graph
    
    # ... scatter plot sparsity (0 = no scatter plot)
    #     scatter is only drawn once, but it can be a lot of data, both computationally and visually
    scatter_sparsity = kwargs.get('scatter_sparsity', 5)

    # ... graph sparsity
    #     keeping the graph sparse improves performance
    graph_sparsity = kwargs.get('graph_sparsity', 1000)

    # whether to display error
    show_err = kwargs.get('show_err', True)
    
    # .. and level of smoothing to apply to error, if so (needs to be an odd number)
    err_smoothing = kwargs.get('err_smoothing', 101)
    
    # how frequently (in batches) to update the graph
    frequency = kwargs.get('frequency', 10)
    if callable(frequency):
        # if a function is provided, it will be called every batch and asked for a True/False response
        frequency_mode = 'function'
    elif np.isscalar(frequency):
        # if a number is provided, updates will be done every [frequency] batches
        frequency_mode = 'scalar'
    else:
        # for array-like setting, update when frequency[batch number] is True
        frequency_mode = 'array'
        
    # figure size
    figure_size = kwargs.get('figure_size', (15, 10))
    
    # text labels
    title = kwargs.get('title', None)
    x_label = kwargs.get('x_label', None)
    y_label = kwargs.get('y_label', None)
    
    # tick formatters
    x_tick_formatter = kwargs.get('x_tick_formatter', None)
    y_tick_formatter = kwargs.get('y_tick_formatter', None)
    
    # loss scale (depends on loss function)
    loss_scale = kwargs.get('loss_scale', 1.0)
    
    # display legend
    show_legend = kwargs.get('show_legend', True)
    
    # write to screen or file?
    display_mode = kwargs.get('display_mode', 'screen')
    filepath = kwargs.get('filepath', 'images/batch')
    
    
    ## PREP DATA FOR QUICKER DISPLAY
    
    # parallel sort feature and label arrays for plotting
    ix = X_in.argsort(axis=0)[:,0]

    # ... reducing number of points used in graph
    ix = ix[::graph_sparsity]
    
    X = X_in[ix]
    y = y_in[ix]
    
    # keep track of total number of batches seens
    tot_batches = 0
    batches_per_epoch = np.ceil(len(X_in) / batch_size)
    
    # scale for loss plot = total number of batches that will be run
    max_batches = epochs * batches_per_epoch
    
    ## DRAW BACKGROUND COMPONENTS
    
    # set the figure size and layout
    fig = plt.figure(figsize=figure_size)
    grd = gridspec.GridSpec(ncols=3, nrows=2)
    
    # graphs for data/error, loss over time, and model parameters
    ax_main = fig.add_subplot(grd[:2, :2])
    ax_loss = fig.add_subplot(grd[:1, 2:])
    ax_params = fig.add_subplot(grd[1:, 2:])

    # data boundaries on main graph
    ax_main.set_xlim(left, right)
    ax_main.set_ylim(bottom, top)
    
    # titles and labels on main graph
    if title:
        ax_main.set_title(title, size=14, fontweight='bold', y=1.03)
    
    if x_label:
        ax_main.set_xlabel(x_label, size=12, fontweight='bold')
        
    if y_label:
        ax_main.set_ylabel(y_label, size=12, fontweight='bold')
                 
    # tick formatting on main graph
    if x_tick_formatter:
        ax_main.xaxis.set_major_formatter(x_tick_formatter)
        
    if y_tick_formatter:
        ax_main.yaxis.set_major_formatter(y_tick_formatter)
        
    # draw a scatter plot of the training data on main graph
    if scatter_sparsity > 0:
        ax_main.scatter(X_in[::scatter_sparsity], y_in[::scatter_sparsity], marker='.', c='silver', s=1, alpha=0.5, zorder=10)

    # set titles and labels on loss plots
    ax_loss.set_title("Total Loss", size=11, fontweight='bold', y=0.9)
    ax_loss.set_xlabel("Batch", size=9, fontweight='bold')
    ax_loss.set_ylabel("Loss", size=9, fontweight='bold')

    ax_params.set_title("Batch Loss", size=11, fontweight='bold', y=0.9)
    ax_params.set_xlabel("Batch", size=9, fontweight='bold')
    ax_params.set_ylabel("Loss", size=9, fontweight='bold')

    # set scale of loss plots
    # x axes are logarithimic because progress slows over course of training
    ax_loss.set_xscale('log', nonposx='clip')
    ax_loss.set_xlim(1, max_batches)
    ax_loss.set_ylim(0, loss_scale)        

    ax_params.set_xscale('log', nonposx='clip')
    ax_params.set_xlim(1, max_batches)
    ax_params.set_ylim(0, loss_scale)        

    if display_mode == 'file':
        plt.savefig("%s-%05d.png" %(filepath, 0))
    
    # declare components that will be retained between calls
    first_pass = True
    y_pred_line = None
    err_line_u = None
    err_line_d = None
    fill_between = None
    
    # RETURN A CALLBACK FUNCTION usable by keras with closure around fixed arguments
    def redraw(batch, logs):
        
        # let Python know that outside scope variables will be used
        nonlocal first_pass, y_pred_line, err_line_u, err_line_d, fill_between, tot_batches

        # keep track of total number of batches seens
        tot_batches += 1
                
        # update graph at the requested frequency
        
        if frequency_mode == 'scalar':
            if tot_batches % frequency != 0:
                return
        elif frequency_mode == 'array':
            if not frequency[tot_batches]:
                return    
        
        if frequency_mode == 'function':
            if not frequency(model, X, y, tot_batches):
                return
        
        # run the model in its current state of training to get the prediction so far
        y_pred = model.predict(X).reshape(-1)
        
    
        if show_err:
            
            # compute the error relative to each training label
            err = np.square(y - y_pred.reshape(-1))

            # smooth error with a moving average 
            if err_smoothing > 1:
                err = np.convolve(err, np.ones((err_smoothing,))/err_smoothing, mode='same')

        # first time through, draw the dynamic portions
        if first_pass:

            # draw the current prediction of the model
            y_pred_line = ax_main.plot(X, y_pred, '-', color='steelblue', lw=4, label='model', zorder=15)[0]

            if show_err:

                # draw the error around the prediction
                err_line_u = ax_main.plot(X, y_pred + err, '-', alpha=0.6, lw=0.5, color='steelblue', label='err', zorder=3)[0]
                err_line_d = ax_main.plot(X, y_pred - err, '-', alpha=0.6, lw=0.5, color='steelblue', zorder=3)[0]

            if display_mode == 'screen':
                plt.show()

            first_pass = False

        # on subsequent calls, update the dynamic portions
        else:

            # draw the current prediction of the model
            y_pred_line.set_ydata(y_pred)

            # update the error around the prediction
            if show_err:
                err_line_u.set_ydata(y_pred + err)
                err_line_d.set_ydata(y_pred - err)

        if show_err:
            
            # shade in the area between the error lines
            if fill_between:
                fill_between.remove()

            fill_between = ax_main.fill_between(X.reshape(-1), y_pred + err, y_pred - err, color='steelblue', alpha=0.2, zorder=0)

        # add points to loss graphs
        tot_loss = err.sum() / len(y)
        ax_loss.scatter([tot_batches], [tot_loss], s=5, c='steelblue')            
        ax_params.scatter([tot_batches], [logs['loss']], s=5, c='steelblue')            
            
        if show_legend:
            ax_main.legend()        

        if display_mode == 'screen':

            # push changes to screen
            fig.canvas.draw()
            fig.canvas.flush_events()

        elif display_mode == 'file':

            # save changes to image file        
            plt.savefig("%s-%05d.png" % (filepath, tot_batches))
    
    # return the closure around the callback that Keras will use
    return redraw


## Create and compile the model

In [7]:
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
from keras.initializers import Constant

# initializing to a negative slope so that the training is more interesting
initializer = Constant(-1.0)

# linear model is an ANN with 1 input node, 1 output node, and linear activation
model = Sequential()
model.add(Dense(1, input_dim=X_train.shape[1], kernel_initializer=initializer, 
                activation='linear', name='output'))

model.compile(loss='mean_squared_error', optimizer='sgd')

epochs = 25
batch_size = 128


## Build the callback function that will be passed to Keras and train the model

In [8]:
# get closured redraw callback function
# this will also draw the background for the graph
cb_redraw = get_redraw( X_train, y_train, model, batch_size, epochs,
                        frequency=get_frequency_callback(weight_threshold=0.03, bias_threshold=0.02),
                        scatter_sparsity=3, show_err=True, err_smoothing=51,
                        title="Linear Regression of Losses vs. Interest Rate",
                        x_label="Interest Rate",
                        y_label="Total Loss (% of Funded Amount)",
                        x_tick_formatter=mpl.ticker.PercentFormatter(xmax=1),
                        y_tick_formatter=mpl.ticker.PercentFormatter(xmax=1),
                        loss_scale=0.8, display_mode='screen')

# wrap callback function in Keras structure, to be called after each batch
redraw_callback = LambdaCallback(on_batch_end=cb_redraw)

# train the model, passing the Keras-wrapped callback function
model.fit(X_train, y_train, epochs=epochs, batch_size=batch_size, callbacks=[redraw_callback])

Epoch 1/25
   256/255480 [..............................] - ETA: 35:36 - loss: 0.7103

  % delta_t_median)
  % delta_t_median)


   512/255480 [..............................] - ETA: 25:05 - loss: 0.6706

  % delta_t_median)
  % delta_t_median)


   768/255480 [..............................] - ETA: 21:34 - loss: 0.6463

  % delta_t_median)
  % delta_t_median)


  1536/255480 [..............................] - ETA: 18:10 - loss: 0.5801

  % delta_t_median)
  % delta_t_median)


  2304/255480 [..............................] - ETA: 16:56 - loss: 0.5219

  % delta_t_median)


  2560/255480 [..............................] - ETA: 16:46 - loss: 0.5052

  % delta_t_median)


  2816/255480 [..............................] - ETA: 16:34 - loss: 0.4916

  % delta_t_median)


  3072/255480 [..............................] - ETA: 16:22 - loss: 0.4780

  % delta_t_median)


Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/25
Epoch 22/25
Epoch 23/25
Epoch 24/25
Epoch 25/25


<keras.callbacks.History at 0x1a2c207240>