# Physics 494/594
## Gradient Descent

In [None]:
# %load ./include/header.py
import numpy as np
import matplotlib.pyplot as plt
import sys
from tqdm import trange,tqdm
sys.path.append('./include')
import ml4s
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
plt.style.use('./include/notebook.mplstyle')
np.set_printoptions(linewidth=120)
ml4s.set_css_style('./include/bootstrap.css')
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

\begin{equation}
\mathcal{C} = \frac{1}{2N} \sum_{n=1}^N  \left( F^{(n)} - y^{(n)} \right)^2 = \frac{1}{2N} \lvert \lvert \vec{F} - \vec{y}\rvert\rvert^2
\end{equation}

## Last Time

### [Notebook Link: 10_Model_Complexity_Regularization.ipynb](./10_Model_Complexity_Regularization.ipynb)

- Learn how linear regression can learn non-linear functions using feature maps.
- Understanding model complexity and the bias-variance tradeoff
- Introduction to Regularization

## Today

- Derive a general framework for optimizing functions of many parameters

Until now, we have always assumed that we can analytically find the minimum of our cost function: 

\begin{equation}
\boxed{
\mathcal{C} = \frac{1}{2N} \sum_{n=1}^N  \left( F^{(n)} - y^{(n)} \right)^2 = \frac{1}{2N} \lvert \lvert \vec{F} - \vec{y}\rvert\rvert^2
}
\end{equation}

which may not be possible in general.  We would like to devise a general algorithm for finding the minimum of $\mathcal{C}$.

Starting from a random point in *weight* space, $\mathbf{w}_0$ we would like to devise an update rule $\mathbf{w}_0 \to \mathbf{w}_1$ such that the the value of $\mathcal{C}(\mathbf{w}_i)$ always decreases.

<div class="span alert alert-warning">
    <strong>Note:</strong> the subscript $j$ in $\mathbf{w}_j$ labels a given set of weights $\mathbf{w}_j \equiv (w_{0,j},w_{1,j},\dots,w_{M-1,j}) \in \mathbb{R}^M$.
</div>

For now, let us consider a simple function of only two variables: $\mathbf{w}^{\sf T} = (w_0,w_1)$

\begin{equation}
f(\mathbf{w}) = \frac{1}{2} \mathbf{w}^{\top} \mathsf{A}\, \mathbf{w}
\end{equation}

where $\mathsf{A} \in \mathbb{R}^{2 \times 2}$ is a positive semi-definite matrix.  

In [None]:
def f(w,A):
    return (1/2) * w.T @ A @ w 

We can think of this function as a quadratic bowl whose curvature is specified by the value of $A$.
This is evident in the contour plots of $f(\mathbf{w})$ for various $A$.  

It always has a minimum at $f(\mathbf{w}^*)=0$ given by $\mathbf{w}^* = (0, 0)^{\sf T}$. 

In [None]:
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.colors import LogNorm

def plot_function(grid_1d, func, contours=50, log_contours=False, exact=[0,0]):
    '''Make a contour plot over the region described by grid_1d for function func.'''
    
    # make the 2D grid
    X,Y = np.meshgrid(grid_1d, grid_1d, indexing='xy')
    Z = np.zeros_like(X)
    
    # numpy bonus exercise: can you think of a way to vectorize the following for-loop?
    for i in range(len(X)):
        for j in range(len(X.T)):
            Z[i, j] = func(np.array((X[i, j], Y[i, j])))  # compute function values
    
    fig = plt.figure(figsize=plt.figaspect(0.5))
    ax = fig.add_subplot(1, 2, 1)
    
    if not log_contours:
        ax.contour(X, Y, Z, contours, cmap='Spectral_r')
    else:
        ax.contour(X, Y, Z, levels=np.logspace(0, 5, 35), norm=LogNorm(), cmap='Spectral_r')
        
    ax.plot(*exact, '*', color='black')

    ax.set_xlabel(r'$w_0$')
    ax.set_ylabel(r'$w_1$')
    ax.set_aspect('equal')
    
    ax3d = fig.add_subplot(1, 2, 2, projection='3d')
    
    if log_contours:
        Z = np.log(Z)
        label = r'$\ln f(\mathbf{w}$'
    else:
        label = r'$f(\mathbf{w})$'
        
    surf = ax3d.plot_surface(X,Y,Z, rstride=1, cstride=1, cmap='Spectral_r', 
                       linewidth=0, antialiased=True, rasterized=True)
    
    ax3d.plot([exact[0]], [exact[0]], [func(np.array(exact))], marker='*', ms=6, linestyle='-', color='k',lw=1, zorder=100)

         
    ax3d.set_xlabel(r'$w_0$',labelpad=8)
    ax3d.set_ylabel(r'$w_1$',labelpad=8)
    ax3d.set_zlabel(label,labelpad=8);
    
    return fig,ax,ax3d

In [None]:
import time
for A in [np.ones((2, 2)), ml4s.random_psd_matrix([2,2]), ml4s.random_psd_matrix([2,2])]:
    print(f'A={A}')
    fig,ax,ax3d = plot_function(np.linspace(-5,5,100),lambda x: f(x,A))
    plt.show()
    time.sleep(1.0)

## Moving Down Hill

We want to solve the following optimization problem:
\begin{align}
\mathbf{w}^* &=\underset{\mathbf{w}}{\arg \min} \ f(\mathbf{w}) \\
& =  \underset{\mathbf{w}}{\arg \min} \ \frac{1}{2}\mathbf{w}^T A \mathbf{w}
\end{align}

Suppose we begin at a randomly chosen point $\mathbf{w}_0 = (w_{0,0},w_{1,0})^{\top}$ and we propose a small change: $\Delta \mathbf{w} = (\Delta w_0, \Delta w_1)^{\top}$ where $\Delta w_j \ll 1$.  From the first term in the Taylor expansion we know that the change in the function is given by:

\begin{align}
\Delta f &\equiv f(\mathbf{w}+\Delta\mathbf{w}) - f(\mathbf{w}) \\
& = \frac{\partial f}{\partial w_0} \Delta w_0 + \frac{\partial f}{\partial w_1} \Delta w_1 + \mathrm{O}\left(|\Delta \mathbf{w}|^2\right) \\
&= \nabla_w f \cdot \Delta \mathbf{w} + \mathrm{O}\left(|\Delta \mathbf{w}|^2\right)
\end{align}

where the gradient (direction and rate of fastest increase) is defined as:

\begin{equation}
\nabla_w f = \left(\frac{\partial f}{\partial w_0},\frac{\partial f}{\partial w_1} \right).
\end{equation}

We can calculate this exaclty for our simple convex function:

\begin{equation}
\nabla_w f(\mathbf{w}) = \nabla_w \left( \frac{1}{2}\mathbf{w}^{\sf T} \mathsf{A} \mathbf{w} \right) = \mathsf{A} \mathbf{w}
\end{equation}

which we can write a function for.

In [None]:
def df_dw_exact(w,A):
     return A @ w

#### Visualize the direction and size of the gradient

In [None]:
wₒ = np.random.uniform(low=-4,high=4,size=2)
df = df_dw_exact(wₒ,A)

_min = min(-5.0,wₒ[0]+df[0]-0.5,wₒ[1]+df[1]-0.5)
_max = max(5.0,wₒ[0]+df[0]+0.5,wₒ[1]+df[1]+0.5)

# plot the surface and the sampling point
fig,ax,ax3d = plot_function(np.linspace(_min,_max,100),lambda x: f(x,A))
ax.plot(wₒ[0],wₒ[1],'x', ms=3)

# plot the gradients
arrow_prop_dict = dict(mutation_scale=10, arrowstyle='-|>', color='k', fc='w', shrinkA=0, shrinkB=0)
ax.annotate('',xy=wₒ+df,xytext=wₒ,xycoords='data', textcoords='data',arrowprops=dict(arrowstyle="-|>", fc='w',
                             shrinkA=0,shrinkB=0))
a = ml4s.Arrow3D([wₒ[0], wₒ[0]+df[0]], [wₒ[1], wₒ[1]+df[1]], [f(wₒ,A),f(wₒ+df,A)], **arrow_prop_dict)
ax3d.add_artist(a);

ax.text(1,1.1,rf'$w_0 = ({wₒ[0]:.3f},{wₒ[1]:.3f}),\; f(w_0) = {f(wₒ,A):.3f},\;  \nabla f(w_0) = ({df[0]:.3f},{df[1]:.3f})$', 
        fontsize=13, transform=ax.transAxes, ha='center');

Thus we can guarantee that $\Delta f$ will always tend to decrease the function if we choose:

\begin{equation}
\Delta \mathbf{w} = -\eta \nabla_w f
\end{equation}

where $\eta$ is a small positive constant known as the **learning rate** such that

\begin{align}
\Delta f &= \nabla_w f \cdot (-\eta \nabla_w f) \\
&= -\eta ||\nabla_w f||^2 \\
&\le 0.
\end{align}

Thus we can implement an iterative minimization procedure:
\begin{equation}
\mathbf{w}_{i+1} \leftarrow \mathbf{w}_i - \eta \nabla_w f(\mathbf{w}_i).
\end{equation}

The value of $\eta$ controls the size of the step we take downhill. Our result will be dependent on the specific choice of $\eta$: 

* if it is too large our steps will oscillate and we may miss the minimum entirely;
* if it is too small, the minimization procedure may be very slow and never converge. 

Ultimately we will want to use an **adaptive** learning rate.

### Jax and Autodiff

To perform this Gradient Descent procedure, we require the gradient of our function $f$.  In the case considered here we were able to exactly compute the gradient.  However, in many cases (including deep neural networks) such analytical gradients are not possible and automatic differentiation packages are key!  

The most important of these packages is `jax` which is **very cool** and does all kinds of things for us.  Check out the [autodiff cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)


In [None]:
import jax.numpy as jnp # jax has it's own accelerated version of numpy
from jax import grad

In [None]:
df_dw = grad(f,argnums=0)

In [None]:
wₒ = np.array([1.0,1.0],dtype=float)
print(f'JAX:   ∇f = {df_dw(wₒ,A)}')
print(f'Exact: ∇f = {df_dw_exact(wₒ,A)}')

<div class="span alert alert-danger">
<strong>Warning:</strong> <code>jax</code> only works for real valued (not integer) inputs)!
</div>

In [None]:
wₒ = np.array([1,1])
df_dw(wₒ,A)

In [None]:
wₒ = np.array([1,1],dtype=float)
df_dw(wₒ,A)

Let's check for a range of values

In [None]:
for A in [np.ones((2, 2)), ml4s.random_psd_matrix([2,2]), ml4s.random_psd_matrix([2,2])]:
    w = np.random.randn(2) 
    print(f'w = {w}\nA = {A}\nf(w,A) = {f(w,A):.3f}\ndf/dx = {df_dw(w,A)}\n')
    
    # we perform a unit test to check accuracy
    assert np.isclose(np.sum((df_dw(w, A) - df_dw_exact(w, A)))**2, 0.0), 'Problem with auto differentiation!' 

### Other features of `jax`
`grad` defaults to taking the gradient with respect to the first paramter but we can specify others:

In [None]:
df_dA = grad(f,argnums=1)
print(df_dA(wₒ,A))

## Performing Gradient Descent

Now that we know how to take gradients using `jax` we are ready to code up our algorithm.

\begin{equation}
\mathbf{w}_{i+1} \leftarrow \mathbf{w}_i  - \eta \nabla_w f(\mathbf{w}_i) \ .
\end{equation}

Before writing a full program, let's explore a little bit.

We initialize at a random point in the domain of our function.

In [None]:
w = np.random.uniform(low=-5,high=5,size=2)
η = 0.5
print(f'f(w) = {f(w,A):.3f}')

Perform the update step and check that we are always moving **downhill**.

In [None]:
w -= η*df_dw(w, A)
print(f'f(w) = {f(w,A):.3f}')

In [None]:
from IPython import display
A = ml4s.random_psd_matrix([2,2], seed=0)
fig, ax, ax3d = plot_function(np.linspace(-5, 5, 100), lambda x: f(x, A))

# hyperparameters
η = 0.5
w = np.array([2.5,-4.0])
num_iter = 20

ax.plot(*w, marker='.', color='k', ms=15)  

for i in range(num_iter):

    # we keep a copy of the previous version for plotting
    w_old = np.copy(w)
    
    # perform the GD update
    w += -η*df_dw(w, A)
    
    # plot
    ax.plot([w_old[0], w[0]], [w_old[1], w[1]], marker='.', linestyle='-', color='k',lw=1) 
    ax3d.plot([w_old[0], w[0]], [w_old[1], w[1]], [f(w_old,A),f(w,A)], marker='.', linestyle='-', color='k',lw=1, zorder=100)

    ax.set_title(f'$i={i}, w=[{w[0]:.2f},{w[1]:.2f}]$' + '\n' + f'$f(w) = {f(w,A):.6f}$', fontsize=14);
    display.display(fig)
    display.clear_output(wait=True)