# Tutorial 1: Training a PINN on the Poisson Equation

**Learning Objectives:**
- Understand what Physics-Informed Neural Networks (PINNs) are
- Learn how PINNs solve partial differential equations (PDEs)
- Train a PINN to solve the 2D Poisson equation
- Visualize and validate the solution
- Extract activations for interpretability analysis

**Estimated Time:** 30-40 minutes

---

## Table of Contents
1. [Introduction to PINNs](#1-Introduction-to-PINNs)
2. [The Poisson Equation](#2-The-Poisson-Equation)
3. [Setup and Imports](#3-Setup-and-Imports)
4. [Creating the PINN Model](#4-Creating-the-PINN-Model)
5. [Defining the Problem](#5-Defining-the-Problem)
6. [Training the PINN](#6-Training-the-PINN)
7. [Visualizing Results](#7-Visualizing-Results)
8. [Extracting Activations](#8-Extracting-Activations)
9. [Summary and Next Steps](#9-Summary-and-Next-Steps)

## 1. Introduction to PINNs

### What are Physics-Informed Neural Networks?

**Physics-Informed Neural Networks (PINNs)** are neural networks that learn to solve partial differential equations (PDEs) by incorporating the physics of the problem directly into the loss function.

**Traditional approach:**
- Discretize the domain (finite differences, finite elements)
- Solve a large system of equations
- Requires mesh generation, can be complex

**PINN approach:**
- Neural network approximates the solution: $u(x) \approx NN(x)$
- Loss function combines:
  - **PDE residual**: How well does the network satisfy the PDE?
  - **Boundary conditions**: Does it match boundary values?
  - **Initial conditions**: Does it match initial state (for time-dependent PDEs)?
- Meshfree, flexible, can handle complex geometries

**Key Advantage:** Automatic differentiation provides exact derivatives for free!

### The PINN Training Process

```
1. Sample collocation points (x, y) in the domain
2. Forward pass: u = NN(x, y)
3. Compute derivatives: âˆ‚u/âˆ‚x, âˆ‚Â²u/âˆ‚xÂ², etc. (via autograd)
4. Evaluate PDE residual: N[u] = 0
5. Compute loss: L = w_pde * ||N[u]||Â² + w_bc * ||u - u_bc||Â²
6. Backpropagate and update weights
7. Repeat until convergence
```

## 2. The Poisson Equation

### Problem Statement

The **2D Poisson equation** is an elliptic PDE that appears in many physics problems:

$$\nabla^2 u = f(x, y) \quad \text{on } \Omega = [0,1]^2$$

with boundary conditions:

$$u(x, y) = g(x, y) \quad \text{on } \partial\Omega$$

where:
- $\nabla^2 u = \frac{\partial^2 u}{\partial x^2} + \frac{\partial^2 u}{\partial y^2}$ is the Laplacian
- $f(x, y)$ is the source term
- $g(x, y)$ specifies boundary values

### Our Test Case

We'll use a **manufactured solution** approach:

**Analytical solution:** $u(x, y) = \sin(\pi x) \sin(\pi y)$

**Source term:** $f(x, y) = -2\pi^2 \sin(\pi x) \sin(\pi y)$

**Boundary conditions:** $u = 0$ on all boundaries (Dirichlet)

This allows us to:
1. Verify PDE is satisfied: $\nabla^2 u = f$
2. Compute exact error: $||u_{PINN} - u_{analytical}||$
3. Validate our implementation

## 3. Setup and Imports

First, let's import the necessary libraries and check our environment.

In [None]:
# Standard imports
import sys
import os

# Add project root to path (adjust if needed)
# If running from notebooks/, need to go up one level
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Scientific computing
import numpy as np
import matplotlib.pyplot as plt
import torch

# Our PINN modules
from src.models import MLP
from src.problems import PoissonProblem
from src.training import train_pinn, PINNTrainer
from src.interpretability import extract_activations_from_model

# Check PyTorch version and device
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("\nâœ“ All imports successful!")

## 4. Creating the PINN Model

### MLP Architecture

Our PINN is a standard **Multi-Layer Perceptron (MLP)** with:
- **Input**: $(x, y)$ coordinates (2D)
- **Hidden layers**: 4 layers Ã— 64 neurons each
- **Activation**: $\tanh$ (smooth, bounded)
- **Output**: Predicted solution $u(x, y)$ (1D)

The architecture looks like:
```
Input (2) â†’ Linear(64) â†’ tanh â†’ Linear(64) â†’ tanh â†’ Linear(64) â†’ tanh â†’ Linear(64) â†’ tanh â†’ Linear(1)
```

**Why tanh?**
- Smooth and differentiable (important for computing derivatives)
- Bounded output helps with training stability
- Commonly used in PINNs literature

In [None]:
# Create the MLP model
model = MLP(
    input_dim=2,                    # (x, y) coordinates
    hidden_dims=[64, 64, 64, 64],   # 4 hidden layers, 64 neurons each
    output_dim=1,                   # Scalar solution u(x, y)
    activation="tanh"               # Smooth activation function
)

# Print model summary
print(model)
print(f"\nTotal parameters: {model.get_parameters_count():,}")

# Verify activation extraction works (for interpretability)
test_input = torch.randn(10, 2)  # Batch of 10 points
test_output = model(test_input)
activations = model.get_activations()

print(f"\nActivation extraction test:")
print(f"  Input shape: {test_input.shape}")
print(f"  Output shape: {test_output.shape}")
print(f"  Captured activations: {list(activations.keys())}")
print(f"  Layer 0 shape: {activations['layer_0'].shape}")

## 5. Defining the Problem

Now let's create the `PoissonProblem` instance, which provides:
- Analytical solution for validation
- Source term $f(x, y)$
- Boundary conditions
- Collocation point sampling
- PDE residual computation

In [None]:
# Create problem instance
problem = PoissonProblem(
    domain_bounds=[(0.0, 1.0), (0.0, 1.0)]  # [0,1] Ã— [0,1] square domain
)

print(f"Problem: {problem}")
print(f"Domain: {problem.domain_bounds}")
print(f"Input dimension: {problem.input_dim}")
print(f"Output dimension: {problem.output_dim}")

### Visualize the Analytical Solution

Let's visualize the true solution $u(x,y) = \sin(\pi x)\sin(\pi y)$ that we're trying to learn.

In [None]:
# Create a grid for visualization
x = np.linspace(0, 1, 100)
y = np.linspace(0, 1, 100)
X, Y = np.meshgrid(x, y)

# Flatten grid to evaluate analytical solution
xy_grid = torch.tensor(np.stack([X.flatten(), Y.flatten()], axis=1), dtype=torch.float32)
u_analytical = problem.analytical_solution(xy_grid).detach().numpy().reshape(100, 100)

# Plot
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Heatmap
im = axes[0].pcolormesh(X, Y, u_analytical, cmap='viridis', shading='auto')
axes[0].set_xlabel('x')
axes[0].set_ylabel('y')
axes[0].set_title('Analytical Solution: u(x,y) = sin(Ï€x)sin(Ï€y)')
axes[0].set_aspect('equal')
plt.colorbar(im, ax=axes[0])

# Cross-section at y=0.5
axes[1].plot(x, u_analytical[50, :], 'b-', linewidth=2, label='y=0.5')
axes[1].set_xlabel('x')
axes[1].set_ylabel('u(x, 0.5)')
axes[1].set_title('Cross-section at y=0.5')
axes[1].grid(True, alpha=0.3)
axes[1].legend()

plt.tight_layout()
plt.show()

print(f"Solution range: [{u_analytical.min():.4f}, {u_analytical.max():.4f}]")
print(f"Solution at center (0.5, 0.5): {u_analytical[50, 50]:.4f}")

### Sample Collocation Points

Unlike traditional numerical methods that use a fixed grid, PINNs sample **collocation points** where we enforce the PDE.

We use two sampling strategies:
- **Interior points**: Latin Hypercube Sampling (better coverage than random)
- **Boundary points**: Uniform sampling on domain edges

In [None]:
# Sample collocation points
n_interior = 1000
n_boundary = 50  # Per edge

interior_points = problem.sample_interior_points(n_interior, method='lhs')
boundary_points = problem.sample_boundary_points(n_boundary)

print(f"Sampled {interior_points.shape[0]} interior points")
print(f"Sampled {boundary_points.shape[0]} boundary points")

# Visualize sampling
plt.figure(figsize=(8, 8))
plt.scatter(interior_points[:, 0], interior_points[:, 1], 
            c='blue', s=1, alpha=0.5, label='Interior (LHS)')
plt.scatter(boundary_points[:, 0], boundary_points[:, 1], 
            c='red', s=10, alpha=0.8, label='Boundary')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Collocation Points Sampling')
plt.legend()
plt.xlim(-0.05, 1.05)
plt.ylim(-0.05, 1.05)
plt.gca().set_aspect('equal')
plt.grid(True, alpha=0.3)
plt.show()

## 6. Training the PINN

### Training Configuration

The loss function combines:

$$L = w_{pde} \cdot L_{pde} + w_{bc} \cdot L_{bc}$$

where:
- $L_{pde} = \frac{1}{N_{int}} \sum_{i=1}^{N_{int}} |\nabla^2 u(x_i, y_i) - f(x_i, y_i)|^2$ (PDE residual)
- $L_{bc} = \frac{1}{N_{bc}} \sum_{i=1}^{N_{bc}} |u(x_i^{bc}, y_i^{bc}) - g(x_i^{bc}, y_i^{bc})|^2$ (boundary error)

**Training parameters:**
- **Optimizer**: Adam (lr=1e-3)
- **Epochs**: 5,000 (quick training for tutorial)
- **Collocation points**: Resampled every epoch (prevents overfitting to specific points)
- **Validation**: Compute relative L2 error every 100 epochs

In [None]:
# Training configuration
config = {
    # Optimizer settings
    "optimizer": "adam",
    "lr": 1e-3,
    
    # Training duration
    "n_epochs": 5000,
    
    # Collocation points
    "n_interior": 1000,
    "n_boundary": 50,
    
    # Loss weights
    "loss_weights": {
        "pde": 1.0,   # PDE residual weight
        "bc": 1.0,    # Boundary condition weight
        "ic": 0.0,    # Initial condition (not used for Poisson)
    },
    
    # Device
    "device": device,
    
    # Sampling and validation
    "resample_every": 1,      # Resample collocation points every epoch
    "validate_every": 100,    # Compute validation error every 100 epochs
    "print_every": 500,       # Print progress every 500 epochs
}

print("Training Configuration:")
print(f"  Optimizer: {config['optimizer']} (lr={config['lr']})")
print(f"  Epochs: {config['n_epochs']}")
print(f"  Interior points: {config['n_interior']}")
print(f"  Boundary points: {config['n_boundary'] * 4} ({config['n_boundary']} per edge)")
print(f"  Device: {config['device']}")
print(f"  Loss weights: PDE={config['loss_weights']['pde']}, BC={config['loss_weights']['bc']}")

### Run Training

Now let's train the PINN! This will take a few minutes (faster on GPU).

In [None]:
print("=" * 80)
print("Starting PINN Training")
print("=" * 80)
print()

# Train the model
trained_model, history = train_pinn(model, problem, config)

print()
print("=" * 80)
print("Training Complete!")
print("=" * 80)

### Training Summary

In [None]:
# Print summary statistics
print("\nTraining Summary:")
print("-" * 40)
print(f"Initial loss:        {history['loss_total'][0]:.6f}")
print(f"Final loss:          {history['loss_total'][-1]:.6f}")
print(f"Loss reduction:      {(1 - history['loss_total'][-1] / history['loss_total'][0]) * 100:.2f}%")
print(f"\nInitial L2 error:    {history['relative_l2_error'][0]:.4f}%")
print(f"Final L2 error:      {history['relative_l2_error'][-1]:.4f}%")
print(f"Error reduction:     {(1 - history['relative_l2_error'][-1] / history['relative_l2_error'][0]) * 100:.2f}%")

if history['relative_l2_error'][-1] < 1.0:
    print("\nâœ“ Excellent! Achieved <1% relative L2 error")
elif history['relative_l2_error'][-1] < 5.0:
    print("\nâœ“ Good! Error <5% (increase epochs for better accuracy)")
else:
    print("\nâš  Error still high (try more epochs or more collocation points)")

## 7. Visualizing Results

Let's visualize the training history and compare the PINN solution to the analytical solution.

### Training History

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 4))

# Total loss (log scale)
axes[0].semilogy(history['epoch'], history['loss_total'], 'b-', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Total Loss (log scale)')
axes[0].set_title('Training Loss Convergence')
axes[0].grid(True, alpha=0.3)

# Loss decomposition
axes[1].semilogy(history['epoch'], history['loss_pde'], 'r-', linewidth=2, label='PDE residual', alpha=0.7)
axes[1].semilogy(history['epoch'], history['loss_bc'], 'g-', linewidth=2, label='Boundary', alpha=0.7)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss Components (log scale)')
axes[1].set_title('Loss Decomposition')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Relative L2 error
axes[2].plot(history['epoch'], history['relative_l2_error'], 'purple', linewidth=2)
axes[2].axhline(y=1.0, color='red', linestyle='--', linewidth=1, label='1% target')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Relative L2 Error (%)')
axes[2].set_title('Validation Error')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

### Solution Comparison

Compare PINN solution vs analytical solution on a dense grid.

In [None]:
# Evaluate PINN on dense grid
trained_model.eval()
with torch.no_grad():
    u_pinn = trained_model(xy_grid).detach().numpy().reshape(100, 100)

# Compute error
error = np.abs(u_pinn - u_analytical)

# Create comparison plot
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# PINN solution
im1 = axes[0].pcolormesh(X, Y, u_pinn, cmap='viridis', shading='auto')
axes[0].set_xlabel('x')
axes[0].set_ylabel('y')
axes[0].set_title('PINN Solution')
axes[0].set_aspect('equal')
plt.colorbar(im1, ax=axes[0])

# Analytical solution
im2 = axes[1].pcolormesh(X, Y, u_analytical, cmap='viridis', shading='auto')
axes[1].set_xlabel('x')
axes[1].set_ylabel('y')
axes[1].set_title('Analytical Solution')
axes[1].set_aspect('equal')
plt.colorbar(im2, ax=axes[1])

# Absolute error
im3 = axes[2].pcolormesh(X, Y, error, cmap='Reds', shading='auto')
axes[2].set_xlabel('x')
axes[2].set_ylabel('y')
axes[2].set_title('Absolute Error')
axes[2].set_aspect('equal')
plt.colorbar(im3, ax=axes[2])

plt.tight_layout()
plt.show()

print(f"\nError Statistics:")
print(f"  Mean absolute error: {error.mean():.6f}")
print(f"  Max absolute error:  {error.max():.6f}")
print(f"  Std of error:        {error.std():.6f}")

### Cross-Sections

Examine 1D slices to see how well PINN matches the analytical solution.

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 4))

# Horizontal slice at y=0.5
axes[0].plot(x, u_analytical[50, :], 'b-', linewidth=2, label='Analytical', alpha=0.7)
axes[0].plot(x, u_pinn[50, :], 'r--', linewidth=2, label='PINN', alpha=0.7)
axes[0].set_xlabel('x')
axes[0].set_ylabel('u(x, 0.5)')
axes[0].set_title('Horizontal Slice (y=0.5)')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Vertical slice at x=0.5
axes[1].plot(y, u_analytical[:, 50], 'b-', linewidth=2, label='Analytical', alpha=0.7)
axes[1].plot(y, u_pinn[:, 50], 'r--', linewidth=2, label='PINN', alpha=0.7)
axes[1].set_xlabel('y')
axes[1].set_ylabel('u(0.5, y)')
axes[1].set_title('Vertical Slice (x=0.5)')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Diagonal slice
diag_analytical = np.diag(u_analytical)
diag_pinn = np.diag(u_pinn)
axes[2].plot(x, diag_analytical, 'b-', linewidth=2, label='Analytical', alpha=0.7)
axes[2].plot(x, diag_pinn, 'r--', linewidth=2, label='PINN', alpha=0.7)
axes[2].set_xlabel('x=y')
axes[2].set_ylabel('u(x, x)')
axes[2].set_title('Diagonal Slice (x=y)')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 8. Extracting Activations

Now let's extract the neural activations for interpretability analysis. This is a key step for mechanistic interpretability!

We'll:
1. Extract activations on a dense grid
2. Save to HDF5 file for efficient access
3. Visualize activation patterns of individual neurons

In [None]:
# Extract activations on 100x100 grid
print("Extracting activations on dense grid...")

activation_store = extract_activations_from_model(
    model=trained_model,
    domain_bounds=[(0, 1), (0, 1)],
    grid_resolution=100,
    save_path="../data/activations/tutorial_poisson.h5",
    batch_size=1000
)

print("\nâœ“ Activation extraction complete!")

# Display metadata
metadata = activation_store.get_metadata()
print("\nActivation Store Metadata:")
print(f"  Grid resolution: {metadata['grid_resolution']}")
print(f"  Total points: {metadata['n_points']}")
print(f"  Input dimension: {metadata['input_dim']}")
print(f"  Layers: {metadata['layer_names']}")

### Visualize Neuron Activations

Let's look at what individual neurons in the first layer have learned.

In [None]:
# Load layer 0 activations
layer_0_acts = activation_store.load_layer("layer_0")  # Shape: (10000, 64)
coords = activation_store.load_coordinates()           # Shape: (10000, 2)

print(f"Layer 0 activations shape: {layer_0_acts.shape}")
print(f"Coordinates shape: {coords.shape}")

# Visualize 4 neurons
fig, axes = plt.subplots(2, 2, figsize=(12, 12))
axes = axes.flatten()

neuron_indices = [0, 5, 15, 31]  # Pick some interesting neurons

for idx, neuron_idx in enumerate(neuron_indices):
    # Get activations for this neuron
    neuron_acts = layer_0_acts[:, neuron_idx].reshape(100, 100)
    
    # Plot
    im = axes[idx].pcolormesh(X, Y, neuron_acts, cmap='RdBu_r', shading='auto')
    axes[idx].set_xlabel('x')
    axes[idx].set_ylabel('y')
    axes[idx].set_title(f'Layer 0, Neuron {neuron_idx}\n(mean={neuron_acts.mean():.3f}, std={neuron_acts.std():.3f})')
    axes[idx].set_aspect('equal')
    plt.colorbar(im, ax=axes[idx])

plt.tight_layout()
plt.show()

print("\nðŸ’¡ Observation: Different neurons learn different spatial features!")
print("   - Some respond to gradients (horizontal, vertical, diagonal)")
print("   - Some activate in specific regions (corners, edges, center)")
print("   - This is the foundation for mechanistic interpretability!")

### Layer Summary Visualization

Get an overview of what an entire layer has learned.

In [None]:
# Visualize 16 neurons from layer 0
fig, axes = plt.subplots(4, 4, figsize=(16, 16))

for i in range(16):
    row = i // 4
    col = i % 4
    
    neuron_acts = layer_0_acts[:, i].reshape(100, 100)
    
    im = axes[row, col].pcolormesh(X, Y, neuron_acts, cmap='RdBu_r', shading='auto')
    axes[row, col].set_title(f'Neuron {i}', fontsize=10)
    axes[row, col].set_aspect('equal')
    axes[row, col].set_xticks([])
    axes[row, col].set_yticks([])

plt.suptitle('Layer 0: First 16 Neurons', fontsize=16, y=0.995)
plt.tight_layout()
plt.show()

print("\nðŸ’¡ Each neuron specializes in detecting different spatial patterns!")
print("   This diversity is crucial for the network to represent complex solutions.")

## 9. Summary and Next Steps

### What We've Learned

In this tutorial, we:

âœ… **Understood PINNs**: How neural networks solve PDEs using physics-informed loss functions

âœ… **Implemented the Poisson equation**: Defined a manufactured solution test case

âœ… **Trained a PINN**: Used automatic differentiation to compute PDE residuals

âœ… **Validated results**: Compared PINN solution vs analytical solution

âœ… **Extracted activations**: Stored neural activations for interpretability analysis

âœ… **Visualized neuron patterns**: Saw how different neurons learn different spatial features

### Key Insights

1. **Automatic differentiation is key**: PyTorch computes exact derivatives for free
2. **Collocation point sampling matters**: Latin Hypercube provides better coverage
3. **Loss decomposition helps**: Separating PDE and BC losses allows fine-tuning
4. **Neurons specialize**: Different neurons learn different spatial patterns

### Next Steps

Now that you understand the basics, you can:

1. **Experiment with hyperparameters**:
   - Try different network architectures (more layers, more neurons)
   - Test different activation functions (relu, gelu, sin)
   - Adjust loss weights

2. **Solve different PDEs**:
   - Time-dependent: Heat equation
   - Nonlinear: Burgers equation
   - Wave propagation: Helmholtz equation

3. **Deep dive into interpretability**:
   - Train probing classifiers to detect derivatives
   - Perform activation patching experiments
   - Identify computational circuits

4. **Explore advanced architectures**:
   - Modified Fourier Networks (better for high-frequency solutions)
   - Attention-Enhanced PINNs (for multi-scale problems)

### Resources

- **Documentation**: See `README.md` and `CLAUDE.md`
- **Demo scripts**: Check `demos/demo_*.py` files for more examples
- **Tests**: Inspect `tests/` for usage patterns
- **Next tutorial**: `02_heat_equation.ipynb` (coming soon!)

---

**Congratulations!** You've successfully trained your first PINN! ðŸŽ‰

## Exercises (Optional)

Try these challenges to deepen your understanding:

### Exercise 1: Architecture Exploration
Train PINNs with different architectures and compare:
- Shallow (2 layers Ã— 128 neurons)
- Deep (8 layers Ã— 32 neurons)
- Wide (4 layers Ã— 128 neurons)

**Question**: Which architecture achieves the best error? Why?

### Exercise 2: Activation Function Comparison
Train PINNs with different activation functions:
- `tanh` (smooth, bounded)
- `relu` (piecewise linear)
- `sin` (periodic)

**Question**: How does activation choice affect convergence speed and final error?

### Exercise 3: Sampling Strategy
Compare different sampling methods:
- Latin Hypercube Sampling (LHS)
- Uniform Random Sampling
- Grid Sampling

**Question**: Which provides fastest convergence? Why?

### Exercise 4: Custom PDE
Define and solve your own Poisson equation with a different manufactured solution:
- Try: $u(x,y) = x^2 + y^2$
- Derive the corresponding source term $f(x,y)$
- Train a PINN and validate

**Hint**: For $u = x^2 + y^2$, compute $\nabla^2 u$ analytically to get $f$.

### Exercise 5: Neuron Analysis
Analyze what specific neurons have learned:
- Identify neurons that respond to horizontal gradients
- Find neurons that activate in specific regions
- Correlate neuron patterns with solution features

**Question**: Can you find a neuron that computes $\partial u/\partial x$?

---

Good luck! ðŸš€