# File: notebooks/02_linear_algebra/05_matrix_ops.ipynb

## JAX Linear Algebra: Matrix Operations

Welcome to the linear algebra section of the JAX-NSL series! This notebook covers fundamental matrix operations in JAX, including matrix multiplication, decompositions, eigenvalue problems, and advanced tensor operations. We'll explore both the high-level `jax.numpy.linalg` interface and lower-level operations for performance-critical applications.

Linear algebra is the backbone of scientific computing and machine learning. JAX's linear algebra operations are designed to be fast, differentiable, and compatible with JAX transformations like JIT compilation and vectorization.

## Setting Up the Environment

```python
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, lax
from jax import random
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, Optional
import functools
import time

# Enable 64-bit precision for numerical stability
jax.config.update("jax_enable_x64", True)

print(f"JAX version: {jax.__version__}")
print(f"Available devices: {jax.devices()}")
```

## Basic Matrix Operations

### Matrix Creation and Properties

```python
# Create various types of matrices
def create_test_matrices():
    """Create different types of matrices for testing"""
    
    key = random.PRNGKey(42)
    
    # Random matrices
    A_random = random.normal(key, (4, 4))
    B_random = random.normal(random.split(key)[1], (4, 3))
    
    # Symmetric matrix
    A_sym = A_random + A_random.T
    
    # Positive definite matrix
    A_pd = A_random @ A_random.T + jnp.eye(4)
    
    # Orthogonal matrix (via QR decomposition)
    Q, R = jnp.linalg.qr(A_random)
    A_orth = Q
    
    # Diagonal matrix
    A_diag = jnp.diag(jnp.array([1.0, 2.0, 3.0, 4.0]))
    
    matrices = {
        'random': A_random,
        'rectangular': B_random,
        'symmetric': A_sym,
        'positive_definite': A_pd,
        'orthogonal': A_orth,
        'diagonal': A_diag
    }
    
    return matrices

matrices = create_test_matrices()

# Examine properties
for name, mat in matrices.items():
    print(f"{name.capitalize()} Matrix:")
    print(f"  Shape: {mat.shape}")
    print(f"  Condition number: {jnp.linalg.cond(mat):.2f}")
    print(f"  Frobenius norm: {jnp.linalg.norm(mat, 'fro'):.2f}")
    print(f"  Determinant: {jnp.linalg.det(mat):.2e}" if mat.shape[0] == mat.shape[1] else "  N/A (not square)")
    print()
```

### Matrix Multiplication Variants

```python
# Different matrix multiplication approaches
def matrix_multiplication_demo():
    """Demonstrate different matrix multiplication methods"""
    
    key = random.PRNGKey(123)
    A = random.normal(key, (100, 50))
    B = random.normal(random.split(key)[1], (50, 80))
    C = random.normal(random.split(key, 3)[2], (80, 30))
    
    print("Matrix shapes: A(100,50), B(50,80), C(80,30)")
    
    # Basic matrix multiplication
    AB = A @ B  # or jnp.dot(A, B) or jnp.matmul(A, B)
    print(f"A @ B shape: {AB.shape}")
    
    # Chain multiplication
    ABC_left = (A @ B) @ C  # Left associative
    ABC_right = A @ (B @ C)  # Right associative
    print(f"Chain multiplication difference: {jnp.max(jnp.abs(ABC_left - ABC_right)):.2e}")
    
    # Batch matrix multiplication
    batch_A = random.normal(key, (5, 10, 10))
    batch_B = random.normal(random.split(key)[1], (5, 10, 10))
    batch_result = batch_A @ batch_B  # Broadcasts correctly
    print(f"Batch multiplication shape: {batch_result.shape}")
    
    # Using vmap for explicit batching
    single_matmul = lambda a, b: a @ b
    vmap_matmul = vmap(single_matmul)
    vmap_result = vmap_matmul(batch_A, batch_B)
    print(f"Vmap result matches: {jnp.allclose(batch_result, vmap_result)}")
    
    return A, B, C

A, B, C = matrix_multiplication_demo()
```

### Einstein Summation (Einsum)

```python
# Advanced tensor operations using einsum
def einsum_examples():
    """Demonstrate einsum for complex tensor operations"""
    
    key = random.PRNGKey(0)
    
    # Basic examples
    a = random.normal(key, (3,))
    b = random.normal(random.split(key)[1], (3,))
    A = random.normal(random.split(key, 3)[2], (3, 4))
    B = random.normal(random.split(key, 4)[3], (4, 5))
    
    print("Einsum Examples:")
    
    # Dot product: a · b
    dot1 = jnp.dot(a, b)
    dot2 = jnp.einsum('i,i->', a, b)
    print(f"Dot product: {dot1:.4f} == {dot2:.4f}")
    
    # Matrix-vector product: A @ a
    mv1 = A @ a
    mv2 = jnp.einsum('ij,j->i', A, a)
    print(f"Matrix-vector max diff: {jnp.max(jnp.abs(mv1 - mv2)):.2e}")
    
    # Matrix multiplication: A @ B
    mm1 = A @ B
    mm2 = jnp.einsum('ij,jk->ik', A, B)
    print(f"Matrix mult max diff: {jnp.max(jnp.abs(mm1 - mm2)):.2e}")
    
    # Trace
    trace1 = jnp.trace(A @ A.T)
    trace2 = jnp.einsum('ij,ji->', A, A.T)
    print(f"Trace: {trace1:.4f} == {trace2:.4f}")
    
    # More complex tensor operations
    T = random.normal(key, (3, 4, 5, 6))
    
    # Sum over specific axes
    sum_02 = jnp.einsum('ijkl->jl', T)  # Sum over axes 0 and 2
    print(f"Tensor sum shape: {sum_02.shape}")
    
    # Tensor contraction
    contraction = jnp.einsum('ijkl,jl->ik', T, B)
    print(f"Tensor contraction shape: {contraction.shape}")

einsum_examples()
```

## Matrix Decompositions

### QR Decomposition

```python
# QR decomposition and applications
def qr_decomposition_demo():
    """Demonstrate QR decomposition and applications"""
    
    key = random.PRNGKey(42)
    
    # Create test matrix
    A = random.normal(key, (6, 4))  # Tall matrix
    
    # QR decomposition
    Q, R = jnp.linalg.qr(A, mode='reduced')  # Reduced QR
    Q_full, R_full = jnp.linalg.qr(A, mode='complete')  # Full QR
    
    print("QR Decomposition:")
    print(f"A shape: {A.shape}")
    print(f"Q shape (reduced): {Q.shape}, R shape: {R.shape}")
    print(f"Q shape (full): {Q_full.shape}, R shape: {R_full.shape}")
    
    # Verify decomposition
    reconstruction_error = jnp.max(jnp.abs(A - Q @ R))
    print(f"Reconstruction error: {reconstruction_error:.2e}")
    
    # Verify Q is orthogonal
    orthogonality_error = jnp.max(jnp.abs(Q.T @ Q - jnp.eye(Q.shape[1])))
    print(f"Q orthogonality error: {orthogonality_error:.2e}")
    
    # Application: Solving least squares
    b = random.normal(random.split(key)[1], (6,))
    
    # Least squares solution using QR
    x_qr = jnp.linalg.solve(R, Q.T @ b)
    
    # Compare with direct solve
    x_direct = jnp.linalg.lstsq(A, b, rcond=None)[0]
    
    print(f"Least squares solution difference: {jnp.max(jnp.abs(x_qr - x_direct)):.2e}")
    
    return Q, R, A, b

Q, R, A, b = qr_decomposition_demo()
```

### SVD (Singular Value Decomposition)

```python
# Singular Value Decomposition
def svd_demo():
    """Demonstrate SVD and applications"""
    
    key = random.PRNGKey(123)
    
    # Create test matrix with known rank
    rank = 3
    U_true = random.normal(key, (5, rank))
    V_true = random.normal(random.split(key)[1], (4, rank))
    S_true = jnp.array([10.0, 5.0, 1.0])  # Singular values
    A = U_true @ jnp.diag(S_true) @ V_true.T
    
    print("SVD Analysis:")
    print(f"Original matrix shape: {A.shape}")
    print(f"True rank: {rank}")
    
    # Perform SVD
    U, S, Vt = jnp.linalg.svd(A, full_matrices=False)
    
    print(f"SVD shapes: U{U.shape}, S{S.shape}, Vt{Vt.shape}")
    print(f"Singular values: {S}")
    
    # Verify decomposition
    reconstruction = U @ jnp.diag(S) @ Vt
    reconstruction_error = jnp.max(jnp.abs(A - reconstruction))
    print(f"Reconstruction error: {reconstruction_error:.2e}")
    
    # Low-rank approximation
    for k in [1, 2, 3]:
        A_k = U[:, :k] @ jnp.diag(S[:k]) @ Vt[:k, :]
        error = jnp.linalg.norm(A - A_k, 'fro')
        compression_ratio = (k * (A.shape[0] + A.shape[1])) / (A.shape[0] * A.shape[1])
        print(f"Rank-{k} approximation: error={error:.4f}, compression={compression_ratio:.2f}")
    
    return U, S, Vt, A

U, S, Vt, A_svd = svd_demo()
```

### Eigenvalue Decomposition

```python
# Eigenvalue and eigenvector computation
def eigenvalue_demo():
    """Demonstrate eigenvalue computations"""
    
    key = random.PRNGKey(456)
    
    # Create symmetric positive definite matrix
    A_base = random.normal(key, (4, 4))
    A = A_base @ A_base.T + jnp.eye(4)  # SPD matrix
    
    print("Eigenvalue Analysis:")
    
    # Eigenvalues and eigenvectors
    eigvals, eigvecs = jnp.linalg.eigh(A)  # For symmetric matrices
    
    print(f"Matrix shape: {A.shape}")
    print(f"Eigenvalues: {eigvals}")
    print(f"Condition number: {eigvals[-1] / eigvals[0]:.2f}")
    
    # Verify eigen-decomposition
    reconstruction = eigvecs @ jnp.diag(eigvals) @ eigvecs.T
    reconstruction_error = jnp.max(jnp.abs(A - reconstruction))
    print(f"Eigendecomposition reconstruction error: {reconstruction_error:.2e}")
    
    # Verify individual eigenpairs
    for i in range(len(eigvals)):
        Av = A @ eigvecs[:, i]
        lv = eigvals[i] * eigvecs[:, i]
        error = jnp.max(jnp.abs(Av - lv))
        print(f"Eigenpair {i} error: {error:.2e}")
    
    # For general matrices, use eig (complex eigenvalues possible)
    B = random.normal(key, (3, 3))
    eigvals_general, eigvecs_general = jnp.linalg.eig(B)
    print(f"General matrix eigenvalues (can be complex): {eigvals_general}")
    
    return eigvals, eigvecs, A

eigvals, eigvecs, A_eigen = eigenvalue_demo()
```

## Advanced Matrix Operations

### Matrix Functions

```python
# Matrix functions: exp, log, sqrt, etc.
def matrix_functions_demo():
    """Demonstrate matrix functions"""
    
    key = random.PRNGKey(789)
    
    # Create a well-conditioned symmetric positive definite matrix
    A_base = random.normal(key, (3, 3))
    A = 0.1 * (A_base @ A_base.T) + jnp.eye(3)  # Small eigenvalues for stability
    
    print("Matrix Functions:")
    print(f"Matrix A condition number: {jnp.linalg.cond(A):.2f}")
    
    # Matrix exponential via eigendecomposition
    def matrix_exp_eigen(M):
        eigvals, eigvecs = jnp.linalg.eigh(M)
        return eigvecs @ jnp.diag(jnp.exp(eigvals)) @ eigvecs.T
    
    # Matrix square root
    def matrix_sqrt_eigen(M):
        eigvals, eigvecs = jnp.linalg.eigh(M)
        return eigvecs @ jnp.diag(jnp.sqrt(jnp.maximum(eigvals, 1e-12))) @ eigvecs.T
    
    # Matrix logarithm
    def matrix_log_eigen(M):
        eigvals, eigvecs = jnp.linalg.eigh(M)
        return eigvecs @ jnp.diag(jnp.log(jnp.maximum(eigvals, 1e-12))) @ eigvecs.T
    
    # Compute matrix functions
    A_exp = matrix_exp_eigen(A)
    A_sqrt = matrix_sqrt_eigen(A)
    A_log = matrix_log_eigen(A)
    
    print(f"Original matrix trace: {jnp.trace(A):.4f}")
    print(f"Matrix exp trace: {jnp.trace(A_exp):.4f}")
    print(f"Matrix sqrt trace: {jnp.trace(A_sqrt):.4f}")
    
    # Verify properties
    # (A^(1/2))^2 should equal A
    sqrt_squared_error = jnp.max(jnp.abs(A_sqrt @ A_sqrt - A))
    print(f"Matrix sqrt verification error: {sqrt_squared_error:.2e}")
    
    # exp(log(A)) should equal A
    exp_log_error = jnp.max(jnp.abs(matrix_exp_eigen(A_log) - A))
    print(f"exp(log(A)) verification error: {exp_log_error:.2e}")
    
    return A_exp, A_sqrt, A_log

A_exp, A_sqrt, A_log = matrix_functions_demo()
```

### Kronecker Products and Vectorization

```python
# Kronecker products and vectorization operations
def kronecker_demo():
    """Demonstrate Kronecker products and vec operations"""
    
    key = random.PRNGKey(0)
    
    A = random.normal(key, (2, 3))
    B = random.normal(random.split(key)[1], (4, 2))
    
    print("Kronecker Products:")
    print(f"A shape: {A.shape}, B shape: {B.shape}")
    
    # Kronecker product A ⊗ B
    kron_AB = jnp.kron(A, B)
    print(f"A ⊗ B shape: {kron_AB.shape}")
    
    # Properties of Kronecker products
    C = random.normal(random.split(key, 3)[2], (3, 2))
    D = random.normal(random.split(key, 4)[3], (2, 3))
    
    # (A ⊗ B)(C ⊗ D) = (AC) ⊗ (BD) when dimensions compatible
    if A.shape[1] == C.shape[0] and B.shape[1] == D.shape[0]:
        left_side = jnp.kron(A, B) @ jnp.kron(C, D)
        right_side = jnp.kron(A @ C, B @ D)
        kron_property_error = jnp.max(jnp.abs(left_side - right_side))
        print(f"Kronecker product property error: {kron_property_error:.2e}")
    
    # Vectorization
    X = random.normal(key, (3, 4))
    vec_X = X.flatten()  # or X.ravel()
    print(f"Matrix X shape: {X.shape}, vectorized shape: {vec_X.shape}")
    
    # Relationship: vec(AXB) = (B^T ⊗ A) vec(X)
    A_small = random.normal(key, (2, 3))
    B_small = random.normal(random.split(key)[1], (4, 5))
    AXB = A_small @ X @ B_small
    
    vec_AXB_direct = AXB.flatten()
    vec_AXB_kron = jnp.kron(B_small.T, A_small) @ vec_X
    
    vectorization_error = jnp.max(jnp.abs(vec_AXB_direct - vec_AXB_kron))
    print(f"Vectorization identity error: {vectorization_error:.2e}")
    
    return kron_AB

kron_result = kronecker_demo()
```

## Numerical Linear Algebra Considerations

### Conditioning and Stability

```python
# Numerical stability analysis
def stability_analysis():
    """Analyze numerical stability of matrix operations"""
    
    key = random.PRNGKey(42)
    
    print("Numerical Stability Analysis:")
    
    # Create ill-conditioned matrix
    n = 5
    U = random.orthogonal(key, n)
    singular_values = jnp.array([1e6, 1e3, 1e0, 1e-3, 1e-6])  # Wide range
    V = random.orthogonal(random.split(key)[1], n)
    
    A_ill = U @ jnp.diag(singular_values) @ V.T
    
    print(f"Condition number: {jnp.linalg.cond(A_ill):.2e}")
    
    # Test different solution methods
    b = random.normal(random.split(key, 3)[2], (n,))
    x_true = jnp.ones(n)  # Known solution
    b = A_ill @ x_true  # Consistent right-hand side
    
    # Method 1: Direct solve
    x_direct = jnp.linalg.solve(A_ill, b)
    error_direct = jnp.linalg.norm(x_direct - x_true)
    
    # Method 2: SVD-based solve with truncation
    U_svd, S_svd, Vt_svd = jnp.linalg.svd(A_ill)
    tolerance = 1e-10
    rank = jnp.sum(S_svd > tolerance)
    
    S_inv = jnp.where(S_svd > tolerance, 1.0 / S_svd, 0.0)
    x_svd = Vt_svd.T @ (S_inv[:, None] * (U_svd.T @ b[:, None]))[:, 0]
    error_svd = jnp.linalg.norm(x_svd - x_true)
    
    print(f"Effective rank: {rank}/{n}")
    print(f"Direct solve error: {error_direct:.2e}")
    print(f"SVD solve error: {error_svd:.2e}")
    
    # Perturbation analysis
    perturbation = 1e-10 * random.normal(key, A_ill.shape)
    A_perturbed = A_ill + perturbation
    
    x_perturbed = jnp.linalg.solve(A_perturbed, b)
    perturbation_effect = jnp.linalg.norm(x_perturbed - x_direct) / jnp.linalg.norm(perturbation)
    
    print(f"Perturbation amplification: {perturbation_effect:.2e}")

stability_analysis()
```

### Performance Benchmarking

```python
# Performance comparison of matrix operations
def performance_benchmark():
    """Benchmark different matrix operations"""
    
    print("Performance Benchmarking:")
    
    sizes = [64, 128, 256, 512, 1024]
    
    for n in sizes:
        key = random.PRNGKey(42)
        A = random.normal(key, (n, n))
        B = random.normal(random.split(key)[1], (n, n))
        
        # Matrix multiplication
        start = time.time()
        for _ in range(10):
            C = A @ B
        matmul_time = (time.time() - start) / 10
        
        # Eigenvalue decomposition
        A_sym = A @ A.T  # Make symmetric for stable eigh
        start = time.time()
        eigvals, eigvecs = jnp.linalg.eigh(A_sym)
        eigen_time = time.time() - start
        
        # SVD
        start = time.time()
        U, S, Vt = jnp.linalg.svd(A)
        svd_time = time.time() - start
        
        # QR decomposition
        start = time.time()
        Q, R = jnp.linalg.qr(A)
        qr_time = time.time() - start
        
        print(f"n={n:4d}: MatMul={matmul_time*1000:6.2f}ms, "
              f"Eigen={eigen_time*1000:6.2f}ms, "
              f"SVD={svd_time*1000:6.2f}ms, "
              f"QR={qr_time*1000:6.2f}ms")

# Run benchmark (may take a moment)
performance_benchmark()
```

## Specialized Matrix Operations

### Block Matrix Operations

```python
# Block matrix operations
def block_matrix_demo():
    """Demonstrate block matrix operations"""
    
    key = random.PRNGKey(0)
    
    # Create block matrices
    A11 = random.normal(key, (2, 2))
    A12 = random.normal(random.split(key, 2)[1], (2, 3))
    A21 = random.normal(random.split(key, 3)[2], (3, 2))
    A22 = random.normal(random.split(key, 4)[3], (3, 3))
    
    # Assemble block matrix
    A_top = jnp.concatenate([A11, A12], axis=1)
    A_bottom = jnp.concatenate([A21, A22], axis=1)
    A_block = jnp.concatenate([A_top, A_bottom], axis=0)
    
    print("Block Matrix Operations:")
    print(f"Block matrix shape: {A_block.shape}")
    print(f"A11 shape: {A11.shape}, A12 shape: {A12.shape}")
    print(f"A21 shape: {A21.shape}, A22 shape: {A22.shape}")
    
    # Block matrix inverse (for 2x2 block structure when A11 is invertible)
    def block_inverse_2x2(A11, A12, A21, A22):
        """Compute inverse of 2x2 block matrix using Schur complement"""
        A11_inv = jnp.linalg.inv(A11)
        schur = A22 - A21 @ A11_inv @ A12  # Schur complement
        schur_inv = jnp.linalg.inv(schur)
        
        # Compute blocks of inverse
        inv_11 = A11_inv + A11_inv @ A12 @ schur_inv @ A21 @ A11_inv
        inv_12 = -A11_inv @ A12 @ schur_inv
        inv_21 = -schur_inv @ A21 @ A11_inv
        inv_22 = schur_inv
        
        return inv_11, inv_12, inv_21, inv_22
    
    # Test block inverse
    if jnp.linalg.cond(A11) < 1e12 and jnp.linalg.cond(A22 - A21 @ jnp.linalg.inv(A11) @ A12) < 1e12:
        inv_11, inv_12, inv_21, inv_22 = block_inverse_2x2(A11, A12, A21, A22)
        
        # Assemble block inverse
        inv_top = jnp.concatenate([inv_11, inv_12], axis=1)
        inv_bottom = jnp.concatenate([inv_21, inv_22], axis=1)
        A_block_inv = jnp.concatenate([inv_top, inv_bottom], axis=0)
        
        # Verify inverse
        identity_error = jnp.max(jnp.abs(A_block @ A_block_inv - jnp.eye(A_block.shape[0])))
        print(f"Block inverse verification error: {identity_error:.2e}")
    
    return A_block

A_block = block_matrix_demo()
```

## Summary

In this notebook, we've explored fundamental matrix operations in JAX:

**Key Operations Covered:**

1. **Basic Operations**: Matrix multiplication, einsum, batch operations
2. **Decompositions**: QR, SVD, eigenvalue decomposition
3. **Matrix Functions**: Exponential, square root, logarithm via eigendecomposition
4. **Advanced Operations**: Kronecker products, vectorization, block matrices

**Numerical Considerations:**
- Condition number analysis for stability
- Perturbation sensitivity
- Choosing appropriate algorithms for different matrix properties
- Performance characteristics of different operations

**JAX-Specific Features:**
- Automatic differentiation through linear algebra operations
- JIT compilation for performance
- Vectorization with vmap for batch processing
- GPU/TPU acceleration

**Best Practices:**
- Use appropriate decompositions for matrix properties (eigh for symmetric, SVD for general)
- Consider numerical stability for ill-conditioned problems
- Leverage JAX transformations for performance
- Choose efficient algorithms based on matrix structure

**Next Steps:**
- The next notebook will cover iterative solvers
- We'll explore conjugate gradient, GMRES, and other iterative methods
- Understanding direct methods enables comparison with iterative approaches

Matrix operations form the computational backbone of scientific computing and machine learning. JAX's implementation provides both high performance and seamless integration with automatic differentiation, making it ideal for research and production applications.