# File: notebooks/01_fundamentals/02_autodiff_basics.ipynb

## JAX Fundamentals: Automatic Differentiation Basics

Welcome to the second notebook in the JAX-NSL series! Automatic differentiation (autodiff) is one of JAX's most powerful features, enabling efficient gradient computation for machine learning and scientific computing. In this notebook, we'll explore JAX's autodiff capabilities from basic gradients to more advanced concepts like Jacobians and Hessians.

Automatic differentiation in JAX is implemented through program transformation - JAX can take a Python function and automatically generate its derivative function. This is fundamental for optimization algorithms, neural network training, and scientific computing applications that require gradients.

## Setting Up the Environment

```python
import jax
import jax.numpy as jnp
from jax import grad, jacfwd, jacrev, hessian, vmap
from jax import random
import numpy as np
import matplotlib.pyplot as plt
from typing import Callable, Tuple
import functools

# Enable 64-bit precision for numerical accuracy
jax.config.update("jax_enable_x64", True)

print(f"JAX version: {jax.__version__}")
```

## Basic Gradient Computation

### Single-Variable Functions

```python
# Define simple functions for gradient computation
def f1(x):
    """f(x) = x^2"""
    return x**2

def f2(x):
    """f(x) = sin(x) + cos(x)"""
    return jnp.sin(x) + jnp.cos(x)

def f3(x):
    """f(x) = exp(x) * log(x + 1)"""
    return jnp.exp(x) * jnp.log(x + 1)

# Compute gradients using jax.grad
grad_f1 = grad(f1)  # Should be 2x
grad_f2 = grad(f2)  # Should be cos(x) - sin(x)
grad_f3 = grad(f3)  # More complex derivative

# Test at specific points
x_vals = jnp.array([1.0, 2.0, 3.0])

for x in x_vals:
    print(f"x = {x}")
    print(f"  f1(x) = {f1(x):.4f}, f1'(x) = {grad_f1(x):.4f}")
    print(f"  f2(x) = {f2(x):.4f}, f2'(x) = {grad_f2(x):.4f}")
    print(f"  f3(x) = {f3(x):.4f}, f3'(x) = {grad_f3(x):.4f}")
    print()
```

### Multi-Variable Functions

```python
# Multi-variable function examples
def g1(x):
    """f(x, y) = x^2 + y^2 (input x is a vector [x, y])"""
    return jnp.sum(x**2)

def g2(x):
    """f(x, y) = x*y + sin(x) * cos(y)"""
    return x[0] * x[1] + jnp.sin(x[0]) * jnp.cos(x[1])

def g3(x):
    """Rosenbrock function: f(x, y) = (1-x)^2 + 100*(y-x^2)^2"""
    return (1 - x[0])**2 + 100 * (x[1] - x[0]**2)**2

# Compute gradients
grad_g1 = grad(g1)
grad_g2 = grad(g2)
grad_g3 = grad(g3)

# Test points
test_points = [jnp.array([1.0, 1.0]), 
               jnp.array([0.5, 2.0]), 
               jnp.array([2.0, 1.5])]

for i, point in enumerate(test_points):
    print(f"Point {i+1}: {point}")
    print(f"  g1 gradient: {grad_g1(point)}")
    print(f"  g2 gradient: {grad_g2(point)}")
    print(f"  g3 gradient: {grad_g3(point)}")
    print()
```

## Jacobians: Handling Vector-Valued Functions

### Forward vs Reverse Mode

```python
# Vector-valued function example
def vector_function(x):
    """Maps R^2 -> R^3"""
    return jnp.array([
        x[0]**2 + x[1],           # f1(x, y) = x^2 + y
        x[0] * x[1],              # f2(x, y) = xy  
        jnp.sin(x[0]) + jnp.cos(x[1])  # f3(x, y) = sin(x) + cos(y)
    ])

# Forward-mode Jacobian (efficient for tall Jacobians)
jac_forward = jacfwd(vector_function)

# Reverse-mode Jacobian (efficient for wide Jacobians)  
jac_reverse = jacrev(vector_function)

# Test both methods
x = jnp.array([1.0, 2.0])
jac_fwd_result = jac_forward(x)
jac_rev_result = jac_reverse(x)

print(f"Input point: {x}")
print(f"Function output: {vector_function(x)}")
print(f"Forward-mode Jacobian:\n{jac_fwd_result}")
print(f"Reverse-mode Jacobian:\n{jac_rev_result}")
print(f"Results match: {jnp.allclose(jac_fwd_result, jac_rev_result)}")
```

### When to Use Forward vs Reverse Mode

```python
# Efficiency demonstration
def tall_jacobian_function(x):
    """R^2 -> R^10 (more outputs than inputs)"""
    return jnp.array([x[0]**i + x[1]**(i+1) for i in range(10)])

def wide_jacobian_function(x):
    """R^10 -> R^2 (more inputs than outputs)"""
    return jnp.array([jnp.sum(x**2), jnp.prod(x[:5])])

# For tall Jacobians, forward mode is more efficient
x_small = jnp.array([1.0, 2.0])
jac_tall = jacfwd(tall_jacobian_function)(x_small)
print(f"Tall Jacobian shape: {jac_tall.shape}")

# For wide Jacobians, reverse mode is more efficient  
x_large = jnp.ones(10)
jac_wide = jacrev(wide_jacobian_function)(x_large)
print(f"Wide Jacobian shape: {jac_wide.shape}")
```

## Hessians: Second-Order Derivatives

### Computing Hessian Matrices

```python
# Functions for Hessian computation
def quadratic_form(x):
    """f(x) = x^T A x + b^T x + c"""
    A = jnp.array([[2, 1], [1, 3]])  # Positive definite matrix
    b = jnp.array([1, -1])
    c = 5
    return jnp.dot(x, jnp.dot(A, x)) + jnp.dot(b, x) + c

def rosenbrock_2d(x):
    """Standard Rosenbrock function"""
    return (1 - x[0])**2 + 100 * (x[1] - x[0]**2)**2

# Compute Hessians
hess_quad = hessian(quadratic_form)
hess_rosen = hessian(rosenbrock_2d)

# Evaluate at test points
x1 = jnp.array([1.0, 1.0])
x2 = jnp.array([0.0, 0.0])

print("Quadratic Form Hessians:")
print(f"At {x1}: \n{hess_quad(x1)}")
print(f"At {x2}: \n{hess_quad(x2)}")
print("Note: Hessian is constant for quadratic functions")

print("\nRosenbrock Hessians:")
print(f"At {x1}: \n{hess_rosen(x1)}")
print(f"At {x2}: \n{hess_rosen(x2)}")
```

### Analyzing Function Curvature

```python
def analyze_critical_point(func, point):
    """Analyze the nature of a critical point using the Hessian"""
    grad_func = grad(func)
    hess_func = hessian(func)
    
    gradient = grad_func(point)
    hessian_matrix = hess_func(point)
    
    # Check if it's actually a critical point
    grad_norm = jnp.linalg.norm(gradient)
    
    # Compute eigenvalues for classification
    eigenvals = jnp.linalg.eigvals(hessian_matrix)
    
    print(f"Point: {point}")
    print(f"Gradient norm: {grad_norm:.6f}")
    print(f"Hessian eigenvalues: {eigenvals}")
    
    if grad_norm < 1e-6:  # Close to critical point
        if jnp.all(eigenvals > 0):
            print("Classification: Local minimum")
        elif jnp.all(eigenvals < 0):
            print("Classification: Local maximum")  
        else:
            print("Classification: Saddle point")
    else:
        print("Not a critical point")
    
    return gradient, hessian_matrix

# Test with Rosenbrock function
print("Analyzing Rosenbrock function:")
# The global minimum is at (1, 1)
analyze_critical_point(rosenbrock_2d, jnp.array([1.0, 1.0]))
print()

# Test at saddle point of another function
def saddle_function(x):
    return x[0]**2 - x[1]**2

print("Analyzing saddle function at origin:")
analyze_critical_point(saddle_function, jnp.array([0.0, 0.0]))
```

## Gradient-Based Optimization

### Simple Gradient Descent

```python
def gradient_descent(func, x0, learning_rate=0.01, num_steps=100, tolerance=1e-6):
    """Simple gradient descent implementation"""
    grad_func = grad(func)
    x = x0
    history = [x]
    
    for step in range(num_steps):
        gradient = grad_func(x)
        x = x - learning_rate * gradient
        history.append(x)
        
        # Check convergence
        if jnp.linalg.norm(gradient) < tolerance:
            print(f"Converged at step {step}")
            break
    
    return x, jnp.array(history)

# Optimize the quadratic function
x0 = jnp.array([5.0, 5.0])
solution, path = gradient_descent(quadratic_form, x0, learning_rate=0.1)

print(f"Starting point: {x0}")
print(f"Final solution: {solution}")
print(f"Final function value: {quadratic_form(solution):.6f}")
print(f"Number of steps: {len(path)-1}")

# Verify against analytical solution for quadratic form
# For f(x) = x^T A x + b^T x + c, minimum is at x* = -0.5 * A^(-1) * b
A = jnp.array([[2, 1], [1, 3]])
b = jnp.array([1, -1])
analytical_solution = -0.5 * jnp.linalg.solve(A, b)
print(f"Analytical solution: {analytical_solution}")
print(f"Error: {jnp.linalg.norm(solution - analytical_solution):.6f}")
```

### Newton's Method

```python
def newton_method(func, x0, num_steps=20, tolerance=1e-8):
    """Newton's method using automatic differentiation"""
    grad_func = grad(func)
    hess_func = hessian(func)
    
    x = x0
    history = [x]
    
    for step in range(num_steps):
        gradient = grad_func(x)
        hessian_matrix = hess_func(x)
        
        # Newton step: x_{k+1} = x_k - H^{-1} * g
        try:
            newton_step = jnp.linalg.solve(hessian_matrix, gradient)
            x = x - newton_step
            history.append(x)
            
            if jnp.linalg.norm(gradient) < tolerance:
                print(f"Converged at step {step}")
                break
        except:
            print(f"Hessian not invertible at step {step}")
            break
    
    return x, jnp.array(history)

# Compare with gradient descent on Rosenbrock
x0 = jnp.array([0.0, 0.0])

# Newton's method
newton_sol, newton_path = newton_method(rosenbrock_2d, x0)

# Gradient descent with smaller learning rate
gd_sol, gd_path = gradient_descent(rosenbrock_2d, x0, learning_rate=0.001, num_steps=1000)

print(f"Newton's method solution: {newton_sol}")
print(f"Newton's method steps: {len(newton_path)-1}")
print(f"Gradient descent solution: {gd_sol}")
print(f"Gradient descent steps: {len(gd_path)-1}")
```

## Advanced Autodiff Patterns

### Higher-Order Derivatives

```python
# Computing higher-order derivatives
def polynomial(x):
    return x**4 - 2*x**3 + x**2 - 5*x + 3

# First through fourth derivatives
first_deriv = grad(polynomial)
second_deriv = grad(first_deriv)
third_deriv = grad(second_deriv)  
fourth_deriv = grad(third_deriv)

x = 2.0
print(f"At x = {x}:")
print(f"f(x) = {polynomial(x)}")
print(f"f'(x) = {first_deriv(x)}")
print(f"f''(x) = {second_deriv(x)}")
print(f"f'''(x) = {third_deriv(x)}")
print(f"f''''(x) = {fourth_deriv(x)}")

# Analytical verification for x^4 - 2x^3 + x^2 - 5x + 3
# f'(x) = 4x^3 - 6x^2 + 2x - 5
print(f"\nAnalytical f'({x}) = {4*x**3 - 6*x**2 + 2*x - 5}")
```

### Gradients of Vector Functions

```python
# Using vmap for efficient batch gradients
def batch_function(x_batch):
    """Function that operates on a batch of inputs"""
    return jnp.sum(x_batch**2, axis=1)  # Sum of squares for each input

# Compute gradient for each input in batch
batch_grad = vmap(grad(lambda x: jnp.sum(x**2)))

# Test with batch data
x_batch = jnp.array([[1.0, 2.0], 
                     [3.0, 4.0], 
                     [5.0, 6.0]])

gradients = batch_grad(x_batch)
print(f"Batch inputs:\n{x_batch}")
print(f"Gradients:\n{gradients}")
```

### Gradient Through Control Flow

```python
# JAX can differentiate through control flow (with limitations)
def conditional_function(x):
    """Function with conditional logic"""
    return jnp.where(x > 0, x**2, -x**2)

def iterative_function(x, n_iters=5):
    """Function with iterative computation"""
    result = x
    for i in range(n_iters):
        result = result + 0.1 * jnp.sin(result)
    return result

# Compute gradients
grad_conditional = grad(conditional_function)
grad_iterative = grad(iterative_function)

# Test
x_vals = jnp.array([-1.0, 0.0, 1.0, 2.0])

print("Conditional function gradients:")
for x in x_vals:
    print(f"x = {x:4.1f}: grad = {grad_conditional(x):6.3f}")

print(f"\nIterative function gradient at x=1.0: {grad_iterative(1.0):.6f}")
```

## Practical Applications

### Linear Regression with Automatic Differentiation

```python
# Generate synthetic data
key = random.PRNGKey(42)
n_samples, n_features = 100, 3

X = random.normal(key, (n_samples, n_features))
true_weights = jnp.array([2.0, -1.5, 3.0])
y = X @ true_weights + 0.1 * random.normal(random.split(key)[1], (n_samples,))

def mse_loss(weights, X, y):
    """Mean squared error loss"""
    predictions = X @ weights
    return jnp.mean((predictions - y)**2)

def train_linear_regression(X, y, learning_rate=0.01, num_steps=100):
    """Train linear regression using gradient descent"""
    # Initialize weights
    weights = jnp.zeros(X.shape[1])
    
    # Loss and gradient functions
    loss_fn = lambda w: mse_loss(w, X, y)
    grad_fn = grad(loss_fn)
    
    losses = []
    for step in range(num_steps):
        loss = loss_fn(weights)
        gradient = grad_fn(weights)
        weights = weights - learning_rate * gradient
        losses.append(loss)
        
        if step % 20 == 0:
            print(f"Step {step}: Loss = {loss:.6f}")
    
    return weights, jnp.array(losses)

# Train the model
learned_weights, loss_history = train_linear_regression(X, y)

print(f"\nTrue weights: {true_weights}")
print(f"Learned weights: {learned_weights}")
print(f"Weight error: {jnp.linalg.norm(learned_weights - true_weights):.6f}")
```

### Logistic Regression

```python
# Generate classification data
def generate_classification_data(key, n_samples=200):
    X = random.normal(key, (n_samples, 2))
    # Create linearly separable classes with some noise
    true_weights = jnp.array([1.5, -2.0])
    bias = 0.5
    logits = X @ true_weights + bias + 0.1 * random.normal(random.split(key)[1], (n_samples,))
    y = (logits > 0).astype(jnp.float32)
    return X, y

key = random.PRNGKey(123)
X_cls, y_cls = generate_classification_data(key)

def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))

def logistic_loss(params, X, y):
    """Binary cross-entropy loss"""
    weights, bias = params
    logits = X @ weights + bias
    probs = sigmoid(logits)
    # Numerical stability
    probs = jnp.clip(probs, 1e-15, 1 - 1e-15)
    return -jnp.mean(y * jnp.log(probs) + (1 - y) * jnp.log(1 - probs))

def train_logistic_regression(X, y, learning_rate=0.1, num_steps=200):
    """Train logistic regression"""
    # Initialize parameters
    weights = jnp.zeros(X.shape[1])
    bias = 0.0
    params = (weights, bias)
    
    grad_fn = grad(logistic_loss)
    losses = []
    
    for step in range(num_steps):
        loss = logistic_loss(params, X, y)
        grads = grad_fn(params, X, y)
        
        # Update parameters
        weights = weights - learning_rate * grads[0]
        bias = bias - learning_rate * grads[1]
        params = (weights, bias)
        
        losses.append(loss)
        
        if step % 50 == 0:
            print(f"Step {step}: Loss = {loss:.6f}")
    
    return params, jnp.array(losses)

# Train logistic regression
final_params, cls_losses = train_logistic_regression(X_cls, y_cls)
final_weights, final_bias = final_params

print(f"Final weights: {final_weights}")
print(f"Final bias: {final_bias:.4f}")

# Compute accuracy
logits = X_cls @ final_weights + final_bias
predictions = (sigmoid(logits) > 0.5).astype(jnp.float32)
accuracy = jnp.mean(predictions == y_cls)
print(f"Accuracy: {accuracy:.4f}")
```

## Performance Tips and Gotchas

### Efficient Gradient Computation

```python
# Avoid recomputing gradients unnecessarily
def inefficient_training_step(params, data):
    """Recomputes gradient function each time"""
    loss_fn = lambda p: jnp.sum((p - data)**2)
    grad_fn = grad(loss_fn)  # BAD: Creates new gradient function
    return grad_fn(params)

def efficient_training_step(grad_fn, params, data):
    """Reuses precomputed gradient function"""  
    return grad_fn(params, data)

# Define gradient function once
def loss_with_data(params, data):
    return jnp.sum((params - data)**2)

efficient_grad_fn = grad(loss_with_data)

# Example usage
key = random.PRNGKey(0)
params = random.normal(key, (1000,))
data = random.normal(random.split(key)[1], (1000,))

# Both give same results but efficient version is faster
grad1 = inefficient_training_step(params, data)
grad2 = efficient_grad_fn(params, data)
print(f"Gradients match: {jnp.allclose(grad1, grad2)}")
```

## Summary

In this notebook, we've explored JAX's automatic differentiation capabilities:

**Key Concepts:**

1. **Basic Gradients**: Using `jax.grad` for scalar-valued functions
2. **Jacobians**: Forward (`jacfwd`) vs reverse (`jacrev`) mode for vector-valued functions  
3. **Hessians**: Second-order derivatives for optimization and analysis
4. **Higher-order derivatives**: Composing `grad` multiple times
5. **Batch gradients**: Using `vmap` for efficient batch processing

**Practical Applications:**
- Gradient descent optimization
- Newton's method
- Linear and logistic regression
- Function analysis and critical point classification

**Performance Best Practices:**
- Choose appropriate differentiation mode (forward vs reverse)
- Avoid recomputing gradient functions
- Use vectorization with `vmap` when possible
- Consider numerical stability in loss functions

**Next Steps:**
- The next notebook will cover custom VJP/JVP for advanced differentiation
- We'll explore how to implement custom derivatives for special functions
- Understanding the mathematical foundations will help with more complex autodiff scenarios

Automatic differentiation is fundamental to modern machine learning and scientific computing, and JAX's implementation provides both ease of use and high performance for research and production applications.