# File: notebooks/02_linear_algebra/06_iterative_solvers.ipynb

## JAX Linear Algebra: Iterative Solvers

This notebook explores iterative methods for solving linear systems and optimization problems in JAX. We'll implement conjugate gradient, GMRES, gradient descent variants, and Newton-type methods. These methods are essential for large-scale problems where direct factorization is computationally prohibitive.

Iterative solvers are particularly important in scientific computing and machine learning, where we often deal with large sparse systems or optimization problems that benefit from warm starts and approximate solutions.

## Setting Up the Environment

```python
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, lax
from jax import random
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, Callable
import functools

jax.config.update("jax_enable_x64", True)
print(f"JAX version: {jax.__version__}")
```

## Conjugate Gradient Method

### Basic CG Implementation

```python
def conjugate_gradient(A, b, x0=None, max_iters=None, tol=1e-6):
    """Conjugate Gradient method for solving Ax = b where A is SPD"""
    
    n = len(b)
    if x0 is None:
        x0 = jnp.zeros(n)
    if max_iters is None:
        max_iters = n
    
    def cg_step(state, _):
        x, r, p, iter_count = state
        
        # Compute alpha
        Ap = A @ p
        pAp = jnp.dot(p, Ap)
        alpha = jnp.dot(r, r) / pAp
        
        # Update solution and residual
        x_new = x + alpha * p
        r_new = r - alpha * Ap
        
        # Compute beta for next iteration
        beta = jnp.dot(r_new, r_new) / jnp.dot(r, r)
        p_new = r_new + beta * p
        
        new_state = (x_new, r_new, p_new, iter_count + 1)
        
        # Return state and residual norm for tracking
        return new_state, jnp.linalg.norm(r_new)
    
    # Initial state
    r0 = b - A @ x0
    p0 = r0
    initial_state = (x0, r0, p0, 0)
    
    # Run iterations
    def cg_cond(state):
        _, r, _, iter_count = state
        return (jnp.linalg.norm(r) > tol) & (iter_count < max_iters)
    
    def cg_body(state):
        new_state, _ = cg_step(state, None)
        return new_state
    
    final_state = lax.while_loop(cg_cond, cg_body, initial_state)
    
    return final_state[0], final_state[3]  # solution, iterations

# Test CG method
key = random.PRNGKey(42)
n = 100

# Create SPD matrix
A_base = random.normal(key, (n, n))
A = A_base @ A_base.T + jnp.eye(n)
b = random.normal(random.split(key)[1], (n,))

# Solve with CG
x_cg, iters = conjugate_gradient(A, b, tol=1e-10)

# Verify solution
residual = jnp.linalg.norm(A @ x_cg - b)
print(f"CG iterations: {iters}")
print(f"Residual norm: {residual:.2e}")

# Compare with direct solve
x_direct = jnp.linalg.solve(A, b)
solution_error = jnp.linalg.norm(x_cg - x_direct)
print(f"Solution error vs direct: {solution_error:.2e}")
```

### Preconditioned CG

```python
def preconditioned_cg(A, b, M=None, x0=None, max_iters=None, tol=1e-6):
    """Preconditioned Conjugate Gradient method"""
    
    n = len(b)
    if x0 is None:
        x0 = jnp.zeros(n)
    if max_iters is None:
        max_iters = n
    if M is None:
        M = jnp.eye(n)  # No preconditioning
    
    # Solve M @ z = r for preconditioning
    solve_M = lambda r: jnp.linalg.solve(M, r)
    
    def pcg_step(state):
        x, r, p = state
        
        # Solve M z = r
        z = solve_M(r)
        
        # Compute alpha
        Ap = A @ p
        alpha = jnp.dot(r, z) / jnp.dot(p, Ap)
        
        # Update solution and residual
        x_new = x + alpha * p
        r_new = r - alpha * Ap
        
        # Compute beta
        z_new = solve_M(r_new)
        beta = jnp.dot(r_new, z_new) / jnp.dot(r, z)
        p_new = z_new + beta * p
        
        return (x_new, r_new, p_new), jnp.linalg.norm(r_new)
    
    # Initial state
    r0 = b - A @ x0
    z0 = solve_M(r0)
    p0 = z0
    initial_state = (x0, r0, p0)
    
    # Run iterations using scan
    def scan_cond_body(carry, _):
        state, converged, iter_count = carry
        
        # Only do step if not converged
        new_state, residual_norm = lax.cond(
            converged,
            lambda s: (s, residual_norm),  # Don't update if converged
            lambda s: pcg_step(s),
            state
        )
        
        new_converged = converged | (residual_norm < tol) | (iter_count >= max_iters)
        new_carry = (new_state, new_converged, iter_count + 1)
        
        return new_carry, residual_norm
    
    # Run scan
    init_carry = (initial_state, False, 0)
    final_carry, residuals = lax.scan(scan_cond_body, init_carry, jnp.arange(max_iters))
    
    final_x = final_carry[0][0]
    final_iters = final_carry[2]
    
    return final_x, final_iters, residuals

# Test with diagonal preconditioning
diag_A = jnp.diag(jnp.diag(A))  # Diagonal preconditioner
x_pcg, iters_pcg, residuals = preconditioned_cg(A, b, M=diag_A, tol=1e-10)

print(f"Preconditioned CG iterations: {iters_pcg}")
print(f"Final residual: {jnp.linalg.norm(A @ x_pcg - b):.2e}")
```

## GMRES (Generalized Minimal Residual)

### GMRES Implementation

```python
def gmres(A, b, x0=None, m=None, max_iters=100, tol=1e-6):
    """GMRES method for solving Ax = b"""
    
    n = len(b)
    if x0 is None:
        x0 = jnp.zeros(n)
    if m is None:
        m = min(n, 50)  # Restart parameter
    
    def arnoldi_process(A, v, m):
        """Arnoldi process to build Krylov subspace"""
        n = len(v)
        V = jnp.zeros((n, m + 1))
        H = jnp.zeros((m + 1, m))
        
        V = V.at[:, 0].set(v / jnp.linalg.norm(v))
        
        def arnoldi_step(i, carry):
            V, H = carry
            
            # Compute A @ V[:, i]
            w = A @ V[:, i]
            
            # Orthogonalization
            for j in range(i + 1):
                h_ji = jnp.dot(w, V[:, j])
                H = H.at[j, i].set(h_ji)
                w = w - h_ji * V[:, j]
            
            # Normalization
            h_next = jnp.linalg.norm(w)
            H = H.at[i + 1, i].set(h_next)
            
            # Avoid division by zero
            V = V.at[:, i + 1].set(jnp.where(h_next > 1e-14, w / h_next, w))
            
            return (V, H)
        
        V, H = lax.fori_loop(0, m, arnoldi_step, (V, H))
        return V, H
    
    def gmres_solve(A, b, x0, m, tol):
        r0 = b - A @ x0
        beta = jnp.linalg.norm(r0)
        
        if beta < tol:
            return x0, 0
        
        # Build Krylov subspace
        V, H = arnoldi_process(A, r0, m)
        
        # Solve least squares problem
        e1 = jnp.zeros(m + 1)
        e1 = e1.at[0].set(1.0)
        
        # QR factorization of H
        Q, R = jnp.linalg.qr(H)
        
        # Solve R y = beta * Q^T e1
        rhs = beta * Q.T @ e1
        y = jnp.linalg.solve(R[:m, :m], rhs[:m])
        
        # Update solution
        x_new = x0 + V[:, :m] @ y
        
        return x_new, m
    
    # GMRES with restart
    x = x0
    total_iters = 0
    
    for restart in range(max_iters // m + 1):
        x, iters = gmres_solve(A, b, x, m, tol)
        total_iters += iters
        
        residual = jnp.linalg.norm(b - A @ x)
        if residual < tol:
            break
    
    return x, total_iters

# Test GMRES
A_nonsym = random.normal(key, (50, 50))  # Non-symmetric matrix
A_nonsym = A_nonsym + 5 * jnp.eye(50)  # Make diagonally dominant
b_nonsym = random.normal(random.split(key)[1], (50,))

x_gmres, iters_gmres = gmres(A_nonsym, b_nonsym, m=20, tol=1e-8)

print(f"GMRES iterations: {iters_gmres}")
print(f"GMRES residual: {jnp.linalg.norm(A_nonsym @ x_gmres - b_nonsym):.2e}")
```

## Gradient Descent Methods

### Steepest Descent

```python
def gradient_descent(A, b, x0=None, learning_rate=None, max_iters=1000, tol=1e-6):
    """Gradient descent for solving Ax = b (minimizes ||Ax - b||^2)"""
    
    n = len(b)
    if x0 is None:
        x0 = jnp.zeros(n)
    
    # Optimal learning rate for quadratic: 2 / (λ_max + λ_min)
    if learning_rate is None:
        eigvals = jnp.linalg.eigvals(A.T @ A)
        learning_rate = 2.0 / (jnp.max(eigvals) + jnp.min(eigvals))
    
    def grad_step(x, _):
        residual = A @ x - b
        grad = A.T @ residual
        x_new = x - learning_rate * grad
        return x_new, jnp.linalg.norm(residual)
    
    x, residuals = lax.scan(grad_step, x0, jnp.arange(max_iters))
    
    # Find convergence point
    converged_idx = jnp.argmax(residuals < tol)
    actual_iters = jnp.where(jnp.any(residuals < tol), converged_idx, max_iters)
    
    return x, actual_iters, residuals

# Test gradient descent
x_gd, iters_gd, residuals_gd = gradient_descent(A, b, max_iters=500, tol=1e-8)

print(f"Gradient descent iterations: {iters_gd}")
print(f"Final residual: {jnp.linalg.norm(A @ x_gd - b):.2e}")
```

### Conjugate Gradient vs Gradient Descent Comparison

```python
def compare_methods():
    """Compare convergence of different iterative methods"""
    
    # Create test problem with different condition numbers
    condition_numbers = [1e1, 1e3, 1e5]
    
    for cond_num in condition_numbers:
        print(f"\nCondition number: {cond_num:.0e}")
        
        # Create matrix with specific condition number
        U = random.orthogonal(key, 30)
        singular_values = jnp.logspace(0, jnp.log10(cond_num), 30)
        A_test = U @ jnp.diag(singular_values) @ U.T
        b_test = random.normal(key, (30,))
        
        # CG
        x_cg, iters_cg = conjugate_gradient(A_test, b_test, tol=1e-8, max_iters=100)
        
        # Gradient descent
        x_gd, iters_gd, _ = gradient_descent(A_test, b_test, tol=1e-8, max_iters=500)
        
        print(f"  CG iterations: {iters_cg}")
        print(f"  GD iterations: {iters_gd}")
        print(f"  CG residual: {jnp.linalg.norm(A_test @ x_cg - b_test):.2e}")
        print(f"  GD residual: {jnp.linalg.norm(A_test @ x_gd - b_test):.2e}")

compare_methods()
```

## Newton-Type Methods

### Newton-Raphson for Nonlinear Systems

```python
def newton_raphson_system(F, J, x0, max_iters=20, tol=1e-8):
    """Newton-Raphson for solving F(x) = 0"""
    
    def newton_step(x, _):
        f_val = F(x)
        j_val = J(x)
        
        # Solve J(x) * delta = -F(x)
        delta = jnp.linalg.solve(j_val, -f_val)
        x_new = x + delta
        
        return x_new, jnp.linalg.norm(f_val)
    
    x, residuals = lax.scan(newton_step, x0, jnp.arange(max_iters))
    
    # Find convergence
    converged_idx = jnp.argmax(residuals < tol)
    actual_iters = jnp.where(jnp.any(residuals < tol), converged_idx, max_iters)
    
    return x, actual_iters

# Example: Solve system x^2 + y^2 = 1, x - y = 0.5
def nonlinear_system(x):
    return jnp.array([
        x[0]**2 + x[1]**2 - 1,
        x[0] - x[1] - 0.5
    ])

def jacobian(x):
    return jnp.array([
        [2*x[0], 2*x[1]],
        [1, -1]
    ])

x0 = jnp.array([1.0, 0.0])
x_newton, iters_newton = newton_raphson_system(nonlinear_system, jacobian, x0)

print(f"Newton iterations: {iters_newton}")
print(f"Solution: {x_newton}")
print(f"Residual: {jnp.linalg.norm(nonlinear_system(x_newton)):.2e}")
```

## Optimization Methods

### L-BFGS

```python
def lbfgs(objective, grad_fn, x0, m=10, max_iters=100, tol=1e-6):
    """Limited-memory BFGS optimization"""
    
    def lbfgs_two_loop(s_list, y_list, rho_list, grad, k):
        """L-BFGS two-loop recursion"""
        q = grad
        alpha_list = jnp.zeros(m)
        
        # First loop
        def first_loop_body(i, carry):
            q, alpha_list = carry
            idx = (k - 1 - i) % m
            
            alpha = rho_list[idx] * jnp.dot(s_list[idx], q)
            alpha_list = alpha_list.at[i].set(alpha)
            q = q - alpha * y_list[idx]
            
            return (q, alpha_list)
        
        q, alpha_list = lax.fori_loop(0, min(k, m), first_loop_body, (q, alpha_list))
        
        # Scale
        if k > 0:
            idx = (k - 1) % m
            gamma = jnp.dot(s_list[idx], y_list[idx]) / jnp.dot(y_list[idx], y_list[idx])
            r = gamma * q
        else:
            r = q
        
        # Second loop
        def second_loop_body(i, r):
            idx = (k - m + i) % m if k >= m else i
            beta = rho_list[idx] * jnp.dot(y_list[idx], r)
            r = r + s_list[idx] * (alpha_list[m - 1 - i] - beta)
            return r
        
        r = lax.fori_loop(0, min(k, m), second_loop_body, r)
        return -r  # Return search direction
    
    n = len(x0)
    x = x0
    
    # Storage for L-BFGS
    s_list = jnp.zeros((m, n))
    y_list = jnp.zeros((m, n))
    rho_list = jnp.zeros(m)
    
    grad_prev = grad_fn(x)
    
    for k in range(max_iters):
        if jnp.linalg.norm(grad_prev) < tol:
            break
        
        # Compute search direction
        p = lbfgs_two_loop(s_list, y_list, rho_list, grad_prev, k)
        
        # Line search (simple backtracking)
        alpha = 1.0
        c1 = 1e-4
        
        f_current = objective(x)
        grad_dot_p = jnp.dot(grad_prev, p)
        
        for _ in range(20):  # Max line search steps
            x_new = x + alpha * p
            f_new = objective(x_new)
            
            if f_new <= f_current + c1 * alpha * grad_dot_p:
                break
            alpha *= 0.5
        
        x_new = x + alpha * p
        grad_new = grad_fn(x_new)
        
        # Update L-BFGS storage
        s = x_new - x
        y = grad_new - grad_prev
        rho = 1.0 / jnp.dot(y, s)
        
        idx = k % m
        s_list = s_list.at[idx].set(s)
        y_list = y_list.at[idx].set(y)
        rho_list = rho_list.at[idx].set(rho)
        
        x = x_new
        grad_prev = grad_new
    
    return x, k

# Test L-BFGS on Rosenbrock function
def rosenbrock(x):
    return (1 - x[0])**2 + 100 * (x[1] - x[0]**2)**2

rosenbrock_grad = grad(rosenbrock)

x0_opt = jnp.array([-1.0, 1.0])
x_lbfgs, iters_lbfgs = lbfgs(rosenbrock, rosenbrock_grad, x0_opt)

print(f"L-BFGS iterations: {iters_lbfgs}")
print(f"Solution: {x_lbfgs}")
print(f"Function value: {rosenbrock(x_lbfgs):.2e}")
```

## Summary

In this notebook, we've implemented and compared various iterative solvers:

**Linear Solvers:**
1. **Conjugate Gradient**: Optimal for SPD systems, O(√κ) convergence
2. **Preconditioned CG**: Improved convergence with preconditioning  
3. **GMRES**: General non-symmetric systems, restart for memory efficiency
4. **Gradient Descent**: Simple but slower convergence

**Nonlinear Solvers:**
1. **Newton-Raphson**: Quadratic convergence for well-behaved systems
2. **L-BFGS**: Quasi-Newton method for optimization problems

**Key Insights:**
- CG is superior to gradient descent for well-conditioned SPD systems
- Preconditioning dramatically improves convergence rates
- GMRES handles non-symmetric systems effectively
- L-BFGS balances memory usage with Newton-like convergence

**Performance Considerations:**
- Condition number strongly affects convergence rates
- Restarts in GMRES prevent memory growth
- Line search ensures convergence in optimization
- JAX's autodiff enables easy implementation of gradients and Jacobians

**Next Steps:**
- The next notebook covers numerical stability considerations
- Understanding when to use direct vs iterative methods
- Implementing custom preconditioners for specific problem types

Iterative methods are essential for large-scale computational problems where direct methods become prohibitively expensive.