In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import os
import time

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Create directory for saving plots and models
os.makedirs('pinn_biharmonic_results_pytorch_final', exist_ok=True)

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# =============================================================================
# COMMON FUNCTIONS AND CLASSES
# =============================================================================

class BiharmonicPINN(nn.Module):
    # --- CORRECTED: Changed Tanh to SiLU for better gradient flow ---
    def __init__(self, layers, activation=nn.SiLU()):
    # --- END CORRECTION ---
        super(BiharmonicPINN, self).__init__()
        self.layers = layers
        self.activation = activation
        
        # Build neural network
        self.linears = nn.ModuleList()
        for i in range(len(layers)-1):
            self.linears.append(nn.Linear(layers[i], layers[i+1]))
            
        self.init_weights()
        
    def init_weights(self):
        for linear in self.linears:
            nn.init.xavier_normal_(linear.weight)
            nn.init.constant_(linear.bias, 0.0)
            
    def forward(self, x):
        for i, linear in enumerate(self.linears):
            x = linear(x)
            if i < len(self.linears) - 1:  # No activation on output layer
                x = self.activation(x)
        return x

def compute_derivatives(u, x, y):
    """Compute all required derivatives using automatic differentiation"""
    # First derivatives
    u_x = torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u), 
                             create_graph=True, retain_graph=True)[0]
    u_y = torch.autograd.grad(u, y, grad_outputs=torch.ones_like(u), 
                             create_graph=True, retain_graph=True)[0]
    
    # Second derivatives
    u_xx = torch.autograd.grad(u_x, x, grad_outputs=torch.ones_like(u_x), 
                              create_graph=True, retain_graph=True)[0]
    u_yy = torch.autograd.grad(u_y, y, grad_outputs=torch.ones_like(u_y), 
                              create_graph=True, retain_graph=True)[0]
    u_xy = torch.autograd.grad(u_x, y, grad_outputs=torch.ones_like(u_x), 
                              create_graph=True, retain_graph=True)[0]
    
    # Third derivatives
    u_xxx = torch.autograd.grad(u_xx, x, grad_outputs=torch.ones_like(u_xx), 
                               create_graph=True, retain_graph=True)[0]
    u_xxy = torch.autograd.grad(u_xx, y, grad_outputs=torch.ones_like(u_xx), 
                               create_graph=True, retain_graph=True)[0]
    u_xyy = torch.autograd.grad(u_xy, y, grad_outputs=torch.ones_like(u_xy), 
                               create_graph=True, retain_graph=True)[0]
    u_yyy = torch.autograd.grad(u_yy, y, grad_outputs=torch.ones_like(u_yy), 
                               create_graph=True, retain_graph=True)[0]
    
    # Fourth derivatives
    u_xxxx = torch.autograd.grad(u_xxx, x, grad_outputs=torch.ones_like(u_xxx), 
                                create_graph=True, retain_graph=True)[0]
    u_xxyy = torch.autograd.grad(u_xxy, y, grad_outputs=torch.ones_like(u_xxy), 
                                create_graph=True, retain_graph=True)[0]
    u_yyyy = torch.autograd.grad(u_yyy, y, grad_outputs=torch.ones_like(u_yyy), 
                                create_graph=True, retain_graph=True)[0]
    
    return u, u_x, u_y, u_xx, u_yy, u_xy, u_xxx, u_xxy, u_xyy, u_yyy, u_xxxx, u_xxyy, u_yyyy

def get_unit_normal(x, y):
    """Compute correct unit outward normal vectors for the unit square (0,1)^2"""
    n_x = torch.zeros_like(x)
    n_y = torch.zeros_like(y)
    
    # Left boundary (x=0): normal = (-1, 0)
    left_mask = (x <= 1e-6)
    n_x[left_mask] = -1.0
    
    # Right boundary (x=1): normal = (1, 0)
    right_mask = (x >= 1.0 - 1e-6)
    n_x[right_mask] = 1.0
    
    # Bottom boundary (y=0): normal = (0, -1)
    bottom_mask = (y <= 1e-6)
    n_y[bottom_mask] = -1.0
    
    # Top boundary (y=1): normal = (0, 1)
    top_mask = (y >= 1.0 - 1e-6)
    n_y[top_mask] = 1.0
    
    return n_x, n_y

def compute_normal_derivatives(x, y, u_x, u_y, u_xx, u_yy, u_xy, u_xxx, u_xxy, u_xyy, u_yyy):
    """Compute normal derivatives and normal derivative of Laplacian"""
    # Use correct unit normal vectors
    n_x, n_y = get_unit_normal(x, y)
    
    # First normal derivative
    u_n = n_x * u_x + n_y * u_y
    
    # Normal derivative of Laplacian
    laplacian_x = u_xxx + u_xyy  # ∂/∂x(Δu)
    laplacian_y = u_xxy + u_yyy  # ∂/∂y(Δu)
    laplacian_n = n_x * laplacian_x + n_y * laplacian_y
    
    return u_n, laplacian_n

def compute_biharmonic(u_xxxx, u_xxyy, u_yyyy):
    """Compute biharmonic operator Δ²u = u_xxxx + 2u_xxyy + u_yyyy"""
    return u_xxxx + 2.0 * u_xxyy + u_yyyy

def compute_errors(u_pred, u_exact, x, y):
    """Compute L2 and energy errors according to PDF definition"""
    # L2 error: ∥u - uθ∥L2(Ω)
    l2_error = torch.sqrt(torch.mean((u_pred - u_exact)**2))
    
    # Compute gradients for energy error
    u_pred_x = torch.autograd.grad(u_pred, x, grad_outputs=torch.ones_like(u_pred), 
                                  create_graph=True, retain_graph=True)[0]
    u_pred_y = torch.autograd.grad(u_pred, y, grad_outputs=torch.ones_like(u_pred), 
                                  create_graph=True, retain_graph=True)[0]
    u_exact_x = torch.autograd.grad(u_exact, x, grad_outputs=torch.ones_like(u_exact), 
                                  create_graph=True, retain_graph=True)[0]
    u_exact_y = torch.autograd.grad(u_exact, y, grad_outputs=torch.ones_like(u_exact), 
                                  create_graph=True, retain_graph=True)[0]
    
    # Energy error = ∥u - uθ∥L2(Ω) + ∥∇(u - uθ)∥L2(Ω)
    grad_diff_norm = torch.sqrt(torch.mean((u_pred_x - u_exact_x)**2 + (u_pred_y - u_exact_y)**2))
    energy_error = l2_error + grad_diff_norm
    
    # Relative errors
    l2_norm_exact = torch.sqrt(torch.mean(u_exact**2))
    grad_norm_exact = torch.sqrt(torch.mean(u_exact_x**2 + u_exact_y**2))
    energy_norm_exact = l2_norm_exact + grad_norm_exact
    
    l2_relative = l2_error / l2_norm_exact
    energy_relative = energy_error / energy_norm_exact
    
    return l2_error, energy_error, l2_relative, energy_relative

def compute_final_errors(pinn, exact_solution, num_test_points=1000):
    """Compute final errors without gradient issues"""
    # Create new test points with gradients enabled
    x_test = torch.rand((num_test_points, 1), dtype=torch.float32, requires_grad=True).to(device)
    y_test = torch.rand((num_test_points, 1), dtype=torch.float32, requires_grad=True).to(device)
    
    # Forward pass
    X_test = torch.cat([x_test, y_test], dim=1)
    u_pred = pinn(X_test)
    u_exact = exact_solution(x_test, y_test)
    
    # L2 error
    l2_error = torch.sqrt(torch.mean((u_pred - u_exact)**2))
    
    # Compute gradients for energy error
    u_pred_x = torch.autograd.grad(u_pred, x_test, grad_outputs=torch.ones_like(u_pred), 
                                  create_graph=False, retain_graph=True)[0]
    u_pred_y = torch.autograd.grad(u_pred, y_test, grad_outputs=torch.ones_like(u_pred), 
                                  create_graph=False, retain_graph=True)[0]
    u_exact_x = torch.autograd.grad(u_exact, x_test, grad_outputs=torch.ones_like(u_exact), 
                                  create_graph=False, retain_graph=True)[0]
    u_exact_y = torch.autograd.grad(u_exact, y_test, grad_outputs=torch.ones_like(u_exact), 
                                  create_graph=False, retain_graph=True)[0]
    
    # Energy error = ∥u - uθ∥L2(Ω) + ∥∇(u - uθ)∥L2(Ω)
    grad_diff_norm = torch.sqrt(torch.mean((u_pred_x - u_exact_x)**2 + (u_pred_y - u_exact_y)**2))
    energy_error = l2_error + grad_diff_norm
    
    # Relative errors
    l2_norm_exact = torch.sqrt(torch.mean(u_exact**2))
    grad_norm_exact = torch.sqrt(torch.mean(u_exact_x**2 + u_exact_y**2))
    energy_norm_exact = l2_norm_exact + grad_norm_exact
    
    l2_relative = l2_error / l2_norm_exact
    energy_relative = energy_error / energy_norm_exact
    
    return (l2_error.item(), energy_error.item(), 
            l2_relative.item(), energy_relative.item())

# =============================================================================
# EXAMPLE 1: u = (1/(2π²)) sin(πx) sin(πy)
# =============================================================================

def run_example1():
    print("=" * 80)
    print("RUNNING EXAMPLE 1: u = (1/(2π²)) sin(πx) sin(πy)")
    print("=" * 80)
    
    # Exact solution and derivatives
    def exact_solution1(x, y):
        return (1.0/(2.0*np.pi**2)) * torch.sin(np.pi*x) * torch.sin(np.pi*y)
    
    def source_term1(x, y):
        # CORRECT: f = Δ²u = 2π² sin(πx) sin(πy)
        return (2.0 * np.pi**2) * torch.sin(np.pi*x) * torch.sin(np.pi*y)
    
    # Generate training data
    N_int = 20000
    N_bc = 6000
    
    # Interior collocation points
    x_int = torch.rand((N_int, 1), dtype=torch.float32, requires_grad=True).to(device)
    y_int = torch.rand((N_int, 1), dtype=torch.float32, requires_grad=True).to(device)
    
    # Boundary collocation points
    N_bc_side = N_bc // 4
    
    # Bottom boundary (y=0)
    x_bc_bottom = torch.rand((N_bc_side, 1), dtype=torch.float32, requires_grad=True).to(device)
    y_bc_bottom = torch.zeros_like(x_bc_bottom, requires_grad=True)
    
    # Top boundary (y=1)
    x_bc_top = torch.rand((N_bc_side, 1), dtype=torch.float32, requires_grad=True).to(device)
    y_bc_top = torch.ones_like(x_bc_top, requires_grad=True)
    
    # Left boundary (x=0)
    x_bc_left = torch.zeros((N_bc_side, 1), dtype=torch.float32, requires_grad=True).to(device)
    y_bc_left = torch.rand((N_bc_side, 1), dtype=torch.float32, requires_grad=True).to(device)
    
    # Right boundary (x=1)
    x_bc_right = torch.ones((N_bc_side, 1), dtype=torch.float32, requires_grad=True).to(device)
    y_bc_right = torch.rand((N_bc_side, 1), dtype=torch.float32, requires_grad=True).to(device)
    
    # Combine boundary points
    x_bc = torch.cat([x_bc_bottom, x_bc_top, x_bc_left, x_bc_right], dim=0)
    y_bc = torch.cat([y_bc_bottom, y_bc_top, y_bc_left, y_bc_right], dim=0)
    
    # Initialize model
    layers = [2, 84, 84, 84, 84, 1]
    pinn = BiharmonicPINN(layers).to(device)
    
    # Optimizer
    optimizer = torch.optim.Adam(pinn.parameters(), lr=1e-3)
    
    # Training parameters
    epochs = 10000
    print_interval = 100
    
    # Loss history
    loss_history = []
    int_loss_history = []
    bc_loss_history = []
    l2_error_history = []
    energy_error_history = []
    
    # Training loop
    def train_step():
        optimizer.zero_grad()
        
        # Interior points forward pass
        X_int = torch.cat([x_int, y_int], dim=1)
        u_int = pinn(X_int)
        
        # Compute derivatives for interior points
        derivatives_int = compute_derivatives(u_int, x_int, y_int)
        u_int, u_x_int, u_y_int, u_xx_int, u_yy_int, u_xy_int, u_xxx_int, u_xxy_int, u_xyy_int, u_yyy_int, u_xxxx_int, u_xxyy_int, u_yyyy_int = derivatives_int
        
        # Biharmonic operator
        biharmonic_int = compute_biharmonic(u_xxxx_int, u_xxyy_int, u_yyyy_int)
        
        # Source term
        f_int = source_term1(x_int, y_int)
        
        # Interior loss
        loss_int = torch.mean((biharmonic_int - f_int)**2)
        
        # Boundary points forward pass
        X_bc = torch.cat([x_bc, y_bc], dim=1)
        u_bc = pinn(X_bc)
        
        # Compute derivatives for boundary points
        derivatives_bc = compute_derivatives(u_bc, x_bc, y_bc)
        u_bc, u_x_bc, u_y_bc, u_xx_bc, u_yy_bc, u_xy_bc, u_xxx_bc, u_xxy_bc, u_xyy_bc, u_yyy_bc, _, _, _ = derivatives_bc
        
        # Compute normal derivatives for PINN prediction
        u_n_bc, laplacian_n_bc = compute_normal_derivatives(
            x_bc, y_bc, u_x_bc, u_y_bc, u_xx_bc, u_yy_bc, u_xy_bc, 
            u_xxx_bc, u_xxy_bc, u_xyy_bc, u_yyy_bc
        )
        
        # Compute exact boundary conditions using autograd
        u_exact_bc = exact_solution1(x_bc, y_bc)
        
        # Compute exact first derivatives
        u_exact_x_bc = torch.autograd.grad(u_exact_bc, x_bc, grad_outputs=torch.ones_like(u_exact_bc), 
                                          create_graph=True, retain_graph=True)[0]
        u_exact_y_bc = torch.autograd.grad(u_exact_bc, y_bc, grad_outputs=torch.ones_like(u_exact_bc), 
                                          create_graph=True, retain_graph=True)[0]
        
        # Compute exact normal derivative (g1)
        n_x_bc, n_y_bc = get_unit_normal(x_bc, y_bc)
        u_n_exact_bc = n_x_bc * u_exact_x_bc + n_y_bc * u_exact_y_bc
        
        # Compute exact Laplacian and its normal derivative (g2)
        u_exact_xx_bc = torch.autograd.grad(u_exact_x_bc, x_bc, grad_outputs=torch.ones_like(u_exact_x_bc), 
                                           create_graph=True, retain_graph=True)[0]
        u_exact_yy_bc = torch.autograd.grad(u_exact_y_bc, y_bc, grad_outputs=torch.ones_like(u_exact_y_bc), 
                                           create_graph=True, retain_graph=True)[0]
        laplacian_exact_bc = u_exact_xx_bc + u_exact_yy_bc
        
        # Compute gradient of Laplacian
        laplacian_x_exact = torch.autograd.grad(laplacian_exact_bc, x_bc, grad_outputs=torch.ones_like(laplacian_exact_bc), 
                                                create_graph=True, retain_graph=True)[0]
        laplacian_y_exact = torch.autograd.grad(laplacian_exact_bc, y_bc, grad_outputs=torch.ones_like(laplacian_exact_bc), 
                                                create_graph=True, retain_graph=True)[0]
        
        # Compute exact normal derivative of Laplacian (g2)
        laplacian_n_exact_bc = n_x_bc * laplacian_x_exact + n_y_bc * laplacian_y_exact
        
        # Boundary loss
        loss_bc = torch.mean((u_n_bc - u_n_exact_bc)**2) + \
                  torch.mean((laplacian_n_bc - laplacian_n_exact_bc)**2)
        
        # --- CORRECTED: Added lambda weighting ---
        lambda_int = 1.0
        lambda_bc = 100.0  # Weighting to stabilize training
        total_loss = (lambda_int * loss_int) + (lambda_bc * loss_bc)
        # --- END CORRECTION ---
        
        # Backward pass
        total_loss.backward()
        optimizer.step()
        
        return total_loss, loss_int, loss_bc
    
    # Test points for error computation during training
    x_test_train = torch.rand((1000, 1), dtype=torch.float32, requires_grad=True).to(device)
    y_test_train = torch.rand((1000, 1), dtype=torch.float32, requires_grad=True).to(device)
    
    print("Starting training for Example 1...")
    start_time = time.time()
    
    for epoch in range(epochs + 1):
        total_loss, loss_int, loss_bc = train_step()
        
        if epoch % print_interval == 0:
            # Compute errors
            X_test = torch.cat([x_test_train, y_test_train], dim=1)
            u_pred = pinn(X_test)
            u_exact = exact_solution1(x_test_train, y_test_train)
            
            l2_error, energy_error, l2_relative, energy_relative = compute_errors(
                u_pred, u_exact, x_test_train, y_test_train
            )
            
            loss_history.append(total_loss.item())
            int_loss_history.append(loss_int.item())
            bc_loss_history.append(loss_bc.item())
            l2_error_history.append(l2_error.item())
            energy_error_history.append(energy_error.item())
            
            print(f"Epoch {epoch:5d}: Total Loss = {total_loss.item():.2e}, "
                  f"Int Loss = {loss_int.item():.2e}, BC Loss = {loss_bc.item():.2e}, "
                  f"L2 Error = {l2_error.item():.2e}, Energy Error = {energy_error.item():.2e}")
    
    training_time = time.time() - start_time
    print(f"Training completed in {training_time:.2f} seconds")
    
    # =========================================================================
    # PLOTTING AND SAVING RESULTS
    # =========================================================================
    
    # Generate grid for final prediction
    x_plot = np.linspace(0, 1, 100)
    y_plot = np.linspace(0, 1, 100)
    X_plot, Y_plot = np.meshgrid(x_plot, y_plot)
    X_plot_flat = X_plot.flatten().reshape(-1, 1)
    Y_plot_flat = Y_plot.flatten().reshape(-1, 1)
    X_plot_tf = torch.tensor(np.hstack([X_plot_flat, Y_plot_flat]), dtype=torch.float32).to(device)
    
    # Predictions
    with torch.no_grad():
        u_pred_plot = pinn(X_plot_tf).cpu().numpy().reshape(100, 100)
    
    u_exact_plot = exact_solution1(
        torch.tensor(X_plot_flat, dtype=torch.float32).to(device), 
        torch.tensor(Y_plot_flat, dtype=torch.float32).to(device)
    ).cpu().numpy().reshape(100, 100)
    
    error_plot = np.abs(u_pred_plot - u_exact_plot)
    
    # Plot 1: Loss history
    plt.figure(figsize=(10, 6))
    plt.semilogy(loss_history, label='Total Loss', linewidth=2)
    plt.semilogy(int_loss_history, label='Interior Loss', linewidth=2)
    plt.semilogy(bc_loss_history, label='Boundary Loss', linewidth=2)
    plt.xlabel('Epoch (x100)')
    plt.ylabel('Loss')
    plt.title('Example 1: Training Loss History')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig('pinn_biharmonic_results_pytorch_final/example1_loss_history.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # Plot 2: Error history
    plt.figure(figsize=(10, 6))
    plt.semilogy(l2_error_history, label='L2 Error', linewidth=2)
    plt.semilogy(energy_error_history, label='Energy Error', linewidth=2)
    plt.xlabel('Epoch (x100)')
    plt.ylabel('Error')
    plt.title('Example 1: Error History')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig('pinn_biharmonic_results_pytorch_final/example1_error_history.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # Plot 3: Predicted solution
    plt.figure(figsize=(8, 6))
    contour = plt.contourf(X_plot, Y_plot, u_pred_plot, levels=50, cmap='viridis')
    plt.colorbar(contour)
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('Example 1: PINN Predicted Solution')
    plt.savefig('pinn_biharmonic_results_pytorch_final/example1_predicted_solution.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # Plot 4: Exact solution
    plt.figure(figsize=(8, 6))
    contour = plt.contourf(X_plot, Y_plot, u_exact_plot, levels=50, cmap='viridis')
    plt.colorbar(contour)
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('Example 1: Exact Solution')
    plt.savefig('pinn_biharmonic_results_pytorch_final/example1_exact_solution.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # Plot 5: Absolute error
    plt.figure(figsize=(8, 6))
    contour = plt.contourf(X_plot, Y_plot, error_plot, levels=50, cmap='hot')
    plt.colorbar(contour)
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('Example 1: Absolute Error')
    plt.savefig('pinn_biharmonic_results_pytorch_final/example1_absolute_error.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # 3D plots
    fig = plt.figure(figsize=(15, 5))
    
    # Predicted solution 3D
    ax1 = fig.add_subplot(131, projection='3d')
    surf1 = ax1.plot_surface(X_plot, Y_plot, u_pred_plot, cmap='viridis', alpha=0.8)
    ax1.set_title('PINN Predicted Solution')
    ax1.set_xlabel('x')
    ax1.set_ylabel('y')
    ax1.set_zlabel('u(x,y)')
    
    # Exact solution 3D
    ax2 = fig.add_subplot(132, projection='3d')
    surf2 = ax2.plot_surface(X_plot, Y_plot, u_exact_plot, cmap='viridis', alpha=0.8)
    ax2.set_title('Exact Solution')
    ax2.set_xlabel('x')
    ax2.set_ylabel('y')
    ax2.set_zlabel('u(x,y)')
    
    # Error 3D
    ax3 = fig.add_subplot(133, projection='3d')
    surf3 = ax3.plot_surface(X_plot, Y_plot, error_plot, cmap='hot', alpha=0.8)
    ax3.set_title('Absolute Error')
    ax3.set_xlabel('x')
    ax3.set_ylabel('y')
    ax3.set_zlabel('Error')
    
    plt.tight_layout()
    plt.savefig('pinn_biharmonic_results_pytorch_final/example1_3d_comparison.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # Save model
    torch.save({
        'model_state_dict': pinn.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, 'pinn_biharmonic_results_pytorch_final/example1_model.pth')
    
    # Final error evaluation using the new function
    print("Computing final errors...")
    l2_error_final, energy_error_final, l2_relative_final, energy_relative_final = compute_final_errors(
        pinn, exact_solution1
    )
    
    print("\n" + "="*50)
    print("FINAL RESULTS - EXAMPLE 1")
    print("="*50)
    print(f"L2 Error: {l2_error_final:.6e}")
    print(f"Energy Error: {energy_error_final:.6e}")
    print(f"Relative L2 Error: {l2_relative_final:.6e}")
    print(f"Relative Energy Error: {energy_relative_final:.6e}")
    
    return pinn

# =============================================================================
# EXAMPLE 2: u = x²y²(1-x)²(1-y)²
# =============================================================================

def run_example2():
    print("\n" + "="*80)
    print("RUNNING EXAMPLE 2: u = x²y²(1-x)²(1-y)²")
    print("="*80)
    
    # Exact solution
    def exact_solution2(x, y):
        return (x**2) * (y**2) * ((1-x)**2) * ((1-y)**2)
    
    # Generate training data
    N_int = 20000
    N_bc = 6000
    
    # Interior collocation points
    x_int = torch.rand((N_int, 1), dtype=torch.float32, requires_grad=True).to(device)
    y_int = torch.rand((N_int, 1), dtype=torch.float32, requires_grad=True).to(device)
    
    # Boundary collocation points
    N_bc_side = N_bc // 4
    
    x_bc_bottom = torch.rand((N_bc_side, 1), dtype=torch.float32, requires_grad=True).to(device)
    y_bc_bottom = torch.zeros_like(x_bc_bottom, requires_grad=True)
    
    x_bc_top = torch.rand((N_bc_side, 1), dtype=torch.float32, requires_grad=True).to(device)
    y_bc_top = torch.ones_like(x_bc_top, requires_grad=True)
    
    x_bc_left = torch.zeros((N_bc_side, 1), dtype=torch.float32, requires_grad=True).to(device)
    y_bc_left = torch.rand((N_bc_side, 1), dtype=torch.float32, requires_grad=True).to(device)
    
    x_bc_right = torch.ones((N_bc_side, 1), dtype=torch.float32, requires_grad=True).to(device)
    y_bc_right = torch.rand((N_bc_side, 1), dtype=torch.float32, requires_grad=True).to(device)
    
    x_bc = torch.cat([x_bc_bottom, x_bc_top, x_bc_left, x_bc_right], dim=0)
    y_bc = torch.cat([y_bc_bottom, y_bc_top, y_bc_left, y_bc_right], dim=0)
    
    # Initialize model
    layers = [2, 84, 84, 84, 84, 1]
    pinn = BiharmonicPINN(layers).to(device)
    
    # Optimizer
    optimizer = torch.optim.Adam(pinn.parameters(), lr=1e-3)
    
    # Training parameters
    epochs = 10000
    print_interval = 100
    
    # Loss history
    loss_history = []
    int_loss_history = []
    bc_loss_history = []
    l2_error_history = []
    energy_error_history = []
    
    # Training loop
    def train_step():
        optimizer.zero_grad()
        
        # Interior points
        X_int = torch.cat([x_int, y_int], dim=1)
        u_int = pinn(X_int)
        
        derivatives_int = compute_derivatives(u_int, x_int, y_int)
        u_int, u_x_int, u_y_int, u_xx_int, u_yy_int, u_xy_int, u_xxx_int, u_xxy_int, u_xyy_int, u_yyy_int, u_xxxx_int, u_xxyy_int, u_yyyy_int = derivatives_int
        
        biharmonic_int = compute_biharmonic(u_xxxx_int, u_xxyy_int, u_yyyy_int)
        
        # Compute source term using autograd on exact solution
        u_exact_int = exact_solution2(x_int, y_int)
        derivatives_exact_int = compute_derivatives(u_exact_int, x_int, y_int)
        _, _, _, _, _, _, _, _, _, _, u_xxxx_exact, u_xxyy_exact, u_yyyy_exact = derivatives_exact_int
        f_int = compute_biharmonic(u_xxxx_exact, u_xxyy_exact, u_yyyy_exact)
        
        loss_int = torch.mean((biharmonic_int - f_int)**2)
        
        # Boundary points
        X_bc = torch.cat([x_bc, y_bc], dim=1)
        u_bc = pinn(X_bc)
        
        derivatives_bc = compute_derivatives(u_bc, x_bc, y_bc)
        u_bc, u_x_bc, u_y_bc, u_xx_bc, u_yy_bc, u_xy_bc, u_xxx_bc, u_xxy_bc, u_xyy_bc, u_yyy_bc, _, _, _ = derivatives_bc
        
        u_n_bc, laplacian_n_bc = compute_normal_derivatives(
            x_bc, y_bc, u_x_bc, u_y_bc, u_xx_bc, u_yy_bc, u_xy_bc, 
            u_xxx_bc, u_xxy_bc, u_xyy_bc, u_yyy_bc
        )
        
        # Compute exact boundary conditions using autograd
        u_exact_bc = exact_solution2(x_bc, y_bc)
        
        # Compute exact first derivatives
        u_exact_x_bc = torch.autograd.grad(u_exact_bc, x_bc, grad_outputs=torch.ones_like(u_exact_bc), 
                                          create_graph=True, retain_graph=True)[0]
        u_exact_y_bc = torch.autograd.grad(u_exact_bc, y_bc, grad_outputs=torch.ones_like(u_exact_bc), 
                                          create_graph=True, retain_graph=True)[0]
        
        # Compute exact normal derivative (g1)
        n_x_bc, n_y_bc = get_unit_normal(x_bc, y_bc)
        u_n_exact_bc = n_x_bc * u_exact_x_bc + n_y_bc * u_exact_y_bc
        
        # Compute exact Laplacian and its normal derivative (g2)
        u_exact_xx_bc = torch.autograd.grad(u_exact_x_bc, x_bc, grad_outputs=torch.ones_like(u_exact_x_bc), 
                                           create_graph=True, retain_graph=True)[0]
        u_exact_yy_bc = torch.autograd.grad(u_exact_y_bc, y_bc, grad_outputs=torch.ones_like(u_exact_y_bc), 
                                           create_graph=True, retain_graph=True)[0]
        laplacian_exact_bc = u_exact_xx_bc + u_exact_yy_bc
        
        # Compute gradient of Laplacian
        laplacian_x_exact = torch.autograd.grad(laplacian_exact_bc, x_bc, grad_outputs=torch.ones_like(laplacian_exact_bc), 
                                                create_graph=True, retain_graph=True)[0]
        laplacian_y_exact = torch.autograd.grad(laplacian_exact_bc, y_bc, grad_outputs=torch.ones_like(laplacian_exact_bc), 
                                                create_graph=True, retain_graph=True)[0]
        
        # Compute exact normal derivative of Laplacian (g2)
        laplacian_n_exact_bc = n_x_bc * laplacian_x_exact + n_y_bc * laplacian_y_exact
        
        # Boundary loss
        loss_bc = torch.mean((u_n_bc - u_n_exact_bc)**2) + \
                  torch.mean((laplacian_n_bc - laplacian_n_exact_bc)**2)
        
        # --- CORRECTED: Added lambda weighting ---
        lambda_int = 1.0
        lambda_bc = 100.0  # Weighting to stabilize training
        total_loss = (lambda_int * loss_int) + (lambda_bc * loss_bc)
        # --- END CORRECTION ---
        
        # Backward pass
        total_loss.backward()
        optimizer.step()
        
        return total_loss, loss_int, loss_bc
    
    # Test points for error computation during training
    x_test_train = torch.rand((1000, 1), dtype=torch.float32, requires_grad=True).to(device)
    y_test_train = torch.rand((1000, 1), dtype=torch.float32, requires_grad=True).to(device)
    
    print("Starting training for Example 2...")
    start_time = time.time()
    
    for epoch in range(epochs + 1):
        total_loss, loss_int, loss_bc = train_step()
        
        if epoch % print_interval == 0:
            X_test = torch.cat([x_test_train, y_test_train], dim=1)
            u_pred = pinn(X_test)
            u_exact = exact_solution2(x_test_train, y_test_train)
            
            l2_error, energy_error, l2_relative, energy_relative = compute_errors(
                u_pred, u_exact, x_test_train, y_test_train
            )
            
            loss_history.append(total_loss.item())
            int_loss_history.append(loss_int.item())
            bc_loss_history.append(loss_bc.item())
            l2_error_history.append(l2_error.item())
            energy_error_history.append(energy_error.item())
            
            print(f"Epoch {epoch:5d}: Total Loss = {total_loss.item():.2e}, "
                  f"Int Loss = {loss_int.item():.2e}, BC Loss = {loss_bc.item():.2e}, "
                  f"L2 Error = {l2_error.item():.2e}, Energy Error = {energy_error.item():.2e}")
    
    training_time = time.time() - start_time
    print(f"Training completed in {training_time:.2f} seconds")
    
    # =========================================================================
    # PLOTTING AND SAVING RESULTS
    # =========================================================================
    
    # Generate grid for final prediction
    x_plot = np.linspace(0, 1, 100)
    y_plot = np.linspace(0, 1, 100)
    X_plot, Y_plot = np.meshgrid(x_plot, y_plot)
    X_plot_flat = X_plot.flatten().reshape(-1, 1)
    Y_plot_flat = Y_plot.flatten().reshape(-1, 1)
    X_plot_tf = torch.tensor(np.hstack([X_plot_flat, Y_plot_flat]), dtype=torch.float32).to(device)
    
    # Predictions
    with torch.no_grad():
        u_pred_plot = pinn(X_plot_tf).cpu().numpy().reshape(100, 100)
    
    u_exact_plot = exact_solution2(
        torch.tensor(X_plot_flat, dtype=torch.float32).to(device), 
        torch.tensor(Y_plot_flat, dtype=torch.float32).to(device)
    ).cpu().numpy().reshape(100, 100)
    
    error_plot = np.abs(u_pred_plot - u_exact_plot)
    
    # Plot 1: Loss history
    plt.figure(figsize=(10, 6))
    plt.semilogy(loss_history, label='Total Loss', linewidth=2)
    plt.semilogy(int_loss_history, label='Interior Loss', linewidth=2)
    plt.semilogy(bc_loss_history, label='Boundary Loss', linewidth=2)
    plt.xlabel('Epoch (x100)')
    plt.ylabel('Loss')
    plt.title('Example 2: Training Loss History')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig('pinn_biharmonic_results_pytorch_final/example2_loss_history.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # Plot 2: Error history
    plt.figure(figsize=(10, 6))
    plt.semilogy(l2_error_history, label='L2 Error', linewidth=2)
    plt.semilogy(energy_error_history, label='Energy Error', linewidth=2)
    plt.xlabel('Epoch (x100)')
    plt.ylabel('Error')
    plt.title('Example 2: Error History')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig('pinn_biharmonic_results_pytorch_final/example2_error_history.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # Plot 3: Predicted solution
    plt.figure(figsize=(8, 6))
    contour = plt.contourf(X_plot, Y_plot, u_pred_plot, levels=50, cmap='viridis')
    plt.colorbar(contour)
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('Example 2: PINN Predicted Solution')
    plt.savefig('pinn_biharmonic_results_pytorch_final/example2_predicted_solution.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # Plot 4: Exact solution
    plt.figure(figsize=(8, 6))
    contour = plt.contourf(X_plot, Y_plot, u_exact_plot, levels=50, cmap='viridis')
    plt.colorbar(contour)
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('Example 2: Exact Solution')
    plt.savefig('pinn_biharmonic_results_pytorch_final/example2_exact_solution.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # Plot 5: Absolute error
    plt.figure(figsize=(8, 6))
    contour = plt.contourf(X_plot, Y_plot, error_plot, levels=50, cmap='hot')
    plt.colorbar(contour)
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('Example 2: Absolute Error')
    plt.savefig('pinn_biharmonic_results_pytorch_final/example2_absolute_error.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # 3D plots
    fig = plt.figure(figsize=(15, 5))
    
    ax1 = fig.add_subplot(131, projection='3d')
    surf1 = ax1.plot_surface(X_plot, Y_plot, u_pred_plot, cmap='viridis', alpha=0.8)
    ax1.set_title('PINN Predicted Solution')
    ax1.set_xlabel('x')
    ax1.set_ylabel('y')
    ax1.set_zlabel('u(x,y)')
    
    ax2 = fig.add_subplot(132, projection='3d')
    surf2 = ax2.plot_surface(X_plot, Y_plot, u_exact_plot, cmap='viridis', alpha=0.8)
    ax2.set_title('Exact Solution')
    ax2.set_xlabel('x')
    ax2.set_ylabel('y')
    ax2.set_zlabel('u(x,y)')
    
    ax3 = fig.add_subplot(133, projection='3d')
    surf3 = ax3.plot_surface(X_plot, Y_plot, error_plot, cmap='hot', alpha=0.8)
    ax3.set_title('Absolute Error')
    ax3.set_xlabel('x')
    ax3.set_ylabel('y')
    ax3.set_zlabel('Error')
    
    plt.tight_layout()
    plt.savefig('pinn_biharmonic_results_pytorch_final/example2_3d_comparison.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # Save model
    torch.save({
        'model_state_dict': pinn.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, 'pinn_biharmonic_results_pytorch_final/example2_model.pth')
    
    # Final error evaluation using the new function
    print("Computing final errors...")
    l2_error_final, energy_error_final, l2_relative_final, energy_relative_final = compute_final_errors(
        pinn, exact_solution2
    )
    
    print("\n" + "="*50)
    print("FINAL RESULTS - EXAMPLE 2")
    print("="*50)
    print(f"L2 Error: {l2_error_final:.6e}")
    print(f"Energy Error: {energy_error_final:.6e}")
    print(f"Relative L2 Error: {l2_relative_final:.6e}")
    print(f"Relative Energy Error: {energy_relative_final:.6e}")
    
    return pinn

# =============================================================================
# MAIN EXECUTION
# =============================================================================

if __name__ == "__main__":
    print("Physics-Informed Neural Network for Biharmonic Problem (P4)")
    print("Cahn-Hilliard Boundary Conditions - FINAL CORRECTED PyTorch Implementation")
    print("Domain: Ω = (0,1)²")
    print("Architecture: 4 hidden layers with 84 units each")
    print("Training points: 20,000 interior, 6,000 boundary")
    print("=" * 80)
    
    # Run both examples
    pinn1 = run_example1()
    pinn2 = run_example2()
    
    print("\n" + "="*80)
    print("ALL COMPUTATIONS COMPLETED SUCCESSFULLY!")
    print("Results saved in 'pinn_biharmonic_results_pytorch_final' directory")
    print("="*80)

Using device: cuda
Physics-Informed Neural Network for Biharmonic Problem (P4)
Cahn-Hilliard Boundary Conditions - FINAL CORRECTED PyTorch Implementation
Domain: Ω = (0,1)²
Architecture: 4 hidden layers with 84 units each
Training points: 20,000 interior, 6,000 boundary
RUNNING EXAMPLE 1: u = (1/(2π²)) sin(πx) sin(πy)
Starting training for Example 1...
Epoch     0: Total Loss = 5.90e+02, Int Loss = 9.75e+01, BC Loss = 4.92e+00, L2 Error = 1.70e-02, Energy Error = 1.31e-01
Epoch   100: Total Loss = 5.44e+01, Int Loss = 2.27e+01, BC Loss = 3.16e-01, L2 Error = 3.86e-01, Energy Error = 5.07e-01
Epoch   200: Total Loss = 2.06e+01, Int Loss = 8.72e+00, BC Loss = 1.18e-01, L2 Error = 2.57e-01, Energy Error = 3.15e-01
