In [None]:
# === Environment Setup ===
import os, sys, math, time, random, json, textwrap, warnings
import numpy as np, pandas as pd, matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from IPython.display import display, Image, Markdown

# JAX is the core library for this chapter.
try:
    import jax
    import jax.numpy as jnp
    from jax import grad, jacfwd, jacrev, hessian, jit, vmap
    from jax.scipy.stats import norm
    JAX_AVAILABLE = True
except ImportError:
    JAX_AVAILABLE = False
    # Define dummy functions if JAX is not available
    def jit(f): return f
    def grad(f): return f
    def hessian(f): return f
    def vmap(f): return f

# SymPy for symbolic differentiation
import sympy as sp

# Graphviz for visualizing computational graphs
try:
    from graphviz import Digraph
    GRAPHVIZ_AVAILABLE = True
except ImportError:
    GRAPHVIZ_AVAILABLE = False

# --- Configuration ---
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams.update({'figure.dpi': 130, 'font.size': 12, 'axes.titlesize': 'x-large',
    'axes.labelsize': 'large', 'xtick.labelsize': 'medium', 'ytick.labelsize': 'medium'})
np.set_printoptions(suppress=True, linewidth=120, precision=8)
sp.init_printing(use_unicode=True)

# --- Utility Functions ---
def note(msg, **kwargs):
    display(Markdown(f"<div class='alert alert-info'>📝 {textwrap.fill(msg, width=100)}</div>"))
def sec(title):
    print(f"\n{100*'='}\n| {title.upper()} |\n{100*'='}")

note(f"Environment initialized. JAX available: {JAX_AVAILABLE}, Graphviz available: {GRAPHVIZ_AVAILABLE}")

# Part 2: Core Numerical Methods
## Chapter 2.3: Numerical and Automatic Differentiation

### Table of Contents
1.  [Numerical Differentiation via Finite Differences](#1.-Numerical-Differentiation-via-Finite-Differences)
    *   [1.1 Taylor Series Derivations](#1.1-Taylor-Series-Derivations)
    *   [1.2 The Error Trade-off](#1.2-The-Error-Trade-off)
2.  [Automatic Differentiation (AD)](#2.-Automatic-Differentiation-(AD))
    *   [2.1 The Computational Graph](#2.1-The-Computational-Graph)
    *   [2.2 Forward-Mode vs. Reverse-Mode AD](#2.2-Forward-Mode-vs.-Reverse-Mode-AD)
3.  [The JAX Framework](#3.-The-JAX-Framework)
    *   [3.1 Core Transformations: `grad`, `jit`, `vmap`](#3.1-Core-Transformations:-grad,-jit,-vmap)
    *   [3.2 The Constraint of Pure Functions](#3.2-The-Constraint-of-Pure-Functions)
4.  [Application 1: Utility Maximization](#4.-Application-1:-Utility-Maximization)
5.  [Application 2: Maximum Likelihood Estimation](#5.-Application-2:-Maximum-Likelihood-Estimation)
6.  [Chapter Summary](#6.-Chapter-Summary)
7.  [Exercises](#7.-Exercises)

### Introduction: The Three Ways to Differentiate Code
Differentiation is a cornerstone of economic analysis. When we have a function implemented in code, we face a choice between three fundamental approaches to compute its derivative:

1.  **Symbolic Differentiation**: The "pen-and-paper" method automated by computer algebra systems (like SymPy). It is exact but can be intractable for complex functions due to "expression swell."

2.  **Numerical Differentiation**: Approximating the derivative using function evaluations at discrete points (e.g., finite differences). It is easy to implement but is an approximation and suffers from a difficult trade-off between truncation and rounding error.

3.  **Automatic Differentiation (AD)**: A powerful technique that computes exact derivatives (to machine precision) by breaking down a function into elementary operations and applying the chain rule. It is the engine of modern machine learning and computational economics.

This notebook provides a deep dive into these methods, with a particular focus on demystifying and applying Automatic Differentiation, which has become the preferred method for most modern applications due to its accuracy and efficiency.

### 1. Numerical Differentiation via Finite Differences

Numerical differentiation approximates derivatives by evaluating a function at nearby points. The basis for these methods is the Taylor series expansion of a function $f(x)$ around a point $x_0$:
$$ f(x_0+h) = f(x_0) + hf'(x_0) + \frac{h^2}{2!}f''(x_0) + \frac{h^3}{3!}f'''(x_0) + \dots $$

#### 1.1 Taylor Series Derivations

By rearranging the Taylor series and truncating higher-order terms, we can derive the common finite difference formulas.

- **Forward Difference:** Solve the expansion for $f'(x_0)$:
  $$ f'(x_0) = \frac{f(x_0+h) - f(x_0)}{h} - \frac{h}{2}f''(x_0) - \dots = \frac{f(x_0+h) - f(x_0)}{h} + O(h) $$
  This formula is simple but has a truncation error of order $O(h)$, making it relatively inaccurate.

- **Central Difference:** Subtract the Taylor expansion for $f(x_0-h)$ from $f(x_0+h)$:
  $$ f(x_0+h) - f(x_0-h) = 2hf'(x_0) + \frac{2h^3}{6}f'''(x_0) + \dots $$
  $$ f'(x_0) = \frac{f(x_0+h) - f(x_0-h)}{2h} - \frac{h^2}{6}f'''(x_0) - \dots = \frac{f(x_0+h) - f(x_0-h)}{2h} + O(h^2) $$
  By canceling the even-powered terms, the central difference formula achieves a much smaller truncation error of order $O(h^2)$, making it significantly more accurate.

#### 1.2 The Error Trade-off

The central challenge is choosing the step size `h`. A smaller `h` reduces the mathematical truncation error (the $O(h^p)$ term), but a too-small `h` means we are subtracting two very close numbers in the numerator, leading to **catastrophic cancellation** (a form of rounding error). This creates a characteristic U-shaped curve for the total error, where the optimal `h` balances these two opposing forces.\nThis trade-off is fundamental. The truncation error is a property of the mathematical formula (the approximation), while the rounding error is a property of the computer hardware (the finite precision). The optimal `h` is typically around the cube root of machine epsilon for forward-difference and the fourth root for central-difference formulas.

![Error vs. Step Size for Finite Differences](../images/02-Numerical-Methods/finite_difference_error.png)

### 2. Automatic Differentiation (AD)
\nAD was first developed in the 1950s and 1960s, but it was the rise of deep learning and the development of frameworks like Theano, TensorFlow, and PyTorch that brought it to the forefront of scientific computing. The backpropagation algorithm, which is a specific instance of reverse-mode AD, is the engine that drives modern neural networks.\n**Automatic Differentiation (AD)** is a set of techniques that computes exact derivatives (to machine precision) with a computational cost that is a small, constant multiple of the cost of evaluating the original function. It is not symbolic and it is not numerical. AD works by decomposing a function into a **computational graph** of elementary operations (e.g., `+`, `*`, `sin`, `exp`) whose derivatives are known. It then repeatedly applies the chain rule to accumulate the derivative of the final output with respect to the inputs.

#### 2.1 The Computational Graph
Consider the function $y = f(x_1, x_2) = \ln(x_1) + x_1 x_2 - \sin(x_2)$. AD first breaks this down into a sequence of primitive operations, forming a graph.

![Computational Graph for a Simple Function](../images/02-Numerical-Methods/computational_graph.png)

#### 2.2 Forward-Mode vs. Reverse-Mode AD

- **Forward-Mode AD** traverses the graph from inputs to outputs. It computes the derivative of each intermediate node with respect to *one input* at a time. This is done by carrying a "dual number" $(v, \dot{v})$ where $v$ is the value and $\dot{v} = \frac{\partial v}{\partial x_i}$ is the derivative. To get the full gradient, this requires one pass per input variable.
  - **Cost:** For $f: \mathbb{R}^n \to \mathbb{R}^m$, the cost is **O(n) * cost(f)**. Efficient for tall, skinny Jacobians ($n \ll m$).

- **Reverse-Mode AD (Backpropagation)** traverses the graph in reverse, from outputs to inputs. It first performs a forward pass to compute the value of every node. Then, it starts from the final output (with derivative $\frac{\partial y}{\partial y} = 1$) and propagates derivatives backward, computing the derivative of the final output with respect to each intermediate node ($"bar"v = \frac{\partial y}{\partial v}$). This gives the entire gradient in one backward pass.
  - **Cost:** For $f: \mathbb{R}^n \to \mathbb{R}^m$, the cost is **O(m) * cost(f)**. Efficient for short, fat Jacobians ($n \gg m$).

**Rule of Thumb:** For the typical case in economics and machine learning—functions with many inputs (parameters) and a single scalar output (e.g., a loss function, a likelihood, a utility value), where $n \gg m=1$—**reverse-mode AD is vastly more efficient**.

### 3. The JAX Framework

#### 3.1 Core Transformations: `grad`, `jit`, `vmap`
JAX provides a set of function transformations that are the key to its power.

- `grad(fun)`: Transforms a Python function `fun` that returns a scalar into a function that computes its gradient using reverse-mode AD.
- `jit(fun)`: **Just-in-time compilation.** Compiles a Python function to highly optimized XLA machine code, often resulting in dramatic speedups, especially for functions with loops.
- `vmap(fun)`: **Vectorizing map.** Transforms a function to map over batch dimensions automatically, without requiring manual loops. This is extremely powerful for batch processing and data parallelism.

In [None]:
sec("Vectorization with vmap")
if not JAX_AVAILABLE:
    note("JAX is not installed. Skipping this section.")
else:
    # A function that computes a weighted norm for a single vector
    def weighted_norm(weights, vector):
        return jnp.sqrt(jnp.sum(weights * vector**2))

    # Sample data: one set of weights, 5 different vectors
    key = jax.random.PRNGKey(0)
    weights = jnp.array([1., 2., 0.5])
    vectors = jax.random.normal(key, (5, 3))

    # Use vmap to apply the function to each vector in the batch
    # in_axes=(None, 0) means: don't map over 'weights', map over the first axis of 'vectors'
    batched_norm = vmap(weighted_norm, in_axes=(None, 0))
    results = batched_norm(weights, vectors)
    
    note("vmap allows us to process a batch of 5 vectors without writing a loop:")
    print(f"Input vectors shape: {vectors.shape}")
    print(f"Output results shape: {results.shape}")
    print(f"Results: {results}")

#### 3.2 The Constraint of Pure Functions
JAX's transformations (`jit`, `grad`, `vmap`, etc.) require that the functions they operate on be **pure**. A pure function has two properties:
1.  It has no side effects (e.g., it doesn't modify global state, print to the screen, or write to a file).
2.  Its output depends only on its inputs.

This functional programming model allows JAX to safely trace, transform, and compile the code. Trying to `jit` or `grad` an impure function will often lead to errors or unexpected behavior.\nFor example, a function that modifies a global variable or prints to the screen is not pure. JAX traces the function to compile it, and if the function's behavior changes between traces due to side effects, the compiled code may be incorrect or unpredictable. This is why a functional programming style is encouraged when working with JAX.

### 4. Application 1: Utility Maximization
A classic economic problem is for a consumer to maximize their utility subject to a budget constraint. We can solve this by setting up the Lagrangian and using JAX to find the first-order conditions.

**Problem:** Maximize $U(c_1, c_2) = (c_1^\rho + c_2^\rho)^{1/\rho}$ subject to $p_1 c_1 + p_2 c_2 = M$.
The Lagrangian is: $\mathcal{L}(c_1, c_2, \lambda) = (c_1^\rho + c_2^\rho)^{1/\rho} - \lambda(p_1 c_1 + p_2 c_2 - M)$.

In [None]:
sec("Utility Maximization with JAX and Newton's Method")
if not JAX_AVAILABLE:
    note("JAX is not installed. Skipping this section.")
else:
    def lagrangian(x, p1, p2, M, rho):
        c1, c2, lam = x[0], x[1], x[2]
        return (c1**rho + c2**rho)**(1/rho) - lam * (p1*c1 + p2*c2 - M)

    params = {'p1': 2.0, 'p2': 3.0, 'M': 100.0, 'rho': 0.8}
    L = lambda x: lagrangian(x, **params)
    FOCs = jit(grad(L))
    Jacobian_FOCs = jit(hessian(L))

    x_k = jnp.array([10.0, 10.0, 0.1])
    for i in range(20):
        step = jnp.linalg.solve(Jacobian_FOCs(x_k), -FOCs(x_k))
        x_k = x_k + step
        if jnp.linalg.norm(step) < 1e-7:
            print(f"Converged in {i+1} iterations.")
            break
            
    note(f"Optimal bundle: c1 = {x_k[0]:.2f}, c2 = {x_k[1]:.2f}, shadow price λ = {x_k[2]:.3f}")

### 5. Application 2: Maximum Likelihood Estimation
Maximum Likelihood Estimation (MLE) is a cornerstone of econometrics. The goal is to find the parameters $\theta$ that maximize the likelihood of observing a given dataset. This is equivalent to maximizing the log-likelihood function, $LL(\theta) = \sum_i \log f(y_i | \theta)$, where $f$ is the probability density function.

We can use JAX to find the MLE parameters for a sample drawn from a normal distribution. The parameters are $\theta = (\mu, \sigma)$. The gradient of the log-likelihood is the **score**, and the negative inverse of the Hessian is the **variance-covariance matrix** of the estimates, which gives us the standard errors.

In [None]:
sec("Maximum Likelihood Estimation of a Normal Distribution")
if not JAX_AVAILABLE:
    note("JAX is not installed. Skipping this section.")
else:
    # 1. Generate sample data
    key = jax.random.PRNGKey(123)
    true_mu, true_sigma = 5.0, 2.0
    data = jax.random.normal(key, (1000,)) * true_sigma + true_mu

    # 2. Define the log-likelihood function
    def log_likelihood(params, y):
        mu, sigma = params[0], jnp.exp(params[1]) # Use log(sigma) for unconstrained optimization
        return jnp.sum(norm.logpdf(y, loc=mu, scale=sigma))

    # 3. Use JAX to get the gradient and hessian
    # We want to maximize, so we minimize the negative log-likelihood
    neg_LL = jit(lambda p, y: -log_likelihood(p, y))
    score_fn = jit(grad(neg_LL, argnums=0))
    hessian_fn = jit(hessian(neg_LL, argnums=0))

    # 4. Solve using a simple optimizer (e.g., BFGS from scipy)
    from scipy.optimize import minimize
    
    # Scipy's minimize function needs functions that accept numpy arrays
    def objective(params, y): return np.array(neg_LL(params, y))
    def jacobian(params, y): return np.array(score_fn(params, y))
    
    initial_params = np.array([0.0, 0.0]) # Initial guess for (mu, log(sigma))
    result = minimize(objective, initial_params, args=(data,), jac=jacobian, method='BFGS')
    mle_params = result.x
    mle_mu, mle_sigma = mle_params[0], np.exp(mle_params[1])

    # 5. Compute standard errors from the Hessian
    H = hessian_fn(mle_params, data)
    cov_matrix = jnp.linalg.inv(H)
    std_errors = jnp.sqrt(jnp.diag(cov_matrix))

    note(f"True parameters:      μ = {true_mu:.3f}, σ = {true_sigma:.3f}")
    note(f"MLE estimates:        μ̂ = {mle_mu:.3f}, σ̂ = {mle_sigma:.3f}")
    note(f"Standard Errors (μ, log σ): {std_errors[0]:.3f}, {std_errors[1]:.3f}")

### 6. Chapter Summary

- **Three Methods:** Differentiation can be done **symbolically** (exact but complex), **numerically** (approximate but simple), or **automatically** (exact and efficient).
- **Numerical Differentiation:** Relies on Taylor series approximations (finite differences). Its accuracy is limited by a fundamental trade-off between **truncation error** (from the formula) and **rounding error** (from machine precision).
- **Automatic Differentiation (AD):** The modern standard. It achieves machine-precision derivatives by applying the chain rule to a computational graph. **Reverse-mode AD** is exceptionally efficient for the common many-to-one functions found in economics and machine learning.
- **JAX:** A powerful framework that combines a NumPy-like API with AD and JIT compilation. Its core transformations (`grad`, `hessian`, `jit`, `vmap`) enable high-performance computational science, but require writing **pure functions**.

### 7. Exercises

1.  **Richardson Extrapolation for the Second Derivative:** The central difference formula for the *second* derivative is $D_2(h) = \frac{f(x+h) - 2f(x) + f(x-h)}{h^2}$, which has $O(h^2)$ truncation error. Derive and implement the Richardson extrapolation for this formula to create an $O(h^4)$ accurate approximation. Test it on $f(x) = e^x$ at $x=1$.

2.  **Symbolic vs. Automatic:** For the utility maximization problem, use `SymPy` to derive the analytical expressions for the gradient and the Hessian of the Lagrangian. Use `sp.lambdify` to turn these into numerical functions and verify that their output matches the results from JAX.

3.  **Forward vs. Reverse Mode Jacobian:** In JAX, `jacfwd` computes Jacobians using forward-mode AD and `jacrev` uses reverse-mode AD. For a function $f: \mathbb{R}^{1000} \to \mathbb{R}^2$, which do you expect to be faster? For a function $g: \mathbb{R}^2 \to \mathbb{R}^{1000}$? Write two simple functions with these signatures and use `%timeit` to verify your hypothesis.

4.  **Batch MLE with `vmap`:** Extend the MLE example. Suppose you have data from 10 different, independent experiments. Your data array now has shape `(10, 1000)`. Use `jax.vmap` to compute the MLE estimates for all 10 experiments in a single, efficient batch operation. The `in_axes` argument for `log_likelihood` should be `(None, 0)` to map over the data.

5.  **Hessian of Utility:** For a CRRA utility function `u(c) = (c**(1-gamma))/(1-gamma)`, the second derivative `u''(c)` is related to the coefficient of absolute risk aversion. Use JAX's `grad` twice (`grad(grad(u))`) to compute the second derivative. Evaluate it for different values of `gamma` and `c` and interpret the sign.