# The One-Lecture Story of Optimization for Machine Learning

In this lecture, we'll go over a simple example that illustrates some of the important ideas of *optimization for machine learning.* 

*Optimization* is the mathematical study of techniques for finding points at which functions reach their maximum or minimum values. For example, the minimum value of the function $f(x) = x^2$ is the point $x = 0$. Optimization techniques are used to obtain similar insights for more complicated functions. 

The reason that we care about optimization in the context of machine learning is that the "training" step of most machine learning methods is nothing other than optimizing a function. As you may remember, the prototypical (supervised) machine learning task has the following form:

> Find a model $f$ from a collection $\mathcal{M}$ of possible models such that $\mathcal{L}(f, X, y)$ is minimized.  

In this problem statement: 

- The *model* $f$ is some kind of function that takes in *predictor data* $X$ and spits out a prediction $\hat{y}$ of the *target* data $y$. 
- $\mathcal{M}$ is a set of all possible models under consideration. Usually this set is parameterized, so that one chooses a model from $\mathcal{M}$ by choosing the values of one or more parameters. 
- $\mathcal{L}$ is a *loss function* that is adapted 

In [8]:
import numpy as np
import matplotlib.pyplot as plt

from matplotlib.animation import FuncAnimation
from IPython.display import HTML

In [9]:
class UpdateLine:
    
    def __init__(self, ax, b0, b1, b0_, b1_, eps = 0.01, n = 100):
        self.ax0 = ax[0]
        self.ax1 = ax[1]
        
        self.b0 = b0
        self.b1 = b1
        
        self.b0_ = b0_
        self.b1_ = b1_
    
        self.n = n
        
        self.X = np.random.rand(n)
        self.y = self.b0 + self.b1*self.X + 0.1*np.random.randn(n)
        self.X = self.X.reshape(n,1)
        self.ax0.scatter(self.X, self.y, color = "grey", s = 4, zorder = 100)
        
        self.eps = eps
        self.point = self.ax0.scatter([], [], color = "red", zorder = 200)
        self.line, = self.ax0.plot([], [], 'k-')
        self.loss, = self.ax1.plot([], [], 'k-')
        
        self.x_space = np.linspace(0, 1, 10)
        
        self.t = []
        self.L = []
        
    def __call__(self, i):
        
        j = np.random.randint(self.n)
                
        x, y = self.X[j,0], self.y[j]
        
        self.b0_ -= self.eps*2*(self.b1_ * x + self.b0_ - y)
        self.b1_ -= self.eps*2*(self.b1_ * x + self.b0_ - y)*x
        
        self.point.set_offsets([[x, y]])
        
        self.line.set_data(self.x_space, self.b1_*self.x_space +self.b0_)
        
        self.t.append(i)
        
        L = sum((self.b1_*self.X[:,0] + self.b0_ - self.y)**2)
        self.L.append(L)
        
        self.loss.set_data(self.t, self.L)
        
        return [self.point, self.line, self.loss]

In [None]:
fig, ax = plt.subplots(1, 2, figsize = (7, 3))

ax[0].set_xlim(0, 1)
ax[0].set_ylim(-0.5, 0.5)

ax[1].set_xlim(0, 200)
ax[1].set_ylim(0, 5)
ax[1].grid(True)

ax[0].set(title = "Regression Problem", 
          xlabel = r"$x$",
          ylabel = r"$y$")

ax[1].set(title = "Current Loss",
          xlabel = "Iteration",
          ylabel = r"$\mathcal{L}$")

ud = UpdateLine(ax, -0.2, 0.5, 0, 0, eps = 0.1, n = 100)
anim = FuncAnimation(fig, ud, frames=200, interval=100, blit=True)

HTML(anim.to_jshtml())