In [None]:
%matplotlib inline
import math,sys,os,numpy as np
from numpy.random import random
from matplotlib import pyplot as plt, rcParams, animation, rc
from __future__ import print_function, division
from ipywidgets import interact, interactive, fixed
from ipywidgets.widgets import *

In [None]:
rc('animation', html='html5')
rcParams['figure.figsize'] = 3,3
%precision 4
np.set_printoptions(precision=4, linewidth=100)

In [None]:
# the weights to fit
a = 3.0
b = 8.

In [None]:
N = 30
def linear(a,b,x): return a*x + b

In [None]:
# the samples to fit
x = random(N)
y = linear(a, b, x)

In [None]:
plt.scatter(x, y)

In [None]:
# Sum of squared errors
def SSE(y,y_pred): return ((y-y_pred)**2).sum()

In [None]:
# Loss of ap and bp: predicted weights
def LOSS(y,ap,bp,x): return SSE(y, linear(ap, bp, x)) 

In [None]:
# Average loss
def AVG_LOSS(y,ap,bp,x): return np.sqrt(LOSS(y, ap, bp, x)/float(len(x)))

In [None]:
# The starting point
a_guess = -1.
b_guess = 1.
AVG_LOSS(y, a_guess, b_guess, x)

In [None]:
# Learning rate
lr = 0.01

\\(f(m,b) =  \frac{1}{N} \sum_{i=1}^{n} (y_i - (mx_i + b))^2\\)

\\(f'(m,b)
= \begin{bmatrix}\frac{df}{dm} \\ \frac{df}{db}\end{bmatrix}
= \begin{bmatrix}\frac{1}{N} \sum -2x_i(y_i - (mx_i + b)) \\ \frac{1}{N} \sum -2(y_i - (mx_i + b))\end{bmatrix}\\)

In [None]:
# Gradient formulas: dy/da, dy/db
# d[(y-(a*x+b))**2,b] = 2 (b + a x - y)     = 2 (y_pred - y)
# d[(y-(a*x+b))**2,a] = 2 x (b + a x - y)   = x * dy/db

In [None]:
def update():
    global a_guess, b_guess
    y_pred = linear(a_guess, b_guess, x)
    # calculate the gradients
    dydb = 2 * (y_pred - y)
    dyda = x * dydb
    # use gradient-descent to update the weights
    a_guess -= lr * dyda.mean()
    b_guess -= lr * dydb.mean()

In [None]:
fig = plt.figure(dpi=100, figsize=(5,4))
plt.scatter(x,y)
line, = plt.plot(x, linear(a_guess, b_guess, x))
plt.close()

def animate(i):
    line.set_ydata(linear(a_guess, b_guess, x))
    for i in range(10): update()
    return line,

ani = animation.FuncAnimation(fig, animate, np.arange(0, 40), interval=100)
ani