# 2.2: 2D Poisson Equation - PyTorch Implementation

This notebook solves the 2D Poisson equation using pure PyTorch instead of DeepXDE.
The Poisson equation is a canonical example of an elliptic PDE.

### Problem Definition:
- **PDE**: $\frac{\partial^2 u}{\partial x^2} + \frac{\partial^2 u}{\partial y^2} = f(x, y)$
- **Source function**: $f(x, y) = -2\pi^2 \sin(\pi x) \sin(\pi y)$
- **Domain**: $x \in [-1, 1]$, $y \in [-1, 1]$
- **Boundary Conditions**: $u(x, y) = 0$ on the entire boundary
- **Analytical Solution**: $u(x, y) = \sin(\pi x) \sin(\pi y)$

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits.mplot3d import Axes3D
import time

# Set device and style
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
sns.set_style("whitegrid")

## 1. Neural Network Architecture

In [None]:
class PoissonPINN(nn.Module):
    """Physics-Informed Neural Network for 2D Poisson Equation"""
    
    def __init__(self, hidden_dim=50, num_layers=3):
        super(PoissonPINN, self).__init__()
        
        layers = []
        layers.append(nn.Linear(2, hidden_dim))  # Input: (x, y)
        
        for _ in range(num_layers):
            layers.append(nn.Tanh())
            layers.append(nn.Linear(hidden_dim, hidden_dim))
        
        layers.append(nn.Tanh())
        layers.append(nn.Linear(hidden_dim, 1))  # Output: u(x,y)
        
        self.network = nn.Sequential(*layers)
        
        # Initialize weights using Glorot normal
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_normal_(module.weight)
            nn.init.zeros_(module.bias)
    
    def forward(self, x):
        return self.network(x)

# Create model
model = PoissonPINN(hidden_dim=50, num_layers=3).to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")
print(model)

## 2. Domain Sampling Functions

In [None]:
def sample_domain_points(num_points, device=device):
    """Sample points from the interior domain [-1,1] × [-1,1]"""
    x = torch.rand(num_points, 1, device=device) * 2 - 1  # x ∈ [-1, 1]
    y = torch.rand(num_points, 1, device=device) * 2 - 1  # y ∈ [-1, 1]
    return torch.cat([x, y], dim=1)

def sample_boundary_points(num_points, device=device):
    """Sample points from the boundary of [-1,1] × [-1,1]"""
    points_per_side = num_points // 4
    
    # Bottom boundary: y = -1
    x_bottom = torch.rand(points_per_side, 1, device=device) * 2 - 1
    y_bottom = torch.full((points_per_side, 1), -1.0, device=device)
    bottom = torch.cat([x_bottom, y_bottom], dim=1)
    
    # Top boundary: y = 1
    x_top = torch.rand(points_per_side, 1, device=device) * 2 - 1
    y_top = torch.full((points_per_side, 1), 1.0, device=device)
    top = torch.cat([x_top, y_top], dim=1)
    
    # Left boundary: x = -1
    x_left = torch.full((points_per_side, 1), -1.0, device=device)
    y_left = torch.rand(points_per_side, 1, device=device) * 2 - 1
    left = torch.cat([x_left, y_left], dim=1)
    
    # Right boundary: x = 1
    x_right = torch.full((points_per_side, 1), 1.0, device=device)
    y_right = torch.rand(points_per_side, 1, device=device) * 2 - 1
    right = torch.cat([x_right, y_right], dim=1)
    
    return torch.cat([bottom, top, left, right], dim=0)

# Sample points
num_domain = 2500
num_boundary = 100

domain_points = sample_domain_points(num_domain, device)
boundary_points = sample_boundary_points(num_boundary, device)

print(f"Domain points: {domain_points.shape}")
print(f"Boundary points: {boundary_points.shape}")

# Visualize sampling
plt.figure(figsize=(8, 8))
domain_np = domain_points.cpu().numpy()
boundary_np = boundary_points.cpu().numpy()
plt.scatter(domain_np[:500, 0], domain_np[:500, 1], alpha=0.3, s=1, label='Domain', color='blue')
plt.scatter(boundary_np[:, 0], boundary_np[:, 1], alpha=0.7, s=10, label='Boundary', color='red')
plt.xlim([-1.1, 1.1])
plt.ylim([-1.1, 1.1])
plt.xlabel('x')
plt.ylabel('y')
plt.title('Sampling Points Distribution')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

## 3. Automatic Differentiation Utilities

In [None]:
def compute_laplacian(u, points):
    """Compute the Laplacian ∇²u = ∂²u/∂x² + ∂²u/∂y²
    
    Args:
        u: Network output u(x,y)
        points: Input points [x, y]
    
    Returns:
        laplacian: ∇²u
    """
    # First derivatives
    grad = torch.autograd.grad(
        outputs=u,
        inputs=points,
        grad_outputs=torch.ones_like(u),
        create_graph=True,
        retain_graph=True
    )[0]
    
    u_x = grad[:, 0:1]  # ∂u/∂x
    u_y = grad[:, 1:2]  # ∂u/∂y
    
    # Second derivatives
    u_xx = torch.autograd.grad(
        outputs=u_x,
        inputs=points,
        grad_outputs=torch.ones_like(u_x),
        create_graph=True,
        retain_graph=True
    )[0][:, 0:1]  # ∂²u/∂x²
    
    u_yy = torch.autograd.grad(
        outputs=u_y,
        inputs=points,
        grad_outputs=torch.ones_like(u_y),
        create_graph=True,
        retain_graph=True
    )[0][:, 1:2]  # ∂²u/∂y²
    
    # Laplacian
    laplacian = u_xx + u_yy
    return laplacian

## 4. Problem Definition Functions

In [None]:
def source_function(points):
    """Source function: f(x,y) = -2π²sin(πx)sin(πy)"""
    x = points[:, 0:1]
    y = points[:, 1:2]
    return -2 * np.pi**2 * torch.sin(np.pi * x) * torch.sin(np.pi * y)

def analytical_solution(points):
    """Analytical solution: u(x,y) = sin(πx)sin(πy)"""
    x = points[:, 0:1]
    y = points[:, 1:2]
    return torch.sin(np.pi * x) * torch.sin(np.pi * y)

def boundary_condition(points):
    """Boundary condition: u = 0 on boundary"""
    return torch.zeros(points.shape[0], 1, device=points.device)

# Test functions
test_points = torch.tensor([[0.0, 0.0], [0.5, 0.5]], device=device)
print(f"Source at (0,0): {source_function(test_points)[0].item():.6f}")
print(f"Analytical at (0.5,0.5): {analytical_solution(test_points)[1].item():.6f}")

## 5. Loss Functions

In [None]:
def compute_losses(model, domain_points, boundary_points):
    """Compute all loss components for the Poisson equation"""
    
    # Enable gradients for domain points
    domain_points.requires_grad_(True)
    
    # 1. PDE Loss: ∇²u - f(x,y) = 0
    u_domain = model(domain_points)
    laplacian = compute_laplacian(u_domain, domain_points)
    source = source_function(domain_points)
    
    pde_residual = laplacian - source
    loss_pde = torch.mean(pde_residual**2)
    
    # 2. Boundary Loss: u = 0 on boundary
    u_boundary = model(boundary_points)
    bc_target = boundary_condition(boundary_points)
    loss_boundary = torch.mean((u_boundary - bc_target)**2)
    
    # Total loss
    total_loss = loss_pde + loss_boundary
    
    return {
        'total': total_loss,
        'pde': loss_pde,
        'boundary': loss_boundary
    }

# Test loss computation
with torch.no_grad():
    losses = compute_losses(model, domain_points, boundary_points)
    print("Initial losses:")
    for key, value in losses.items():
        print(f"  {key}: {value.item():.6f}")

## 6. Training

In [None]:
def train_model(model, domain_points, boundary_points, 
                epochs=20000, lr=1e-3, print_every=1000):
    """Train the PINN model"""
    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_history = []
    
    print(f"Starting training for {epochs} epochs...")
    start_time = time.time()
    
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        
        # Compute losses
        losses = compute_losses(model, domain_points, boundary_points)
        total_loss = losses['total']
        
        # Backward pass
        total_loss.backward()
        optimizer.step()
        
        # Record losses
        loss_history.append({
            'epoch': epoch,
            'total': total_loss.item(),
            'pde': losses['pde'].item(),
            'boundary': losses['boundary'].item()
        })
        
        # Print progress
        if (epoch + 1) % print_every == 0:
            elapsed_time = time.time() - start_time
            print(f"Epoch {epoch+1:5d}/{epochs} | "
                  f"Total: {total_loss.item():.2e} | "
                  f"PDE: {losses['pde'].item():.2e} | "
                  f"BC: {losses['boundary'].item():.2e} | "
                  f"Time: {elapsed_time:.1f}s")
            start_time = time.time()
    
    return loss_history

# Train the model
loss_history = train_model(model, domain_points, boundary_points, 
                          epochs=20000, lr=1e-3, print_every=1000)

## 7. Visualization and Validation

In [None]:
def plot_loss_history(loss_history):
    """Plot training loss history"""
    epochs = [h['epoch'] for h in loss_history]
    total_losses = [h['total'] for h in loss_history]
    pde_losses = [h['pde'] for h in loss_history]
    boundary_losses = [h['boundary'] for h in loss_history]
    
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.plot(epochs, total_losses, 'b-', linewidth=2)
    plt.yscale('log')
    plt.title('Total Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 3, 2)
    plt.plot(epochs, pde_losses, 'r-', linewidth=2)
    plt.yscale('log')
    plt.title('PDE Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 3, 3)
    plt.plot(epochs, boundary_losses, 'g-', linewidth=2)
    plt.yscale('log')
    plt.title('Boundary Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Plot loss history
plot_loss_history(loss_history)

In [None]:
def visualize_solution(model, resolution=100):
    """Visualize the 2D solution as surface plots"""
    model.eval()
    
    # Create grid for visualization
    x = torch.linspace(-1, 1, resolution)
    y = torch.linspace(-1, 1, resolution)
    X, Y = torch.meshgrid(x, y, indexing='ij')
    
    # Flatten for model prediction
    xy_test = torch.stack([X.flatten(), Y.flatten()], dim=1).to(device)
    
    # Predict solution
    with torch.no_grad():
        u_pred = model(xy_test).cpu()
    U_pred = u_pred.reshape(resolution, resolution)
    
    # Analytical solution
    u_analytical = analytical_solution(xy_test).cpu()
    U_analytical = u_analytical.reshape(resolution, resolution)
    
    # Compute errors
    error = torch.abs(U_pred - U_analytical)
    l2_error = torch.sqrt(torch.mean((U_pred - U_analytical)**2))
    l2_relative = l2_error / torch.sqrt(torch.mean(U_analytical**2))
    max_error = torch.max(error)
    
    # Convert to numpy for plotting
    X_np, Y_np = X.numpy(), Y.numpy()
    U_pred_np = U_pred.numpy()
    U_analytical_np = U_analytical.numpy()
    error_np = error.numpy()
    
    # Create 3D surface plots
    fig = plt.figure(figsize=(18, 5))
    
    # PINN Solution
    ax1 = fig.add_subplot(1, 3, 1, projection='3d')
    surf1 = ax1.plot_surface(X_np, Y_np, U_pred_np, cmap='hot', alpha=0.8)
    ax1.set_title('PINN Solution')
    ax1.set_xlabel('x')
    ax1.set_ylabel('y')
    ax1.set_zlabel('u(x,y)')
    
    # Analytical Solution
    ax2 = fig.add_subplot(1, 3, 2, projection='3d')
    surf2 = ax2.plot_surface(X_np, Y_np, U_analytical_np, cmap='hot', alpha=0.8)
    ax2.set_title('Analytical Solution')
    ax2.set_xlabel('x')
    ax2.set_ylabel('y')
    ax2.set_zlabel('u(x,y)')
    
    # Error
    ax3 = fig.add_subplot(1, 3, 3, projection='3d')
    surf3 = ax3.plot_surface(X_np, Y_np, error_np, cmap='viridis', alpha=0.8)
    ax3.set_title(f'Absolute Error\nL2: {l2_error:.2e}')
    ax3.set_xlabel('x')
    ax3.set_ylabel('y')
    ax3.set_zlabel('|Error|')
    
    plt.suptitle("2D Poisson Equation - PyTorch PINN Results", fontsize=16)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()
    
    # Create contour plots
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # PINN Solution contour
    im1 = axes[0].contourf(X_np, Y_np, U_pred_np, levels=20, cmap='hot')
    axes[0].set_title('PINN Solution (Contour)')
    axes[0].set_xlabel('x')
    axes[0].set_ylabel('y')
    plt.colorbar(im1, ax=axes[0])
    
    # Analytical Solution contour
    im2 = axes[1].contourf(X_np, Y_np, U_analytical_np, levels=20, cmap='hot')
    axes[1].set_title('Analytical Solution (Contour)')
    axes[1].set_xlabel('x')
    axes[1].set_ylabel('y')
    plt.colorbar(im2, ax=axes[1])
    
    # Error contour
    im3 = axes[2].contourf(X_np, Y_np, error_np, levels=20, cmap='viridis')
    axes[2].set_title('Absolute Error (Contour)')
    axes[2].set_xlabel('x')
    axes[2].set_ylabel('y')
    plt.colorbar(im3, ax=axes[2])
    
    plt.tight_layout()
    plt.show()
    
    print(f"L2 Absolute Error: {l2_error:.6e}")
    print(f"L2 Relative Error: {l2_relative:.6e}")
    print(f"Max Absolute Error: {max_error:.6e}")
    
    return l2_error.item(), l2_relative.item(), max_error.item()

# Visualize results
l2_abs, l2_rel, max_abs = visualize_solution(model)

## 8. Detailed Analysis and Validation

In [None]:
# Test model at specific points
model.eval()
test_points = torch.tensor([
    [0.0, 0.0],   # Center
    [0.5, 0.5],   # First quadrant
    [-0.5, 0.5],  # Second quadrant
    [-0.5, -0.5], # Third quadrant
    [0.5, -0.5],  # Fourth quadrant
    [1.0, 0.0],   # Boundary point
    [0.0, 1.0],   # Boundary point
], device=device)

with torch.no_grad():
    u_pred_test = model(test_points)
    u_analytical_test = analytical_solution(test_points)

print("\n=== Point-wise Validation ===")
print("Point\t\tPINN\t\tAnalytical\tError")
print("-" * 60)
for i, point in enumerate(test_points):
    x, y = point[0].item(), point[1].item()
    pred = u_pred_test[i].item()
    true = u_analytical_test[i].item()
    error = abs(pred - true)
    print(f"({x:5.1f}, {y:5.1f})\t{pred:10.6f}\t{true:10.6f}\t{error:.2e}")

print(f"\n=== Final Model Statistics ===")
print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")
print(f"Training epochs: {len(loss_history)}")
print(f"Final total loss: {loss_history[-1]['total']:.2e}")
print(f"Final PDE loss: {loss_history[-1]['pde']:.2e}")
print(f"Final boundary loss: {loss_history[-1]['boundary']:.2e}")
print(f"L2 relative error: {l2_rel:.6e}")

## 9. Comparison with DeepXDE

This PyTorch implementation demonstrates:

### **Advantages of Pure PyTorch:**
1. **Full Control**: Direct access to all components of the training process
2. **Transparency**: Clear understanding of how gradients are computed
3. **Flexibility**: Easy to modify loss functions, architectures, or training strategies
4. **Debugging**: Ability to inspect intermediate computations
5. **Custom Features**: Simple to add problem-specific enhancements

### **Key Implementation Details:**
- **Laplacian Computation**: Manual implementation using `torch.autograd.grad`
- **Domain Sampling**: Custom functions for interior and boundary points
- **Loss Construction**: Explicit combination of PDE and boundary losses
- **Training Loop**: Standard PyTorch optimization with detailed monitoring

The results should match the DeepXDE implementation while providing greater insight into the underlying mathematics and computation.