# Location: notebooks/capstone_projects/20_physics_informed_nn.ipynb

## Physics-Informed Neural Networks (PINNs) in JAX

This capstone project implements Physics-Informed Neural Networks for solving partial differential equations (PDEs) by incorporating physical laws directly into the loss function using automatic differentiation.

## Introduction to PINNs

Physics-Informed Neural Networks combine the function approximation capabilities of neural networks with the physical constraints encoded in differential equations. This approach allows us to solve PDEs without requiring large datasets, instead using the physics itself as a regularizer.

```python
import jax
import jax.numpy as jnp
from jax import grad, vmap, jit, random
import matplotlib.pyplot as plt
from functools import partial
import time

# Basic PINN architecture
def create_pinn_model(layer_sizes, activation=jax.nn.tanh):
    """Create a Physics-Informed Neural Network"""
    
    def init_params(key):
        """Initialize network parameters"""
        keys = random.split(key, len(layer_sizes) - 1)
        params = []
        
        for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
            # Xavier initialization
            scale = jnp.sqrt(2.0 / (in_size + out_size))
            w = random.normal(keys[i], (in_size, out_size)) * scale
            b = jnp.zeros(out_size)
            params.append({'w': w, 'b': b})
        
        return params
    
    def forward(params, x):
        """Forward pass through the network"""
        h = x
        for i, layer in enumerate(params[:-1]):
            h = h @ layer['w'] + layer['b']
            h = activation(h)
        
        # Final layer (linear)
        final_layer = params[-1]
        output = h @ final_layer['w'] + final_layer['b']
        return output
    
    return init_params, forward

# Test basic PINN architecture
init_pinn, pinn_forward = create_pinn_model([2, 50, 50, 50, 1])
params = init_pinn(random.PRNGKey(42))

# Test forward pass
test_input = jnp.array([[0.5, 0.3], [1.0, 0.7]])
test_output = pinn_forward(params, test_input)

print("PINN Architecture Test:")
print(f"Input shape: {test_input.shape}")
print(f"Output shape: {test_output.shape}")
print(f"Number of parameters per layer: {[layer['w'].size + layer['b'].size for layer in params]}")
print(f"Total parameters: {sum(layer['w'].size + layer['b'].size for layer in params)}")
```

## Solving the 1D Heat Equation

```python
def create_heat_equation_pinn():
    """Solve 1D heat equation: ∂u/∂t = α ∂²u/∂x²"""
    
    # Initialize PINN for heat equation
    # Input: [x, t], Output: [u(x,t)]
    init_params, forward = create_pinn_model([2, 64, 64, 64, 1])
    
    def pinn_forward_with_derivatives(params, x, t):
        """Forward pass with automatic differentiation"""
        
        def u_net(x, t):
            """Network approximation of u(x,t)"""
            inputs = jnp.array([x, t])
            return forward(params, inputs.reshape(1, -1))[0, 0]
        
        # First derivatives
        u_t = grad(u_net, argnums=1)(x, t)  # ∂u/∂t
        u_x = grad(u_net, argnums=0)(x, t)  # ∂u/∂x
        
        # Second derivative
        u_xx = grad(grad(u_net, argnums=0), argnums=0)(x, t)  # ∂²u/∂x²
        
        u_val = u_net(x, t)
        
        return u_val, u_t, u_x, u_xx
    
    def heat_pde_residual(params, x, t, alpha=0.01):
        """Compute PDE residual: ∂u/∂t - α ∂²u/∂x²"""
        u_val, u_t, u_x, u_xx = pinn_forward_with_derivatives(params, x, t)
        residual = u_t - alpha * u_xx
        return residual
    
    def boundary_condition(params, t):
        """Boundary conditions: u(0,t) = u(1,t) = 0"""
        u_0, _, _, _ = pinn_forward_with_derivatives(params, 0.0, t)
        u_1, _, _, _ = pinn_forward_with_derivatives(params, 1.0, t)
        return u_0, u_1
    
    def initial_condition(params, x):
        """Initial condition: u(x,0) = sin(π*x)"""
        u_val, _, _, _ = pinn_forward_with_derivatives(params, x, 0.0)
        true_initial = jnp.sin(jnp.pi * x)
        return u_val - true_initial
    
    def pinn_loss(params, x_pde, t_pde, x_bc, t_bc, x_ic, 
                  lambda_pde=1.0, lambda_bc=1.0, lambda_ic=1.0):
        """Total PINN loss combining PDE, boundary, and initial conditions"""
        
        # PDE residual loss
        pde_residuals = vmap(lambda x, t: heat_pde_residual(params, x, t))(x_pde, t_pde)
        pde_loss = jnp.mean(pde_residuals**2)
        
        # Boundary condition loss
        bc_residuals = vmap(lambda t: boundary_condition(params, t))(t_bc)
        bc_loss = jnp.mean(bc_residuals[0]**2) + jnp.mean(bc_residuals[1]**2)
        
        # Initial condition loss  
        ic_residuals = vmap(lambda x: initial_condition(params, x))(x_ic)
        ic_loss = jnp.mean(ic_residuals**2)
        
        # Total loss
        total_loss = (lambda_pde * pde_loss + 
                     lambda_bc * bc_loss + 
                     lambda_ic * ic_loss)
        
        return total_loss, {
            'pde_loss': pde_loss,
            'bc_loss': bc_loss, 
            'ic_loss': ic_loss,
            'total_loss': total_loss
        }
    
    return init_params, pinn_forward_with_derivatives, pinn_loss

# Create heat equation PINN
init_heat_params, heat_forward, heat_loss_fn = create_heat_equation_pinn()
heat_params = init_heat_params(random.PRNGKey(123))

# Generate training data points
key = random.PRNGKey(456)
n_pde = 10000
n_bc = 100  
n_ic = 100

# PDE collocation points (interior)
x_pde = random.uniform(key, (n_pde,)) * 1.0  # x ∈ [0, 1]
t_pde = random.uniform(random.split(key)[0], (n_pde,)) * 0.5  # t ∈ [0, 0.5]

# Boundary points (x=0 and x=1 for all t)
t_bc = random.uniform(random.split(key)[1], (n_bc,)) * 0.5

# Initial condition points (t=0 for all x)
x_ic = random.uniform(random.split(key)[2], (n_ic,)) * 1.0

print("Heat Equation PINN Setup:")
print(f"PDE collocation points: {n_pde}")
print(f"Boundary condition points: {n_bc}")
print(f"Initial condition points: {n_ic}")

# Test loss computation
loss, info = heat_loss_fn(heat_params, x_pde, t_pde, x_ic, t_bc, x_ic)
print(f"\nInitial losses:")
print(f"Total loss: {loss:.6f}")
print(f"PDE loss: {info['pde_loss']:.6f}")
print(f"BC loss: {info['bc_loss']:.6f}")
print(f"IC loss: {info['ic_loss']:.6f}")
```

## Training the Heat Equation PINN

```python
def train_heat_pinn(init_params, loss_fn, training_data, n_epochs=5000, lr=1e-3):
    """Train the heat equation PINN"""
    
    x_pde, t_pde, x_bc, t_bc, x_ic = training_data
    
    # Initialize optimizer (Adam)
    def init_adam(params):
        return {
            'm': jax.tree_map(jnp.zeros_like, params),
            'v': jax.tree_map(jnp.zeros_like, params),
            'step': 0
        }
    
    def adam_update(params, grads, opt_state, lr=1e-3, beta1=0.9, beta2=0.999, eps=1e-8):
        step = opt_state['step'] + 1
        
        # Update biased first and second moment estimates
        m = jax.tree_map(lambda m, g: beta1 * m + (1 - beta1) * g, opt_state['m'], grads)
        v = jax.tree_map(lambda v, g: beta2 * v + (1 - beta2) * g**2, opt_state['v'], grads)
        
        # Bias correction
        m_hat = jax.tree_map(lambda m: m / (1 - beta1**step), m)
        v_hat = jax.tree_map(lambda v: v / (1 - beta2**step), v)
        
        # Parameter update
        new_params = jax.tree_map(
            lambda p, m, v: p - lr * m / (jnp.sqrt(v) + eps),
            params, m_hat, v_hat
        )
        
        new_opt_state = {'m': m, 'v': v, 'step': step}
        return new_params, new_opt_state
    
    @jit
    def train_step(params, opt_state):
        """Single training step"""
        loss, info = loss_fn(params, x_pde, t_pde, x_bc, t_bc, x_ic)
        grads = grad(lambda p: loss_fn(p, x_pde, t_pde, x_bc, t_bc, x_ic)[0])(params)
        
        new_params, new_opt_state = adam_update(params, grads, opt_state, lr)
        return new_params, new_opt_state, loss, info
    
    # Training loop
    params = init_params
    opt_state = init_adam(params)
    losses = []
    
    print("Training Heat Equation PINN:")
    print("Epoch | Total Loss | PDE Loss  | BC Loss   | IC Loss")
    print("-" * 55)
    
    for epoch in range(n_epochs):
        params, opt_state, loss, info = train_step(params, opt_state)
        losses.append(loss)
        
        if epoch % 500 == 0:
            print(f"{epoch:5d} | {loss:9.6f} | {info['pde_loss']:8.6f} | "
                  f"{info['bc_loss']:8.6f} | {info['ic_loss']:8.6f}")
    
    print(f"\nFinal loss: {losses[-1]:.6f}")
    return params, losses

# Train the PINN
training_data = (x_pde, t_pde, x_bc, t_bc, x_ic)
trained_params, training_losses = train_heat_pinn(
    heat_params, heat_loss_fn, training_data, n_epochs=2000, lr=1e-3
)

# Analytical solution for comparison
def analytical_heat_solution(x, t, alpha=0.01):
    """Analytical solution: u(x,t) = sin(π*x) * exp(-π²*α*t)"""
    return jnp.sin(jnp.pi * x) * jnp.exp(-jnp.pi**2 * alpha * t)

# Test trained PINN
test_x = jnp.linspace(0, 1, 50)
test_t = jnp.array([0.0, 0.1, 0.2, 0.3])

print("\nTesting Trained PINN:")
print("x    | t    | PINN     | Analytical | Error")
print("-" * 45)

for t_val in test_t:
    for i, x_val in enumerate(test_x[::10]):  # Sample every 10th point
        pinn_val, _, _, _ = heat_forward(trained_params, x_val, t_val)
        analytical_val = analytical_heat_solution(x_val, t_val)
        error = abs(pinn_val - analytical_val)
        
        print(f"{x_val:.1f} | {t_val:.1f} | {pinn_val:8.5f} | {analytical_val:8.5f} | {error:.5f}")
    if t_val < test_t[-1]:
        print("-" * 45)
```

## Solving the 2D Poisson Equation

```python
def create_poisson_equation_pinn():
    """Solve 2D Poisson equation: ∇²u = f(x,y)"""
    
    # Network: Input [x, y], Output [u(x,y)]
    init_params, forward = create_pinn_model([2, 100, 100, 100, 1])
    
    def poisson_pinn_forward(params, x, y):
        """Forward pass with second derivatives"""
        
        def u_net(x, y):
            inputs = jnp.array([x, y])
            return forward(params, inputs.reshape(1, -1))[0, 0]
        
        # First derivatives
        u_x = grad(u_net, argnums=0)(x, y)
        u_y = grad(u_net, argnums=1)(x, y)
        
        # Second derivatives
        u_xx = grad(grad(u_net, argnums=0), argnums=0)(x, y)
        u_yy = grad(grad(u_net, argnums=1), argnums=1)(x, y)
        
        u_val = u_net(x, y)
        laplacian = u_xx + u_yy  # ∇²u
        
        return u_val, u_x, u_y, laplacian
    
    def source_function(x, y):
        """Source term: f(x,y) = -2π² sin(πx) sin(πy)"""
        return -2 * jnp.pi**2 * jnp.sin(jnp.pi * x) * jnp.sin(jnp.pi * y)
    
    def poisson_pde_residual(params, x, y):
        """PDE residual: ∇²u - f(x,y)"""
        u_val, u_x, u_y, laplacian = poisson_pinn_forward(params, x, y)
        f_val = source_function(x, y)
        residual = laplacian - f_val
        return residual
    
    def dirichlet_bc(params, x, y):
        """Dirichlet boundary condition: u = 0 on boundary"""
        u_val, _, _, _ = poisson_pinn_forward(params, x, y)
        return u_val
    
    def poisson_loss(params, x_pde, y_pde, x_bc, y_bc, lambda_pde=1.0, lambda_bc=1.0):
        """Total loss for Poisson equation"""
        
        # PDE residual loss
        pde_residuals = vmap(lambda x, y: poisson_pde_residual(params, x, y))(x_pde, y_pde)
        pde_loss = jnp.mean(pde_residuals**2)
        
        # Boundary condition loss
        bc_residuals = vmap(lambda x, y: dirichlet_bc(params, x, y))(x_bc, y_bc)
        bc_loss = jnp.mean(bc_residuals**2)
        
        # Total loss
        total_loss = lambda_pde * pde_loss + lambda_bc * bc_loss
        
        return total_loss, {
            'pde_loss': pde_loss,
            'bc_loss': bc_loss,
            'total_loss': total_loss
        }
    
    return init_params, poisson_pinn_forward, poisson_loss

# Create Poisson PINN
init_poisson_params, poisson_forward, poisson_loss_fn = create_poisson_equation_pinn()
poisson_params = init_poisson_params(random.PRNGKey(789))

# Generate training data for 2D domain [0,1] × [0,1]
key = random.PRNGKey(101112)
n_pde_2d = 10000
n_bc_2d = 1000

# Interior points
x_pde_2d = random.uniform(key, (n_pde_2d,))
y_pde_2d = random.uniform(random.split(key)[0], (n_pde_2d,))

# Boundary points (edges of unit square)
def generate_boundary_points(key, n_points):
    """Generate points on boundary of unit square"""
    keys = random.split(key, 4)
    n_per_edge = n_points // 4
    
    # Bottom edge: y = 0
    x_bottom = random.uniform(keys[0], (n_per_edge,))
    y_bottom = jnp.zeros(n_per_edge)
    
    # Top edge: y = 1
    x_top = random.uniform(keys[1], (n_per_edge,))
    y_top = jnp.ones(n_per_edge)
    
    # Left edge: x = 0
    x_left = jnp.zeros(n_per_edge)
    y_left = random.uniform(keys[2], (n_per_edge,))
    
    # Right edge: x = 1
    x_right = jnp.ones(n_per_edge)
    y_right = random.uniform(keys[3], (n_per_edge,))
    
    x_bc = jnp.concatenate([x_bottom, x_top, x_left, x_right])
    y_bc = jnp.concatenate([y_bottom, y_top, y_left, y_right])
    
    return x_bc, y_bc

x_bc_2d, y_bc_2d = generate_boundary_points(random.split(key)[1], n_bc_2d)

# Test Poisson loss
poisson_loss, poisson_info = poisson_loss_fn(poisson_params, x_pde_2d, y_pde_2d, x_bc_2d, y_bc_2d)

print("2D Poisson Equation PINN Setup:")
print(f"Interior PDE points: {n_pde_2d}")
print(f"Boundary points: {len(x_bc_2d)}")
print(f"Initial total loss: {poisson_loss:.6f}")
print(f"Initial PDE loss: {poisson_info['pde_loss']:.6f}")
print(f"Initial BC loss: {poisson_info['bc_loss']:.6f}")

# Analytical solution for Poisson equation
def analytical_poisson_solution(x, y):
    """Analytical solution: u(x,y) = sin(πx) sin(πy)"""
    return jnp.sin(jnp.pi * x) * jnp.sin(jnp.pi * y)

# Quick training (fewer epochs for demonstration)
def train_poisson_pinn(params, loss_fn, x_pde, y_pde, x_bc, y_bc, n_epochs=1000):
    """Train Poisson PINN"""
    
    @jit
    def train_step(params, lr=1e-3):
        loss, info = loss_fn(params, x_pde, y_pde, x_bc, y_bc)
        grads = grad(lambda p: loss_fn(p, x_pde, y_pde, x_bc, y_bc)[0])(params)
        new_params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
        return new_params, loss, info
    
    print("\nTraining 2D Poisson PINN:")
    current_params = params
    
    for epoch in range(0, n_epochs, 200):  # Every 200 epochs
        current_params, loss, info = train_step(current_params)
        print(f"Epoch {epoch:4d}: Total={loss:.6f}, PDE={info['pde_loss']:.6f}, BC={info['bc_loss']:.6f}")
    
    return current_params

# Train Poisson PINN
trained_poisson_params = train_poisson_pinn(
    poisson_params, poisson_loss_fn, x_pde_2d, y_pde_2d, x_bc_2d, y_bc_2d, n_epochs=1000
)

# Test on grid
test_x_grid = jnp.linspace(0, 1, 11)
test_y_grid = jnp.linspace(0, 1, 11) 

print("\n2D Poisson Solution Comparison (sample points):")
print("x    | y    | PINN      | Analytical | Error")
print("-" * 50)

for i in range(0, len(test_x_grid), 2):  # Sample every other point
    for j in range(0, len(test_y_grid), 2):
        x_val, y_val = test_x_grid[i], test_y_grid[j]
        pinn_val, _, _, _ = poisson_forward(trained_poisson_params, x_val, y_val)
        analytical_val = analytical_poisson_solution(x_val, y_val)
        error = abs(pinn_val - analytical_val)
        
        print(f"{x_val:.1f} | {y_val:.1f} | {pinn_val:9.6f} | {analytical_val:9.6f} | {error:.6f}")
```

## Advanced PINN: Navier-Stokes Equations

```python
def create_navier_stokes_pinn():
    """Solve 2D incompressible Navier-Stokes equations"""
    
    # Network outputs: [u, v, p] (velocity components and pressure)
    init_params, forward = create_pinn_model([3, 128, 128, 128, 128, 3])  # Input: [x, y, t]
    
    def ns_forward(params, x, y, t):
        """Forward pass with all required derivatives"""
        
        def uvp_net(x, y, t):
            inputs = jnp.array([x, y, t])
            output = forward(params, inputs.reshape(1, -1))[0]
            return output[0], output[1], output[2]  # u, v, p
        
        # Velocity and pressure values
        u, v, p = uvp_net(x, y, t)
        
        # First derivatives of u
        u_x = grad(lambda x, y, t: uvp_net(x, y, t)[0], argnums=0)(x, y, t)
        u_y = grad(lambda x, y, t: uvp_net(x, y, t)[0], argnums=1)(x, y, t)
        u_t = grad(lambda x, y, t: uvp_net(x, y, t)[0], argnums=2)(x, y, t)
        
        # First derivatives of v
        v_x = grad(lambda x, y, t: uvp_net(x, y, t)[1], argnums=0)(x, y, t)
        v_y = grad(lambda x, y, t: uvp_net(x, y, t)[1], argnums=1)(x, y, t)
        v_t = grad(lambda x, y, t: uvp_net(x, y, t)[1], argnums=2)(x, y, t)
        
        # Pressure derivatives
        p_x = grad(lambda x, y, t: uvp_net(x, y, t)[2], argnums=0)(x, y, t)
        p_y = grad(lambda x, y, t: uvp_net(x, y, t)[2], argnums=1)(x, y, t)
        
        # Second derivatives of u
        u_xx = grad(grad(lambda x, y, t: uvp_net(x, y, t)[0], argnums=0), argnums=0)(x, y, t)
        u_yy = grad(grad(lambda x, y, t: uvp_net(x, y, t)[0], argnums=1), argnums=1)(x, y, t)
        
        # Second derivatives of v  
        v_xx = grad(grad(lambda x, y, t: uvp_net(x, y, t)[1], argnums=0), argnums=0)(x, y, t)
        v_yy = grad(grad(lambda x, y, t: uvp_net(x, y, t)[1], argnums=1), argnums=1)(x, y, t)
        
        return {
            'u': u, 'v': v, 'p': p,
            'u_x': u_x, 'u_y': u_y, 'u_t': u_t,
            'v_x': v_x, 'v_y': v_y, 'v_t': v_t,
            'p_x': p_x, 'p_y': p_y,
            'u_xx': u_xx, 'u_yy': u_yy,
            'v_xx': v_xx, 'v_yy': v_yy
        }
    
    def navier_stokes_residuals(params, x, y, t, Re=100.0):
        """Compute Navier-Stokes equation residuals"""
        
        derivatives = ns_forward(params, x, y, t)
        
        # Continuity equation: ∂u/∂x + ∂v/∂y = 0
        continuity = derivatives['u_x'] + derivatives['v_y']
        
        # Momentum equations
        # ∂u/∂t + u∂u/∂x + v∂u/∂y = -∂p/∂x + (1/Re)(∂²u/∂x² + ∂²u/∂y²)
        momentum_u = (derivatives['u_t'] + 
                     derivatives['u'] * derivatives['u_x'] + 
                     derivatives['v'] * derivatives['u_y'] +
                     derivatives['p_x'] -
                     (1/Re) * (derivatives['u_xx'] + derivatives['u_yy']))
        
        # ∂v/∂t + u∂v/∂x + v∂v/∂y = -∂p/∂y + (1/Re)(∂²v/∂x² + ∂²v/∂y²)  
        momentum_v = (derivatives['v_t'] +
                     derivatives['u'] * derivatives['v_x'] +
                     derivatives['v'] * derivatives['v_y'] +
                     derivatives['p_y'] -
                     (1/Re) * (derivatives['v_xx'] + derivatives['v_yy']))
        
        return continuity, momentum_u, momentum_v
    
    def ns_loss(params, x_pde, y_pde, t_pde, x_bc, y_bc, t_bc, u_bc, v_bc,
                lambda_pde=1.0, lambda_bc=1.0):
        """Navier-Stokes PINN loss"""
        
        # PDE residuals
        def compute_residuals(x, y, t):
            return navier_stokes_residuals(params, x, y, t)
        
        residuals = vmap(compute_residuals)(x_pde, y_pde, t_pde)
        continuity_loss = jnp.mean(residuals[0]**2)
        momentum_u_loss = jnp.mean(residuals[1]**2) 
        momentum_v_loss = jnp.mean(residuals[2]**2)
        pde_loss = continuity_loss + momentum_u_loss + momentum_v_loss
        
        # Boundary condition loss
        def compute_bc_residuals(x, y, t, u_true, v_true):
            derivatives = ns_forward(params, x, y, t)
            u_pred, v_pred = derivatives['u'], derivatives['v']
            return (u_pred - u_true)**2 + (v_pred - v_true)**2
        
        bc_residuals = vmap(compute_bc_residuals)(x_bc, y_bc, t_bc, u_bc, v_bc)
        bc_loss = jnp.mean(bc_residuals)
        
        total_loss = lambda_pde * pde_loss + lambda_bc * bc_loss
        
        return total_loss, {
            'pde_loss': pde_loss,
            'continuity_loss': continuity_loss,
            'momentum_u_loss': momentum_u_loss,
            'momentum_v_loss': momentum_v_loss,
            'bc_loss': bc_loss,
            'total_loss': total_loss
        }
    
    return init_params, ns_forward, ns_loss

# Note: Navier-Stokes PINN is computationally intensive
# This is a simplified demonstration setup

print("\nNavier-Stokes PINN Setup:")
init_ns_params, ns_forward_fn, ns_loss_fn = create_navier_stokes_pinn()
ns_params = init_ns_params(random.PRNGKey(131415))

# Generate minimal training data for demonstration
key = random.PRNGKey(161718)
n_pde_ns = 1000  # Reduced for demonstration
n_bc_ns = 100

# Domain: [0,1] × [0,1] × [0,0.1]
x_pde_ns = random.uniform(key, (n_pde_ns,))
y_pde_ns = random.uniform(random.split(key)[0], (n_pde_ns,))
t_pde_ns = random.uniform(random.split(key)[1], (n_pde_ns,)) * 0.1

# Boundary conditions (simplified: no-slip walls)
x_bc_ns = random.uniform(random.split(key)[2], (n_bc_ns,))
y_bc_ns = jnp.zeros(n_bc_ns)  # Bottom wall
t_bc_ns = random.uniform(random.split(key)[3], (n_bc_ns,)) * 0.1
u_bc_ns = jnp.zeros(n_bc_ns)  # No-slip condition
v_bc_ns = jnp.zeros(n_bc_ns)

# Test NS loss computation
ns_loss, ns_info = ns_loss_fn(ns_params, x_pde_ns, y_pde_ns, t_pde_ns,
                              x_bc_ns, y_bc_ns, t_bc_ns, u_bc_ns, v_bc_ns)

print(f"Initial NS losses:")
print(f"Total loss: {ns_loss:.6f}")
print(f"PDE loss: {ns_info['pde_loss']:.6f}")
print(f"Continuity loss: {ns_info['continuity_loss']:.6f}")
print(f"Momentum U loss: {ns_info['momentum_u_loss']:.6f}")  
print(f"Momentum V loss: {ns_info['momentum_v_loss']:.6f}")
print(f"BC loss: {ns_info['bc_loss']:.6f}")

# Test forward pass
test_derivatives = ns_forward_fn(ns_params, 0.5, 0.5, 0.05)
print(f"\nSample derivatives at (0.5, 0.5, 0.05):")
print(f"u: {test_derivatives['u']:.6f}, v: {test_derivatives['v']:.6f}, p: {test_derivatives['p']:.6f}")
```

## PINN with Adaptive Weights

```python
def create_adaptive_weight_pinn():
    """PINN with adaptive loss weighting for better convergence"""
    
    def adaptive_loss_weights(losses_history, method='grad_norm'):
        """Compute adaptive weights based on loss history"""
        if len(losses_history) < 2:
            return {'pde': 1.0, 'bc': 1.0, 'ic': 1.0}
        
        recent_losses = losses_history[-1]
        
        if method == 'grad_norm':
            # Weight inversely proportional to gradient magnitudes
            pde_loss, bc_loss, ic_loss = recent_losses['pde_loss'], recent_losses['bc_loss'], recent_losses['ic_loss']
            
            # Avoid division by zero
            pde_weight = 1.0 / (pde_loss + 1e-8)
            bc_weight = 1.0 / (bc_loss + 1e-8) 
            ic_weight = 1.0 / (ic_loss + 1e-8)
            
            # Normalize weights
            total_weight = pde_weight + bc_weight + ic_weight
            return {
                'pde': pde_weight / total_weight * 3,  # Scale to maintain total ~3
                'bc': bc_weight / total_weight * 3,
                'ic': ic_weight / total_weight * 3
            }
        
        elif method == 'max_loss':
            # Give higher weight to the largest loss component
            losses = [recent_losses['pde_loss'], recent_losses['bc_loss'], recent_losses['ic_loss']]
            max_loss = max(losses)
            
            return {
                'pde': max_loss / (recent_losses['pde_loss'] + 1e-8),
                'bc': max_loss / (recent_losses['bc_loss'] + 1e-8),
                'ic': max_loss / (recent_losses['ic_loss'] + 1e-8)
            }
    
    def train_with_adaptive_weights(init_params, loss_fn, training_data, 
                                   n_epochs=3000, lr=1e-3, adapt_freq=100):
        """Train PINN with adaptive weight adjustment"""
        
        x_pde, t_pde, x_bc, t_bc, x_ic = training_data
        
        # Initialize
        params = init_params
        opt_state = {'m': jax.tree_map(jnp.zeros_like, params),
                    'v': jax.tree_map(jnp.zeros_like, params),
                    'step': 0}
        
        losses_history = []
        weight_history = []
        
        @jit
        def train_step_adaptive(params, opt_state, lambda_pde, lambda_bc, lambda_ic):
            def loss_fn_weighted(p):
                return loss_fn(p, x_pde, t_pde, x_bc, t_bc, x_ic, 
                              lambda_pde, lambda_bc, lambda_ic)[0]
            
            loss, grads = jax.value_and_grad(loss_fn_weighted)(params)
            
            # Adam update
            step = opt_state['step'] + 1
            m = jax.tree_map(lambda m, g: 0.9 * m + 0.1 * g, opt_state['m'], grads)
            v = jax.tree_map(lambda v, g: 0.999 * v + 0.001 * g**2, opt_state['v'], grads)
            
            m_hat = jax.tree_map(lambda m: m / (1 - 0.9**step), m)
            v_hat = jax.tree_map(lambda v: v / (1 - 0.999**step), v)
            
            new_params = jax.tree_map(
                lambda p, m, v: p - lr * m / (jnp.sqrt(v) + 1e-8),
                params, m_hat, v_hat
            )
            
            new_opt_state = {'m': m, 'v': v, 'step': step}
            return new_params, new_opt_state, loss
        
        # Initial weights
        weights = {'pde': 1.0, 'bc': 1.0, 'ic': 1.0}
        
        print("Training with Adaptive Weights:")
        print("Epoch | Total Loss | PDE λ | BC λ  | IC λ")
        print("-" * 40)
        
        for epoch in range(n_epochs):
            # Training step
            params, opt_state, total_loss = train_step_adaptive(
                params, opt_state, weights['pde'], weights['bc'], weights['ic']
            )
            
            # Record losses for adaptation
            if epoch % 50 == 0:  # Less frequent loss computation
                _, info = loss_fn(params, x_pde, t_pde, x_bc, t_bc, x_ic)
                losses_history.append(info)
                weight_history.append(weights.copy())
                
                # Adapt weights
                if epoch > 0 and epoch % adapt_freq == 0:
                    weights = adaptive_loss_weights(losses_history)
                
                if epoch % 500 == 0:
                    print(f"{epoch:5d} | {total_loss:10.6f} | {weights['pde']:5.2f} | "
                          f"{weights['bc']:5.2f} | {weights['ic']:5.2f}")
        
        return params, losses_history, weight_history
    
    return train_with_adaptive_weights, adaptive_loss_weights

# Test adaptive weight training
adaptive_train_fn, adaptive_weights_fn = create_adaptive_weight_pinn()

print("\n" + "="*50)
print("ADAPTIVE WEIGHT PINN TRAINING")
print("="*50)

# Use smaller problem for demonstration
small_heat_params = init_heat_params(random.PRNGKey(192021))
small_training_data = (x_pde[:1000], t_pde[:1000], x_bc[:50], t_bc[:50], x_ic[:50])

# Train with adaptive weights
adaptive_params, adaptive_losses, adaptive_weights = adaptive_train_fn(
    small_heat_params, heat_loss_fn, small_training_data, n_epochs=1000, adapt_freq=200
)

print(f"\nFinal adaptive weights:")
final_weights = adaptive_weights[-1]
for name, weight in final_weights.items():
    print(f"  {name}: {weight:.3f}")
```

## Multi-Scale PINN for Complex Domains

```python
def create_multiscale_pinn():
    """Multi-scale PINN for problems with multiple length/time scales"""
    
    def multiscale_encoding(x, t, scales=[1, 2, 4, 8]):
        """Multi-scale Fourier feature encoding"""
        features = []
        
        for scale in scales:
            # Fourier features for x
            features.extend([
                jnp.sin(2 * jnp.pi * scale * x),
                jnp.cos(2 * jnp.pi * scale * x)
            ])
            
            # Fourier features for t  
            features.extend([
                jnp.sin(2 * jnp.pi * scale * t),
                jnp.cos(2 * jnp.pi * scale * t)
            ])
        
        return jnp.array(features)
    
    def create_multiscale_network(base_dim=2, n_scales=4, hidden_dims=[128, 128, 128]):
        """Create network with multiscale inputs"""
        
        # Input dimension: base_dim + 4 * n_scales (sin/cos for each variable and scale)
        input_dim = base_dim + 4 * n_scales
        layer_sizes = [input_dim] + hidden_dims + [1]
        
        init_params, forward = create_pinn_model(layer_sizes)
        
        def multiscale_forward(params, x, t):
            # Create multiscale features
            ms_features = multiscale_encoding(x, t)
            
            # Concatenate with original inputs
            full_input = jnp.concatenate([jnp.array([x, t]), ms_features])
            
            # Forward pass
            return forward(params, full_input.reshape(1, -1))[0, 0]
        
        return init_params, multiscale_forward
    
    def create_domain_decomposition_pinn(n_subdomains=4):
        """Domain decomposition PINN"""
        
        # Create separate networks for each subdomain
        subdomain_networks = []
        for i in range(n_subdomains):
            init_fn, forward_fn = create_pinn_model([2, 64, 64, 1])
            subdomain_networks.append((init_fn, forward_fn))
        
        def init_dd_params(key):
            keys = random.split(key, n_subdomains)
            return [init_fn(k) for (init_fn, _), k in zip(subdomain_networks, keys)]
        
        def dd_forward(all_params, x, t):
            """Forward pass with domain decomposition"""
            # Determine subdomain based on spatial location
            subdomain_idx = int(x * n_subdomains)
            subdomain_idx = jnp.clip(subdomain_idx, 0, n_subdomains - 1)
            
            # Use appropriate subdomain network
            params = all_params[subdomain_idx]
            _, forward_fn = subdomain_networks[subdomain_idx]
            
            return forward_fn(params, jnp.array([[x, t]]))[0, 0]
        
        def interface_continuity_loss(all_params, interfaces):
            """Enforce continuity at subdomain interfaces"""
            continuity_loss = 0.0
            
            for i in range(len(interfaces)):
                x_interface = interfaces[i]
                
                # Values from left and right subdomains
                if i < n_subdomains - 1:
                    left_val = all_params[i]  # Simplified
                    right_val = all_params[i + 1]
                    
                    # This would compute actual interface values
                    # continuity_loss += (left_val - right_val)**2
            
            return continuity_loss
        
        return init_dd_params, dd_forward, interface_continuity_loss
    
    return multiscale_encoding, create_multiscale_network, create_domain_decomposition_pinn

# Test multiscale features
ms_encoding, ms_network, dd_pinn = create_multiscale_pinn()

# Test multiscale encoding
x_test, t_test = 0.3, 0.1
ms_features = ms_encoding(x_test, t_test)
print("Multiscale Features:")
print(f"Input: x={x_test}, t={t_test}")
print(f"Multiscale features shape: {ms_features.shape}")
print(f"Feature sample: {ms_features[:8]}")  # First 8 features

# Create multiscale network
init_ms, forward_ms = ms_network()
ms_params = init_ms(random.PRNGKey(222324))

ms_output = forward_ms(ms_params, x_test, t_test)
print(f"Multiscale network output: {ms_output:.6f}")

# Test domain decomposition
init_dd, forward_dd, interface_loss = dd_pinn(n_subdomains=2)
dd_params = init_dd(random.PRNGKey(252627))

dd_output = forward_dd(dd_params, x_test, t_test)
print(f"Domain decomposition output: {dd_output:.6f}")
```

## PINN Performance Analysis and Validation

```python
def create_pinn_analysis_tools():
    """Tools for analyzing and validating PINN performance"""
    
    def compute_error_metrics(pinn_solution, analytical_solution, domain_points):
        """Compute various error metrics"""
        
        # Evaluate solutions at domain points
        pinn_vals = jnp.array([pinn_solution(x, t) for x, t in domain_points])
        analytical_vals = jnp.array([analytical_solution(x, t) for x, t in domain_points])
        
        # Error metrics
        absolute_error = jnp.abs(pinn_vals - analytical_vals)
        relative_error = absolute_error / (jnp.abs(analytical_vals) + 1e-8)
        
        metrics = {
            'l1_error': jnp.mean(absolute_error),
            'l2_error': jnp.sqrt(jnp.mean(absolute_error**2)),
            'linf_error': jnp.max(absolute_error),
            'mean_relative_error': jnp.mean(relative_error),
            'max_relative_error': jnp.max(relative_error)
        }
        
        return metrics, absolute_error, relative_error
    
    def conservation_analysis(pinn_forward_fn, params, domain_points):
        """Analyze conservation properties"""
        
        def compute_conservation_quantities(x, t):
            # This would compute conserved quantities specific to the PDE
            # For heat equation: total energy
            u_val, _, _, _ = pinn_forward_fn(params, x, t)
            return u_val
        
        conservation_vals = jnp.array([compute_conservation_quantities(x, t) 
                                     for x, t in domain_points])
        
        # Analyze conservation over time
        time_groups = {}
        for i, (x, t) in enumerate(domain_points):
            t_key = round(float(t), 3)
            if t_key not in time_groups:
                time_groups[t_key] = []
            time_groups[t_key].append(conservation_vals[i])
        
        conservation_analysis = {}
        for t_key, vals in time_groups.items():
            conservation_analysis[t_key] = {
                'mean': jnp.mean(jnp.array(vals)),
                'std': jnp.std(jnp.array(vals)),
                'total': jnp.sum(jnp.array(vals))
            }
        
        return conservation_analysis
    
    def residual_analysis(pinn_residual_fn, params, test_points):
        """Analyze PDE residuals across domain"""
        
        residuals = jnp.array([pinn_residual_fn(params, x, t) for x, t in test_points])
        
        residual_stats = {
            'mean_residual': jnp.mean(jnp.abs(residuals)),
            'max_residual': jnp.max(jnp.abs(residuals)),
            'residual_std': jnp.std(residuals),
            'residuals': residuals
        }
        
        return residual_stats
    
    def boundary_condition_analysis(pinn_forward_fn, params, boundary_points, true_bc_values):
        """Analyze boundary condition satisfaction"""
        
        predicted_bc = jnp.array([pinn_forward_fn(params, x, t)[0] for x, t in boundary_points])
        
        bc_errors = jnp.abs(predicted_bc - true_bc_values)
        
        bc_analysis = {
            'mean_bc_error': jnp.mean(bc_errors),
            'max_bc_error': jnp.max(bc_errors),
            'bc_satisfaction_rate': jnp.mean(bc_errors < 1e-3)  # Threshold for "satisfied"
        }
        
        return bc_analysis
    
    return compute_error_metrics, conservation_analysis, residual_analysis, boundary_condition_analysis

# Comprehensive PINN validation
error_metrics_fn, conservation_fn, residual_fn, bc_analysis_fn = create_pinn_analysis_tools()

# Create validation dataset
validation_x = jnp.linspace(0, 1, 21)
validation_t = jnp.linspace(0, 0.3, 16)
validation_points = [(x, t) for x in validation_x for t in validation_t]

print("\n" + "="*50)
print("PINN VALIDATION ANALYSIS")
print("="*50)

# Define PINN solution function
def pinn_solution(x, t):
    val, _, _, _ = heat_forward(trained_params, x, t)
    return val

# Compute error metrics
metrics, abs_errors, rel_errors = error_metrics_fn(
    pinn_solution, analytical_heat_solution, validation_points
)

print("Error Metrics vs Analytical Solution:")
for name, value in metrics.items():
    print(f"  {name}: {value:.6f}")

# Residual analysis
def heat_residual(params, x, t):
    return heat_pde_residual(params, x, t, alpha=0.01)

residual_stats = residual_fn(heat_residual, trained_params, validation_points)
print(f"\nPDE Residual Analysis:")
print(f"  Mean residual: {residual_stats['mean_residual']:.6f}")
print(f"  Max residual: {residual_stats['max_residual']:.6f}")
print(f"  Residual std: {residual_stats['residual_std']:.6f}")

# Boundary condition analysis
boundary_points_val = [(0.0, t) for t in validation_t] + [(1.0, t) for t in validation_t]
boundary_true_vals = jnp.zeros(len(boundary_points_val))  # u=0 at boundaries

bc_stats = bc_analysis_fn(heat_forward, trained_params, boundary_points_val, boundary_true_vals)
print(f"\nBoundary Condition Analysis:")
print(f"  Mean BC error: {bc_stats['mean_bc_error']:.6f}")
print(f"  Max BC error: {bc_stats['max_bc_error']:.6f}")
print(f"  BC satisfaction rate: {bc_stats['bc_satisfaction_rate']:.3f}")

# Summary statistics
print(f"\nSUMMARY:")
print(f"  Overall L2 error: {metrics['l2_error']:.6f}")
print(f"  Physics violation (residual): {residual_stats['mean_residual']:.6f}")
print(f"  Boundary compliance: {bc_stats['bc_satisfaction_rate']*100:.1f}%")

if metrics['l2_error'] < 0.01 and residual_stats['mean_residual'] < 0.1:
    print("  ✓ PINN solution appears accurate and physically consistent")
else:
    print("  ⚠ PINN may need further training or architecture adjustment")
```

## Summary

This capstone project demonstrated Physics-Informed Neural Networks (PINNs) in JAX:

**Core PINN Concepts:**
- **Physics Integration**: Incorporating PDEs directly into loss functions
- **Automatic Differentiation**: Computing derivatives for PDE residuals
- **Multi-Objective Optimization**: Balancing PDE, boundary, and initial conditions
- **Collocation Methods**: Using scattered points instead of structured grids

**Implemented Solutions:**
- **1D Heat Equation**: Parabolic PDE with analytical validation
- **2D Poisson Equation**: Elliptic PDE with Dirichlet boundaries
- **Navier-Stokes**: Complex fluid dynamics system
- **Adaptive Weighting**: Dynamic loss component balancing

**Advanced Techniques:**
- **Multi-Scale Encoding**: Fourier features for multiple scales
- **Domain Decomposition**: Splitting complex domains
- **Conservation Analysis**: Verifying physical consistency
- **Error Metrics**: Comprehensive validation tools

**Key Advantages:**
- **Mesh-Free**: No grid discretization required
- **Flexible Boundaries**: Handle complex geometries easily
- **Data-Efficient**: Physics provides regularization
- **Inverse Problems**: Can solve for unknown parameters

**Applications:**
- **Engineering**: Heat transfer, fluid flow, structural mechanics
- **Physics**: Wave propagation, quantum mechanics, plasma physics
- **Finance**: Option pricing, risk modeling
- **Biology**: Population dynamics, epidemiology

**Best Practices:**
- Use adaptive loss weighting for balanced training
- Employ multi-scale features for complex phenomena
- Validate against analytical solutions when available
- Monitor PDE residuals and conservation laws
- Consider domain decomposition for large problems

PINNs represent a powerful paradigm shift in computational physics, enabling the solution of PDEs through machine learning while respecting fundamental physical principles.