# Numerical Differentiation: From symbolic differentiation to automatic differentiation

In [None]:
import matplotlib.pyplot as plt

### Symbolic Differentiation

In [None]:
import sympy as sp

# Define symbolic variables
x = sp.Symbol('x')

# Define a function
f = sp.sin(x) + x**2

# Compute derivatives
df_dx = sp.diff(f, x)  # Partial derivative with respect to x

# Print the result
print(df_dx)
# Output: 2*x + cos(x)

In [None]:
df_dx.subs({"x": 1}).evalf()  # Evaluate the derivative at x=1

### Numerical Differentiation

In [None]:
import numpy as np
from scipy import optimize

def f(x):
    return x**2 + np.sin(x)

optimize.approx_fprime([1.0], f, 1e-6)

### Automatic Differentiation

In [None]:
import jax.numpy as jnp
import jax
from jax import grad, jit, vmap

def f(x):
    return jnp.sin(x) + x**2

f_p = grad(f)
print(f_p(1.0))

It supports higher order derivatives:

In [None]:
f_pp = grad(grad(f))
print(f_pp(1.0))

Let's plot the function and its derivatives.

Note: We'll use `vmap` to vectorize the function and its derivatives.

In [None]:
x = jnp.linspace(-2, 2, 100)
y = vmap(f)(x)
y_p = vmap(f_p)(x)
y_pp = vmap(f_pp)(x)

plt.plot(x, y, label='f(x)')
plt.plot(x, y_p, label="f'(x)")
plt.plot(x, y_pp, label="f''(x)")
plt.grid()
plt.legend()

You can also do it with `torch.func`

In [None]:
from torch import func, Tensor
import torch

def f(x: Tensor) -> Tensor:
    return x**2 + x.sin()

x = torch.ones([])
grad_f = func.grad(f)
grad_f(x)

A more ML-oriented example with a multi-layer perceptron using `tanh` activation function.

In [None]:
def predict(params, inputs):
    outputs = inputs
    for W, b in params:
        outputs = jnp.dot(inputs, W) + b
        outputs = jnp.tanh(outputs)
    return outputs

key = jax.random.key(0)
inputs = jax.random.normal(key, shape=(100, 10))
targets = jax.random.normal(key, shape=(100, 1))
n_layers = 3
params = [
    (jax.random.normal(key, shape=(10, 10)), jnp.zeros(10)) for _ in range(n_layers)
]
outputs = predict(params, inputs)
outputs.shape

The loss function is the mean squared error:

In [None]:
def loss_fun(params, inputs, targets):
    preds = predict(params, inputs)
    return jnp.sum((preds - targets)**2)

loss_fun(params, inputs, targets)

Question: Could I compute the gradient of the `predict` function?

Eg using:

```python
grad_fn = grad(predict)
```

In [None]:
# grad(predict)(params, inputs)

In [None]:
grad_fun = jit(grad(loss_fun, argnums=0))

grad_params = grad_fun(params, inputs, targets)
len(grad_params[0]), grad_params[0][0].shape, grad_params[0][1].shape

### Jacobian with forward and reverse mode

In [None]:
from numpy.testing import assert_allclose
from jax import jacfwd, jacrev

W = jax.random.normal(key, (3, 3))

def f(x):
    return jnp.dot(W, x)

x = jax.random.normal(key, (3,))

In [None]:
J_f = jacfwd(f)
assert_allclose(J_f(x), W)
J_f(x) - W

In [None]:
J_f = jacrev(f)
assert_allclose(J_f(x), W)
J_f(x) - W

In [None]:
def quadratic(x):
    return 0.5 * x.T @ W @ x

In [None]:
grad_quadratic = grad(quadratic)
assert_allclose(grad_quadratic(x), W @ x)  # FAILS WHY?

In [None]:
H = jacfwd(jacrev(quadratic))  # Hessian via forward-over-reverse
assert_allclose(H(x), W)

### A full example with Logistic Regression

In [None]:
from sklearn.datasets import load_iris

X, y = load_iris(return_X_y=True)
X = X[:, :2]  # Use only the first feature to be able to visualize the decision boundary
X = X[y != 2]
y = y[y != 2]

X = X.astype(np.float32)
y = y.astype(np.float32)
y[y == 0] = -1  # Convert y to -1 and 1

X = jnp.array(X)
y = jnp.array(y)

plt.scatter(X[:, 0], X[:, 1], c=y)

In [None]:
def loss_fn(params, X, y):
    w, b = params
    logits = X @ w + b
    return jnp.mean(jnp.log(1 + jnp.exp(- y * logits)))

w = jnp.zeros(X.shape[1])
b = jnp.zeros(1)

params = (w, b)

grad_fn = jit(grad(loss_fn))

# Gradient descent
lr = 0.05
for _ in range(10_000):
    grad_params = grad_fn(params, X, y)
    params = [p - lr * g for p, g in zip(params, grad_params)]
    accuracy = jnp.mean(jnp.sign(X @ params[0] + params[1]) == y)

params

In [None]:
plt.scatter(X[:, 0], X[:, 1], c=y)
xx = jnp.linspace(4, 8, 100)
yy = (-params[1] - params[0][0] * xx) / params[0][1]
plt.plot(xx, yy, 'r-')

- Exercise1 : Implement the Newton's method for optimization.
- Exercise2 : Implement the multi-class logistic regression (using softmax) with the 3 classes.