In [6]:
from ipywidgets import interact
from mlp.data_providers import CCPPDataProvider
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
def error(outputs, targets):
    Error = np.dot((outputs-targets).T,outputs-targets)
    if Error.shape==():
        totalerror = Error
    else:
        totalerror = 0
        for i in range(Error.shape[0]):
            totalerror += Error[i,i]
    return totalerror/(2*outputs.shape[0])
    
def fprop(inputs, weights, biases):
    return np.dot(inputs,weights.T) + biases  

def error_grad(outputs, targets):
    return (outputs-targets)/outputs.shape[0]

def grads_wrt_params(inputs, grads_wrt_outputs):
    return (grads_wrt_outputs.T @ inputs,grads_wrt_outputs.sum(axis=0))

def setup_figure():
    # create figure and axes
    fig = plt.figure(figsize=(12, 6))
    ax1 = fig.add_axes([0., 0., 0.5, 1.], projection='3d')
    ax2 = fig.add_axes([0.6, 0.1, 0.4, 0.8])
    # set axes properties
    ax2.spines['right'].set_visible(False)
    ax2.spines['top'].set_visible(False)
    ax2.yaxis.set_ticks_position('left')
    ax2.xaxis.set_ticks_position('bottom')
    ax2.set_yscale('log')
    ax1.set_xlim((-2, 2))
    ax1.set_ylim((-2, 2))
    ax1.set_zlim((-2, 2))
    #set axes labels and title
    ax1.set_title('Parameter trajectories over training')
    ax1.set_xlabel('Weight 1')
    ax1.set_ylabel('Weight 2')
    ax1.set_zlabel('Bias')
    ax2.set_title('Batch errors over training')
    ax2.set_xlabel('Batch update number')
    ax2.set_ylabel('Batch error')
    return fig, ax1, ax2

def visualise_training(n_epochs=1, batch_size=200, log_lr=-1., n_inits=5,
                       w_scale=1., b_scale=1., elev=30., azim=0.):
    fig, ax1, ax2 = setup_figure()
    # create seeded random number generator
    rng = np.random.RandomState(1234)
    # create data provider
    data_provider = CCPPDataProvider(
        input_dims=[0, 1],
        batch_size=batch_size, 
        shuffle_order=False,
    )
    learning_rate = 10 ** log_lr
    n_batches = data_provider.num_batches
    weights_traj = np.empty((n_inits, n_epochs * n_batches + 1, 1, 2))
    biases_traj = np.empty((n_inits, n_epochs * n_batches + 1, 1))
    errors_traj = np.empty((n_inits, n_epochs * n_batches))
    # randomly initialise parameters
    weights = rng.uniform(-w_scale, w_scale, (n_inits, 1, 2))
    biases = rng.uniform(-b_scale, b_scale, (n_inits, 1))
    # store initial parameters
    weights_traj[:, 0] = weights
    biases_traj[:, 0] = biases
    # iterate across different initialisations
    for i in range(n_inits):
        # iterate across epochs
        for e in range(n_epochs):
            # iterate across batches
            for b, (inputs, targets) in enumerate(data_provider):
                outputs = fprop(inputs, weights[i], biases[i])
                errors_traj[i, e * n_batches + b] = error(outputs, targets)
                grad_wrt_outputs = error_grad(outputs, targets)
                weights_grad, biases_grad = grads_wrt_params(inputs, grad_wrt_outputs)
                weights[i] -= learning_rate * weights_grad
                biases[i] -= learning_rate * biases_grad
                weights_traj[i, e * n_batches + b + 1] = weights[i]
                biases_traj[i, e * n_batches + b + 1] = biases[i]
    # choose a different color for each trajectory
    colors = plt.cm.jet(np.linspace(0, 1, n_inits))
    # plot all trajectories
    for i in range(n_inits):
        lines_1 = ax1.plot(
            weights_traj[i, :, 0, 0], 
            weights_traj[i, :, 0, 1], 
            biases_traj[i, :, 0], 
            '-', c=colors[i], lw=2)
        lines_2 = ax2.plot(
            np.arange(n_batches * n_epochs),
            errors_traj[i],
            c=colors[i]
        )
    ax1.view_init(elev, azim)
    plt.show()

w = interact(
    visualise_training,
    elev=(-90, 90, 2),
    azim=(-180, 180, 2), 
    n_epochs=(1, 5), 
    batch_size=(100, 1000, 100),
    log_lr=(-3., 1.),
    w_scale=(0., 2.),
    b_scale=(0., 2.),
    n_inits=(1, 10)
)

for child in w.widget.children:
    child.layout.width = '100%'

interactive(children=(IntSlider(value=1, description='n_epochs', max=5, min=1), IntSlider(value=200, descriptiâ€¦