# PyTorch PINN Template

This notebook provides a comprehensive template for implementing Physics-Informed Neural Networks (PINNs) using pure PyTorch, replacing the DeepXDE-based implementations in the course.

## Key Differences from DeepXDE:
- Manual implementation of automatic differentiation for PDE residuals
- Custom sampling of collocation points and boundary points
- Explicit loss function construction and weighting
- More control over training process and optimization strategies

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.optimize import minimize
import time

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set style for plots
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (10, 6)

## 1. Neural Network Architecture

Define the neural network that will approximate the solution.

In [None]:
class PINN(nn.Module):
    """Physics-Informed Neural Network"""
    
    def __init__(self, input_dim, hidden_dim, output_dim, num_hidden_layers=3, activation=nn.Tanh()):
        super(PINN, self).__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.num_hidden_layers = num_hidden_layers
        
        # Build the network
        layers = []
        layers.append(nn.Linear(input_dim, hidden_dim))
        
        for _ in range(num_hidden_layers - 1):
            layers.append(activation)
            layers.append(nn.Linear(hidden_dim, hidden_dim))
        
        layers.append(activation)
        layers.append(nn.Linear(hidden_dim, output_dim))
        
        self.network = nn.Sequential(*layers)
        
        # Initialize weights using Xavier/Glorot initialization
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_normal_(module.weight)
            nn.init.zeros_(module.bias)
    
    def forward(self, x):
        return self.network(x)

# Example instantiation
model = PINN(input_dim=2, hidden_dim=50, output_dim=1, num_hidden_layers=3).to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
print(model)

## 2. Domain Sampling

Functions to sample points from different parts of the domain.

In [None]:
def sample_domain_2d(bounds, num_points, device=device):
    """Sample points uniformly from a 2D rectangular domain.
    
    Args:
        bounds: [[x_min, x_max], [y_min, y_max]]
        num_points: Number of points to sample
        device: PyTorch device
    
    Returns:
        torch.Tensor: Shape (num_points, 2)
    """
    x = torch.rand(num_points, 1, device=device) * (bounds[0][1] - bounds[0][0]) + bounds[0][0]
    y = torch.rand(num_points, 1, device=device) * (bounds[1][1] - bounds[1][0]) + bounds[1][0]
    return torch.cat([x, y], dim=1)

def sample_boundary_2d(bounds, num_points_per_side, device=device):
    """Sample points from the boundary of a 2D rectangular domain.
    
    Args:
        bounds: [[x_min, x_max], [y_min, y_max]]
        num_points_per_side: Number of points per boundary side
        device: PyTorch device
    
    Returns:
        torch.Tensor: Shape (4 * num_points_per_side, 2)
    """
    x_min, x_max = bounds[0]
    y_min, y_max = bounds[1]
    
    # Left boundary (x = x_min)
    left_y = torch.rand(num_points_per_side, 1, device=device) * (y_max - y_min) + y_min
    left_x = torch.full_like(left_y, x_min)
    left = torch.cat([left_x, left_y], dim=1)
    
    # Right boundary (x = x_max)
    right_y = torch.rand(num_points_per_side, 1, device=device) * (y_max - y_min) + y_min
    right_x = torch.full_like(right_y, x_max)
    right = torch.cat([right_x, right_y], dim=1)
    
    # Bottom boundary (y = y_min)
    bottom_x = torch.rand(num_points_per_side, 1, device=device) * (x_max - x_min) + x_min
    bottom_y = torch.full_like(bottom_x, y_min)
    bottom = torch.cat([bottom_x, bottom_y], dim=1)
    
    # Top boundary (y = y_max)
    top_x = torch.rand(num_points_per_side, 1, device=device) * (x_max - x_min) + x_min
    top_y = torch.full_like(top_x, y_max)
    top = torch.cat([top_x, top_y], dim=1)
    
    return torch.cat([left, right, bottom, top], dim=0)

def sample_time_domain(x_bounds, t_bounds, num_domain, num_boundary, num_initial, device=device):
    """Sample points for time-dependent problems.
    
    Args:
        x_bounds: [x_min, x_max]
        t_bounds: [t_min, t_max]
        num_domain: Number of collocation points in the domain
        num_boundary: Number of boundary points
        num_initial: Number of initial condition points
        device: PyTorch device
    
    Returns:
        dict: Dictionary containing 'domain', 'boundary', and 'initial' points
    """
    # Domain points (interior)
    x_domain = torch.rand(num_domain, 1, device=device) * (x_bounds[1] - x_bounds[0]) + x_bounds[0]
    t_domain = torch.rand(num_domain, 1, device=device) * (t_bounds[1] - t_bounds[0]) + t_bounds[0]
    domain_points = torch.cat([x_domain, t_domain], dim=1)
    
    # Boundary points (spatial boundaries at different times)
    t_boundary = torch.rand(num_boundary, 1, device=device) * (t_bounds[1] - t_bounds[0]) + t_bounds[0]
    x_boundary_left = torch.full((num_boundary//2, 1), x_bounds[0], device=device)
    x_boundary_right = torch.full((num_boundary//2, 1), x_bounds[1], device=device)
    x_boundary = torch.cat([x_boundary_left, x_boundary_right], dim=0)
    boundary_points = torch.cat([x_boundary, t_boundary[:num_boundary]], dim=1)
    
    # Initial condition points (t = t_min)
    x_initial = torch.rand(num_initial, 1, device=device) * (x_bounds[1] - x_bounds[0]) + x_bounds[0]
    t_initial = torch.full((num_initial, 1), t_bounds[0], device=device)
    initial_points = torch.cat([x_initial, t_initial], dim=1)
    
    return {
        'domain': domain_points,
        'boundary': boundary_points,
        'initial': initial_points
    }

# Example usage
bounds = [[-1, 1], [-1, 1]]
domain_points = sample_domain_2d(bounds, 1000)
boundary_points = sample_boundary_2d(bounds, 50)

print(f"Domain points shape: {domain_points.shape}")
print(f"Boundary points shape: {boundary_points.shape}")

## 3. Automatic Differentiation Utilities

Helper functions to compute derivatives using PyTorch's automatic differentiation.

In [None]:
def compute_gradient(u, x, create_graph=True):
    """Compute gradient of u with respect to x.
    
    Args:
        u: Output tensor of shape (batch_size, output_dim)
        x: Input tensor of shape (batch_size, input_dim)
        create_graph: Whether to create computation graph for higher-order derivatives
    
    Returns:
        torch.Tensor: Gradient tensor of shape (batch_size, input_dim)
    """
    grad = torch.autograd.grad(
        outputs=u,
        inputs=x,
        grad_outputs=torch.ones_like(u),
        create_graph=create_graph,
        retain_graph=True
    )[0]
    return grad

def compute_laplacian_2d(u, x, create_graph=True):
    """Compute 2D Laplacian of u with respect to x.
    
    Args:
        u: Output tensor of shape (batch_size, 1)
        x: Input tensor of shape (batch_size, 2) where x[:, 0] is x-coordinate and x[:, 1] is y-coordinate
        create_graph: Whether to create computation graph
    
    Returns:
        torch.Tensor: Laplacian tensor of shape (batch_size, 1)
    """
    # First derivatives
    grad = compute_gradient(u, x, create_graph=create_graph)
    u_x = grad[:, 0:1]
    u_y = grad[:, 1:2]
    
    # Second derivatives
    u_xx = compute_gradient(u_x, x, create_graph=create_graph)[:, 0:1]
    u_yy = compute_gradient(u_y, x, create_graph=create_graph)[:, 1:2]
    
    # Laplacian
    laplacian = u_xx + u_yy
    return laplacian

def compute_time_derivative(u, x, create_graph=True):
    """Compute time derivative for time-dependent problems.
    
    Args:
        u: Output tensor of shape (batch_size, 1)
        x: Input tensor of shape (batch_size, 2) where x[:, -1] is the time coordinate
        create_graph: Whether to create computation graph
    
    Returns:
        torch.Tensor: Time derivative tensor of shape (batch_size, 1)
    """
    grad = compute_gradient(u, x, create_graph=create_graph)
    u_t = grad[:, -1:]
    return u_t

# Example: Test gradient computation
x_test = torch.randn(10, 2, device=device, requires_grad=True)
u_test = model(x_test)
grad_test = compute_gradient(u_test, x_test)
print(f"Input shape: {x_test.shape}")
print(f"Output shape: {u_test.shape}")
print(f"Gradient shape: {grad_test.shape}")

## 4. Loss Functions

Define various loss components for different types of PDEs.

In [None]:
class PINNLoss:
    """Container for PINN loss functions"""
    
    def __init__(self, model, loss_weights=None):
        self.model = model
        self.loss_weights = loss_weights or {'pde': 1.0, 'bc': 1.0, 'ic': 1.0, 'data': 1.0}
    
    def poisson_2d_loss(self, domain_points, boundary_points, source_function, boundary_function):
        """Loss function for 2D Poisson equation: ∇²u = f(x,y)
        
        Args:
            domain_points: Interior collocation points
            boundary_points: Boundary points
            source_function: Source term function f(x,y)
            boundary_function: Boundary condition function g(x,y)
        
        Returns:
            dict: Dictionary containing individual loss components and total loss
        """
        domain_points.requires_grad_(True)
        
        # PDE residual loss
        u_domain = self.model(domain_points)
        laplacian = compute_laplacian_2d(u_domain, domain_points)
        source = source_function(domain_points)
        pde_residual = laplacian - source
        loss_pde = torch.mean(pde_residual**2)
        
        # Boundary condition loss
        u_boundary = self.model(boundary_points)
        bc_target = boundary_function(boundary_points)
        loss_bc = torch.mean((u_boundary - bc_target)**2)
        
        # Total loss
        total_loss = (self.loss_weights['pde'] * loss_pde + 
                     self.loss_weights['bc'] * loss_bc)
        
        return {
            'total': total_loss,
            'pde': loss_pde,
            'bc': loss_bc
        }
    
    def heat_equation_loss(self, points_dict, initial_function, boundary_function, alpha=1.0):
        """Loss function for heat equation: ∂u/∂t = α∇²u
        
        Args:
            points_dict: Dictionary with 'domain', 'boundary', 'initial' points
            initial_function: Initial condition function u(x,0)
            boundary_function: Boundary condition function
            alpha: Thermal diffusivity
        
        Returns:
            dict: Dictionary containing individual loss components and total loss
        """
        domain_points = points_dict['domain']
        boundary_points = points_dict['boundary']
        initial_points = points_dict['initial']
        
        domain_points.requires_grad_(True)
        
        # PDE residual loss
        u_domain = self.model(domain_points)
        u_t = compute_time_derivative(u_domain, domain_points)
        
        # For 1D heat equation, only need second derivative w.r.t. x
        grad = compute_gradient(u_domain, domain_points)
        u_x = grad[:, 0:1]
        u_xx = compute_gradient(u_x, domain_points)[:, 0:1]
        
        pde_residual = u_t - alpha * u_xx
        loss_pde = torch.mean(pde_residual**2)
        
        # Boundary condition loss
        u_boundary = self.model(boundary_points)
        bc_target = boundary_function(boundary_points)
        loss_bc = torch.mean((u_boundary - bc_target)**2)
        
        # Initial condition loss
        u_initial = self.model(initial_points)
        ic_target = initial_function(initial_points)
        loss_ic = torch.mean((u_initial - ic_target)**2)
        
        # Total loss
        total_loss = (self.loss_weights['pde'] * loss_pde + 
                     self.loss_weights['bc'] * loss_bc + 
                     self.loss_weights['ic'] * loss_ic)
        
        return {
            'total': total_loss,
            'pde': loss_pde,
            'bc': loss_bc,
            'ic': loss_ic
        }
    
    def data_loss(self, data_points, data_values):
        """Data fitting loss for inverse problems
        
        Args:
            data_points: Observation points
            data_values: Observation values
        
        Returns:
            torch.Tensor: Data fitting loss
        """
        u_pred = self.model(data_points)
        return torch.mean((u_pred - data_values)**2)

# Example loss function setup
loss_computer = PINNLoss(model, loss_weights={'pde': 1.0, 'bc': 10.0, 'ic': 10.0})
print("Loss computer initialized")

## 5. Training Utilities

Training functions with support for different optimizers and strategies.

In [None]:
class PINNTrainer:
    """PINN training utilities"""
    
    def __init__(self, model, loss_computer):
        self.model = model
        self.loss_computer = loss_computer
        self.loss_history = []
    
    def train_adam(self, loss_function, epochs=10000, lr=1e-3, print_every=1000):
        """Train with Adam optimizer
        
        Args:
            loss_function: Function that returns loss dictionary
            epochs: Number of training epochs
            lr: Learning rate
            print_every: Print frequency
        
        Returns:
            list: Loss history
        """
        optimizer = optim.Adam(self.model.parameters(), lr=lr)
        
        print(f"Starting Adam training for {epochs} epochs...")
        start_time = time.time()
        
        for epoch in range(epochs):
            optimizer.zero_grad()
            
            loss_dict = loss_function()
            total_loss = loss_dict['total']
            
            total_loss.backward()
            optimizer.step()
            
            # Record loss
            self.loss_history.append({
                'epoch': epoch,
                'optimizer': 'Adam',
                **{k: v.item() if torch.is_tensor(v) else v for k, v in loss_dict.items()}
            })
            
            if (epoch + 1) % print_every == 0:
                elapsed_time = time.time() - start_time
                print(f"Epoch {epoch+1:5d}/{epochs} | "
                      f"Loss: {total_loss.item():.6e} | "
                      f"Time: {elapsed_time:.1f}s")
                start_time = time.time()
        
        return self.loss_history
    
    def train_lbfgs(self, loss_function, max_iter=1000):
        """Train with L-BFGS optimizer
        
        Args:
            loss_function: Function that returns loss dictionary
            max_iter: Maximum number of iterations
        
        Returns:
            list: Loss history
        """
        print(f"Starting L-BFGS training for up to {max_iter} iterations...")
        
        # Convert model parameters to numpy for scipy
        def closure():
            loss_dict = loss_function()
            total_loss = loss_dict['total']
            
            # Record loss
            current_epoch = len([h for h in self.loss_history if h['optimizer'] == 'L-BFGS'])
            self.loss_history.append({
                'epoch': current_epoch,
                'optimizer': 'L-BFGS',
                **{k: v.item() if torch.is_tensor(v) else v for k, v in loss_dict.items()}
            })
            
            return total_loss
        
        # Use PyTorch's L-BFGS optimizer
        optimizer = optim.LBFGS(
            self.model.parameters(), 
            max_iter=max_iter,
            tolerance_grad=1e-7,
            tolerance_change=1e-9,
            history_size=100
        )
        
        def step_closure():
            optimizer.zero_grad()
            loss = closure()
            loss.backward()
            return loss
        
        optimizer.step(step_closure)
        
        return self.loss_history
    
    def train_two_stage(self, loss_function, adam_epochs=10000, adam_lr=1e-3, 
                       lbfgs_max_iter=1000, print_every=1000):
        """Two-stage training: Adam followed by L-BFGS
        
        Args:
            loss_function: Function that returns loss dictionary
            adam_epochs: Number of Adam epochs
            adam_lr: Adam learning rate
            lbfgs_max_iter: L-BFGS max iterations
            print_every: Print frequency for Adam
        
        Returns:
            list: Complete loss history
        """
        print("=== Two-Stage Training ===")
        print("Stage 1: Adam optimization")
        self.train_adam(loss_function, adam_epochs, adam_lr, print_every)
        
        print("\nStage 2: L-BFGS fine-tuning")
        self.train_lbfgs(loss_function, lbfgs_max_iter)
        
        return self.loss_history
    
    def plot_loss_history(self):
        """Plot training loss history"""
        if not self.loss_history:
            print("No training history available")
            return
        
        adam_history = [h for h in self.loss_history if h['optimizer'] == 'Adam']
        lbfgs_history = [h for h in self.loss_history if h['optimizer'] == 'L-BFGS']
        
        plt.figure(figsize=(12, 8))
        
        if adam_history:
            adam_epochs = [h['epoch'] for h in adam_history]
            adam_losses = [h['total'] for h in adam_history]
            plt.plot(adam_epochs, adam_losses, 'b-', label='Adam', linewidth=2)
        
        if lbfgs_history:
            lbfgs_start = len(adam_history) if adam_history else 0
            lbfgs_epochs = [lbfgs_start + h['epoch'] for h in lbfgs_history]
            lbfgs_losses = [h['total'] for h in lbfgs_history]
            plt.plot(lbfgs_epochs, lbfgs_losses, 'r-', label='L-BFGS', linewidth=2)
        
        plt.xlabel('Iteration')
        plt.ylabel('Total Loss')
        plt.title('PINN Training Loss History')
        plt.yscale('log')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.show()
        
        # Print final losses
        if self.loss_history:
            final_loss = self.loss_history[-1]
            print(f"\nFinal Loss: {final_loss['total']:.6e}")
            if 'pde' in final_loss:
                print(f"PDE Loss: {final_loss['pde']:.6e}")
            if 'bc' in final_loss:
                print(f"BC Loss: {final_loss['bc']:.6e}")
            if 'ic' in final_loss:
                print(f"IC Loss: {final_loss['ic']:.6e}")

# Example trainer setup
trainer = PINNTrainer(model, loss_computer)
print("Trainer initialized")

## 6. Example: 2D Poisson Equation

Complete example solving the 2D Poisson equation: ∇²u = f(x,y)

In [None]:
# Problem setup: 2D Poisson equation
# PDE: ∇²u = -2π²sin(πx)sin(πy)
# Domain: [-1,1] × [-1,1]
# BC: u = 0 on boundary
# Analytical solution: u(x,y) = sin(πx)sin(πy)

# Define problem parameters
bounds = [[-1, 1], [-1, 1]]
num_domain = 2000
num_boundary = 200

# Sample points
domain_points = sample_domain_2d(bounds, num_domain, device)
boundary_points = sample_boundary_2d(bounds, num_boundary//4, device)

# Define source function and boundary conditions
def source_function(points):
    """Source term: f(x,y) = -2π²sin(πx)sin(πy)"""
    x, y = points[:, 0:1], points[:, 1:2]
    return -2 * np.pi**2 * torch.sin(np.pi * x) * torch.sin(np.pi * y)

def boundary_function(points):
    """Boundary condition: u = 0"""
    return torch.zeros((points.shape[0], 1), device=device)

def analytical_solution(points):
    """Analytical solution: u(x,y) = sin(πx)sin(πy)"""
    x, y = points[:, 0:1], points[:, 1:2]
    return torch.sin(np.pi * x) * torch.sin(np.pi * y)

# Create model and trainer
poisson_model = PINN(input_dim=2, hidden_dim=50, output_dim=1, num_hidden_layers=3).to(device)
poisson_loss_computer = PINNLoss(poisson_model, loss_weights={'pde': 1.0, 'bc': 10.0})
poisson_trainer = PINNTrainer(poisson_model, poisson_loss_computer)

# Define loss function
def poisson_loss():
    return poisson_loss_computer.poisson_2d_loss(
        domain_points, boundary_points, source_function, boundary_function
    )

# Train the model
print("Training 2D Poisson equation solver...")
history = poisson_trainer.train_two_stage(
    poisson_loss, 
    adam_epochs=5000, 
    adam_lr=1e-3, 
    lbfgs_max_iter=500,
    print_every=1000
)

# Plot training history
poisson_trainer.plot_loss_history()

## 7. Visualization and Validation

Functions to visualize results and compute errors.

In [None]:
def visualize_2d_solution(model, bounds, analytical_solution=None, resolution=100):
    """Visualize 2D solution with comparison to analytical solution
    
    Args:
        model: Trained PINN model
        bounds: Domain bounds [[x_min, x_max], [y_min, y_max]]
        analytical_solution: Analytical solution function (optional)
        resolution: Grid resolution for visualization
    """
    # Create visualization grid
    x = torch.linspace(bounds[0][0], bounds[0][1], resolution)
    y = torch.linspace(bounds[1][0], bounds[1][1], resolution)
    X, Y = torch.meshgrid(x, y, indexing='ij')
    
    # Flatten for model prediction
    xy_test = torch.stack([X.flatten(), Y.flatten()], dim=1).to(device)
    
    # Predict solution
    with torch.no_grad():
        u_pred = model(xy_test).cpu()
    u_pred_grid = u_pred.reshape(resolution, resolution)
    
    # Setup plot
    if analytical_solution is not None:
        fig, axes = plt.subplots(1, 3, figsize=(18, 5), subplot_kw={'projection': '3d'})
        
        # Analytical solution
        u_analytical = analytical_solution(xy_test).cpu()
        u_analytical_grid = u_analytical.reshape(resolution, resolution)
        
        # Plot analytical solution
        X_np, Y_np = X.numpy(), Y.numpy()
        surf1 = axes[0].plot_surface(X_np, Y_np, u_analytical_grid.numpy(), cmap='viridis')
        axes[0].set_title('Analytical Solution')
        axes[0].set_xlabel('x')
        axes[0].set_ylabel('y')
        
        # Plot PINN solution
        surf2 = axes[1].plot_surface(X_np, Y_np, u_pred_grid.numpy(), cmap='viridis')
        axes[1].set_title('PINN Solution')
        axes[1].set_xlabel('x')
        axes[1].set_ylabel('y')
        
        # Plot error
        error = torch.abs(u_pred_grid - u_analytical_grid)
        surf3 = axes[2].plot_surface(X_np, Y_np, error.numpy(), cmap='hot')
        axes[2].set_title('Absolute Error')
        axes[2].set_xlabel('x')
        axes[2].set_ylabel('y')
        
        # Compute metrics
        l2_error = torch.sqrt(torch.mean((u_pred_grid - u_analytical_grid)**2))
        l2_relative = l2_error / torch.sqrt(torch.mean(u_analytical_grid**2))
        max_error = torch.max(torch.abs(u_pred_grid - u_analytical_grid))
        
        print(f"L2 Error: {l2_error:.6e}")
        print(f"L2 Relative Error: {l2_relative:.6e}")
        print(f"Max Absolute Error: {max_error:.6e}")
        
    else:
        fig = plt.figure(figsize=(10, 8))
        ax = fig.add_subplot(111, projection='3d')
        
        X_np, Y_np = X.numpy(), Y.numpy()
        surf = ax.plot_surface(X_np, Y_np, u_pred_grid.numpy(), cmap='viridis')
        ax.set_title('PINN Solution')
        ax.set_xlabel('x')
        ax.set_ylabel('y')
        
        plt.colorbar(surf)
    
    plt.tight_layout()
    plt.show()

def compute_metrics(u_pred, u_true):
    """Compute various error metrics
    
    Args:
        u_pred: Predicted solution
        u_true: True solution
    
    Returns:
        dict: Dictionary of error metrics
    """
    l2_error = torch.sqrt(torch.mean((u_pred - u_true)**2))
    l2_relative = l2_error / torch.sqrt(torch.mean(u_true**2))
    l_inf_error = torch.max(torch.abs(u_pred - u_true))
    
    return {
        'l2_absolute': l2_error.item(),
        'l2_relative': l2_relative.item(),
        'l_inf': l_inf_error.item()
    }

# Visualize the trained Poisson model
print("Visualizing Poisson equation results...")
visualize_2d_solution(poisson_model, bounds, analytical_solution)

## 8. Advanced Features

Additional utilities for more complex PINN implementations.

In [None]:
class AdvancedPINNFeatures:
    """Advanced PINN features and techniques"""
    
    @staticmethod
    def adaptive_weights(loss_dict, alpha=0.5, max_weight=1000.0):
        """Adaptive loss weighting using GradNorm-like approach
        
        Args:
            loss_dict: Dictionary of individual losses
            alpha: Weighting parameter
            max_weight: Maximum weight value
        
        Returns:
            dict: Updated loss weights
        """
        # Simple adaptive weighting based on loss magnitudes
        losses = {k: v for k, v in loss_dict.items() if k != 'total'}
        
        if len(losses) < 2:
            return {k: 1.0 for k in losses.keys()}
        
        # Compute relative loss rates
        loss_values = torch.tensor(list(losses.values()))
        avg_loss = torch.mean(loss_values)
        
        weights = {}
        for k, v in losses.items():
            # Higher weight for larger losses
            weight = torch.clamp(avg_loss / (v + 1e-10), 0.1, max_weight)
            weights[k] = weight.item()
        
        return weights
    
    @staticmethod
    def causal_weighting(points, current_time, time_window=0.1):
        """Causal weighting for time-dependent problems
        
        Args:
            points: Time-space points tensor
            current_time: Current time step in training
            time_window: Width of the causal window
        
        Returns:
            torch.Tensor: Causal weights
        """
        t = points[:, -1]  # Assume last dimension is time
        weights = torch.exp(-torch.clamp((t - current_time) / time_window, 0, 10))
        return weights.unsqueeze(-1)
    
    @staticmethod
    def gradient_enhanced_loss(model, data_points, data_values, grad_points, grad_values):
        """Gradient-enhanced PINN (gPINN) loss
        
        Args:
            model: PINN model
            data_points: Points where solution values are known
            data_values: Known solution values
            grad_points: Points where gradient values are known
            grad_values: Known gradient values
        
        Returns:
            torch.Tensor: Combined data and gradient loss
        """
        # Standard data loss
        u_pred = model(data_points)
        data_loss = torch.mean((u_pred - data_values)**2)
        
        # Gradient loss
        grad_points.requires_grad_(True)
        u_grad_pred = model(grad_points)
        grad_pred = compute_gradient(u_grad_pred, grad_points)
        grad_loss = torch.mean((grad_pred - grad_values)**2)
        
        return data_loss + grad_loss
    
    @staticmethod
    def fourier_feature_embedding(x, num_frequencies=10, sigma=1.0):
        """Fourier feature embedding for better high-frequency learning
        
        Args:
            x: Input coordinates
            num_frequencies: Number of frequency components
            sigma: Standard deviation for random frequencies
        
        Returns:
            torch.Tensor: Fourier features
        """
        device = x.device
        dtype = x.dtype
        
        # Random frequency matrix
        B = torch.randn(x.shape[-1], num_frequencies, device=device, dtype=dtype) * sigma
        
        # Compute projections
        x_proj = 2 * np.pi * x @ B
        
        # Sine and cosine features
        features = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
        
        return features

# Example: Using Fourier features
class FourierPINN(nn.Module):
    """PINN with Fourier feature embedding"""
    
    def __init__(self, input_dim, hidden_dim, output_dim, num_frequencies=10, sigma=1.0):
        super(FourierPINN, self).__init__()
        
        self.num_frequencies = num_frequencies
        self.sigma = sigma
        
        # Fourier embedding increases input dimension
        fourier_dim = 2 * num_frequencies
        
        self.network = nn.Sequential(
            nn.Linear(fourier_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_normal_(module.weight)
            nn.init.zeros_(module.bias)
    
    def forward(self, x):
        # Apply Fourier feature embedding
        fourier_features = AdvancedPINNFeatures.fourier_feature_embedding(
            x, self.num_frequencies, self.sigma
        )
        return self.network(fourier_features)

print("Advanced PINN features loaded")

# Example: Create a Fourier PINN
fourier_model = FourierPINN(input_dim=2, hidden_dim=64, output_dim=1, num_frequencies=10).to(device)
print(f"Fourier PINN created with {sum(p.numel() for p in fourier_model.parameters())} parameters")

## 9. Template Usage Guide

This template provides a complete framework for implementing PINNs in PyTorch. Here's how to use it for different types of problems:

### For Time-Dependent PDEs (e.g., Heat Equation):
1. Use `sample_time_domain()` to generate points
2. Use `heat_equation_loss()` method from `PINNLoss`
3. Include initial conditions in your loss function

### For Elliptic PDEs (e.g., Poisson Equation):
1. Use `sample_domain_2d()` and `sample_boundary_2d()`
2. Use `poisson_2d_loss()` method from `PINNLoss`
3. Focus on boundary conditions

### For Inverse Problems:
1. Add unknown parameters as model attributes
2. Use `data_loss()` method for observation data
3. Include parameter regularization if needed

### For Advanced Techniques:
1. Use `FourierPINN` for high-frequency problems
2. Apply `gradient_enhanced_loss()` when derivative data is available
3. Use `adaptive_weights()` for multi-objective optimization

### Training Strategy:
1. Start with `train_adam()` for initial exploration
2. Use `train_two_stage()` for best accuracy
3. Monitor loss components individually
4. Adjust loss weights based on problem requirements

This template replaces the DeepXDE dependency while providing equivalent functionality with more flexibility and control over the implementation.