# Optimization using Auto-Differentiation

Purpose: To familiarize yourself with the `ad` code, and use it in a simple gradient-based optimization task.

In [None]:
import numpy as np
from ad import *
import matplotlib.pyplot as plt
%load_ext autoreload
%autoreload 2

## Using `Var` and `Operation` objects

In [None]:
X = Var(2.)
A = Var(3.)
F = X**2 * A
print(F)
print(F())

In [None]:
F.zero_grad()  # reset grads to zero
F.backward()   # propagate derivatives down through graph
print(A.grad)  # dFdA
print(X.grad)  # dFdX

<div class="alert alert-block alert-info">
    You can even implement your own function (and its derivative) using the template below.
</div>

In [None]:
def your_function(a):
    '''
     y = a.your_function()
     a is a Var
     y is a Var such that y.val = your_function(a.val)
    '''
    c = YourFunction([a])()
    return c

class YourFunction(Operation):
    def __init__(self, args):
        self.args = args  # store the list of inputs to the function

    def __call__(self):
        # The arguments to the function were saved in self.args.
        # The line below is where you implement your function using NumPy.
        val = self.args[0].val  # <=================
        y = Var(val)       # Create a corresponding Var object
        y.creator = self   # Hook to the operation that create it
        return y

    def __repr__(self):
        return 'your_function('+self.args[0].__repr__()+')'

    def backward(self, s=1.):
        # Implement the derivative (using NumPy), and call
        # backward on the arguments.
        deriv = np.zeros_like(self.args[0].val) # <=========
        self.args[0].backward(s*deriv)


In [None]:
y = your_function(A)
print(f'{y} =')
print(y())

## Optimize a function

### Build the expression
Let's encode the mathematical expression,
$$
F = 8x^4 + 4x^3 -28 x^2 - 24
$$
and then find its minimum.

In [None]:
# x
X = Var(0.0)

# Coefficients
A = Var(8.0)
B = Var(4.0)
C = Var(-28.0)
D = Var(-24.0)

In [None]:
# Build the formula using the functions from the `ad` module
term1 = A * X**4
term2 = B * X**3
term3 = C * X**2
term4 = D * X
F = term1 + term2 + term3 + term4

In [None]:
# Or you can build it all at once using one of these
#F = A * X**4 + B * X**3 + C * X**2 + D * X
#F = Var(8.)*X**4 + Var(4)*X**3 + Var(-28.)*X**2 + Var(-24)*X

In [None]:
# Display the expression
F

<div class="alert alert-block alert-info">
    Try using the AD code on your own expression.
</div>

### Evaluate the expression

In [None]:
F.val  # for the given X-value as it was set up

In [None]:
# Can also use functional notation
F()

In [None]:
# Let's choose a different x-value
X.set(1.)
F.evaluate()  # Causes recomputation of the whole graph

In [None]:
F

### Gradients

In [None]:
F.zero_grad()   # sets all gradients to zero (duh)
F.backward()   # projects the derivatives down through the graph
X.grad

In [None]:
X.set(-2)
F.evaluate()
F.zero_grad()
F.backward()
print(X.grad)

### Gradient-Descent Optimization

In [None]:
# *** YOU MIGHT HAVE TO ADJUST THESE 3 PARAMETERS ***
x = 0.4     # initial guess at root
kappa = 0.005  # gradient step multiplier
n_iters = 20  # number of gradient-descent steps

# for plotting
xh = []
fh = []

for n in range(n_iters):
    X.set(x)
    F.evaluate()

    # Record values for plotting
    xh.append(X())
    fh.append(F())

    # Compute gradients
    F.zero_grad()
    F.backward()
    
    # Gradient step
    x -= kappa*X.grad
    
    print(X(), F())

In [None]:
xx = np.linspace(-2,2,100)  # Choose a bunch of x values
yy = []
# Evaluate expression for each x value
for x in xx:
    X.set(x)
    F.evaluate()
    yy.append(F())
# Plot the graph, and the optimization iterates
plt.plot(xx,yy);
plt.plot(xh, fh, 'r.');