# üìò Notebook 3: Automatic Differentiation - Computing Gradients Like Magic

Welcome to the "calculus" chapter of JAX! Don't worry - you don't need to remember your calculus classes. JAX does the math for you!

## üéØ What You'll Learn (30-40 minutes)

By the end of this notebook, you'll understand:
- ‚úÖ What automatic differentiation is (and why it's amazing)
- ‚úÖ How to compute gradients with `jax.grad()`
- ‚úÖ Multi-variable functions and partial derivatives
- ‚úÖ Getting both function value and gradient together
- ‚úÖ Jacobian and Hessian matrices (we'll explain these!)
- ‚úÖ Higher-order derivatives (gradients of gradients)
- ‚úÖ Practical application: Gradient descent

## ü§î What is Automatic Differentiation?

### The Problem: Manual Calculus is Hard
Machine learning requires computing **derivatives** (gradients) of complex functions. Doing this by hand is:
- Time-consuming 
- Error-prone
- Tedious
- Nearly impossible for neural networks with millions of parameters

### The Solution: Automatic Differentiation (Autodiff)
JAX **automatically** computes exact derivatives for you!

```python
# You write the FORWARD function:
def f(x):
    return x ** 2 + 2 * x + 1

# JAX computes the DERIVATIVE automatically:
df_dx = jax.grad(f)

print(df_dx(3.0))  # 8.0 (which is 2*3 + 2)
```

No manual calculus needed! üéâ

### Why This Matters for Machine Learning
**Training neural networks = finding gradients**
- Forward pass: Compute predictions
- Backward pass: Compute gradients (how to improve)
- Update: Adjust weights using gradients

JAX handles the backward pass automatically!

## üìö Key Concepts (Don't Worry, We'll Explain!)

### 1. Gradient (First Derivative)
**What is it?** How much a function changes when you change its input.

**Example:** If `f(x) = x¬≤`, then `f'(x) = 2x`
- At x=3: gradient is 6 (function increases by 6 for each unit increase in x)

### 2. Partial Derivatives
**What is it?** Gradient with respect to ONE variable in a multi-variable function.

**Example:** If `f(x,y) = x¬≤ + 3xy + y¬≤`
- ‚àÇf/‚àÇx = 2x + 3y (how f changes when x changes, y fixed)
- ‚àÇf/‚àÇy = 3x + 2y (how f changes when y changes, x fixed)

### 3. Jacobian Matrix
**What is it?** All partial derivatives of a vector-valued function.

**When?** When your function returns multiple outputs.

**Example:** `f(x) = [x¬≤, x¬≥, sin(x)]` ‚Üí Jacobian has 3 derivatives

### 4. Hessian Matrix  
**What is it?** Matrix of second derivatives (derivatives of derivatives).

**When?** For optimization algorithms that use second-order information.

**Don't worry!** JAX computes these automatically - you don't need to understand the math deeply!

## üéì What's in This Notebook?

This notebook has **8 comprehensive examples**:

1. **Basic gradient** - Single variable functions
2. **Multi-variable gradient** - Functions with multiple inputs
3. **value_and_grad** - Get function value and gradient together (efficient!)
4. **Jacobian** - Vector-valued functions
5. **Hessian** - Second derivatives
6. **Higher-order derivatives** - Gradients of gradients
7. **Multiple arguments** - Gradients w.r.t. different parameters
8. **Practical example** - Linear regression with gradient descent

## üöÄ Prerequisites

Before starting this notebook, you should:
- ‚úÖ Complete Notebook 1 (JAX Basics)
- ‚úÖ Understand what a function is
- ‚úÖ Know basic math (addition, multiplication, powers)
- ‚ùå **Don't need**: Deep calculus knowledge (JAX handles it!)

## üí° Key Takeaway

**You write the forward function. JAX computes exact gradients automatically.**

This is the secret sauce that makes modern machine learning possible!

Let's see autodiff in action! üìà

In [1]:
# =============================================================================
# AUTOMATIC DIFFERENTIATION - BASICS
# =============================================================================

import jax
import jax.numpy as jnp
import numpy as np

print("=" * 70)
print("COMPUTING GRADIENTS WITH JAX")
print("=" * 70)

# -----------------------------------------------------------------------------
# Example 1: Basic Gradient - Single Variable
# -----------------------------------------------------------------------------
print("\n1Ô∏è‚É£  BASIC GRADIENT")
print("-" * 70)

def f(x):
    """Simple quadratic function: f(x) = x^2"""
    return x ** 2

# Create gradient function
df_dx = jax.grad(f)

# Evaluate at different points
x = 3.0
print(f"Function: f(x) = x^2")
print(f"f({x}) = {f(x)}")
print(f"f'({x}) = {df_dx(x)}")
print(f"Expected: 2*{x} = {2*x} ‚úì")

# -----------------------------------------------------------------------------
# Example 2: Multi-variable Function
# -----------------------------------------------------------------------------
print("\n2Ô∏è‚É£  MULTI-VARIABLE GRADIENT")
print("-" * 70)

def g(x, y):
    """Function with two inputs: g(x,y) = x^2 + 3xy + y^2"""
    return x**2 + 3*x*y + y**2

# grad() computes gradient w.r.t. FIRST argument by default
dg_dx = jax.grad(g)

# To get gradient w.r.t. other arguments, use argnums
dg_dy = jax.grad(g, argnums=1)

x, y = 2.0, 1.0
print(f"Function: g(x,y) = x^2 + 3xy + y^2")
print(f"g({x}, {y}) = {g(x, y)}")
print(f"‚àÇg/‚àÇx at ({x},{y}) = {dg_dx(x, y)}")
print(f"‚àÇg/‚àÇy at ({x},{y}) = {dg_dy(x, y)}")
print(f"Expected: ‚àÇg/‚àÇx = 2x + 3y = {2*x + 3*y} ‚úì")
print(f"Expected: ‚àÇg/‚àÇy = 3x + 2y = {3*x + 2*y} ‚úì")

# -----------------------------------------------------------------------------
# Example 3: value_and_grad - Get Both Function Value and Gradient
# -----------------------------------------------------------------------------
print("\n3Ô∏è‚É£  VALUE AND GRADIENT TOGETHER")
print("-" * 70)

def loss(params):
    """Typical ML loss function"""
    return jnp.sum(params ** 2)

# This is super useful for optimization - get loss and gradient in one call
val_and_grad_fn = jax.value_and_grad(loss)

params = jnp.array([1.0, 2.0, 3.0])
value, gradient = val_and_grad_fn(params)

print(f"Parameters: {params}")
print(f"Loss value: {value}")
print(f"Gradient: {gradient}")
print(f"Expected gradient: 2*params = {2*params} ‚úì")

# -----------------------------------------------------------------------------
# Example 4: Gradients of Vector Functions - Jacobian
# -----------------------------------------------------------------------------
print("\n4Ô∏è‚É£  JACOBIAN (for vector-valued functions)")
print("-" * 70)

def vector_func(x):
    """Returns a vector: [x^2, x^3, sin(x)]"""
    return jnp.array([x**2, x**3, jnp.sin(x)])

# Jacobian computes all partial derivatives
jacobian_fn = jax.jacobian(vector_func)

x = 2.0
jac = jacobian_fn(x)
print(f"Function: f(x) = [x^2, x^3, sin(x)]")
print(f"Jacobian at x={x}:")
print(f"  df‚ÇÅ/dx = 2x = {jac[0]:.4f}")
print(f"  df‚ÇÇ/dx = 3x^2 = {jac[1]:.4f}")
print(f"  df‚ÇÉ/dx = cos(x) = {jac[2]:.4f}")

# -----------------------------------------------------------------------------
# Example 5: Second Derivatives - Hessian
# -----------------------------------------------------------------------------
print("\n5Ô∏è‚É£  HESSIAN (second derivatives)")
print("-" * 70)

def h(x):
    """Multi-variable function for Hessian demo"""
    return jnp.array([x[0]**2 + x[1]**2, x[0]*x[1]])

# Hessian: matrix of second partial derivatives
hessian_fn = jax.hessian(lambda x: x[0]**2 + x[1]**2)

x = jnp.array([1.0, 2.0])
hess = hessian_fn(x)
print(f"Function: f(x,y) = x^2 + y^2")
print(f"Hessian matrix at {x}:")
print(hess)
print("(Second derivatives: diagonal is 2, off-diagonal is 0)")

# -----------------------------------------------------------------------------
# Example 6: Higher-Order Derivatives
# -----------------------------------------------------------------------------
print("\n6Ô∏è‚É£  HIGHER-ORDER DERIVATIVES")
print("-" * 70)

def cubic(x):
    """f(x) = x^3"""
    return x ** 3

# First derivative
first_deriv = jax.grad(cubic)
# Second derivative (gradient of gradient)
second_deriv = jax.grad(first_deriv)
# Third derivative
third_deriv = jax.grad(second_deriv)

x = 2.0
print(f"Function: f(x) = x^3")
print(f"f({x}) = {cubic(x)}")
print(f"f'({x}) = {first_deriv(x)} (expected: 3x^2 = {3*x**2})")
print(f"f''({x}) = {second_deriv(x)} (expected: 6x = {6*x})")
print(f"f'''({x}) = {third_deriv(x)} (expected: 6)")

# -----------------------------------------------------------------------------
# Example 7: Gradient with Respect to Multiple Arguments
# -----------------------------------------------------------------------------
print("\n7Ô∏è‚É£  GRADIENTS W.R.T. MULTIPLE ARGUMENTS")
print("-" * 70)

def model(weights, bias, x):
    """Simple linear model: y = w*x + b"""
    return jnp.dot(weights, x) + bias

# Get gradients w.r.t. first two arguments (weights and bias)
grad_fn = jax.grad(model, argnums=(0, 1))

w = jnp.array([1.0, 2.0, 3.0])
b = 0.5
x = jnp.array([1.0, 1.0, 1.0])

grad_w, grad_b = grad_fn(w, b, x)
print(f"Model: y = w¬∑x + b")
print(f"Gradient w.r.t. weights: {grad_w}")
print(f"Gradient w.r.t. bias: {grad_b}")
print(f"(dL/dw = x, dL/db = 1)")

# -----------------------------------------------------------------------------
# Example 8: Practical Example - Gradient Descent
# -----------------------------------------------------------------------------
print("\n8Ô∏è‚É£  GRADIENT DESCENT OPTIMIZATION")
print("-" * 70)

def loss_fn(params, x, y):
    """MSE loss for linear regression"""
    prediction = params[0] * x + params[1]
    return jnp.mean((prediction - y) ** 2)

# Training data: y = 2x + 1 with some noise
np.random.seed(42)
x_data = jnp.array(np.linspace(0, 10, 20))
y_data = 2 * x_data + 1 + np.random.randn(20) * 0.5

# Initialize parameters
params = jnp.array([0.0, 0.0])  # [slope, intercept]

# Training loop
learning_rate = 0.01
grad_fn = jax.grad(loss_fn)

print("Training linear regression with gradient descent...")
print(f"True parameters: slope=2.0, intercept=1.0")
print(f"Initial params: {params}")

for step in range(100):
    grads = grad_fn(params, x_data, y_data)
    params = params - learning_rate * grads
    
    if step % 20 == 0:
        loss = loss_fn(params, x_data, y_data)
        print(f"Step {step:3d}: loss={loss:.4f}, params={params}")

final_loss = loss_fn(params, x_data, y_data)
print(f"\nFinal params: slope={params[0]:.3f}, intercept={params[1]:.3f}")
print(f"Final loss: {final_loss:.4f}")

print("\n" + "=" * 70)
print("KEY POINTS - AUTOMATIC DIFFERENTIATION")
print("=" * 70)
print("""
‚úÖ jax.grad() computes exact derivatives (not numerical approximations)
‚úÖ Works with any Python/JAX function - no manual chain rule needed
‚úÖ Efficient reverse-mode autodiff (backpropagation)
‚úÖ Can compute higher-order derivatives by composing grad()
‚úÖ value_and_grad() gives both function value and gradient
‚úÖ Jacobian and Hessian for vector functions and second derivatives
‚úÖ argnums parameter controls which arguments to differentiate
‚úÖ Combine with JIT for ultra-fast gradient computations
""")

COMPUTING GRADIENTS WITH JAX

1Ô∏è‚É£  BASIC GRADIENT
----------------------------------------------------------------------
Function: f(x) = x^2
f(3.0) = 9.0
f'(3.0) = 6.0
Expected: 2*3.0 = 6.0 ‚úì

2Ô∏è‚É£  MULTI-VARIABLE GRADIENT
----------------------------------------------------------------------
Function: g(x,y) = x^2 + 3xy + y^2
g(2.0, 1.0) = 11.0
‚àÇg/‚àÇx at (2.0,1.0) = 7.0
‚àÇg/‚àÇy at (2.0,1.0) = 8.0
Expected: ‚àÇg/‚àÇx = 2x + 3y = 7.0 ‚úì
Expected: ‚àÇg/‚àÇy = 3x + 2y = 8.0 ‚úì

3Ô∏è‚É£  VALUE AND GRADIENT TOGETHER
----------------------------------------------------------------------
Parameters: [1. 2. 3.]
Loss value: 14.0
Gradient: [2. 4. 6.]
Expected gradient: 2*params = [2. 4. 6.] ‚úì

4Ô∏è‚É£  JACOBIAN (for vector-valued functions)
----------------------------------------------------------------------
Function: f(x) = [x^2, x^3, sin(x)]
Jacobian at x=2.0:
  df‚ÇÅ/dx = 2x = 4.0000
  df‚ÇÇ/dx = 3x^2 = 12.0000
  df‚ÇÉ/dx = cos(x) = -0.4161

5Ô∏è‚É£  HESSIAN (second deriv