In this tutorial notebook we'll try to understand local regression using interactive visualizations. Users can interact with the plot as follows:

* Update the kernel bandwith and polynomial order using sliders
* The gaussian kernel plot updates in response to changes in the `bandwidth` parameter
* The local regression plot updates in response to changes in the `bandwidth` and `order` parameters
* New points can be added by clicking on the local regression plot (thereby updating the regression fit)
* Existing points can be updated my moving them with the mouse on the local regression plot (thereby updating the regression fit)
* By checking the 'Display std bands?' checkbox, user can see the +1/-1 local standard deviation bands

Fun things to try:
* Change the  `bandwidth` parameter to understand its impact on the regression fit (low values tend to overfit whereas high values increase the bias)
* Change the  `order` parameter to understand its impact on the regression fit (order 1 fixes linear bias in the dataset, order 2 fixes the convexity bias etc.)
* Move the points in the scatter and see the impact of outliers and overfitting (small changes in points lead to complete flipping of the regresson curve)

In [None]:
import numpy as np

from ipywidgets import *
from bqplot import LinearScale
import bqplot.pyplot as plt

import warnings
warnings.simplefilter('ignore')

In [None]:
def gaussian_kernel(x_train, x_test, bw=1.):
    z = (x_train - x_test[:, np.newaxis]) / bw
    return np.exp(-.5 * z ** 2)

In [None]:
def local_regression(x_train, y_train, x_test,
                     kernel=gaussian_kernel, bw=1., order=0):
    '''
    computes local regression with weights coming from the kernel function
    '''
    # compute the weights using the kernel function
    w = kernel(x_train, x_test, bw=bw)
    
    if order == 0: # weighted average
        return np.dot(w, y_train) / np.sum(w, axis=1)
    else: # weighted polyfit
        y_test = np.empty_like(x_test)
        for i, x0 in enumerate(x_test):
            y_test[i] = np.polyval(np.polyfit(x_train, y_train, w=w[i], deg=order), x0)
        return y_test

def local_std(x_train, y_train, x_test, y_bar=None,
              kernel=gaussian_kernel, bw=1., order=0):
    '''
    computes local std with weights coming from the kernel function
    '''
    # compute the weights using the kernel function
    w = kernel(x_train, x_test, bw=bw)
    if y_bar is None:
        y_bar = local_regression(x_train, y_train, x_test, 
                                 kernel=kernel, bw=bw, order=order)
    
    return np.sqrt((w * (y_train - y_bar[:, np.newaxis]) ** 2)\
                       .sum(axis=1) / np.sum(w, axis=1))

In [None]:
def padded_val(x, eps=1e-3):
    return np.ceil(x + eps) if x > 0 else np.floor(x - eps)

In [None]:
# generate some train/test data
x_train = np.linspace(-5, 5, 200)
y_train = x_train ** 2 + np.random.randn(200) * 5
x_test = np.linspace(-8, 8, 400)

ymin, ymax = padded_val(np.min(y_train)), padded_val(np.max(y_train))

axes_options = {'x': {'label': 'X'},
                'y': {'tick_format': '0.0f', 'label': 'Y'}}

reg_fig = plt.figure(animation_duration=1000)
reg_fig.layout.width = '900px'
reg_fig.layout.height = '600px'

plt.scales(scales={'x': LinearScale(min=-8, max=8),
                   'y': LinearScale(min=ymin - 5, max=ymax + 5)})
scatter = plt.scatter(x_train, y_train, axes_options=axes_options,
                      enable_move=True, stroke='black',
                      default_size=40,
                      interactions={'click': 'add'})

reg_line = plt.plot(x_test, [], 'y', stroke_width=3, 
                    opacities=[.6], interpolation='basis')
std_bands = plt.plot(np.hstack([x_test, x_test[::-1]]), [], 'y', 
                     fill='bottom', fill_opacities=[.2], stroke_width=0)

kernel_fig = plt.figure(animation_duration=1000, title='Gaussian Kernel')
kernel_fig.layout.width = '500px'
kernel_fig.layout.height = '400px'

plt.scales(scales={'y': LinearScale(min=0, max=1)})
axes_options = {'x': {'label': 'X'}, 
                'y': {'tick_format': '.1f'}}
kernel_line = plt.plot(x_train, [], 'm', axes_options=axes_options, 
                       interpolation='basis')

# widgets for hyper params
bw_slider = FloatSlider(description='Kernel Band Width', 
                        min=.1, max=10, step=.1, value=3,
                        continuous_update=False,
                        readout_format='.1f',
                        layout=Layout(width='350px'))

order_slider = IntSlider(description='Polynomial Order',
                         min=0, max=10, step=1, value=0,
                         continuous_update=False,
                         layout=Layout(width='350px'))

reset_button = Button(description='Reset Points', button_style='success')
reset_button.layout.margin = '0px 0px 0px 50px'

band_checkbox = Checkbox(description='Display std bands?')
band_checkbox.layout.margin = '0px 0px 0px 30px'

bw_slider.layout.margin = '60px 0px 0px 40px'

def update_reg_line(change):
    global y_test, std
    bw = bw_slider.value
    order = order_slider.value
    reg_fig.title = 'Local regression(bw={}, polynomial_order={})'.format(bw, order)
    try:
        y_test = local_regression(scatter.x,
                                  scatter.y,
                                  x_test,
                                  bw=bw, 
                                  order=order)
        std = local_std(scatter.x, scatter.y, 
                        x_test, y_bar=y_test, 
                        bw=bw, order=order)
        
        reg_line.y = y_test
        if band_checkbox.value:
            std_bands.y = np.concatenate([y_test - std, (y_test + std)[::-1]])
        else:
            std_bands.y = []
    except Exception as e:
        print(e)

def reset_points(*args):
    with scatter.hold_trait_notifications():
        scatter.x = x_train
        scatter.y = y_train

reset_button.on_click(lambda btn: reset_points())

# event handlers for widget traits
for w in [bw_slider, order_slider]:
    w.observe(update_reg_line, 'value')

scatter.observe(update_reg_line, names=['x', 'y'])

def update_kernel_plot(*args):
    new_bw_value = bw_slider.value
    kernel_line.y = gaussian_kernel(x_train, np.array([0]), bw=bw_slider.value).squeeze()

bw_slider.observe(update_kernel_plot, 'value')

def display_std_bands(*args):
    if band_checkbox.value:
        std_bands.y = np.concatenate([y_test - std, (y_test + std)[::-1]])
    else:
        std_bands.y = []
    
band_checkbox.observe(display_std_bands, 'value')

update_reg_line(None)
update_kernel_plot(None)

HBox([VBox([bw_slider, kernel_fig]),
      VBox([reg_fig, 
            HBox([order_slider, band_checkbox, reset_button], 
                 layout=Layout(margin='0px 0px 0px 50px'))])])