# File: notebooks/01_fundamentals/03_custom_vjp_jvp.ipynb

## JAX Fundamentals: Custom VJP and JVP

Welcome to the third notebook in the JAX-NSL series! In this notebook, we'll dive deep into custom Vector-Jacobian Products (VJP) and Jacobian-Vector Products (JVP) in JAX. These are advanced autodiff concepts that allow you to define custom differentiation rules for functions, which is essential for implementing special operations, numerical algorithms, and optimization techniques that require custom gradient behavior.

Understanding VJP and JVP is crucial for implementing custom layers, special mathematical functions, and optimized gradient computations that go beyond JAX's automatic differentiation capabilities.

## Setting Up the Environment

```python
import jax
import jax.numpy as jnp
from jax import grad, jacfwd, jacrev, vmap, jvp, vjp
from jax import custom_vjp, custom_jvp
from jax import random, lax
import numpy as np
import matplotlib.pyplot as plt
from typing import Callable, Tuple, Any
import functools

# Enable 64-bit precision
jax.config.update("jax_enable_x64", True)

print(f"JAX version: {jax.__version__}")
print(f"Devices available: {jax.devices()}")
```

## Understanding VJP and JVP

### Vector-Jacobian Products (VJP) - Reverse Mode

```python
# Basic VJP example
def simple_function(x):
    """f(x) = [x^2, sin(x), exp(x)]"""
    return jnp.array([x**2, jnp.sin(x), jnp.exp(x)])

# Compute VJP manually
def demonstrate_vjp():
    x = 2.0
    
    # Forward pass and VJP function
    y, vjp_fn = vjp(simple_function, x)
    
    # Vector for left multiplication (cotangent)
    v = jnp.array([1.0, 0.5, 0.1])  # Some cotangent vector
    
    # Compute VJP: v^T * J
    vjp_result = vjp_fn(v)[0]  # Returns tuple, take first element
    
    print(f"Input x: {x}")
    print(f"Function output y: {y}")
    print(f"Cotangent vector v: {v}")
    print(f"VJP result (v^T * J): {vjp_result}")
    
    # Manual verification
    # J = [2x, cos(x), exp(x)] at x=2
    manual_jacobian = jnp.array([2*x, jnp.cos(x), jnp.exp(x)])
    manual_vjp = jnp.dot(v, manual_jacobian)
    print(f"Manual VJP calculation: {manual_vjp}")
    print(f"Results match: {jnp.allclose(vjp_result, manual_vjp)}")

demonstrate_vjp()
```

### Jacobian-Vector Products (JVP) - Forward Mode

```python
# Basic JVP example  
def demonstrate_jvp():
    x = 2.0
    
    # Tangent vector (direction for directional derivative)
    v = 1.5
    
    # Compute JVP: J * v
    y, jvp_result = jvp(simple_function, (x,), (v,))
    
    print(f"Input x: {x}")
    print(f"Tangent vector v: {v}")
    print(f"Function output y: {y}")
    print(f"JVP result (J * v): {jvp_result}")
    
    # Manual verification
    # J = [2x, cos(x), exp(x)] at x=2
    manual_jacobian = jnp.array([2*x, jnp.cos(x), jnp.exp(x)])
    manual_jvp = manual_jacobian * v
    print(f"Manual JVP calculation: {manual_jvp}")
    print(f"Results match: {jnp.allclose(jvp_result, manual_jvp)}")

demonstrate_jvp()
```

## Custom VJP Implementation

### Basic Custom VJP

```python
# Example 1: Custom square function with custom VJP
@custom_vjp
def custom_square(x):
    """Square function with custom backward pass"""
    return x * x

def custom_square_fwd(x):
    """Forward pass: return (output, residual_data)"""
    return x * x, x  # Save input for backward pass

def custom_square_bwd(residual, cotangent):
    """Backward pass: compute VJP given residual and cotangent"""
    x = residual  # Retrieve saved input
    return (2 * x * cotangent,)  # Return tuple of partial derivatives

# Register the custom VJP
custom_square.defvjp(custom_square_fwd, custom_square_bwd)

# Test the custom function
x = 3.0
result = custom_square(x)
grad_result = grad(custom_square)(x)

print(f"custom_square({x}) = {result}")
print(f"grad(custom_square)({x}) = {grad_result}")
print(f"Expected gradient: {2 * x}")
```

### Advanced Custom VJP Example: Matrix Square Root

```python
# Custom matrix square root with stable gradients
@custom_vjp
def matrix_sqrt(A):
    """Compute matrix square root A^(1/2)"""
    # Use eigendecomposition for symmetric positive definite matrices
    eigvals, eigvecs = jnp.linalg.eigh(A)
    sqrt_eigvals = jnp.sqrt(jnp.maximum(eigvals, 1e-12))  # Numerical stability
    return eigvecs @ jnp.diag(sqrt_eigvals) @ eigvecs.T

def matrix_sqrt_fwd(A):
    """Forward pass for matrix square root"""
    X = matrix_sqrt(A)
    return X, (A, X)  # Return result and save both input and output

def matrix_sqrt_bwd(residuals, cotangent):
    """Backward pass using Sylvester equation solution"""
    A, X = residuals
    G = cotangent  # Cotangent matrix
    
    # Solve Sylvester equation: XY + YX = G for Y
    # This gives us the gradient dL/dA = Y
    def sylvester_solve(X, G):
        # For small matrices, use direct solve
        # In practice, you'd use more sophisticated solvers
        n = X.shape[0]
        I = jnp.eye(n)
        
        # Vectorize: vec(XY + YX) = (I⊗X + X⊗I)vec(Y) = vec(G)
        kron_term = jnp.kron(I, X) + jnp.kron(X, I)
        vec_G = G.flatten()
        vec_Y = jnp.linalg.solve(kron_term, vec_G)
        Y = vec_Y.reshape(n, n)
        return Y
    
    Y = sylvester_solve(X, G)
    return (Y,)

# Register custom VJP
matrix_sqrt.defvjp(matrix_sqrt_fwd, matrix_sqrt_bwd)

# Test with a simple symmetric positive definite matrix
A = jnp.array([[4.0, 1.0], [1.0, 2.0]])
X = matrix_sqrt(A)

print(f"Matrix A:\n{A}")
print(f"Matrix square root X:\n{X}")
print(f"Verification X @ X:\n{X @ X}")
print(f"Error: {jnp.max(jnp.abs(X @ X - A))}")

# Test gradient computation
def trace_matrix_sqrt(A):
    return jnp.trace(matrix_sqrt(A))

grad_trace = grad(trace_matrix_sqrt)(A)
print(f"Gradient of trace(sqrt(A)):\n{grad_trace}")
```

## Custom JVP Implementation

### Basic Custom JVP

```python
# Example: Custom exponential with custom JVP
@custom_jvp
def custom_exp(x):
    """Exponential function with custom forward mode"""
    return jnp.exp(x)

@custom_exp.defjvp
def custom_exp_jvp(primals, tangents):
    """Custom JVP rule for exponential function"""
    x, = primals
    x_dot, = tangents
    
    # Forward pass
    exp_x = jnp.exp(x)
    
    # JVP: d/dx[exp(x)] = exp(x), so JVP = exp(x) * x_dot
    exp_x_dot = exp_x * x_dot
    
    return exp_x, exp_x_dot

# Test custom JVP
x = 1.0
tangent = 0.5

# Using JVP
result, jvp_result = jvp(custom_exp, (x,), (tangent,))
print(f"custom_exp({x}) = {result}")
print(f"JVP with tangent {tangent}: {jvp_result}")

# Compare with automatic differentiation
auto_jvp = jvp(jnp.exp, (x,), (tangent,))
print(f"Automatic JVP: {auto_jvp}")
```

### Advanced Custom JVP: Iterative Algorithm

```python
# Custom JVP for iterative square root computation
@custom_jvp  
def iterative_sqrt(x, num_iters=10):
    """Compute sqrt using Newton-Raphson iterations"""
    # Newton-Raphson: x_{n+1} = 0.5 * (x_n + a/x_n)
    estimate = x / 2.0  # Initial guess
    
    for _ in range(num_iters):
        estimate = 0.5 * (estimate + x / estimate)
    
    return estimate

@iterative_sqrt.defjvp
def iterative_sqrt_jvp(primals, tangents):
    """Custom JVP for iterative sqrt"""
    x, num_iters = primals
    x_dot, _ = tangents  # num_iters is not differentiable
    
    # Forward pass
    sqrt_x = iterative_sqrt(x, num_iters)
    
    # JVP: d/dx[sqrt(x)] = 1/(2*sqrt(x))
    sqrt_x_dot = x_dot / (2 * sqrt_x)
    
    return sqrt_x, sqrt_x_dot

# Test iterative sqrt
x = 9.0
tangent = 1.0

result, jvp_result = jvp(iterative_sqrt, (x, 10), (tangent, 0))
print(f"iterative_sqrt({x}) = {result}")
print(f"JVP result: {jvp_result}")
print(f"True sqrt: {jnp.sqrt(x)}")
print(f"True derivative: {1/(2*jnp.sqrt(x))}")
```

## Combining Custom VJP and JVP

### Bi-directional Custom Rules

```python
# Function that supports both custom VJP and JVP
@custom_vjp
@custom_jvp
def special_function(x):
    """f(x) = x * log(1 + exp(x)) - numerically stable version"""
    # This is a numerically stable version of x * log(1 + exp(x))
    return jnp.where(x > 10, x**2, x * jnp.log1p(jnp.exp(x)))

# Custom JVP
@special_function.defjvp
def special_function_jvp(primals, tangents):
    x, = primals
    x_dot, = tangents
    
    # Forward pass
    y = special_function(x)
    
    # Derivative: d/dx[x * log(1 + exp(x))] = log(1 + exp(x)) + x * exp(x)/(1 + exp(x))
    #           = log(1 + exp(x)) + x * sigmoid(x)
    sigmoid_x = jax.nn.sigmoid(x)
    dy_dx = jnp.log1p(jnp.exp(jnp.minimum(x, 10))) + x * sigmoid_x
    
    return y, dy_dx * x_dot

# Custom VJP  
def special_function_fwd(x):
    return special_function(x), x

def special_function_bwd(x, cotangent):
    sigmoid_x = jax.nn.sigmoid(x)
    dy_dx = jnp.log1p(jnp.exp(jnp.minimum(x, 10))) + x * sigmoid_x
    return (dy_dx * cotangent,)

special_function.defvjp(special_function_fwd, special_function_bwd)

# Test both directions
x = 2.0

# Test JVP
tangent = 1.5
y, jvp_result = jvp(special_function, (x,), (tangent,))
print(f"JVP test: f({x}) = {y}, JVP = {jvp_result}")

# Test VJP via grad
grad_result = grad(special_function)(x)
print(f"VJP test: grad(f)({x}) = {grad_result}")

# Verify consistency
print(f"JVP/VJP consistency: {jnp.allclose(jvp_result/tangent, grad_result)}")
```

## Practical Applications

### Custom Activation Function

```python
# Custom activation function with optimized gradients
@custom_vjp
def swish_custom(x):
    """Swish activation: x * sigmoid(x)"""
    return x * jax.nn.sigmoid(x)

def swish_fwd(x):
    """Forward pass for Swish"""
    sigmoid_x = jax.nn.sigmoid(x)
    return x * sigmoid_x, (x, sigmoid_x)

def swish_bwd(residuals, cotangent):
    """Backward pass for Swish with optimized computation"""
    x, sigmoid_x = residuals
    
    # Derivative: d/dx[x * sigmoid(x)] = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x))
    #           = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
    derivative = sigmoid_x * (1 + x * (1 - sigmoid_x))
    
    return (derivative * cotangent,)

swish_custom.defvjp(swish_fwd, swish_bwd)

# Compare with automatic differentiation
def swish_auto(x):
    return x * jax.nn.sigmoid(x)

# Test both versions
x_test = jnp.linspace(-5, 5, 100)

# Forward pass
y_custom = swish_custom(x_test)
y_auto = swish_auto(x_test)

print(f"Forward pass error: {jnp.max(jnp.abs(y_custom - y_auto))}")

# Gradient computation
grad_custom = vmap(grad(swish_custom))(x_test)
grad_auto = vmap(grad(swish_auto))(x_test)

print(f"Gradient error: {jnp.max(jnp.abs(grad_custom - grad_auto))}")
```

### Implicit Function Differentiation

```python
# Solve implicit equation F(x, y) = 0 and differentiate through the solution
@custom_vjp
def implicit_solve(x):
    """Solve y^3 + y - x = 0 for y given x using Newton-Raphson"""
    
    def residual(y, x):
        return y**3 + y - x
    
    def residual_derivative(y):
        return 3*y**2 + 1
    
    # Newton-Raphson iterations
    y = x / 2  # Initial guess
    for _ in range(10):  # Fixed number of iterations
        f = residual(y, x)
        df = residual_derivative(y)
        y = y - f / df
    
    return y

def implicit_solve_fwd(x):
    """Forward pass"""
    y = implicit_solve(x)
    return y, (x, y)

def implicit_solve_bwd(residuals, cotangent):
    """Backward pass using implicit function theorem"""
    x, y = residuals
    
    # From F(x, y) = 0, we have dF/dx + dF/dy * dy/dx = 0
    # So dy/dx = -dF/dx / dF/dy
    
    # F(x, y) = y^3 + y - x
    dF_dx = -1.0
    dF_dy = 3*y**2 + 1
    
    dy_dx = -dF_dx / dF_dy  # = 1 / (3*y^2 + 1)
    
    return (dy_dx * cotangent,)

implicit_solve.defvjp(implicit_solve_fwd, implicit_solve_bwd)

# Test the implicit solver
x_vals = jnp.array([1.0, 2.0, 5.0, 10.0])

for x in x_vals:
    y = implicit_solve(x)
    
    # Verify solution
    residual = y**3 + y - x
    
    # Compute gradient
    dy_dx = grad(implicit_solve)(x)
    
    print(f"x = {x:4.1f}: y = {y:6.3f}, residual = {residual:8.2e}, dy/dx = {dy_dx:6.3f}")
```

### Efficient Linear System Solver

```python
# Custom linear solver with efficient VJP
@custom_vjp
def solve_linear_system(A, b):
    """Solve Ax = b using Cholesky decomposition"""
    return jnp.linalg.solve(A, b)

def solve_fwd(A, b):
    """Forward pass for linear solve"""
    x = jnp.linalg.solve(A, b)
    return x, (A, b, x)

def solve_bwd(residuals, x_cotangent):
    """Backward pass for linear solve using efficient method"""
    A, b, x = residuals
    
    # Gradient computation:
    # If x = A^(-1)b, then dx = -A^(-1) dA A^(-1) b + A^(-1) db
    #                       = -A^(-1) dA x + A^(-1) db
    
    # For efficiency, solve A^T lambda = x_cotangent instead of computing A^(-1)
    lambda_vec = jnp.linalg.solve(A.T, x_cotangent)
    
    # dL/dA = -lambda * x^T
    dA = -jnp.outer(lambda_vec, x)
    
    # dL/db = lambda
    db = lambda_vec
    
    return dA, db

solve_linear_system.defvjp(solve_fwd, solve_bwd)

# Test linear solver
key = random.PRNGKey(42)
n = 5

# Generate a well-conditioned positive definite matrix
A_base = random.normal(key, (n, n))
A = A_base @ A_base.T + jnp.eye(n)  # Make positive definite
b = random.normal(random.split(key)[1], (n,))

# Solve system
x = solve_linear_system(A, b)

print(f"Solution x: {x}")
print(f"Residual ||Ax - b||: {jnp.linalg.norm(A @ x - b)}")

# Test gradient computation
def objective(A, b):
    x = solve_linear_system(A, b)
    return jnp.sum(x**2)

grad_A, grad_b = grad(objective, argnums=(0, 1))(A, b)
print(f"Gradient w.r.t. A shape: {grad_A.shape}")
print(f"Gradient w.r.t. b shape: {grad_b.shape}")
```

## Debugging and Validation

### Gradient Checking

```python
def finite_difference_check(func, x, eps=1e-5):
    """Check gradients using finite differences"""
    
    def grad_component(x, i):
        """Compute i-th component of gradient using finite differences"""
        x_plus = x.at[i].add(eps)
        x_minus = x.at[i].add(-eps)
        return (func(x_plus) - func(x_minus)) / (2 * eps)
    
    # Compute analytical gradient
    analytical_grad = grad(func)(x)
    
    # Compute numerical gradient
    numerical_grad = jnp.array([grad_component(x, i) for i in range(len(x))])
    
    # Compare
    error = jnp.linalg.norm(analytical_grad - numerical_grad)
    relative_error = error / (jnp.linalg.norm(analytical_grad) + 1e-8)
    
    return analytical_grad, numerical_grad, error, relative_error

# Test custom functions
def test_function(x):
    return jnp.sum(swish_custom(x))

x_test = jnp.array([1.0, -0.5, 2.0])
anal_grad, num_grad, abs_error, rel_error = finite_difference_check(test_function, x_test)

print(f"Analytical gradient: {anal_grad}")
print(f"Numerical gradient: {num_grad}")
print(f"Absolute error: {abs_error:.2e}")
print(f"Relative error: {rel_error:.2e}")

if rel_error < 1e-5:
    print("✓ Gradient check passed")
else:
    print("✗ Gradient check failed")
```

### VJP/JVP Consistency Check

```python
def check_vjp_jvp_consistency(func, x, v, eps=1e-10):
    """Check that VJP and JVP are transposes of each other"""
    
    # Test the identity: v^T (Jx) = (v^T J) x for random vectors
    key = random.PRNGKey(0)
    
    # Forward mode: compute J*x (JVP)
    y, jvp_result = jvp(func, (x,), (v,))
    
    # Reverse mode: compute v^T*J (VJP)  
    y_check, vjp_fn = vjp(func, x)
    vjp_result = vjp_fn(v)[0]
    
    # For scalar functions, they should be equal
    # For vector functions, we need to check the transpose relationship
    
    print(f"Function output shape: {y.shape}")
    print(f"JVP result shape: {jvp_result.shape}")
    print(f"VJP result shape: {vjp_result.shape}")
    
    if y.shape == ():  # Scalar function
        consistency_error = abs(jvp_result - vjp_result)
        print(f"JVP result: {jvp_result}")
        print(f"VJP result: {vjp_result}")
        print(f"Consistency error: {consistency_error}")
        return consistency_error < eps
    else:
        # For vector functions, check <v, Jx> = <v^T J, x>
        # This requires inner product consistency
        inner_product_1 = jnp.dot(v, jvp_result)  # v^T (J x)
        inner_product_2 = jnp.dot(vjp_result, x)  # (v^T J) x
        
        consistency_error = abs(inner_product_1 - inner_product_2)
        print(f"Inner product 1 (v^T Jx): {inner_product_1}")
        print(f"Inner product 2 ((v^T J)x): {inner_product_2}")
        print(f"Consistency error: {consistency_error}")
        return consistency_error < eps

# Test consistency for our custom functions
x = jnp.array([1.0, 2.0])
v = jnp.array([0.5, -0.3])

print("Testing VJP/JVP consistency for custom square:")
is_consistent = check_vjp_jvp_consistency(lambda x: custom_square(x[0]), x, v[0])
print(f"Consistent: {is_consistent}\n")
```

## Performance Comparison

### Benchmarking Custom vs Automatic Differentiation

```python
import time

def benchmark_differentiation():
    """Compare performance of custom vs automatic differentiation"""
    
    # Setup test data
    key = random.PRNGKey(42)
    x = random.normal(key, (1000,))
    
    # Functions to test
    funcs = {
        'swish_auto': lambda x: x * jax.nn.sigmoid(x),
        'swish_custom': swish_custom
    }
    
    # Warmup and benchmark
    for name, func in funcs.items():
        # Warmup
        for _ in range(10):
            _ = grad(func)(x[0])
        
        # Benchmark gradient computation
        start_time = time.time()
        for _ in range(100):
            grad_result = vmap(grad(func))(x)
        end_time = time.time()
        
        print(f"{name}: {(end_time - start_time)*1000:.2f} ms for 100 iterations")

benchmark_differentiation()
```

## Summary

In this notebook, we've explored advanced automatic differentiation concepts in JAX:

**Key Concepts:**

1. **VJP (Vector-Jacobian Products)**: Reverse-mode differentiation for efficient backpropagation
2. **JVP (Jacobian-Vector Products)**: Forward-mode differentiation for directional derivatives
3. **Custom VJP**: Implementing custom backward passes with `@custom_vjp`
4. **Custom JVP**: Implementing custom forward passes with `@custom_jvp`
5. **Implicit Differentiation**: Differentiating through iterative algorithms and implicit equations

**Practical Applications:**
- Custom activation functions with optimized gradients
- Matrix operations with numerically stable derivatives
- Linear system solvers with efficient backpropagation
- Implicit function differentiation for constrained optimization

**Best Practices:**
- Use custom VJP/JVP for numerical stability and performance
- Save minimal residual data in forward passes
- Validate custom derivatives with finite difference checks
- Check VJP/JVP consistency for correctness

**When to Use Custom Rules:**
- Numerical stability issues with automatic differentiation
- Performance optimization for specific operations
- Implementing algorithms with known analytical derivatives
- Differentiating through implicit functions or iterative solvers

**Next Steps:**
- The next notebook will cover control flow and scan operations
- We'll explore how JAX handles differentiation through loops and conditionals
- Understanding these concepts enables implementation of RNNs and iterative algorithms

Custom VJP and JVP are powerful tools for advanced JAX programming, enabling fine-grained control over differentiation behavior while maintaining the benefits of automatic differentiation.