In [None]:
import time

import numpy as np
from sklearn.datasets import make_blobs

from ipywidgets import *
from bqplot import *
import bqplot.pyplot as plt

In [None]:
def padded_val(x, eps=1e-3):
    return np.ceil(x + eps) if x > 0 else np.floor(x - eps)

In [None]:
# data setup
# create two classes separable by a line
N = 100
X, Y = make_blobs(n_samples=N, centers=2, random_state=0, cluster_std=.6)
Y[Y == 0] = -1
Y = -Y

x0, y0 = np.min(X, axis=0)
x1, y1 = np.max(X, axis=0)
    
xmin, xmax, ymin, ymax = [padded_val(x) for x in (x0, x1, y0, y1)]

beta = np.random.rand(3)
xmin, xmax = np.min(X[:, 0]), np.max(X[:, 0])
x = np.linspace(xmin, xmax, 50)
y = [0] * len(x)

fig = plt.figure(title='Perceptron Animation',
                 preserve_aspect=True,
                 animation_duration=1500,
                 layout=Layout(height='800px', width='1000px'))
plt.scales(scales={'color': OrdinalColorScale(colors=['#f3172d', 'limegreen'], 
                                              domain=[-1, 1]),
                   'x': LinearScale(min=xmin, max=xmax),
                   'y': LinearScale(min=ymin, max=ymax)
                  })
axes_options  = {'x': {'label': 'X', 'tick_format': '.1f'},
                 'y': {'label': 'Y', 'tick_format': '.1f'},
                 'color': {'visible': False}}

scatt = plt.scatter(X[:, 0], X[:, 1], color=Y,
                    stroke='black', default_size=100, 
                    unselected_style={'opacity': 0.2},
                    axes_options=axes_options,
                    selected=None)

sep_line = plt.plot(x, y, colors=['dodgerblue'], stroke_width=4)

btn_layout = Layout(width='75px')
go_btn = Button(description='GO', button_style='success', layout=btn_layout)
reset_btn = Button(description='Reset', button_style='warning', layout=btn_layout)
btns = VBox([go_btn, reset_btn])
btns.layout.margin = '80px 0px 0px 0px'

def start_animation():
    global beta, Y
    # stochastic gradient descent
    learning_rate = 1
    misclassified_pts = np.where(Y * (beta[0] * X[:, 0] + beta[1] * X[:, 1] + beta[2]) < 0)[0]

    while len(misclassified_pts) > 0:
        # pick a random misclassified pt
        i = np.random.choice(misclassified_pts, 1)[0]
        time.sleep(1.5)
        scatt.selected = [i]

        # update beta using gradient descent
        beta += learning_rate * np.append(Y[i] * X[i], Y[i])
        sep_line.y = (-beta[2] - beta[0] * sep_line.x) / beta[1]

        # update misclassified pts since the line has changed
        misclassified_pts = np.where(Y * (beta[0] * X[:, 0] + beta[1] * X[:, 1] + beta[2]) < 0)[0]

    scatt.selected = None

def reset():
    sep_line.y = [0] * len(sep_line.x)
    
go_btn.on_click(lambda btn: start_animation())
reset_btn.on_click(lambda btn: reset())
    
HBox([fig, btns])