# 4.2: Two-Stage Optimization Strategy - PyTorch Implementation

This notebook demonstrates the powerful two-stage optimization strategy (Adam + L-BFGS) for training Physics-Informed Neural Networks using pure PyTorch. We'll solve a simple ODE and compare different optimization approaches.

### Problem Definition:
- **ODE**: $\frac{dy}{dx} + y = 0$
- **Initial Condition**: $y(0) = 1$
- **Domain**: $x \in [0, 5]$
- **Analytical Solution**: $y(x) = e^{-x}$

### Optimization Strategies to Compare:
1. **Adam Only**: Standard adaptive optimizer
2. **L-BFGS Only**: Quasi-Newton method
3. **Two-Stage**: Adam (exploration) + L-BFGS (refinement)
4. **Advanced**: Adaptive techniques with learning rate scheduling

### Key Insights:
- Adam excels at initial exploration of complex loss landscapes
- L-BFGS provides superior final convergence precision
- Two-stage approach combines the best of both methods

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import time
from scipy.optimize import minimize
from copy import deepcopy

# Set device and style
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 8)

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

## 1. Neural Network and Problem Setup

In [None]:
class PINN(nn.Module):
    """Physics-Informed Neural Network for ODE solving"""
    
    def __init__(self, hidden_dim=32, num_layers=3):
        super(PINN, self).__init__()
        
        layers = []
        layers.append(nn.Linear(1, hidden_dim))  # Input: x
        
        for _ in range(num_layers):
            layers.append(nn.Tanh())
            layers.append(nn.Linear(hidden_dim, hidden_dim))
        
        layers.append(nn.Tanh())
        layers.append(nn.Linear(hidden_dim, 1))  # Output: y
        
        self.network = nn.Sequential(*layers)
        self.init_weights()
    
    def init_weights(self):
        """Initialize weights using Xavier/Glorot normal initialization"""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_normal_(module.weight)
                nn.init.zeros_(module.bias)
    
    def forward(self, x):
        return self.network(x)
    
    def get_weights_vector(self):
        """Get all parameters as a single vector (for L-BFGS)"""
        return torch.cat([p.view(-1) for p in self.parameters()])
    
    def set_weights_vector(self, weights):
        """Set all parameters from a single vector (for L-BFGS)"""
        start = 0
        for param in self.parameters():
            end = start + param.numel()
            param.data = weights[start:end].view(param.shape)
            start = end
    
    def clone_weights(self):
        """Create a copy of current weights"""
        return {name: param.clone() for name, param in self.named_parameters()}
    
    def load_weights(self, weight_dict):
        """Load weights from a dictionary"""
        for name, param in self.named_parameters():
            param.data = weight_dict[name].clone()

def analytical_solution(x):
    """Analytical solution: y(x) = exp(-x)"""
    return torch.exp(-x)

def compute_loss(model, x_domain, x0, y0):
    """Compute the composite PINN loss"""
    # Enable gradients for automatic differentiation
    x_domain.requires_grad_(True)
    
    # ODE residual loss: dy/dx + y = 0
    y_pred = model(x_domain)
    dy_dx = torch.autograd.grad(
        y_pred, x_domain,
        grad_outputs=torch.ones_like(y_pred),
        create_graph=True
    )[0]
    
    ode_residual = dy_dx + y_pred
    loss_ode = torch.mean(ode_residual**2)
    
    # Initial condition loss: y(0) = 1
    y0_pred = model(x0)
    loss_ic = torch.mean((y0_pred - y0)**2)
    
    # Total loss
    total_loss = loss_ode + loss_ic
    
    return total_loss, loss_ode, loss_ic

# Problem setup
x_domain = torch.linspace(0, 5, 100, requires_grad=True).view(-1, 1).to(device)
x0 = torch.tensor([[0.0]], device=device)
y0 = torch.tensor([[1.0]], device=device)

print(f"Domain points: {x_domain.shape}")
print(f"Initial condition: y({x0.item():.1f}) = {y0.item():.1f}")

## 2. Optimization Strategy Classes

In [None]:
class OptimizationTracker:
    """Track optimization progress and metrics"""
    
    def __init__(self, model, x_test=None):
        self.model = model
        self.x_test = x_test if x_test is not None else torch.linspace(0, 5, 200).view(-1, 1).to(device)
        self.history = []
        self.start_time = None
        self.phase_times = []
    
    def start_tracking(self, phase_name):
        """Start tracking a new optimization phase"""
        self.start_time = time.time()
        self.current_phase = phase_name
    
    def record_step(self, epoch, total_loss, ode_loss, ic_loss, optimizer_name=None):
        """Record a single optimization step"""
        with torch.no_grad():
            # Compute L2 error against analytical solution
            y_pred = self.model(self.x_test)
            y_true = analytical_solution(self.x_test)
            l2_error = torch.sqrt(torch.mean((y_pred - y_true)**2)).item()
            l2_relative = l2_error / torch.sqrt(torch.mean(y_true**2)).item()
        
        elapsed_time = time.time() - self.start_time if self.start_time else 0
        
        self.history.append({
            'epoch': epoch,
            'phase': getattr(self, 'current_phase', 'unknown'),
            'optimizer': optimizer_name or getattr(self, 'current_phase', 'unknown'),
            'total_loss': total_loss,
            'ode_loss': ode_loss,
            'ic_loss': ic_loss,
            'l2_error': l2_error,
            'l2_relative': l2_relative,
            'time': elapsed_time
        })
    
    def end_phase(self):
        """End current optimization phase"""
        if self.start_time:
            phase_time = time.time() - self.start_time
            self.phase_times.append(phase_time)
            self.start_time = None
            return phase_time
        return 0
    
    def get_final_metrics(self):
        """Get final performance metrics"""
        if not self.history:
            return {}
        
        final_record = self.history[-1]
        return {
            'final_loss': final_record['total_loss'],
            'final_l2_error': final_record['l2_error'],
            'final_l2_relative': final_record['l2_relative'],
            'total_epochs': len(self.history),
            'total_time': sum(self.phase_times)
        }

class AdamOptimizer:
    """Adam optimizer implementation"""
    
    def __init__(self, model, lr=1e-3, betas=(0.9, 0.999), eps=1e-8):
        self.model = model
        self.optimizer = optim.Adam(model.parameters(), lr=lr, betas=betas, eps=eps)
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', factor=0.8, patience=1000, verbose=False
        )
    
    def train(self, x_domain, x0, y0, epochs=10000, print_every=1000, tracker=None):
        """Train with Adam optimizer"""
        if tracker:
            tracker.start_tracking('Adam')
        
        print(f"Starting Adam training for {epochs} epochs...")
        
        for epoch in range(epochs):
            self.model.train()
            self.optimizer.zero_grad()
            
            total_loss, ode_loss, ic_loss = compute_loss(self.model, x_domain, x0, y0)
            total_loss.backward()
            self.optimizer.step()
            self.scheduler.step(total_loss)
            
            if tracker:
                tracker.record_step(epoch, total_loss.item(), ode_loss.item(), 
                                  ic_loss.item(), 'Adam')
            
            if (epoch + 1) % print_every == 0:
                lr = self.optimizer.param_groups[0]['lr']
                print(f"  Epoch {epoch+1:5d} | Loss: {total_loss.item():.2e} | "
                      f"ODE: {ode_loss.item():.2e} | IC: {ic_loss.item():.2e} | LR: {lr:.1e}")
        
        if tracker:
            phase_time = tracker.end_phase()
            print(f"Adam training completed in {phase_time:.1f}s")
        
        return tracker.history if tracker else []

class LBFGSOptimizer:
    """L-BFGS optimizer implementation using PyTorch's L-BFGS"""
    
    def __init__(self, model, lr=1, max_iter=20, tolerance_grad=1e-7, tolerance_change=1e-9):
        self.model = model
        self.optimizer = optim.LBFGS(
            model.parameters(),
            lr=lr,
            max_iter=max_iter,
            tolerance_grad=tolerance_grad,
            tolerance_change=tolerance_change,
            history_size=100
        )
        self.current_epoch = 0
        self.tracker = None
    
    def train(self, x_domain, x0, y0, max_iterations=1000, print_every=50, tracker=None):
        """Train with L-BFGS optimizer"""
        self.tracker = tracker
        self.current_epoch = 0
        
        if tracker:
            tracker.start_tracking('L-BFGS')
        
        print(f"Starting L-BFGS training for up to {max_iterations} iterations...")
        
        def closure():
            self.optimizer.zero_grad()
            total_loss, ode_loss, ic_loss = compute_loss(self.model, x_domain, x0, y0)
            total_loss.backward()
            
            if self.tracker:
                self.tracker.record_step(self.current_epoch, total_loss.item(), 
                                       ode_loss.item(), ic_loss.item(), 'L-BFGS')
            
            if self.current_epoch % print_every == 0:
                print(f"  Iter {self.current_epoch:4d} | Loss: {total_loss.item():.2e} | "
                      f"ODE: {ode_loss.item():.2e} | IC: {ic_loss.item():.2e}")
            
            self.current_epoch += 1
            return total_loss
        
        # Run L-BFGS optimization
        for i in range(max_iterations // 20):  # L-BFGS does multiple iterations per step
            self.optimizer.step(closure)
            
            # Check for convergence
            with torch.no_grad():
                current_loss, _, _ = compute_loss(self.model, x_domain, x0, y0)
                if current_loss.item() < 1e-10:
                    print(f"  Early convergence achieved at iteration {self.current_epoch}")
                    break
        
        if tracker:
            phase_time = tracker.end_phase()
            print(f"L-BFGS training completed in {phase_time:.1f}s")
        
        return tracker.history if tracker else []

class TwoStageOptimizer:
    """Two-stage optimization: Adam followed by L-BFGS"""
    
    def __init__(self, model, adam_lr=1e-3, lbfgs_lr=1):
        self.model = model
        self.adam_optimizer = AdamOptimizer(model, lr=adam_lr)
        self.lbfgs_optimizer = LBFGSOptimizer(model, lr=lbfgs_lr)
    
    def train(self, x_domain, x0, y0, adam_epochs=10000, lbfgs_iterations=500, 
              print_every=1000, tracker=None):
        """Two-stage training"""
        print("=== TWO-STAGE OPTIMIZATION ===")
        print("Stage 1: Adam optimization (exploration phase)")
        
        # Stage 1: Adam
        self.adam_optimizer.train(x_domain, x0, y0, adam_epochs, print_every, tracker)
        
        print("\nStage 2: L-BFGS optimization (refinement phase)")
        
        # Stage 2: L-BFGS
        self.lbfgs_optimizer.train(x_domain, x0, y0, lbfgs_iterations, 
                                 max(print_every//20, 10), tracker)
        
        print("Two-stage optimization completed!")
        
        return tracker.history if tracker else []

print("Optimization classes initialized!")

## 3. Comparative Study: Different Optimization Strategies

In [None]:
def run_optimization_comparison():
    """Run comprehensive comparison of optimization strategies"""
    
    results = {}
    
    # Test configurations
    strategies = {
        'Adam Only': {
            'class': AdamOptimizer,
            'params': {'lr': 1e-3},
            'train_params': {'epochs': 15000, 'print_every': 2000}
        },
        'L-BFGS Only': {
            'class': LBFGSOptimizer,
            'params': {'lr': 1, 'max_iter': 20},
            'train_params': {'max_iterations': 500, 'print_every': 50}
        },
        'Two-Stage': {
            'class': TwoStageOptimizer,
            'params': {'adam_lr': 1e-3, 'lbfgs_lr': 1},
            'train_params': {'adam_epochs': 10000, 'lbfgs_iterations': 300, 'print_every': 2000}
        }
    }
    
    print("Starting Optimization Strategy Comparison...")
    print("=" * 60)
    
    for strategy_name, config in strategies.items():
        print(f"\n\nTesting Strategy: {strategy_name}")
        print("-" * 40)
        
        # Create fresh model for each strategy
        model = PINN(hidden_dim=32, num_layers=3).to(device)
        
        # Initialize tracker
        tracker = OptimizationTracker(model)
        
        # Create optimizer
        optimizer = config['class'](model, **config['params'])
        
        # Train
        start_time = time.time()
        history = optimizer.train(x_domain, x0, y0, **config['train_params'], tracker=tracker)
        total_time = time.time() - start_time
        
        # Collect results
        final_metrics = tracker.get_final_metrics()
        final_metrics['strategy'] = strategy_name
        final_metrics['history'] = history
        final_metrics['model_state'] = model.clone_weights()
        
        results[strategy_name] = final_metrics
        
        print(f"\nFinal Results for {strategy_name}:")
        print(f"  Final Loss: {final_metrics['final_loss']:.2e}")
        print(f"  L2 Error: {final_metrics['final_l2_error']:.2e}")
        print(f"  L2 Relative: {final_metrics['final_l2_relative']:.2e}")
        print(f"  Total Time: {final_metrics['total_time']:.1f}s")
        print(f"  Total Epochs/Iterations: {final_metrics['total_epochs']}")
    
    return results

# Run the comparison
comparison_results = run_optimization_comparison()

## 4. Detailed Analysis and Visualization

In [None]:
def plot_optimization_comparison(results):
    """Create comprehensive visualization of optimization comparison"""
    
    fig, axes = plt.subplots(2, 3, figsize=(20, 12))
    
    # Colors for different strategies
    colors = {'Adam Only': 'blue', 'L-BFGS Only': 'red', 'Two-Stage': 'green'}
    
    # 1. Loss convergence
    ax = axes[0, 0]
    for strategy, result in results.items():
        history = result['history']
        epochs = [h['epoch'] for h in history]
        losses = [h['total_loss'] for h in history]
        
        # Handle two-stage plotting
        if strategy == 'Two-Stage':
            # Split into Adam and L-BFGS phases
            adam_history = [h for h in history if h['optimizer'] == 'Adam']
            lbfgs_history = [h for h in history if h['optimizer'] == 'L-BFGS']
            
            if adam_history:
                adam_epochs = [h['epoch'] for h in adam_history]
                adam_losses = [h['total_loss'] for h in adam_history]
                ax.plot(adam_epochs, adam_losses, color=colors[strategy], linewidth=2, alpha=0.7)
            
            if lbfgs_history:
                # Continue epoch counting from Adam phase
                offset = len(adam_history) if adam_history else 0
                lbfgs_epochs = [h['epoch'] + offset for h in lbfgs_history]
                lbfgs_losses = [h['total_loss'] for h in lbfgs_history]
                ax.plot(lbfgs_epochs, lbfgs_losses, color=colors[strategy], linewidth=3, 
                       linestyle='--', label=f'{strategy} (L-BFGS phase)')
                
                # Mark transition point
                if adam_history:
                    transition_x = len(adam_history)
                    transition_y = adam_losses[-1]
                    ax.axvline(x=transition_x, color=colors[strategy], linestyle=':', alpha=0.5)
                    ax.plot(transition_x, transition_y, 'o', color=colors[strategy], 
                           markersize=8, label=f'{strategy} (transition)')
        else:
            ax.plot(epochs, losses, color=colors[strategy], linewidth=2, label=strategy)
    
    ax.set_yscale('log')
    ax.set_xlabel('Iteration')
    ax.set_ylabel('Total Loss')
    ax.set_title('Loss Convergence Comparison')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # 2. L2 Error evolution
    ax = axes[0, 1]
    for strategy, result in results.items():
        history = result['history']
        epochs = [h['epoch'] for h in history]
        l2_errors = [h['l2_error'] for h in history]
        ax.plot(epochs, l2_errors, color=colors[strategy], linewidth=2, label=strategy)
    
    ax.set_yscale('log')
    ax.set_xlabel('Iteration')
    ax.set_ylabel('L2 Error')
    ax.set_title('L2 Error Evolution')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # 3. Final performance comparison
    ax = axes[0, 2]
    strategies = list(results.keys())
    final_losses = [results[s]['final_loss'] for s in strategies]
    final_l2_errors = [results[s]['final_l2_error'] for s in strategies]
    
    x_pos = np.arange(len(strategies))
    width = 0.35
    
    ax.bar(x_pos - width/2, np.log10(final_losses), width, 
           label='log₁₀(Final Loss)', alpha=0.7, color='lightblue')
    ax.bar(x_pos + width/2, np.log10(final_l2_errors), width, 
           label='log₁₀(L2 Error)', alpha=0.7, color='lightcoral')
    
    ax.set_xlabel('Strategy')
    ax.set_ylabel('log₁₀(Error)')
    ax.set_title('Final Performance Comparison')
    ax.set_xticks(x_pos)
    ax.set_xticklabels(strategies, rotation=45)
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # 4. Solution accuracy visualization
    ax = axes[1, 0]
    x_test = torch.linspace(0, 5, 200).view(-1, 1).to(device)
    y_true = analytical_solution(x_test).cpu().numpy()
    
    ax.plot(x_test.cpu().numpy(), y_true, 'k-', linewidth=3, label='Analytical', alpha=0.8)
    
    for strategy, result in results.items():
        # Create a temporary model to load weights
        temp_model = PINN(hidden_dim=32, num_layers=3).to(device)
        temp_model.load_weights(result['model_state'])
        
        with torch.no_grad():
            y_pred = temp_model(x_test).cpu().numpy()
        
        ax.plot(x_test.cpu().numpy(), y_pred, color=colors[strategy], 
               linewidth=2, linestyle='--', label=f'{strategy}', alpha=0.8)
    
    ax.set_xlabel('x')
    ax.set_ylabel('y(x)')
    ax.set_title('Solution Comparison')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # 5. Training time comparison
    ax = axes[1, 1]
    training_times = [results[s]['total_time'] for s in strategies]
    bars = ax.bar(strategies, training_times, color=[colors[s] for s in strategies], alpha=0.7)
    
    # Add time labels on bars
    for bar, time_val in zip(bars, training_times):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height + height*0.01,
                f'{time_val:.1f}s', ha='center', va='bottom', fontweight='bold')
    
    ax.set_ylabel('Training Time (seconds)')
    ax.set_title('Training Time Comparison')
    ax.set_xticklabels(strategies, rotation=45)
    ax.grid(True, alpha=0.3)
    
    # 6. Efficiency analysis (accuracy vs time)
    ax = axes[1, 2]
    for strategy, result in results.items():
        x_val = result['total_time']
        y_val = -np.log10(result['final_l2_error'])  # Higher is better (lower error)
        ax.scatter(x_val, y_val, color=colors[strategy], s=200, alpha=0.7, 
                  edgecolors='black', linewidth=2, label=strategy)
        
        # Add strategy labels
        ax.annotate(strategy, (x_val, y_val), xytext=(5, 5), 
                   textcoords='offset points', fontsize=10, fontweight='bold')
    
    ax.set_xlabel('Training Time (seconds)')
    ax.set_ylabel('-log₁₀(L2 Error) [Higher = Better]')
    ax.set_title('Efficiency Analysis: Accuracy vs Time')
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Create comprehensive visualization
plot_optimization_comparison(comparison_results)

## 5. Detailed Performance Analysis

In [None]:
def analyze_optimization_performance(results):
    """Detailed analysis of optimization performance"""
    
    print("\n" + "="*80)
    print("                   OPTIMIZATION STRATEGY ANALYSIS")
    print("="*80)
    
    # Create summary table
    print(f"\n{'Strategy':<15} {'Final Loss':<12} {'L2 Error':<12} {'L2 Relative':<12} {'Time (s)':<10} {'Epochs':<8}")
    print("-" * 80)
    
    for strategy, result in results.items():
        print(f"{strategy:<15} {result['final_loss']:<12.2e} "
              f"{result['final_l2_error']:<12.2e} {result['final_l2_relative']:<12.2e} "
              f"{result['total_time']:<10.1f} {result['total_epochs']:<8}")
    
    # Find best performer in each category
    best_accuracy = min(results.items(), key=lambda x: x[1]['final_l2_error'])
    best_speed = min(results.items(), key=lambda x: x[1]['total_time'])
    best_efficiency = min(results.items(), key=lambda x: x[1]['final_l2_error'] * x[1]['total_time'])
    
    print(f"\n{'PERFORMANCE RANKINGS:':<30}")
    print(f"{'Best Accuracy:':<20} {best_accuracy[0]} (L2 Error: {best_accuracy[1]['final_l2_error']:.2e})")
    print(f"{'Fastest Training:':<20} {best_speed[0]} ({best_speed[1]['total_time']:.1f}s)")
    print(f"{'Best Efficiency:':<20} {best_efficiency[0]} (error×time: {best_efficiency[1]['final_l2_error'] * best_efficiency[1]['total_time']:.2e})")
    
    # Convergence analysis
    print(f"\n{'CONVERGENCE ANALYSIS:':<30}")
    for strategy, result in results.items():
        history = result['history']
        if len(history) < 10:
            continue
            
        # Find when loss drops below certain thresholds
        thresholds = [1e-2, 1e-4, 1e-6]
        convergence_epochs = {}
        
        for threshold in thresholds:
            for i, h in enumerate(history):
                if h['total_loss'] < threshold:
                    convergence_epochs[threshold] = i
                    break
            else:
                convergence_epochs[threshold] = None
        
        print(f"\n{strategy}:")
        for threshold, epoch in convergence_epochs.items():
            if epoch is not None:
                print(f"  Loss < {threshold:.0e}: epoch {epoch}")
            else:
                print(f"  Loss < {threshold:.0e}: not achieved")
    
    # Two-stage analysis
    if 'Two-Stage' in results:
        print(f"\n{'TWO-STAGE DETAILED ANALYSIS:':<30}")
        two_stage_history = results['Two-Stage']['history']
        
        adam_history = [h for h in two_stage_history if h['optimizer'] == 'Adam']
        lbfgs_history = [h for h in two_stage_history if h['optimizer'] == 'L-BFGS']
        
        if adam_history and lbfgs_history:
            adam_final_loss = adam_history[-1]['total_loss']
            adam_time = adam_history[-1]['time']
            lbfgs_improvement = adam_final_loss / lbfgs_history[-1]['total_loss']
            lbfgs_time = lbfgs_history[-1]['time'] - adam_time
            
            print(f"  Adam Phase:")
            print(f"    Final loss: {adam_final_loss:.2e}")
            print(f"    Time: {adam_time:.1f}s")
            print(f"    Epochs: {len(adam_history)}")
            
            print(f"  L-BFGS Phase:")
            print(f"    Final loss: {lbfgs_history[-1]['total_loss']:.2e}")
            print(f"    Improvement factor: {lbfgs_improvement:.1f}x")
            print(f"    Time: {lbfgs_time:.1f}s")
            print(f"    Iterations: {len(lbfgs_history)}")
            
            print(f"  Overall Benefit:")
            print(f"    L-BFGS achieved {lbfgs_improvement:.1f}x better accuracy")
            print(f"    in only {lbfgs_time/adam_time*100:.1f}% additional time")
    
    print("\n" + "="*80)
    
    # Recommendations
    print("\nRECOMMENDATIONS:")
    print("" + "-"*20)
    
    if best_accuracy[0] == 'Two-Stage':
        print("✓ Two-Stage optimization achieves the highest accuracy")
    
    if best_efficiency[0] == 'Two-Stage':
        print("✓ Two-Stage optimization provides the best accuracy-time tradeoff")
    
    adam_only_result = results.get('Adam Only', {})
    two_stage_result = results.get('Two-Stage', {})
    
    if (adam_only_result and two_stage_result and 
        two_stage_result['final_l2_error'] < adam_only_result['final_l2_error'] * 0.1):
        print("✓ L-BFGS refinement provides significant accuracy improvement")
    
    print("\nBest practices:")
    print("  • Use Adam for initial exploration (handles complex loss landscapes)")
    print("  • Follow with L-BFGS for high-precision convergence")
    print("  • Monitor loss plateaus to determine optimal transition point")
    print("  • L-BFGS is most effective when starting from a good Adam solution")
    
# Run detailed analysis
analyze_optimization_performance(comparison_results)

## 6. Advanced Two-Stage Strategy with Adaptive Switching

In [None]:
class AdaptiveTwoStageOptimizer:
    """Advanced two-stage optimizer with adaptive switching criteria"""
    
    def __init__(self, model, adam_lr=1e-3, lbfgs_lr=1, patience=500, 
                 min_improvement=1e-6, switch_threshold=1e-4):
        self.model = model
        self.adam_lr = adam_lr
        self.lbfgs_lr = lbfgs_lr
        self.patience = patience
        self.min_improvement = min_improvement
        self.switch_threshold = switch_threshold
        
    def train(self, x_domain, x0, y0, max_epochs=20000, print_every=1000):
        """Adaptive two-stage training with intelligent switching"""
        
        tracker = OptimizationTracker(self.model)
        
        print("=== ADAPTIVE TWO-STAGE OPTIMIZATION ===")
        print(f"Switch criteria: loss < {self.switch_threshold:.0e} OR plateau for {self.patience} epochs")
        
        # Stage 1: Adam with adaptive switching
        print("\nStage 1: Adam optimization with adaptive monitoring...")
        tracker.start_tracking('Adaptive-Adam')
        
        optimizer_adam = optim.Adam(self.model.parameters(), lr=self.adam_lr)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer_adam, mode='min', factor=0.5, patience=200, verbose=True
        )
        
        loss_history = []
        best_loss = float('inf')
        stagnation_count = 0
        switch_epoch = None
        
        for epoch in range(max_epochs):
            self.model.train()
            optimizer_adam.zero_grad()
            
            total_loss, ode_loss, ic_loss = compute_loss(self.model, x_domain, x0, y0)
            total_loss.backward()
            optimizer_adam.step()
            scheduler.step(total_loss)
            
            current_loss = total_loss.item()
            loss_history.append(current_loss)
            
            tracker.record_step(epoch, current_loss, ode_loss.item(), 
                              ic_loss.item(), 'Adaptive-Adam')
            
            # Check for improvement
            if current_loss < best_loss - self.min_improvement:
                best_loss = current_loss
                stagnation_count = 0
            else:
                stagnation_count += 1
            
            # Print progress
            if (epoch + 1) % print_every == 0:
                lr = optimizer_adam.param_groups[0]['lr']
                print(f"  Epoch {epoch+1:5d} | Loss: {current_loss:.2e} | "
                      f"Best: {best_loss:.2e} | Stagnation: {stagnation_count} | LR: {lr:.1e}")
            
            # Check switching criteria
            should_switch = (
                current_loss < self.switch_threshold or  # Low enough loss
                stagnation_count >= self.patience        # Plateaued
            )
            
            if should_switch:
                switch_epoch = epoch
                switch_reason = "loss threshold" if current_loss < self.switch_threshold else "plateau detected"
                print(f"\n  → Switching to L-BFGS at epoch {epoch+1} (reason: {switch_reason})")
                print(f"  → Adam final loss: {current_loss:.2e}")
                break
        
        adam_time = tracker.end_phase()
        
        # Stage 2: L-BFGS refinement
        if switch_epoch is not None:
            print(f"\nStage 2: L-BFGS refinement...")
            tracker.start_tracking('Adaptive-LBFGS')
            
            optimizer_lbfgs = optim.LBFGS(
                self.model.parameters(),
                lr=self.lbfgs_lr,
                max_iter=20,
                tolerance_grad=1e-9,
                tolerance_change=1e-12,
                history_size=100
            )
            
            lbfgs_epoch = 0
            
            def closure():
                nonlocal lbfgs_epoch
                optimizer_lbfgs.zero_grad()
                total_loss, ode_loss, ic_loss = compute_loss(self.model, x_domain, x0, y0)
                total_loss.backward()
                
                tracker.record_step(switch_epoch + lbfgs_epoch, total_loss.item(), 
                                  ode_loss.item(), ic_loss.item(), 'Adaptive-LBFGS')
                
                if lbfgs_epoch % 50 == 0:
                    print(f"  L-BFGS Iter {lbfgs_epoch:3d} | Loss: {total_loss.item():.2e}")
                
                lbfgs_epoch += 1
                return total_loss
            
            # Run L-BFGS optimization
            for i in range(25):  # Max L-BFGS steps
                optimizer_lbfgs.step(closure)
                
                # Check for convergence
                with torch.no_grad():
                    current_loss, _, _ = compute_loss(self.model, x_domain, x0, y0)
                    if current_loss.item() < 1e-12:
                        print(f"  → L-BFGS converged to machine precision at step {i+1}")
                        break
            
            lbfgs_time = tracker.end_phase()
            
            # Final assessment
            final_loss, _, _ = compute_loss(self.model, x_domain, x0, y0)
            improvement_factor = loss_history[switch_epoch] / final_loss.item()
            
            print(f"\n=== ADAPTIVE OPTIMIZATION RESULTS ===")
            print(f"Adam Phase: {switch_epoch+1:,} epochs, {adam_time:.1f}s")
            print(f"L-BFGS Phase: {lbfgs_epoch} iterations, {lbfgs_time:.1f}s")
            print(f"Final improvement: {improvement_factor:.1f}x better than Adam")
            print(f"Total time: {adam_time + lbfgs_time:.1f}s")
            
        else:
            print(f"\nAdam training completed without switching (max epochs reached)")
        
        return tracker

# Test the adaptive optimizer
print("Testing Adaptive Two-Stage Optimizer...")
adaptive_model = PINN(hidden_dim=32, num_layers=3).to(device)
adaptive_optimizer = AdaptiveTwoStageOptimizer(
    adaptive_model, 
    adam_lr=1e-3, 
    patience=1000, 
    switch_threshold=1e-5
)

adaptive_tracker = adaptive_optimizer.train(x_domain, x0, y0, max_epochs=15000, print_every=1000)

## 7. Key Insights and Best Practices

In [None]:
def demonstrate_solution_quality():
    """Demonstrate the final solution quality"""
    
    # Create test points
    x_test = torch.linspace(0, 5, 1000).view(-1, 1).to(device)
    y_true = analytical_solution(x_test)
    
    # Get predictions from adaptive model
    with torch.no_grad():
        y_pred = adaptive_model(x_test)
    
    # Calculate comprehensive error metrics
    l2_error = torch.sqrt(torch.mean((y_pred - y_true)**2))
    l2_relative = l2_error / torch.sqrt(torch.mean(y_true**2))
    l_inf_error = torch.max(torch.abs(y_pred - y_true))
    
    # Point-wise error analysis
    pointwise_error = torch.abs(y_pred - y_true)
    mean_error = torch.mean(pointwise_error)
    std_error = torch.std(pointwise_error)
    
    print("\n" + "="*60)
    print("           FINAL SOLUTION QUALITY ASSESSMENT")
    print("="*60)
    
    print(f"\nERROR METRICS:")
    print(f"  L2 Absolute Error:     {l2_error.item():.2e}")
    print(f"  L2 Relative Error:     {l2_relative.item():.2e}")
    print(f"  L∞ (Maximum) Error:    {l_inf_error.item():.2e}")
    print(f"  Mean Pointwise Error:  {mean_error.item():.2e}")
    print(f"  Std Pointwise Error:   {std_error.item():.2e}")
    
    # Test specific points
    test_points = torch.tensor([[0.0], [1.0], [2.0], [3.0], [4.0], [5.0]], device=device)
    
    print(f"\nPOINT-WISE VALIDATION:")
    print(f"{'x':<6} {'True':<12} {'Predicted':<12} {'Error':<12} {'Rel Error':<10}")
    print("-" * 60)
    
    with torch.no_grad():
        for x_val in test_points:
            y_true_pt = analytical_solution(x_val).item()
            y_pred_pt = adaptive_model(x_val).item()
            error_pt = abs(y_true_pt - y_pred_pt)
            rel_error_pt = error_pt / abs(y_true_pt) * 100 if abs(y_true_pt) > 1e-10 else 0
            
            print(f"{x_val.item():<6.1f} {y_true_pt:<12.6f} {y_pred_pt:<12.6f} "
                  f"{error_pt:<12.2e} {rel_error_pt:<10.3f}%")
    
    # Check initial condition satisfaction
    with torch.no_grad():
        y0_pred = adaptive_model(torch.tensor([[0.0]], device=device))
        ic_error = abs(y0_pred.item() - 1.0)
    
    print(f"\nINITIAL CONDITION CHECK:")
    print(f"  y(0) = 1.0 (required)")
    print(f"  y(0) = {y0_pred.item():.8f} (predicted)")
    print(f"  Error: {ic_error:.2e}")
    print(f"  Status: {'✓ SATISFIED' if ic_error < 1e-6 else '⚠ NEEDS IMPROVEMENT'}")
    
    # Physical behavior check
    print(f"\nPHYSICAL BEHAVIOR VALIDATION:")
    
    # Check if solution is decreasing (as expected for y' + y = 0)
    x_check = torch.linspace(0, 5, 100).view(-1, 1).to(device)
    x_check.requires_grad_(True)
    y_check = adaptive_model(x_check)
    dy_dx = torch.autograd.grad(y_check, x_check, 
                               grad_outputs=torch.ones_like(y_check),
                               create_graph=False)[0]
    
    is_decreasing = torch.all(dy_dx < 0.1)  # Allow small numerical errors
    print(f"  Solution monotonically decreasing: {'✓ YES' if is_decreasing else '✗ NO'}")
    print(f"  Maximum derivative: {torch.max(dy_dx).item():.2e} (should be negative)")
    
    # ODE satisfaction check
    ode_residual = dy_dx + y_check
    max_residual = torch.max(torch.abs(ode_residual)).item()
    mean_residual = torch.mean(torch.abs(ode_residual)).item()
    
    print(f"  ODE residual (dy/dx + y):")
    print(f"    Maximum |residual|: {max_residual:.2e}")
    print(f"    Mean |residual|:    {mean_residual:.2e}")
    print(f"    ODE satisfaction:   {'✓ EXCELLENT' if max_residual < 1e-6 else '✓ GOOD' if max_residual < 1e-4 else '⚠ NEEDS IMPROVEMENT'}")
    
    print("\n" + "="*60)
    
    # Overall assessment
    overall_success = (
        l2_relative.item() < 1e-6 and
        ic_error < 1e-6 and
        max_residual < 1e-6
    )
    
    print(f"OVERALL ASSESSMENT: {'🎉 OUTSTANDING SUCCESS!' if overall_success else '✓ Good performance, minor improvements possible'}")
    
    if overall_success:
        print("The two-stage optimization achieved machine precision accuracy!")
    
    return {
        'l2_error': l2_error.item(),
        'l2_relative': l2_relative.item(),
        'linf_error': l_inf_error.item(),
        'ic_error': ic_error,
        'max_residual': max_residual,
        'success': overall_success
    }

# Demonstrate final solution quality
final_metrics = demonstrate_solution_quality()

## 8. Summary and Best Practices

### Key Findings from Two-Stage Optimization:

1. **Adam Exploration Phase:**
   - Excellent for navigating complex, non-convex loss landscapes
   - Adaptive learning rates handle different gradient magnitudes
   - Robust to initialization and hyperparameter choices
   - Gets "close" to optimal solution quickly

2. **L-BFGS Refinement Phase:**
   - Superior final convergence precision
   - Leverages second-order information (curvature)
   - Can achieve machine precision accuracy
   - Most effective when starting from a good initial point

3. **Two-Stage Benefits:**
   - Combines exploration capability with precision
   - Often achieves 10-100x better final accuracy
   - Reliable convergence across different problems
   - Computational overhead of L-BFGS is justified by accuracy gains

### Implementation Best Practices:

1. **Switch Timing:**
   - Monitor loss plateaus (e.g., no improvement for 500-1000 epochs)
   - Switch when loss drops below problem-specific threshold
   - Use learning rate decay as indicator of Adam exhaustion

2. **Hyperparameter Guidelines:**
   - Adam LR: Start with 1e-3, use scheduler for adaptation
   - L-BFGS: Use default parameters, focus on tolerance settings
   - Patience: 500-2000 epochs depending on problem complexity

3. **Monitoring and Diagnostics:**
   - Track both loss and solution accuracy metrics
   - Monitor gradient norms and parameter updates
   - Validate physics constraint satisfaction

### When to Use Two-Stage Optimization:
- High-accuracy requirements (engineering applications)
- Complex PDE problems with multiple constraints
- Inverse problems requiring parameter precision
- When computational budget allows for longer training

This PyTorch implementation provides full control over the optimization process, enabling researchers to adapt the strategy for their specific PINN applications.