In [None]:
import numpy as np

from ipywidgets import *
from bqplot import LinearScale
import bqplot.pyplot as plt

import warnings
warnings.simplefilter('ignore')

In [None]:
def gaussian(x_train, x_test, bw=1.):
    u = (x_train - x_test[:, np.newaxis]) / bw
    return (1./np.sqrt(2 * np.pi)) * np.exp(-.5 * u ** 2)

def triangular(x_train, x_test, bw=1.):
    abs_u = np.abs(x_train - x_test[:, np.newaxis])
    return np.where(abs_u <= 1, 1 - abs_u, 0)

def epanechnikov(x_train, x_test, bw=1.):
    abs_u = np.abs(x_train - x_test[:, np.newaxis])
    return np.where(abs_u <= 1, .75 * (1 - abs_u ** 2), 0)

def cosine(x_train, x_test, bw=1.):
    u = x_train - x_test[:, np.newaxis]
    return np.where(np.abs(u) <= 1, (np.pi / 4) * np.cos((np.pi / 2) * u), 0)

def sigmoid(x_train, x_test, bw=1.):
    u = x_train - x_test[:, np.newaxis]
    return (2 / np.pi) * (1 / (np.exp(u) + np.exp(-u)))

kernels = {'Gaussian': gaussian,
           'Traingular': triangular,
           'Epanechnikov': epanechnikov,
           'Cosine': cosine,
           'Sigmoid': sigmoid}

In [None]:
def padded_val(x, eps=1e-3):
    return np.ceil(x + eps) if x > 0 else np.floor(x - eps)

def local_regression(x_train, y_train, x_test,
                     kernel=gaussian, bw=1., order=0):
    """
    computes local regression with weights coming from the kernel function
    """
    # compute the weights using the kernel function
    w = kernel(x_train, x_test, bw=bw)
    
    if order == 0: # weigted average
        return np.dot(w, y_train) / np.sum(w, axis=1)
    else: # weighted polyfit
        y_test = np.empty_like(x_test)
        for i, x0 in enumerate(x_test):
            y_test[i] = np.polyval(np.polyfit(x_train, y_train, 
                                              w=w[i], deg=order), x0)
        return y_test

In [None]:
# generate some train/test data
x_train = np.linspace(-5, 5, 100)
y_train = x_train ** 2 + np.random.randn(100) * 5
x_test = np.linspace(-10, 10, 200)

x0, x1 = np.min(x_train), np.max(x_train)
y0, y1 = np.min(y_train), np.max(y_train)
    
xmin, xmax, ymin, ymax = [padded_val(x) for x in (x0, x1, y0, y1)]

axes_options = {'x': {'label': 'X'},
                'y': {'tick_format': '0.0f', 'label': 'Y'}}

fig = plt.figure(animation_duration=1000,
                 fig_margin=dict(top=60, left=40,
                                 bottom=30, right=10))
fig.layout.width = '800px'
fig.layout.height = '550px'

plt.scales(scales={'x': LinearScale(min=xmin, max=xmax),
                   'y': LinearScale(min=ymin, max=ymax)})

scatter = plt.scatter(x_train, y_train, axes_options=axes_options,
                      colors=['red'], enable_move = True,
                      interactions = {'click': 'add'})
reg_line = plt.plot(x_test, [], 'g', stroke_width=5, opacities=[.5], interpolation='basis')

# widgets for hyper params
kernel_dropdown = Dropdown(description='Kernel', 
                           options=kernels,
                           layout=Layout(width='200px'))

bw_slider = FloatSlider(description='Band Width', 
                        min=.1, max=10, step=.1, value=3,
                        continuous_update=False,
                        readout_format='.1f',
                        layout=Layout(width='290px'))

order_slider = IntSlider(description='Order',
                         min=0, max=10, step=1, value=0,
                         continuous_update=False,
                         layout=Layout(width='300px'))

reset_button = Button(description='Reset Points', button_style='success')

widgets_layout = VBox([kernel_dropdown, bw_slider, order_slider, reset_button])
widgets_layout.layout.margin = '200px 0px 0px 0px'

def update_reg_line(change):
    kernel = kernel_dropdown.value
    bw = bw_slider.value
    order = order_slider.value
    fig.title = 'Local regression(bw={}, polynomial_order={})'.format(bw, order)
    try:
        reg_line.y = local_regression(scatter.x,
                                      scatter.y,
                                      x_test, 
                                      kernel=kernel,
                                      bw=bw, 
                                      order=order)
    except Exception as e:
        print(e)

def reset_points(*args):
    with scatter.hold_sync():
        # hold_sync will send trait updates 
        # (x and y here) to front end in one trip
        scatter.x = x_train
        scatter.y = y_train

reset_button.on_click(lambda btn: reset_points())

# event handlers for widget traits
for w in [kernel_dropdown, bw_slider, order_slider]:
    w.observe(update_reg_line, 'value')
    
scatter.observe(update_reg_line, 'y')

update_reg_line(None)
HBox([fig, widgets_layout])