# Variational Physics-Informed Neural Network (VPINN) - PyTorch Implementation

This notebook implements a Variational Physics-Informed Neural Network (VPINN) using PyTorch to solve a 1D Poisson equation using the weak formulation.

## Problem Definition:
- **PDE**: $-u_{xx}(x) = \pi^2 \sin(\pi x)$
- **Domain**: $x \in [-1, 1]$
- **Boundary Conditions**: $u(-1) = 0$, $u(1) = 0$
- **Analytical Solution**: $u(x) = \sin(\pi x)$

## Key VPINN Concept:
Instead of directly enforcing the PDE residual, VPINN uses the **weak form** of the PDE:
$$\int_{\Omega} u_x v_x \, dx = \int_{\Omega} f v \, dx$$
where $v$ are test functions that satisfy the boundary conditions.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Tuple, List, Callable

# Set style for plots
sns.set_style("whitegrid")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 1. Problem Setup and Analytical Solution

In [None]:
# Domain boundaries
x_min, x_max = -1.0, 1.0

def source_function(x: torch.Tensor) -> torch.Tensor:
    """Source function f(x) = π² sin(πx)."""
    return torch.pi**2 * torch.sin(torch.pi * x)

def analytical_solution(x: torch.Tensor) -> torch.Tensor:
    """Analytical solution: u(x) = sin(πx)."""
    return torch.sin(torch.pi * x)

def boundary_condition(x: torch.Tensor) -> torch.Tensor:
    """Boundary condition: u(-1) = u(1) = 0."""
    return torch.zeros_like(x[:, 0:1])

print(f"Domain: [{x_min}, {x_max}]")
print(f"PDE: -u_xx = π² sin(πx)")
print(f"Analytical solution: u(x) = sin(πx)")

## 2. Test Functions for Variational Formulation

Test functions $v(x)$ must satisfy the boundary conditions (i.e., $v(-1) = v(1) = 0$) to properly enforce the variational formulation.

In [None]:
class TestFunctions:
    """Collection of test functions and their derivatives for VPINN."""
    
    @staticmethod
    def v_0(x: torch.Tensor) -> torch.Tensor:
        """Test function v₀(x) = 1 - x²."""
        return 1 - x**2
    
    @staticmethod
    def v_0_x(x: torch.Tensor) -> torch.Tensor:
        """Derivative of v₀: v₀'(x) = -2x."""
        return -2 * x
    
    @staticmethod
    def v_1(x: torch.Tensor) -> torch.Tensor:
        """Test function v₁(x) = x(1 - x²)."""
        return x * (1 - x**2)
    
    @staticmethod
    def v_1_x(x: torch.Tensor) -> torch.Tensor:
        """Derivative of v₁: v₁'(x) = 1 - 3x²."""
        return 1 - 3 * x**2
    
    @classmethod
    def get_test_functions(cls) -> List[Tuple[Callable, Callable]]:
        """Return list of (test_function, derivative) pairs."""
        return [(cls.v_0, cls.v_0_x), (cls.v_1, cls.v_1_x)]

# Visualize test functions
x_plot = torch.linspace(x_min, x_max, 100).unsqueeze(1)
test_funcs = TestFunctions.get_test_functions()

plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
for i, (v, v_x) in enumerate(test_funcs):
    plt.plot(x_plot.numpy(), v(x_plot).numpy(), label=f'$v_{i}(x)$', linewidth=2)
plt.xlabel('x')
plt.ylabel('v(x)')
plt.title('Test Functions')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
for i, (v, v_x) in enumerate(test_funcs):
    plt.plot(x_plot.numpy(), v_x(x_plot).numpy(), label=f"$v'_{i}(x)$", linewidth=2)
plt.xlabel('x')
plt.ylabel("v'(x)")
plt.title('Test Function Derivatives')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Number of test functions: {len(test_funcs)}")
print("Test functions satisfy boundary conditions: v(-1) = v(1) = 0")

## 3. Data Generation Functions

In [None]:
def generate_domain_points(n_points: int) -> torch.Tensor:
    """Generate random points in the domain for Monte Carlo integration."""
    x = torch.rand(n_points, 1) * (x_max - x_min) + x_min
    return x.to(device)

def generate_boundary_points(n_points: int) -> torch.Tensor:
    """Generate points on the boundary."""
    # Half points at x = -1, half at x = 1
    n_left = n_points // 2
    n_right = n_points - n_left
    
    x_left = torch.full((n_left, 1), x_min)
    x_right = torch.full((n_right, 1), x_max)
    
    boundary_points = torch.cat([x_left, x_right], dim=0)
    return boundary_points.to(device)

def generate_test_points(n_points: int) -> torch.Tensor:
    """Generate test points for evaluation."""
    x = torch.linspace(x_min, x_max, n_points).unsqueeze(1)
    return x.to(device)

## 4. Neural Network Architecture

In [None]:
class VPINN(nn.Module):
    """Variational Physics-Informed Neural Network."""
    
    def __init__(self, layers: List[int] = [1, 20, 20, 20, 1]):
        super(VPINN, self).__init__()
        
        # Neural network layers
        self.layers = nn.ModuleList()
        for i in range(len(layers) - 1):
            self.layers.append(nn.Linear(layers[i], layers[i + 1]))
        
        # Initialize weights using Xavier normal
        self.init_weights()
        
        print(f"VPINN initialized with {len(layers)} layers: {layers}")
    
    def init_weights(self):
        """Initialize network weights using Xavier normal initialization."""
        for layer in self.layers:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_normal_(layer.weight)
                nn.init.zeros_(layer.bias)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through the network."""
        for i, layer in enumerate(self.layers[:-1]):
            x = torch.tanh(layer(x))
        # Final layer without activation
        x = self.layers[-1](x)
        return x

## 5. Variational Loss Functions

The key innovation of VPINN is using the **weak form** of the PDE instead of the strong form. For our Poisson equation:

**Strong form**: $-u_{xx} = f$

**Weak form**: $\int_{\Omega} u_x v_x \, dx = \int_{\Omega} f v \, dx$ for all test functions $v$

In [None]:
def compute_derivative(u: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    """Compute first derivative du/dx using automatic differentiation."""
    u_x = torch.autograd.grad(u, x, torch.ones_like(u), 
                             create_graph=True, retain_graph=True)[0]
    return u_x

def variational_residual(model: VPINN, x: torch.Tensor, test_functions: List[Tuple[Callable, Callable]]) -> torch.Tensor:
    """Compute variational residual using weak form of PDE."""
    x.requires_grad_(True)
    u = model(x)
    u_x = compute_derivative(u, x)
    
    # Source function values
    f_val = source_function(x)
    
    residuals = []
    for v, v_x in test_functions:
        # Test function and its derivative values
        v_val = v(x)
        v_x_val = v_x(x)
        
        # Weak form: ∫(u_x * v_x)dx - ∫(f * v)dx = 0
        # Monte Carlo approximation: mean over batch points
        integrand = u_x * v_x_val - f_val * v_val
        integral_residual = torch.mean(integrand)
        
        residuals.append(integral_residual)
    
    # Sum of squared residuals from all test functions
    total_residual = torch.sum(torch.stack(residuals)**2)
    
    return total_residual

def compute_losses(model: VPINN, domain_points: torch.Tensor, 
                  boundary_points: torch.Tensor, test_functions: List[Tuple[Callable, Callable]]) -> Tuple[torch.Tensor, torch.Tensor]:
    """Compute variational and boundary losses."""
    
    # Variational (PDE) loss using weak form
    vpde_loss = variational_residual(model, domain_points, test_functions)
    
    # Boundary condition loss
    bc_pred = model(boundary_points)
    bc_true = boundary_condition(boundary_points)
    bc_loss = torch.mean((bc_pred - bc_true)**2)
    
    return vpde_loss, bc_loss

## 6. Training Setup and Execution

In [None]:
# Initialize model
model = VPINN([1, 20, 20, 20, 1]).to(device)
test_functions = TestFunctions.get_test_functions()

# Generate training data
n_domain = 500  # More points for better Monte Carlo integration
n_boundary = 2

domain_points = generate_domain_points(n_domain)
boundary_points = generate_boundary_points(n_boundary)

print(f"Training data generated:")
print(f"  Domain points: {domain_points.shape}")
print(f"  Boundary points: {boundary_points.shape}")
print(f"  Test functions: {len(test_functions)}")

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Training parameters
epochs = 15000
log_interval = 1000

# Storage for loss history
loss_history = []

In [None]:
# Training loop
print("\nStarting VPINN training...")
print("=" * 60)

for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    
    # Compute losses
    vpde_loss, bc_loss = compute_losses(model, domain_points, boundary_points, test_functions)
    total_loss = vpde_loss + bc_loss
    
    # Backward pass
    total_loss.backward()
    optimizer.step()
    
    # Store history
    loss_history.append([total_loss.item(), vpde_loss.item(), bc_loss.item()])
    
    # Logging
    if (epoch + 1) % log_interval == 0:
        print(f"Epoch {epoch + 1:5d}/{epochs} | "
              f"Total: {total_loss.item():.2e} | "
              f"VPDE: {vpde_loss.item():.2e} | "
              f"BC: {bc_loss.item():.2e}")

print("\nTraining completed!")

# Convert to numpy for plotting
loss_history = np.array(loss_history)

## 7. Results Analysis and Visualization

In [None]:
# Plot loss history
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(loss_history[:, 0], 'b-', label='Total Loss', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Total Loss History')
plt.yscale('log')
plt.grid(True, alpha=0.3)
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(loss_history[:, 1], 'r-', label='Variational PDE Loss', linewidth=2)
plt.plot(loss_history[:, 2], 'g-', label='Boundary Loss', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Individual Loss Components')
plt.yscale('log')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Generate test data and evaluate model
model.eval()
with torch.no_grad():
    x_test = generate_test_points(100)
    u_pred = model(x_test).cpu().numpy()
    u_true = analytical_solution(x_test).cpu().numpy()
    x_test_np = x_test.cpu().numpy()

# Calculate L2 relative error
l2_error = np.linalg.norm(u_true - u_pred) / np.linalg.norm(u_true)

# Visualization
plt.figure(figsize=(15, 5))

# Solution comparison
plt.subplot(1, 3, 1)
plt.plot(x_test_np, u_true, 'b-', label='Analytical Solution', linewidth=3)
plt.plot(x_test_np, u_pred, 'r--', label='VPINN Prediction', linewidth=2)
plt.xlabel('x')
plt.ylabel('u(x)')
plt.title('VPINN vs. Analytical Solution')
plt.legend()
plt.grid(True, alpha=0.3)

# Error plot
plt.subplot(1, 3, 2)
error = np.abs(u_true - u_pred)
plt.plot(x_test_np, error, 'g-', linewidth=2)
plt.xlabel('x')
plt.ylabel('|Error|')
plt.title(f'Absolute Error\nL2 Relative Error: {l2_error:.2e}')
plt.grid(True, alpha=0.3)

# Source function and solution
plt.subplot(1, 3, 3)
f_vals = source_function(x_test).cpu().numpy()
plt.plot(x_test_np, f_vals, 'm-', label='Source f(x)', linewidth=2)
plt.plot(x_test_np, u_true, 'b-', label='Solution u(x)', linewidth=2)
plt.xlabel('x')
plt.ylabel('Value')
plt.title('Source Function and Solution')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 8. Variational Form Verification

Let's verify that our solution satisfies the weak form by checking the residuals for each test function.

In [None]:
# Verify weak form residuals
model.eval()
with torch.no_grad():
    # Use a dense set of points for accurate integration
    x_verify = generate_test_points(1000)
    x_verify.requires_grad_(True)
    
    u = model(x_verify)
    u_x = compute_derivative(u, x_verify)
    f_val = source_function(x_verify)
    
    print("Verification of Weak Form Residuals:")
    print("=" * 40)
    
    for i, (v, v_x) in enumerate(test_functions):
        v_val = v(x_verify)
        v_x_val = v_x(x_verify)
        
        # Compute weak form residual: ∫(u_x * v_x - f * v)dx
        integrand = u_x * v_x_val - f_val * v_val
        residual = torch.mean(integrand).item()
        
        print(f"Test function v_{i}: Residual = {residual:.2e}")
    
    print("\nSmall residuals indicate good satisfaction of the weak form!")

## 9. Comparison: Strong Form vs. Weak Form

Let's compare the residuals when evaluated using both strong and weak formulations.

In [None]:
# Compare strong form vs weak form residuals
model.eval()
with torch.no_grad():
    x_comp = generate_test_points(100)
    x_comp.requires_grad_(True)
    
    u = model(x_comp)
    u_x = compute_derivative(u, x_comp)
    u_xx = torch.autograd.grad(u_x, x_comp, torch.ones_like(u_x), 
                              create_graph=True, retain_graph=True)[0]
    
    # Strong form residual: -u_xx - f
    f_val = source_function(x_comp)
    strong_residual = -u_xx - f_val
    
    plt.figure(figsize=(12, 4))
    
    # Strong form residual
    plt.subplot(1, 2, 1)
    plt.plot(x_comp.cpu().numpy(), strong_residual.cpu().numpy(), 'r-', linewidth=2)
    plt.xlabel('x')
    plt.ylabel('Residual')
    plt.title('Strong Form Residual: -u_xx - f')
    plt.grid(True, alpha=0.3)
    
    # Weak form verification for first test function
    plt.subplot(1, 2, 2)
    v, v_x = test_functions[0]
    v_val = v(x_comp)
    v_x_val = v_x(x_comp)
    weak_integrand = u_x * v_x_val - f_val * v_val
    
    plt.plot(x_comp.cpu().numpy(), weak_integrand.cpu().numpy(), 'b-', linewidth=2)
    plt.xlabel('x')
    plt.ylabel('Integrand')
    plt.title('Weak Form Integrand: u_x*v_x - f*v')
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"Strong form RMS residual: {torch.sqrt(torch.mean(strong_residual**2)).item():.2e}")
    print(f"Weak form integral residual: {torch.mean(weak_integrand).item():.2e}")

## 10. Final Summary

In [None]:
# Final analysis
final_total_loss = loss_history[-1, 0]
final_vpde_loss = loss_history[-1, 1]
final_bc_loss = loss_history[-1, 2]

print("=" * 70)
print("VARIATIONAL PINN (VPINN) TRAINING SUMMARY")
print("=" * 70)
print(f"Problem: 1D Poisson Equation with Variational Formulation")
print(f"Network Architecture: {[1, 20, 20, 20, 1]}")
print(f"Training Epochs: {epochs:,}")
print(f"Domain Points: {n_domain:,} (for Monte Carlo integration)")
print(f"Test Functions: {len(test_functions)}")
print()
print("FINAL RESULTS:")
print(f"  Total Loss: {final_total_loss:.4e}")
print(f"  Variational PDE Loss: {final_vpde_loss:.4e}")
print(f"  Boundary Loss: {final_bc_loss:.4e}")
print(f"  L2 Relative Error: {l2_error:.4e}")
print()
print("KEY FEATURES OF VPINN:")
print("✓ Uses weak formulation instead of strong form PDE residual")
print("✓ Employs test functions that satisfy boundary conditions")
print("✓ Monte Carlo integration for variational integrals")
print("✓ Better numerical stability for certain PDE types")
print("✓ Natural incorporation of boundary conditions in weak form")
print("✓ Can handle problems where strong form is difficult to enforce")
print("=" * 70)