<a href="https://colab.research.google.com/github/GDS-Education-Community-of-Practice/DSECOP/blob/main/Automatic_Differentiation/02_automatic_differentiation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Automatic differentiation for parameter estimation

## Automatic differentiation

In [None]:
import numpy as np
import scipy.integrate as si
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

Previously, we use a simple derivative approximation via “finite
differences” as was shown in **Equation 6**. This is known to be
an approximate method in that it requires choosing a value for $h$.

Many modern machine learning methods instead use alternatives to this
approximation by exactly computing the derivatives of the computer
operations as they are done. This is called “forward mode automatic
differentiation”.

To better understand how this works, consider a “dual number” made up of
two parts, a value $u$ and a derivative $v$. The number can be written
simply as $(u,v)$.

**Examples:**

-   We could write a variable $x$ as $(x,1)$ since the derivative
    $dx/dx = 1$
-   We could write a function $\sin(x)$ as $(\sin(x),\cos(x))$ since the
    derivative $d/dx\sin(x) = \cos(x)$
-   We could write a function $\sin(\cos(x))$ as
    $(\sin(\cos(x)),-\cos(\cos(x))\sin(x))$ since the derivative
    $d/dx\sin(\cos(x)) = -\cos(\cos(x))\sin(x)$ by the chain rule

To avoid big nesting like the last example, we can redefine common
functions to take dual numbers as input and yield dual numbers as
output.

**Examples:**

-   The function $f(x) = x^3$ can be redefined as
    $f(u,v) = (u^3, 3u^2v)$ since the derivative
    $df/dx = 3x^2\frac{dx}{dx}$ by the chain rule
-   The function $f(x) = \sin(x)$ can be redefined as
    $f(u,v) = (\sin(u), \cos(u)v)$ since the derivative
    $df/dx = \cos(x)\frac{dx}{dx}$ by the chain rule

*Notice that derivatives always include the chain rule in case the value
$u$ is actually a function of $x$.*

This is all a little bit funny, really just rewriting common
mathematical notation in an alternative format. However, if we can write
programs with this dual number logic, then we will be both computing
values and their derivatives at the same time! Additionally, the
derivatives computed will not be approximations, but the exact
derivatives of the expression (up to the rounding ability of the
computer).

### Simple examples

In order to write code with this “dual number” format, we can create a
Python “class” for dual numbers. For those not familiar, “classes” in
programming are just a structure that can store data and functions. So
for example, our class will store the data `value` and `derivative` as
the elements of the dual number. It will also store functions for common
operations like adding or subtracting dual numbers. In fact, the only
core operations of dual numbers are:

1.  Addition: $(u,v) + (r,s) = (u+r, v+s)$
2.  Subtraction: $(u,v) - (r,s) = (u-r, v-s)$
3.  Multiplication: $(u,v) \times (r,s) = (ur, us + rv)$ (by the product
    rule of calculus)
4.  Division: $(u,v) / (r,s) = (u / r, (rv - us) / r^2)$ (by the
    quotient rule)

We can repeat these operations for a dual number operated with a regular
number:

1.  Addition: $(u,v) + r = (u+r, v)$
2.  Subtraction: $(u,v) - r = (u-r, v)$
3.  Multiplication: $(u,v) \times r = (ur, v)$
4.  Division: $(u,v) / r = (u / r, v)$
5.  Exponent: $(u,v)^r = (u^r, ru^{(r-1)}v)$

*Note that we didn’t have an exponential of dual numbers with dual
numbers because that is messy.*

We can also add some common mathematical functions:

1.  Exponential: $e^{(u,v)} = (e^u, e^uv)$
2.  Logarithm: $\log(u,v) = (\log(u), v/u)$
3.  Sine: $\sin(u,v) = (\sin(u), \cos(u)v)$
4.  Cosine: $\cos(u,v) = (\cos(u), -\sin(u)v)$

To get comfortable with a Python dual number class, let’s begin by
creating a class that stores a value and a derivative:

In [None]:
class Dual:
    def __init__(self, value, derivative):
        self.value = value
        self.derivative = derivative

Notice that this stores the values in its `self`.

We could make one of these numbers by simply writing:

In [None]:
dual = Dual(1, 0)

We can then add an addition function `__add__` that gets called whenever
we have `Dual + Dual`:

In [None]:
class Dual:
    def __init__(self, value, derivative):
        self.value = value
        self.derivative = derivative

    def __add__(self, other):
        value = self.value + other.value
        derivative = self.derivative + other.derivative
        return Dual(value, derivative)

We can try this out by running:

In [None]:
dual1 = Dual(1, 0)
dual2 = Dual(2, 0)
dual1 + dual2

It looks like we need a way to print our dual numbers. We can add that
in with the `__repr__` method:

In [None]:
class Dual:
    def __init__(self, value, derivative):
        self.value = value
        self.derivative = derivative

    def __add__(self, other):
        value = self.value + other.value
        derivative = self.derivative + other.derivative
        return Dual(value, derivative)

    def __repr__(self):
        return "u = {}, du/dx = {}".format(self.value, self.derivative)

Now,

In [None]:
print(dual1 + dual2)

Unfortunately, our `Dual` class does not yet allow for use to add dual
numbers with regular numbers. To add this functionality to our `__add__`
function, we need to check to see if the inputted number is a dual or
not:

In [None]:
class Dual:
    def __init__(self, value, derivative):
        self.value = value
        self.derivative = derivative

    def __add__(self, other):
        if isinstance(other, Dual):
            value = self.value + other.value
            derivative = self.derivative + other.derivative
            return Dual(value, derivative)
        else:
            return Dual(self.value+other, self.derivative)

    def __repr__(self):
        return "u = {}, du/dx = {}".format(self.value, self.derivative)

We can check it with:

In [None]:
dual1 = Dual(1, 0)
number1 = 2
print(dual1 + number1)

Unfortunately, our `__add__` function doesn’t work for
`number1 + dual1`. To add this functionality, we need to add a
`__radd__` function:

In [None]:
class Dual:
    def __init__(self, value, derivative):
        self.value = value
        self.derivative = derivative

    def __add__(self, other):
        if isinstance(other, Dual):
            value = self.value + other.value
            derivative = self.derivative + other.derivative
            return Dual(value, derivative)
        else:
            return Dual(self.value+other, self.derivative)

    def __radd__(self, other):
        return self + other

    def __repr__(self):
        return "u = {}, du/dx = {}".format(self.value, self.derivative)

Now,

In [None]:
dual1 = Dual(1, 0)
number1 = 2
print(number1 + dual1)

Great! We now have a dual number class that can do addition!

Given that example, we can add the other functions 2 - 12 from earlier
in the section to our class. These can be written as follows:

In [None]:
class Dual:
    # When creating a dual number, require a value and derivative
    def __init__(self,value,derivative):
        self.value = value
        self.derivative = derivative

    # Operations with other dual numbers
    def __add__(self, other):
        if isinstance(other, Dual):
            # 1. dual + dual
            value = self.value + other.value
            derivative = self.derivative + other.derivative
            return Dual(value, derivative)
        else:
            # 5. dual + number
            return Dual(self.value + other, self.derivative)

    def __sub__(self, other):
        if isinstance(other, Dual):
            # 2. dual - dual
            value = self.value - other.value
            derivative = self.derivative - other.derivative
            return Dual(value, derivative)
        else:
            # 6. dual - number
           return Dual(self.value - other, self.derivative)

    def __mul__(self, other):
        if isinstance(other, Dual):
            # 3. dual * dual
            value = self.value * other.value
            derivative = self.value * other.derivative + other.value * self.derivative
            return Dual(value, derivative)
        else:
            # 7. dual * number
            return Dual(self.value * other, self.derivative * other)

    def __truediv__(self, other):
        if isinstance(other, Dual):
            # 4. dual / dual
            value = self.value / other.value
            derivative = (other.value * self.derivative - self.value * other.derivative) / other.value**2
            return Dual(value, derivative)
        else:
            # 8. dual / number
            return Dual(self.value / other, self.derivative / other)

    def __pow__(self, n):
        # 9. dual**number
        value = self.value ** n
        derivative = self.derivative * n * self.value ** (n-1)
        return Dual(value, derivative)

    # In case the operations are called backwards, ie number + dual
    def __radd__(self, other):
        return self + other

    def __rsub__(self, other):
        return -self + other

    def __rmul__(self, other):
        return self * other

    def __rtruediv__(self, other):
        return other / self.value

    # For negating the number
    def __neg__(self):
        return self * -1

    # For printing the number
    def __repr__(self):
        return "y = {}, dy/dx = {}".format(self.value, self.derivative)

----
#### Exercise 1

The `Dual` class currently implements some common operations such as `*,+,-,/` but it lacks special functions such as `exp,log,sin,cos`. Redefine the `Dual` class with including all the previously implemented functions but also include `exp,log,sin,cos` functions.

**Solution:**

----

Now let’s see how this class can be used to compute some simple
derivatives:

Consider the function $f(x) = x^2 + 2x - 3$ with derivative
$f'(x) = 2x + 2$. With this info, $f(1) = 0$ and $f'(1) = 4$:

In [None]:
# f(x) = x^2 + 2x - 3
# df/dx = 2x + 2
# f(1) = 0, df/dx(1) = 4
x = Dual(1,1)
y = x**2 + 2*x - 3
print(y)

Wow! Just by defining the first number as a `Dual`, we could write our
regular Python code and automatically get the derivative!

Let’s try this with a couple other examples:

In [None]:
# f(x) = x^7 - 12x
# df/dx = 7x^6 - 12
# f(1) = -11, df/dx(1) = -5
x = Dual(1,1)
y = x**7 - 12*x
print(y)

In [None]:
# f(x) = sin(x)^2
# df/dx = 2sin(x)*cos(x)
# f(pi/4) = 1/2, df/dx(pi/4) = 1
x = Dual(np.pi/4,1)
y = np.sin(x)**2
print(y)

How about something that is messy enough that we really don’t want to
work it out?

In [None]:
# f(x) = exp(sin(x^2 - (3-cos(x^7 + 10exp(x)))^2))
# df/dx = terribly messy
# f(1) = 0.50898, df/dx(1) = 8.62062
x = Dual(1,1)
y = np.exp(np.sin(x**2 - (3 - np.cos(x**7 + 10*np.exp(x)))**2))
print(y)

Very automatic.

## Parameter estimation with automatic differentiation

Recall the previous notebook that illustrated a simple method for
estimating parameters of the equation for an RLC circuit using finite
differences and gradient descent. The equation is:

<span id="eq-rlc-diffeq">$$ L Q''(t) + RQ'(t) + \frac{Q(t)}{C} = E(t)  \qquad(1)$$</span>

To review, it used the following code to simulate the system:

In [None]:
t0 = 0                           # Starting time
tf = 200                         # Final time
R = 1                            # Parameters
L = 10
C = 2
E = 60
ps = (R,L,C,E)
y0_0 = np.zeros(2)               # Starting conditions, Q(t0) = D(t0) = 0
times = np.linspace(t0, tf, 200) # Times to collect simulation at

# dy/dt = f(t,y,R,L,C,E)
# Assuming y[0] = D(t) and y[1] = Q(t)
def f(t,y,R,L,C,E):
    return np.array([-R*y[0]/L - y[1]/(L*C) + E/L, y[0]])

approx_solution = si.solve_ivp(f, (t0, tf), y0_0, args=ps, t_eval=times)

It then used the following to estimate the parameters:

In [None]:
# First define how we calculate Q over time with simulation
def Q(R,L):
    return si.solve_ivp(f, (t0, tf), y0_0, args=(R,L,C,E), t_eval=times).y[1,:]

def loss(R,L,Qhat):
    Q_RL = Q(R,L)
    result = 0
    N = len(Q_RL)
    for i in range(len(Q_RL)):
        result += (Q_RL[i] - Qhat[i])**2
    return result / N

def dQ(R,L,h):
    Q_RL = Q(R,L)
    dQdR = (Q(R+h,L) - Q_RL) / h
    dQdL = (Q(R,L+h) - Q_RL) / h
    return np.array([dQdR,dQdL])

def dloss(Qhat,R,L,h):
    Q_RL = Q(R,L)
    dQ_RL = dQ(R,L,h)
    N = len(Q_RL)
    return 2*np.sum((Q_RL - Qhat) * dQ_RL, axis=1) / N

def gradient_descent(p0, df, alpha=0.1, max_iter=100):
    pstar = p0
    all_pstars = []
    for n in range(max_iter):
        pstar = pstar - alpha*df(pstar)
        all_pstars.append(pstar)
    return pstar, np.array(all_pstars)

def plot_RL_descent(descent_RLs):
    # Map out the landscape
    max_R = max(abs(descent_RLs[:,0].max()+.5),5); min_R = min(abs(descent_RLs[:,0].min()-.5), .5)
    max_L = max(abs(descent_RLs[:,1].max()+.5),10.5); min_L = min(abs(descent_RLs[:,1].min()-.5), .5)
    Rs = np.linspace(min_R, max_R, 30)
    Ls = np.linspace(min_L, max_L, 30)
    Qs = np.zeros((30,30))
    for i in range(30):
        for j in range(30):
            Qs[j,i] = loss(Rs[i],Ls[j],data)

    fig = plt.figure()
    plt.contourf(Rs,Ls,Qs,levels=30),
    plots = [
        plt.scatter(1,10,c='orange',s=100,zorder=2,label="Correct"),
        plt.scatter(descent_RLs[0][0],descent_RLs[0][1],c='r',s=100,zorder=3,label="Gradient descent"),
    ]
    plt.legend()
    plt.colorbar(plots[0], label="Loss")
    plt.xlabel("$R$"); plt.ylabel("$L$")
    def anim_func(i):
        plots[1].set_offsets([descent_RLs[i][0],descent_RLs[i][1]])
        return plots

    anim = FuncAnimation(fig, anim_func, frames=range(len(descent_RLs)), interval=100, blit=True)
    plt.close()
    return HTML(anim.to_jshtml())

# Initial guess
RL_0 = np.array([8, 5])
data = approx_solution.y[1,:]
wrapped_dloss = lambda RL: dloss(data,RL[0],RL[1],1e-4)

# Iteratively find the best parameters to match `data`
best_RL, descent_RLs = gradient_descent(RL_0, wrapped_dloss, 3e-3, 100)
plot_RL_descent(descent_RLs)

This example uses finite differences to find the gradient with respect
to the parameters $R$ and $L$ and the `scipy.integrate.solve_ivp`
function to simulate [Equation 1](#eq-rlc-diffeq) for the different
values of $R$ and $L$. Given that we would like to now use our new
automatic differentiation with the `Dual` number class for $R$ and $L$,
we unfortunately cannot use `solve_ivp` as it sends most of the work to
`C` code where Python classes are all but forgotten. Thus, we will need
to implement our own simple numerical solver for differential equations.

Comically, the simplest simulator for differential equations is the
Euler method, which uses finite differences. This method follows this
procedure:

1.  Start with an initial condition $Q(t_0)$
2.  Approximate the derivative in time with a finite difference: $$
    \begin{align*}
     \frac{dQ}{dt}(t) \approx \frac{Q(t+h) - Q(t)}{h} &= f(t,R,L,C,E) \\
     Q(t+h) &= Q(t) + hf(t,R,L,C,E) \\
    \end{align*}
    $$
3.  Take steps of size $h$ from starting time $t_0$ to end time $t_N$
    collecting the solution $Q(t_i)$ at all points along the way

We can write a simple Python function to do this iteration (with dual
numbers) as follows:

In [None]:
def euler(f, y0, times, R, L, C, E):
    h = times[1]-times[0]
    ys = []
    ys.append(y0)
    for i in range(len(times)-1):
        ys.append(ys[i] + h*f(times[i], ys[i], R, L, C, E))
    return np.array(ys).T

We can now make one of our parameters $R$ into a dual number and the
result of our Euler method simulation will include both the values of
$Q$ at all our time points and the derivative $\frac{dQ}{dR}$ at all
those points:

In [None]:
R = Dual(1,1)
L = 10
y0 = np.array([Dual(0,1), Dual(0,1)])

dual_solution = euler(f, y0, times, R, L, C, E)

print(dual_solution[:,:10])

This now allows us to rewrite our loss gradient function without needing
to do finite differences:

----
#### Exercise 2

Rewrite the `dloss` function to use the dual solutions instead of finite differences.

Note that it might be easier to write a `dualQ` function (to replace the `Q` function we have been using) that returns both the solution `Q` and the derivatives `dQdR` and `dQdL`. Also note that in order to get the derivative with respect to `L`, you will have to use `euler` with only `L` as a `Dual` value. To get the derivative with respect to `R`, you will have to use only `R` as a `Dual` number.

**Solution:**

*Note that unfortunately we need to run a simulation for each derivative
$\frac{dQ}{dR}$ and $\frac{dQ}{dL}$. This is computationally more
expensive that it needs to be and we could avoid it by instead having an
array for our dual values and derivatives, but that would require a bit
of rewriting for our `Dual` class that is beyond the scope of what we
are doing. However, if we were to make that improvement, we wouldn't even
need a `dloss` function. Instead we could use `loss` and directly pull out
the values $\frac{d\mathcal{L}}{dR}$ and $\frac{d\mathcal{L}}{dL}$ from
that.*

----

We can now use this new automatic differentiation loss to do gradient
descent:

In [None]:
RL_0 = np.array([Dual(8,1), Dual(5,1)])
data = approx_solution.y[1,:]
wrapped_dloss = lambda RL: dloss(data,RL[0],RL[1])
_, descent_RLs_dual = gradient_descent(RL_0, wrapped_dloss, 3e-3, 100)

descent_RLs_values = np.array([np.array([RL[0].value,RL[1].value]) for RL in descent_RLs_dual])

plot_RL_descent(descent_RLs_values)

### Improving the estimation

Note that the resulting parameter estimations do not arrive at the correct parameter values, though they are approaching.
How could this be given we now have more accurate derivatives?

To answer this question, we need to remember that automatic differentiation gives exact derivatives *for the computational procedure*.
This means that it gives us exact derivatives of the Euler method.
But we know that the Euler method is a crude approximation with finite differences!
In comparison, the simulation method used by `solve_ivp` is far superior in accuracy and behavior.

Thus, the key to improving our results is to improve the model that we use to match the data.

With this in mind, let's see if we can improve our parameter estimation results by using a better simulation.
A commonly used, but much more accurate simulation method is called "Runge-Kutta 4".
The details of the method are beyond the scope of this module, but it can be written almost exactly like Euler method with just a few more operations for each time step as shown below:

In [None]:
def rk4(f, y0, times, R, L, C, E):
    h = times[1] - times[0]
    ys = []
    ys.append(y0)
    for i in range(len(times)-1):
        k1 = f(times[i], ys[i], R, L, C, E)
        k2 = f(times[i] + h/2, ys[i] + h*k1/2, R, L, C, E)
        k3 = f(times[i] + h/2, ys[i] + h*k2/2, R, L, C, E)
        k4 = f(times[i] + h, ys[i] + h*k3, R, L, C, E)
        ys.append(ys[i] + (h/6)*(k1 + 2*k2 + 2*k3 + k4))
    return np.array(ys).T

Using this method, we can now redefine our `dualQ` computation and see much improved results:

In [None]:
def dualQ(R,L):
    dual_Rsolution = rk4(f, y0, times, R, L.value, C, E)
    dual_Lsolution = rk4(f, y0, times, R.value, L, C, E)

    Q = np.array([Q.value for Q in dual_Rsolution[1]])
    dQdR = np.array([Q.derivative for Q in dual_Rsolution[1]])
    dQdL = np.array([Q.derivative for Q in dual_Lsolution[1]])
    return Q, np.array([dQdR, dQdL])

_, descent_RLs_dual = gradient_descent(RL_0, wrapped_dloss, 30*1e-4, 100)
descent_RLs_values = np.array([np.array([RL[0].value,RL[1].value]) for RL in descent_RLs_dual])
plot_RL_descent(descent_RLs_values)

We can now see that our parameter estimation lands right on top of the correct values. An excellent improvement!

## Using Python libraries

The demonstration we have gone through above has shown how easy it is to get started with forward mode automatic differentiation in Python. There really isn't any magic to it, it is just carefully defining functions to include the chain rule.

However, there are other forms of automatic differentiation such as:

- Reverse mode: instead of actually computing the derivatives at each step, just keep track of what operations should have been done (in a list or "on a tape"), then when you reach the end of the computation, go backwards calculating the derivatives until you reach the parameters of interest. This allows for some optimization of the operations because we have them all written out (so we may be able to do several at once). It's generally cheaper in computation time but more expensive in memory.
- Intermediate representation (IR): When a language like Python is run on a computer, it starts as the form you write, but is ultimately boiled down to 1s and 0s to send to your hardware. In between those two stages are usually a couple of "intermediate representations" of the code. The most cutting edge automatic differentiation looks at this IR stage (which has already been optimized) to determine the operations and therefore the gradient. It is more similar to reverse mode than forward mode.

Below, we will explore some very common Python packages that implement automatic differentiation. These all generally implement reverse mode (which is cheaper for neural networks, their main focus). They are capable and well tested and can be used across many different applications.

### PyTorch
In PyTorch, the key idea is to define all your objects as `torch` arrays instead of `Dual` numbers. To repeat an example from above:

In [None]:
import torch

# f(x) = exp(sin(x^2 - (3-cos(x^7 + 10exp(x)))^2))
# df/dx = terribly messy
# f(1) = 0.50898, df/dx(1) = 8.62062

# Use all pytorch objects and operations
x = torch.ones(1, requires_grad=True)
y = torch.exp(torch.sin(x**2 - (3 - torch.cos(x**7 + 10*torch.exp(x)))**2))

# Print the actual computed value
print(y)

# Do a "backward pass" of reverse mode autodiff and print the gradient
y.backward()
print(x.grad)

Really not that far off from what we did before! Let's see if we can adjust our parameter estimation codes to make use of this. The basic idea here is to replace any `numpy` calls with `torch` calls.

In [None]:
######## Just replacing `np.array` with `torch.stack` in most places ############
def torchf(t,y,R,L,C,E):
    return torch.stack([-R*y[0]/L - y[1]/(L*C) + E/L, y[0]])

def torchrk4(f, y0, times, R, L, C, E):
    h = times[1] - times[0]
    ys = []
    ys.append(y0)
    for i in range(len(times)-1):
        k1 = f(times[i], ys[i], R, L, C, E)
        k2 = f(times[i] + h/2, ys[i] + h*k1/2, R, L, C, E)
        k3 = f(times[i] + h/2, ys[i] + h*k2/2, R, L, C, E)
        k4 = f(times[i] + h, ys[i] + h*k3, R, L, C, E)
        step = ys[i] + (h/6)*(k1 + 2*k2 + 2*k3 + k4)
        ys.append(step)
    return torch.stack(ys)

def torch_descent(p0, df, alpha=0.1, max_iter=100):
    pstar = p0
    all_pstars = []
    for n in range(max_iter):
        pstar = pstar - alpha*df(pstar)
        all_pstars.append(pstar)
    return pstar, torch.stack(all_pstars)

def torchloss(R,L,Qhat):
    Q_RL = torchQ(R,L)
    result = 0
    N = len(Q_RL)
    for i in range(len(Q_RL)):
        result += (Q_RL[i] - Qhat[i])**2
    return result / N

############# More changes because of different gradient form  ###########
def torchQ(R,L):
    Q = torchrk4(torchf, y0, times, R, L, C, E)
    return Q[:,1]

def torchdloss(Qhat,R,L):
    Q_loss = torchloss(R, L, Qhat)
    Q_loss.backward()
    Q_grad = tRL_0.grad
    # Make sure we get rid of the gradients for future runs!
    tRL_0.grad = torch.zeros(2)
    return Q_grad

wrapped_torchdloss = lambda RL: torchdloss(data,RL[0],RL[1])

We also need to redefine some of our initial conditions

In [None]:
t0 = 0
tf = 200
C = 2
E = 60
ps = (R,L,C,E)
y0 = torch.zeros(2)
times = torch.linspace(t0, tf, 200) # Times to collect simulation at

data = torch.tensor(approx_solution.y[1,:])

In [None]:
tRL_0 = torch.tensor([8.0, 5.0], requires_grad=True)
_, descent_RL_values = torch_descent(tRL_0, wrapped_torchdloss, 30*1e-4, 100)
descent_RL_values = descent_RL_values.detach().numpy()
plot_RL_descent(descent_RL_values)

### JAX

JAX is a modern IR automatic differentiation library. This makes it a little different to work with than working with dual numbers or with the backward pass of PyTorch. Because it is analyzing the code itself, it actually makes it very easy to define exactly the derivative you are looking for. Instead of working with the variables themselves, you instead use functions and find their derivatives, then plug the values in.

To really illustrate this, let's visit our previous example:

In [None]:
from jax import numpy as jnp
from jax import grad, jit

# f(x) = exp(sin(x^2 - (3-cos(x^7 + 10exp(x)))^2))
# df/dx = terribly messy
# f(1) = 0.50898, df/dx(1) = 8.62062

# Define a function and find it's gradient
x = 1.0
y = lambda x: jnp.exp(jnp.sin(x**2 - (3 - jnp.cos(x**7 + 10*jnp.exp(x)))**2))
dydx = grad(y)

# Print the actual computed value
print(y(x))

# Do a gradient value at dy/dx(1)
print(dydx(x))

See how intuitive that was! Just like when we write out the math, we can write out the functions and find their derivatives without needing to put any values in.

*Note that we replaced our `numpy` `np` variable with `jax.numpy`. This means that all the code we wrote before using numpy should just work with JAX!*

Let's do it for our parameter estimation:

In [None]:
data = jnp.array(approx_solution.y[1,:])
y0 = jnp.zeros(2)
times = jnp.linspace(t0, tf, 200)

RL_0 = jnp.array([8.0, 5.0])

In [None]:
######## Just replacing `np.array` with `jnp.array` in most places ############
def jaxf(t,y,R,L,C,E):
    return jnp.array([-R*y[0]/L - y[1]/(L*C) + E/L, y[0]])

def jaxrk4(f, y0, times, R, L, C, E):
    h = times[1] - times[0]
    ys = []
    ys.append(y0)
    for i in range(len(times)-1):
        k1 = f(times[i], ys[i], R, L, C, E)
        k2 = f(times[i] + h/2, ys[i] + h*k1/2, R, L, C, E)
        k3 = f(times[i] + h/2, ys[i] + h*k2/2, R, L, C, E)
        k4 = f(times[i] + h, ys[i] + h*k3, R, L, C, E)
        step = ys[i] + (h/6)*(k1 + 2*k2 + 2*k3 + k4)
        ys.append(step)
    return jnp.array(ys)

def jax_descent(p0, df, alpha=0.1, max_iter=100):
    pstar = p0
    all_pstars = []
    for n in range(max_iter):
        pstar = pstar - alpha*df(pstar)
        all_pstars.append(pstar)
    return pstar, jnp.array(all_pstars)

@jit
def jaxloss(R,L,Qhat):
    Q_RL = jaxQ(R,L)
    result = 0
    N = len(Q_RL)
    for i in range(len(Q_RL)):
        result += (Q_RL[i] - Qhat[i])**2
    return result / N

############# More changes because of different gradient form  ###########
def jaxQ(R,L):
    Q = jaxrk4(jaxf, y0, times, R, L, C, E)
    return Q[:,1]

# Define our gradient function
wrapped_loss = lambda RL: jaxloss(RL[0],RL[1],data)
jaxdloss = grad(wrapped_loss)

Now, the under the hood language analysis takes a significant amount of time up front. But we add the `jit` function call to just-in-time compile (do the analysis once for the code then just use it after). We should still run it just once to get things compiled:

In [None]:
jaxdloss(RL_0)

In [None]:
_, descent_RLs_jax = jax_descent(RL_0, jaxdloss, 30*1e-4, 100)
plot_RL_descent(descent_RLs_jax)