# Regularization Network

Quickly hacked this regularization network example base on https://scikit-learn.org/stable/auto_examples/gaussian_process/plot_gpr_noisy_targets.html by Vincent Dubourg, Jake Vanderplas, Jan Hendrik Metzen.

A Gaussian process for regression and the regularization network give the same solution, therefore I just simplified their Gaussian process example to demonstrate regularization networks. This explains why the network is called `gp`.

In [None]:
import numpy as np
from matplotlib import pyplot as plt

from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF

%matplotlib inline
import ipywidgets as widgets
from ipywidgets import interactive, fixed

In [None]:
# Generate some data, sinc function plus Gaussian noise
np.random.seed(42)
def f(x):
    """The function to predict."""
    return x * np.sin(x)
X = np.atleast_2d([1., 3., 5., 6., 7., 8.]).T
y = f(X).ravel()
y += 2 * np.random.random(y.shape)

In [None]:
def make_plot(X, y, log_scale_length, log_alpha):
    # Mesh the input space for evaluations of the real function, the prediction 
    x = np.atleast_2d(np.linspace(0, 10, 1000)).T

    # Instantiate model
    kernel = RBF(np.exp(log_scale_length))
    gp = GaussianProcessRegressor(kernel=kernel, alpha=np.exp(log_alpha))

    # Fit to data 
    gp.fit(X, y)

    # Make the prediction on the meshed x-axis (ask for MSE as well)
    y_pred, sigma = gp.predict(x, return_std=True)

    # Plot the function and the prediction 
    plt.figure()
    plt.plot(x, f(x), 'r:', label=u'$f(x) = x\,\sin(x)$')
    plt.plot(X, y, 'r.', markersize=10, label=u'Observations')
    plt.plot(x, y_pred, 'b-', label=u'Prediction')
    #plt.fill(np.concatenate([x, x[::-1]]),
    #         np.concatenate([y_pred - 1.9600 * sigma,
    #                        (y_pred + 1.9600 * sigma)[::-1]]),
    #            alpha=.5, fc='b', ec='None', label='95% confidence interval')
    plt.xlabel('$x$')
    plt.ylabel('$f(x)$')
    plt.ylim(-10, 20)
    plt.legend(loc='upper left')
    plt.show()

In [None]:
interactive_plot = interactive(make_plot, 
                               X=fixed(X), y=fixed(y),
                               log_scale_length=widgets.FloatSlider(min=-6,max=2,step=0.1,value=-1.,description="log scale-length", continuous_update=False), 
                               log_alpha=widgets.FloatSlider(min=-3,max=3,step=0.1,value=0,description="log regularization", continuous_update=False))

In [None]:
interactive_plot