# File: notebooks/02_linear_algebra/07_numerical_stability.ipynb

## JAX Linear Algebra: Numerical Stability

This notebook explores numerical stability in linear algebra computations. We'll cover condition numbers, numerical precision, stable algorithms, and techniques for handling ill-conditioned problems. Understanding numerical stability is crucial for reliable scientific computing and machine learning applications.

Numerical stability determines whether small perturbations in input data lead to small changes in output, which is essential for robust algorithms in the presence of round-off errors and data uncertainty.

## Setting Up the Environment

```python
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, lax
from jax import random
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, Optional
import functools

jax.config.update("jax_enable_x64", True)
print(f"JAX version: {jax.__version__}")
```

## Understanding Condition Numbers

### Matrix Condition Number Analysis

```python
def condition_number_analysis():
    """Analyze how condition numbers affect numerical stability"""
    
    key = random.PRNGKey(42)
    n = 10
    
    # Create matrices with different condition numbers
    condition_numbers = [1e1, 1e3, 1e6, 1e9, 1e12]
    
    print("Condition Number Analysis:")
    print("=" * 50)
    
    for cond_target in condition_numbers:
        # Create matrix with specific condition number
        U = random.orthogonal(key, n)
        V = random.orthogonal(random.split(key)[1], n)
        
        # Logarithmically spaced singular values
        singular_values = jnp.logspace(0, jnp.log10(cond_target), n)
        A = U @ jnp.diag(singular_values) @ V.T
        
        actual_cond = jnp.linalg.cond(A)
        
        # Test linear solve stability
        x_true = random.normal(key, (n,))
        b = A @ x_true
        
        # Add small perturbation to b
        perturbation = 1e-14 * random.normal(random.split(key)[1], (n,))
        b_perturbed = b + perturbation
        
        # Solve both systems
        x_computed = jnp.linalg.solve(A, b)
        x_perturbed = jnp.linalg.solve(A, b_perturbed)
        
        # Compute relative errors
        rel_error = jnp.linalg.norm(x_computed - x_true) / jnp.linalg.norm(x_true)
        perturbation_amplification = (jnp.linalg.norm(x_perturbed - x_computed) / jnp.linalg.norm(x_computed)) / (jnp.linalg.norm(perturbation) / jnp.linalg.norm(b))
        
        print(f"Target κ: {cond_target:.0e}, Actual κ: {actual_cond:.2e}")
        print(f"  Relative solve error: {rel_error:.2e}")
        print(f"  Perturbation amplification: {perturbation_amplification:.2e}")
        print(f"  Theoretical bound: {actual_cond * 1e-14:.2e}")
        print()

condition_number_analysis()
```

### Hilbert Matrix Example

```python
def hilbert_matrix_stability():
    """Analyze the notoriously ill-conditioned Hilbert matrix"""
    
    def hilbert_matrix(n):
        """Generate n×n Hilbert matrix H[i,j] = 1/(i+j+1)"""
        i, j = jnp.meshgrid(jnp.arange(n), jnp.arange(n))
        return 1.0 / (i + j + 1)
    
    print("Hilbert Matrix Stability Analysis:")
    print("=" * 40)
    
    for n in [5, 8, 10, 12]:
        H = hilbert_matrix(n)
        cond_H = jnp.linalg.cond(H)
        
        # Known exact solution
        x_exact = jnp.ones(n)
        b = H @ x_exact
        
        # Solve using different methods
        x_solve = jnp.linalg.solve(H, b)
        x_lstsq = jnp.linalg.lstsq(H, b, rcond=None)[0]
        
        # SVD-based solve with truncation
        U, s, Vt = jnp.linalg.svd(H)
        threshold = 1e-12
        s_inv = jnp.where(s > threshold, 1/s, 0)
        x_svd = Vt.T @ (s_inv[:, None] * (U.T @ b[:, None]))[:, 0]
        
        print(f"n = {n}: κ(H) = {cond_H:.2e}")
        print(f"  Direct solve error: {jnp.linalg.norm(x_solve - x_exact):.2e}")
        print(f"  Least squares error: {jnp.linalg.norm(x_lstsq - x_exact):.2e}")
        print(f"  SVD solve error: {jnp.linalg.norm(x_svd - x_exact):.2e}")
        print()

hilbert_matrix_stability()
```

## Stable Algorithms for Common Operations

### Numerically Stable Sum and Products

```python
def stable_summation():
    """Compare numerical stability of different summation algorithms"""
    
    # Create data that challenges floating-point arithmetic
    key = random.PRNGKey(0)
    
    # Large numbers with small differences
    large_vals = 1e10 + random.normal(key, (1000,)) * 1e-5
    small_vals = random.normal(random.split(key)[1], (1000,)) * 1e-10
    
    mixed_data = jnp.concatenate([large_vals, small_vals])
    mixed_data = random.permutation(random.split(key, 3)[2], mixed_data)
    
    # Standard summation
    sum_standard = jnp.sum(mixed_data)
    
    # Kahan summation (stable summation)
    def kahan_sum(arr):
        def kahan_step(carry, x):
            sum_val, c = carry
            y = x - c
            t = sum_val + y
            c = (t - sum_val) - y
            return (t, c), t
        
        final_carry, _ = lax.scan(kahan_step, (0.0, 0.0), arr)
        return final_carry[0]
    
    sum_kahan = kahan_sum(mixed_data)
    
    # Sorted summation (add small values first)
    sorted_data = jnp.sort(jnp.abs(mixed_data))
    sum_sorted = jnp.sum(jnp.where(mixed_data >= 0, 
                                  sorted_data, 
                                  -sorted_data))
    
    print("Stable Summation Comparison:")
    print(f"Standard sum: {sum_standard:.10e}")
    print(f"Kahan sum: {sum_kahan:.10e}")  
    print(f"Sorted sum: {sum_sorted:.10e}")
    print(f"Difference (Kahan - Standard): {sum_kahan - sum_standard:.2e}")

stable_summation()
```

### Stable Matrix Operations

```python
def stable_matrix_operations():
    """Demonstrate stable implementations of common matrix operations"""
    
    key = random.PRNGKey(123)
    
    # Log-sum-exp: stable computation of log(sum(exp(x)))
    def logsumexp_unstable(x):
        return jnp.log(jnp.sum(jnp.exp(x)))
    
    def logsumexp_stable(x):
        x_max = jnp.max(x)
        return x_max + jnp.log(jnp.sum(jnp.exp(x - x_max)))
    
    # Test with large values
    large_values = jnp.array([700.0, 800.0, 900.0])  # Would overflow exp()
    
    print("Log-Sum-Exp Stability:")
    try:
        unstable_result = logsumexp_unstable(large_values)
        print(f"Unstable result: {unstable_result}")
    except:
        print("Unstable version: overflow/underflow")
    
    stable_result = logsumexp_stable(large_values)
    jax_result = jax.nn.logsumexp(large_values)
    print(f"Stable result: {stable_result:.6f}")
    print(f"JAX result: {jax_result:.6f}")
    
    # Stable softmax
    def softmax_unstable(x):
        exp_x = jnp.exp(x)
        return exp_x / jnp.sum(exp_x)
    
    def softmax_stable(x):
        x_shifted = x - jnp.max(x)
        exp_x = jnp.exp(x_shifted)
        return exp_x / jnp.sum(exp_x)
    
    # Test softmax
    large_logits = jnp.array([1000.0, 1001.0, 999.0])
    
    print("\nSoftmax Stability:")
    try:
        unstable_softmax = softmax_unstable(large_logits)
        print(f"Unstable softmax: {unstable_softmax}")
    except:
        print("Unstable softmax: overflow")
    
    stable_softmax = softmax_stable(large_logits)
    jax_softmax = jax.nn.softmax(large_logits)
    print(f"Stable softmax: {stable_softmax}")
    print(f"JAX softmax: {jax_softmax}")

stable_matrix_operations()
```

## Regularization Techniques

### Ridge Regression and Regularization

```python
def regularization_demo():
    """Demonstrate regularization for ill-conditioned problems"""
    
    key = random.PRNGKey(42)
    
    # Create ill-conditioned regression problem
    n_samples, n_features = 50, 45  # More features than samples
    X = random.normal(key, (n_samples, n_features))
    true_beta = random.normal(random.split(key)[1], (n_features,))
    y = X @ true_beta + 0.1 * random.normal(random.split(key, 3)[2], (n_samples,))
    
    # Normal equations: X^T X β = X^T y
    XtX = X.T @ X
    Xty = X.T @ y
    
    print("Regularization for Ill-conditioned Problems:")
    print(f"Condition number of X^T X: {jnp.linalg.cond(XtX):.2e}")
    
    # Different regularization levels
    lambda_values = [0, 1e-6, 1e-3, 1e-1, 1.0]
    
    for lam in lambda_values:
        # Ridge regression: (X^T X + λI) β = X^T y
        regularized_matrix = XtX + lam * jnp.eye(n_features)
        beta_ridge = jnp.linalg.solve(regularized_matrix, Xty)
        
        # Compute metrics
        train_error = jnp.mean((X @ beta_ridge - y)**2)
        param_error = jnp.linalg.norm(beta_ridge - true_beta)
        cond_regularized = jnp.linalg.cond(regularized_matrix)
        
        print(f"λ = {lam:6.0e}: cond = {cond_regularized:8.2e}, "
              f"train_err = {train_error:.4f}, param_err = {param_error:.4f}")

regularization_demo()
```

### Singular Value Truncation

```python
def svd_regularization():
    """Use SVD truncation for regularization"""
    
    key = random.PRNGKey(0)
    
    # Create low-rank matrix with noise
    rank = 5
    m, n = 20, 15
    
    U_true = random.normal(key, (m, rank))
    V_true = random.normal(random.split(key)[1], (n, rank))
    A_clean = U_true @ V_true.T
    
    # Add noise
    noise = 0.1 * random.normal(random.split(key, 3)[2], (m, n))
    A_noisy = A_clean + noise
    
    print("SVD-based Regularization:")
    print(f"True rank: {rank}, Matrix shape: {A_noisy.shape}")
    
    # SVD of noisy matrix
    U, s, Vt = jnp.linalg.svd(A_noisy, full_matrices=False)
    
    print(f"Singular values: {s[:8]}")
    
    # Truncated SVD reconstruction
    truncation_ranks = [3, 5, 7, 10, min(m, n)]
    
    for r in truncation_ranks:
        A_truncated = U[:, :r] @ jnp.diag(s[:r]) @ Vt[:r, :]
        
        reconstruction_error = jnp.linalg.norm(A_truncated - A_clean, 'fro')
        compression_ratio = r * (m + n) / (m * n)
        
        print(f"Rank {r:2d}: error = {reconstruction_error:.4f}, "
              f"compression = {compression_ratio:.3f}")

svd_regularization()
```

## Precision and Round-off Error Analysis

### Float32 vs Float64 Comparison

```python
def precision_comparison():
    """Compare numerical precision between float32 and float64"""
    
    # Test with different precision levels
    def test_precision(dtype_str):
        if dtype_str == 'float32':
            jax.config.update("jax_enable_x64", False)
        else:
            jax.config.update("jax_enable_x64", True)
        
        key = random.PRNGKey(42)
        
        # Create a moderately ill-conditioned problem
        n = 20
        U = random.orthogonal(key, n)
        singular_values = jnp.logspace(0, 6, n)  # Condition number ~1e6
        A = U @ jnp.diag(singular_values) @ U.T
        
        x_true = jnp.ones(n)
        b = A @ x_true
        
        # Solve system
        x_computed = jnp.linalg.solve(A, b)
        
        # Compute errors
        forward_error = jnp.linalg.norm(x_computed - x_true)
        backward_error = jnp.linalg.norm(A @ x_computed - b)
        
        return forward_error, backward_error, jnp.linalg.cond(A)
    
    print("Precision Comparison (Float32 vs Float64):")
    print("=" * 50)
    
    # Test float32
    forward_32, backward_32, cond = test_precision('float32')
    
    # Test float64  
    forward_64, backward_64, cond = test_precision('float64')
    
    print(f"Matrix condition number: {cond:.2e}")
    print(f"Float32 - Forward error: {forward_32:.2e}, Backward error: {backward_32:.2e}")
    print(f"Float64 - Forward error: {forward_64:.2e}, Backward error: {backward_64:.2e}")
    print(f"Improvement factor: {forward_32/forward_64:.1f}x")
    
    # Reset to float64
    jax.config.update("jax_enable_x64", True)

precision_comparison()
```

## Error Analysis and Bounds

### Forward and Backward Error Analysis

```python
def error_analysis():
    """Demonstrate forward vs backward error analysis"""
    
    key = random.PRNGKey(123)
    
    # Create test problems with different condition numbers
    condition_numbers = [1e2, 1e6, 1e10]
    
    print("Forward vs Backward Error Analysis:")
    print("=" * 40)
    
    for cond_target in condition_numbers:
        n = 15
        
        # Create matrix with target condition number
        U = random.orthogonal(key, n)
        singular_values = jnp.logspace(0, jnp.log10(cond_target), n)
        A = U @ jnp.diag(singular_values) @ U.T
        
        x_true = random.normal(key, (n,))
        b_exact = A @ x_true
        
        # Add small perturbation to simulate round-off
        perturbation = 1e-12 * jnp.linalg.norm(b_exact) * random.normal(random.split(key)[1], (n,))
        b_perturbed = b_exact + perturbation
        
        # Solve perturbed system
        x_computed = jnp.linalg.solve(A, b_perturbed)
        
        # Forward error: ||x_computed - x_true|| / ||x_true||
        forward_error = jnp.linalg.norm(x_computed - x_true) / jnp.linalg.norm(x_true)
        
        # Backward error: ||A*x_computed - b|| / ||b||
        backward_error = jnp.linalg.norm(A @ x_computed - b_exact) / jnp.linalg.norm(b_exact)
        
        # Theoretical bounds
        actual_cond = jnp.linalg.cond(A)
        data_perturbation = jnp.linalg.norm(perturbation) / jnp.linalg.norm(b_exact)
        theoretical_bound = actual_cond * data_perturbation
        
        print(f"κ = {actual_cond:.2e}")
        print(f"  Data perturbation: {data_perturbation:.2e}")
        print(f"  Forward error: {forward_error:.2e}")
        print(f"  Backward error: {backward_error:.2e}")
        print(f"  Theoretical bound: {theoretical_bound:.2e}")
        print(f"  Bound tightness: {forward_error/theoretical_bound:.2f}")
        print()

error_analysis()
```

## Iterative Refinement

### Improving Solution Accuracy

```python
def iterative_refinement():
    """Demonstrate iterative refinement for improved accuracy"""
    
    key = random.PRNGKey(456)
    
    # Create moderately ill-conditioned system
    n = 20
    A = random.normal(key, (n, n))
    A = A @ A.T + 1e-8 * jnp.eye(n)  # Make SPD but ill-conditioned
    
    x_true = random.normal(random.split(key)[1], (n,))
    b = A @ x_true
    
    print("Iterative Refinement:")
    print(f"Condition number: {jnp.linalg.cond(A):.2e}")
    
    # Initial solution
    x = jnp.linalg.solve(A, b)
    
    print(f"Initial error: {jnp.linalg.norm(x - x_true):.2e}")
    
    # Iterative refinement
    for iteration in range(5):
        # Compute residual
        r = b - A @ x
        
        # Solve correction equation A * δx = r
        delta_x = jnp.linalg.solve(A, r)
        
        # Update solution
        x = x + delta_x
        
        error = jnp.linalg.norm(x - x_true)
        residual_norm = jnp.linalg.norm(r)
        
        print(f"Iter {iteration+1}: error = {error:.2e}, residual = {residual_norm:.2e}")
        
        if residual_norm < 1e-14:
            break

iterative_refinement()
```

## Practical Stability Guidelines

### Choosing Stable Algorithms

```python
def algorithm_stability_guide():
    """Guidelines for choosing stable algorithms"""
    
    print("Algorithm Stability Guidelines:")
    print("=" * 35)
    
    scenarios = [
        ("Well-conditioned SPD system", "Cholesky > LU > QR"),
        ("Ill-conditioned system", "SVD with truncation > QR > LU"),
        ("Overdetermined system", "QR factorization > Normal equations"),
        ("Underdetermined system", "SVD > QR with pivoting"),
        ("Large sparse system", "Iterative methods with preconditioning"),
        ("Eigenvalue problems", "eigh for symmetric > eig for general"),
        ("Optimization", "L-BFGS > Newton > Gradient descent"),
    ]
    
    for scenario, recommendation in scenarios:
        print(f"{scenario:.<30} {recommendation}")
    
    print("\nNumerical Stability Checklist:")
    print("✓ Check condition numbers before solving")
    print("✓ Use appropriate precision (float64 for ill-conditioned problems)")
    print("✓ Consider regularization for ill-posed problems")
    print("✓ Implement stable algorithms (avoid normal equations)")
    print("✓ Monitor residuals and backward error")
    print("✓ Use iterative refinement when necessary")

algorithm_stability_guide()
```

## Summary

In this notebook, we've explored critical aspects of numerical stability:

**Key Concepts:**

1. **Condition Numbers**: Measure sensitivity to perturbations
2. **Stable Algorithms**: Methods that control error propagation  
3. **Regularization**: Techniques for ill-posed problems
4. **Precision Effects**: Float32 vs Float64 trade-offs
5. **Error Analysis**: Forward vs backward error measurement

**Stability Techniques:**
- Use SVD for ill-conditioned problems
- Implement regularization for under-determined systems
- Choose appropriate algorithms based on matrix properties
- Monitor condition numbers and residuals
- Apply iterative refinement when needed

**JAX-Specific Considerations:**
- Enable float64 precision for critical computations
- Use built-in stable functions (logsumexp, softmax)
- Leverage automatic differentiation for error analysis
- Consider numerical stability in custom VJP/JVP implementations

**Best Practices:**
- Always check condition numbers before solving
- Use backward error as a stability indicator
- Prefer QR over normal equations for least squares
- Apply regularization judiciously
- Validate results with multiple methods

**Next Steps:**
- The next notebook will cover neural network implementations
- We'll apply these stability concepts to ML algorithms
- Understanding numerical stability is crucial for reliable deep learning

Numerical stability is fundamental to reliable scientific computing and forms the foundation for robust machine learning algorithms.