# Self-Adaptive PINN (SA-PINN) - PyTorch Implementation

This notebook demonstrates the Self-Adaptive PINN concept using PyTorch to solve a convection-diffusion equation with learnable loss weights.

## Problem Definition:
- **PDE**: $\frac{\partial u}{\partial t} + \beta \frac{\partial u}{\partial x} = \nu \frac{\partial^2 u}{\partial x^2}$
- **Domain**: $x \in [-1, 1]$, $t \in [0, 1]$
- **Initial Condition**: $u(x, 0) = -\sin(\pi x)$
- **Boundary Conditions**: $u(-1, t) = u(1, t) = 0$
- **Parameters**: $\beta = 1.0$, $\nu = 0.01/\pi$

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Tuple, List

# Set style for plots
sns.set_style("whitegrid")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 1. Problem Parameters and Domain Setup

In [None]:
# Define PDE parameters
beta = 1.0  # Convection coefficient
nu = 0.01 / np.pi  # Diffusion coefficient

# Domain boundaries
x_min, x_max = -1.0, 1.0
t_min, t_max = 0.0, 1.0

print(f"Convection coefficient (beta): {beta}")
print(f"Diffusion coefficient (nu): {nu:.6f}")
print(f"Spatial domain: [{x_min}, {x_max}]")
print(f"Time domain: [{t_min}, {t_max}]")

## 2. Data Generation Functions

In [None]:
def generate_domain_points(n_points: int) -> torch.Tensor:
    """Generate random points in the domain."""
    x = torch.rand(n_points, 1) * (x_max - x_min) + x_min
    t = torch.rand(n_points, 1) * (t_max - t_min) + t_min
    return torch.cat([x, t], dim=1).to(device)

def generate_boundary_points(n_points: int) -> torch.Tensor:
    """Generate points on the spatial boundary."""
    # Left boundary: x = -1
    n_left = n_points // 2
    x_left = torch.full((n_left, 1), x_min)
    t_left = torch.rand(n_left, 1) * (t_max - t_min) + t_min
    
    # Right boundary: x = 1
    n_right = n_points - n_left
    x_right = torch.full((n_right, 1), x_max)
    t_right = torch.rand(n_right, 1) * (t_max - t_min) + t_min
    
    boundary_points = torch.cat([
        torch.cat([x_left, t_left], dim=1),
        torch.cat([x_right, t_right], dim=1)
    ], dim=0)
    
    return boundary_points.to(device)

def generate_initial_points(n_points: int) -> torch.Tensor:
    """Generate points on the initial condition (t = 0)."""
    x = torch.rand(n_points, 1) * (x_max - x_min) + x_min
    t = torch.zeros(n_points, 1)
    return torch.cat([x, t], dim=1).to(device)

def initial_condition(x: torch.Tensor) -> torch.Tensor:
    """Initial condition: u(x, 0) = -sin(π*x)."""
    return -torch.sin(np.pi * x[:, 0:1])

def boundary_condition(x: torch.Tensor) -> torch.Tensor:
    """Boundary condition: u(-1, t) = u(1, t) = 0."""
    return torch.zeros_like(x[:, 0:1])

## 3. Neural Network Architecture

In [None]:
class SAPINN(nn.Module):
    """Self-Adaptive PINN with learnable loss weights."""
    
    def __init__(self, layers: List[int] = [2, 32, 32, 32, 32, 1]):
        super(SAPINN, self).__init__()
        
        # Neural network layers
        self.layers = nn.ModuleList()
        for i in range(len(layers) - 1):
            self.layers.append(nn.Linear(layers[i], layers[i + 1]))
        
        # Initialize weights using Xavier normal
        self.init_weights()
        
        # Learnable loss weights (SA-PINN key feature)
        self.lambda_pde = nn.Parameter(torch.tensor(1.0, dtype=torch.float32))
        self.lambda_bc = nn.Parameter(torch.tensor(1.0, dtype=torch.float32))
        self.lambda_ic = nn.Parameter(torch.tensor(1.0, dtype=torch.float32))
        
        print(f"SA-PINN initialized with {len(layers)} layers: {layers}")
        print(f"Learnable weights initialized: λ_pde={self.lambda_pde.item():.3f}, λ_bc={self.lambda_bc.item():.3f}, λ_ic={self.lambda_ic.item():.3f}")
    
    def init_weights(self):
        """Initialize network weights using Xavier normal initialization."""
        for layer in self.layers:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_normal_(layer.weight)
                nn.init.zeros_(layer.bias)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through the network."""
        for i, layer in enumerate(self.layers[:-1]):
            x = torch.tanh(layer(x))
        # Final layer without activation
        x = self.layers[-1](x)
        return x
    
    def get_loss_weights(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Get current values of adaptive loss weights."""
        return self.lambda_pde, self.lambda_bc, self.lambda_ic

## 4. Physics-Informed Loss Functions

In [None]:
def compute_derivatives(u: torch.Tensor, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Compute first and second derivatives using automatic differentiation."""
    u_x = torch.autograd.grad(u, x, torch.ones_like(u), create_graph=True, retain_graph=True)[0]
    u_t = u_x[:, 1:2]  # ∂u/∂t
    u_x_spatial = u_x[:, 0:1]  # ∂u/∂x
    
    # Second derivative ∂²u/∂x²
    u_xx = torch.autograd.grad(u_x_spatial, x, torch.ones_like(u_x_spatial), 
                              create_graph=True, retain_graph=True)[0][:, 0:1]
    
    return u_t, u_x_spatial, u_xx

def pde_residual(model: SAPINN, x: torch.Tensor) -> torch.Tensor:
    """Compute PDE residual: ∂u/∂t + β∂u/∂x - ν∂²u/∂x² = 0."""
    x.requires_grad_(True)
    u = model(x)
    
    u_t, u_x, u_xx = compute_derivatives(u, x)
    
    # PDE: u_t + beta * u_x - nu * u_xx = 0
    residual = u_t + beta * u_x - nu * u_xx
    return residual

def compute_losses(model: SAPINN, domain_points: torch.Tensor, 
                  boundary_points: torch.Tensor, initial_points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Compute all loss components."""
    
    # PDE loss
    pde_pred = pde_residual(model, domain_points)
    pde_loss = torch.mean(pde_pred**2)
    
    # Boundary condition loss
    bc_pred = model(boundary_points)
    bc_true = boundary_condition(boundary_points)
    bc_loss = torch.mean((bc_pred - bc_true)**2)
    
    # Initial condition loss
    ic_pred = model(initial_points)
    ic_true = initial_condition(initial_points)
    ic_loss = torch.mean((ic_pred - ic_true)**2)
    
    return pde_loss, bc_loss, ic_loss

def compute_total_loss(model: SAPINN, domain_points: torch.Tensor,
                      boundary_points: torch.Tensor, initial_points: torch.Tensor) -> torch.Tensor:
    """Compute weighted total loss using adaptive weights."""
    pde_loss, bc_loss, ic_loss = compute_losses(model, domain_points, boundary_points, initial_points)
    
    # Get adaptive weights
    lambda_pde, lambda_bc, lambda_ic = model.get_loss_weights()
    
    # Weighted total loss
    total_loss = lambda_pde * pde_loss + lambda_bc * bc_loss + lambda_ic * ic_loss
    
    return total_loss, pde_loss, bc_loss, ic_loss

## 5. Training Setup and Execution

In [None]:
# Initialize model
model = SAPINN([2, 32, 32, 32, 32, 1]).to(device)

# Generate training data
n_domain = 2540
n_boundary = 80
n_initial = 160

domain_points = generate_domain_points(n_domain)
boundary_points = generate_boundary_points(n_boundary)
initial_points = generate_initial_points(n_initial)

print(f"Training data generated:")
print(f"  Domain points: {domain_points.shape}")
print(f"  Boundary points: {boundary_points.shape}")
print(f"  Initial points: {initial_points.shape}")

# Optimizer (includes both network parameters and adaptive weights)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Training parameters
epochs = 25000
log_interval = 1000

# Storage for loss history and weight evolution
loss_history = []
weight_history = []

In [None]:
# Training loop
print("\nStarting SA-PINN training...")
print("=" * 50)

for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    
    # Compute loss
    total_loss, pde_loss, bc_loss, ic_loss = compute_total_loss(
        model, domain_points, boundary_points, initial_points
    )
    
    # Backward pass
    total_loss.backward()
    optimizer.step()
    
    # Store history
    loss_history.append([
        total_loss.item(), pde_loss.item(), bc_loss.item(), ic_loss.item()
    ])
    
    # Store weight evolution
    lambda_pde, lambda_bc, lambda_ic = model.get_loss_weights()
    weight_history.append([
        lambda_pde.item(), lambda_bc.item(), lambda_ic.item()
    ])
    
    # Logging
    if (epoch + 1) % log_interval == 0:
        print(f"Epoch {epoch + 1:5d}/{epochs} | "
              f"Total: {total_loss.item():.2e} | "
              f"PDE: {pde_loss.item():.2e} | "
              f"BC: {bc_loss.item():.2e} | "
              f"IC: {ic_loss.item():.2e}")
        print(f"{'':17} | "
              f"λ_pde: {lambda_pde.item():.3f} | "
              f"λ_bc: {lambda_bc.item():.3f} | "
              f"λ_ic: {lambda_ic.item():.3f}")
        print("-" * 80)

print("\nTraining completed!")

# Convert to numpy for plotting
loss_history = np.array(loss_history)
weight_history = np.array(weight_history)

## 6. Results Visualization

In [None]:
# Plot loss history
plt.figure(figsize=(15, 5))

# Total loss
plt.subplot(1, 3, 1)
plt.plot(loss_history[:, 0], 'b-', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Total Loss')
plt.title('Total Weighted Loss')
plt.yscale('log')
plt.grid(True, alpha=0.3)

# Individual loss components
plt.subplot(1, 3, 2)
plt.plot(loss_history[:, 1], 'r-', label='PDE Loss', linewidth=2)
plt.plot(loss_history[:, 2], 'g-', label='BC Loss', linewidth=2)
plt.plot(loss_history[:, 3], 'b-', label='IC Loss', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.title('Individual Loss Components')
plt.yscale('log')
plt.legend()
plt.grid(True, alpha=0.3)

# Adaptive weights evolution
plt.subplot(1, 3, 3)
plt.plot(weight_history[:, 0], 'r-', label='$\lambda_{PDE}$', linewidth=2)
plt.plot(weight_history[:, 1], 'g-', label='$\lambda_{BC}$', linewidth=2)
plt.plot(weight_history[:, 2], 'b-', label='$\lambda_{IC}$', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Weight Value')
plt.title('Evolution of Adaptive Weights')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Generate test data for visualization
model.eval()
with torch.no_grad():
    # Create a grid for visualization
    x_test = torch.linspace(x_min, x_max, 100)
    t_test = torch.linspace(t_min, t_max, 50)
    X_test, T_test = torch.meshgrid(x_test, t_test, indexing='ij')
    
    # Flatten and create input tensor
    test_points = torch.stack([X_test.flatten(), T_test.flatten()], dim=1).to(device)
    
    # Get predictions
    u_pred = model(test_points).cpu().numpy().reshape(X_test.shape)
    
    X_test_np = X_test.numpy()
    T_test_np = T_test.numpy()

# Visualization
plt.figure(figsize=(15, 5))

# 2D solution plot
plt.subplot(1, 3, 1)
contour = plt.contourf(X_test_np, T_test_np, u_pred, levels=50, cmap='RdYlBu')
plt.colorbar(contour)
plt.xlabel('x')
plt.ylabel('t')
plt.title('SA-PINN Solution u(x,t)')

# Solution at different times
plt.subplot(1, 3, 2)
time_indices = [0, 12, 24, 36, 49]  # Different time steps
for i, t_idx in enumerate(time_indices):
    plt.plot(x_test.numpy(), u_pred[:, t_idx], 
             label=f't = {T_test_np[0, t_idx]:.2f}', linewidth=2)
plt.xlabel('x')
plt.ylabel('u(x,t)')
plt.title('Solution at Different Times')
plt.legend()
plt.grid(True, alpha=0.3)

# 3D surface plot
ax = plt.subplot(1, 3, 3, projection='3d')
surf = ax.plot_surface(X_test_np, T_test_np, u_pred, 
                      cmap='RdYlBu', alpha=0.9)
ax.set_xlabel('x')
ax.set_ylabel('t')
ax.set_zlabel('u(x,t)')
ax.set_title('3D Solution Surface')

plt.tight_layout()
plt.show()

## 7. Final Analysis and Summary

In [None]:
# Final loss values and weights
final_total_loss = loss_history[-1, 0]
final_pde_loss = loss_history[-1, 1]
final_bc_loss = loss_history[-1, 2]
final_ic_loss = loss_history[-1, 3]

final_lambda_pde = weight_history[-1, 0]
final_lambda_bc = weight_history[-1, 1]
final_lambda_ic = weight_history[-1, 2]

print("=" * 60)
print("SA-PINN TRAINING SUMMARY")
print("=" * 60)
print(f"Problem: Convection-Diffusion Equation")
print(f"Network Architecture: {[2, 32, 32, 32, 32, 1]}")
print(f"Training Epochs: {epochs:,}")
print(f"Training Points: {n_domain + n_boundary + n_initial:,}")
print()
print("FINAL LOSS VALUES:")
print(f"  Total Loss: {final_total_loss:.4e}")
print(f"  PDE Loss:   {final_pde_loss:.4e}")
print(f"  BC Loss:    {final_bc_loss:.4e}")
print(f"  IC Loss:    {final_ic_loss:.4e}")
print()
print("ADAPTIVE WEIGHT VALUES:")
print(f"  λ_pde (Initial → Final): 1.000 → {final_lambda_pde:.3f}")
print(f"  λ_bc  (Initial → Final): 1.000 → {final_lambda_bc:.3f}")
print(f"  λ_ic  (Initial → Final): 1.000 → {final_lambda_ic:.3f}")
print()
print("KEY FEATURES OF SA-PINN:")
print("✓ Automatic balance of loss components through learnable weights")
print("✓ No manual tuning of loss weight hyperparameters")
print("✓ Adaptive optimization balances competing physical constraints")
print("✓ Improved convergence for challenging PDEs with multiple scales")
print("=" * 60)