# Structural Simulation of a Pressure Vessel with FEniCS and a PyTorch AI Surrogate

This notebook demonstrates an end-to-end workflow for the structural analysis of a pressure vessel, combining high-fidelity finite element analysis (FEA) with a fast AI-based surrogate model built using PyTorch.

**Problem:** We are analyzing a hollow cylinder with hemispherical end caps, a common geometry for pressure vessels. The vessel is subjected to internal pressure. Our goal is to understand the resulting displacement and stress.

**Workflow:**
1.  **Mesh Generation:** We will create a 3D mesh of the pressure vessel using `mshr`, a component of the FEniCS library.
2.  **FEniCS Solve (Ground Truth):** We will solve the linear elasticity equations on this mesh using FEniCS to find the displacement field under a given internal pressure. This serves as our high-fidelity, "ground truth" data.
3.  **AI Surrogate (POD + NN with PyTorch):**
    *   We will generate a dataset by running the FEniCS solver for various pressure values.
    *   **Proper Orthogonal Decomposition (POD)** will be used to reduce the dimensionality of our simulation results.
    *   A **Neural Network (NN)** will be trained with PyTorch to map input pressure directly to the low-dimensional POD representation. This NN becomes our surrogate model.
4.  **Error Analysis:** We will evaluate the accuracy of the surrogate model by comparing its predictions to the ground truth from the FEniCS solver.
5.  **ParaView Export:** We will export the results (ground truth, NN prediction, and error) to XDMF files for detailed 3D visualization in ParaView, an open-source data analysis and visualization application.

### Step 0: Setup and Dependencies

First, we import all the necessary modules. Ensure you have FEniCS, PyTorch, and other libraries installed in your environment.

In [None]:
# FEniCS for the simulation
from fenics import *
import mshr

# Standard scientific Python libraries
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

# PyTorch for the Neural Network
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

# For creating a progress bar
from tqdm.notebook import tqdm

# Set Matplotlib to plot nicely
%matplotlib inline
plt.style.use('default')

### Step 1: Mesh Generation

We define the geometry of the pressure vessel—a cylinder with spherical caps—and generate a 3D tetrahedral mesh.

In [None]:
# --- Geometric Parameters ---
R_inner = 1.0  # Inner radius of the cylinder/sphere
R_outer = 1.1  # Outer radius (determines thickness)
L = 4.0        # Length of the cylindrical part
mesh_resolution = 30 # Controls the mesh density

# --- Define the Geometry Components ---
# Outer boundary
cylinder_outer = mshr.Cylinder(Point(0, 0, -L/2), Point(0, 0, L/2), R_outer, R_outer)
sphere_top_outer = mshr.Sphere(Point(0, 0, L/2), R_outer)
sphere_bottom_outer = mshr.Sphere(Point(0, 0, -L/2), R_outer)

# Inner boundary (to be subtracted)
cylinder_inner = mshr.Cylinder(Point(0, 0, -L/2), Point(0, 0, L/2), R_inner, R_inner)
sphere_top_inner = mshr.Sphere(Point(0, 0, L/2), R_inner)
sphere_bottom_inner = mshr.Sphere(Point(0, 0, -L/2), R_inner)


# --- Combine and Subtract to Create the Final Hollow Geometry ---
vessel_outer = cylinder_outer + sphere_top_outer + sphere_bottom_outer
vessel_inner = cylinder_inner + sphere_top_inner + sphere_bottom_inner
vessel_domain = vessel_outer - vessel_inner

# --- Generate the Mesh ---
mesh = mshr.generate_mesh(vessel_domain, mesh_resolution)

print(f"Mesh generated with {mesh.num_vertices()} vertices and {mesh.num_cells()} cells.")

# --- Plot the Mesh ---
plt.figure(figsize=(10, 8))
plot(mesh, title="Generated 3D Mesh of the Pressure Vessel")
plt.show()

### Step 2: FEniCS Solve (The High-Fidelity Model)

Here, we define and solve the linear elasticity problem. This involves setting up the material properties, boundary conditions, and the weak form of the governing equations. This function will be our "ground truth" generator.

In [None]:
def solve_pressure_vessel(mesh, pressure_value):
    """
    Solves the linear elasticity problem for the pressure vessel.
    
    Args:
        mesh: The FEniCS mesh object.
        pressure_value: The magnitude of the internal pressure.
        
    Returns:
        u: The displacement solution function.
    """
    # --- Material Parameters (Structural Steel) ---
    E = Constant(200e9)  # Young's modulus in Pa
    nu = Constant(0.3)   # Poisson's ratio
    mu = E / (2 * (1 + nu))
    lmbda = E * nu / ((1 + nu) * (1 - 2 * nu))

    # --- Define Function Space ---
    V = VectorFunctionSpace(mesh, 'Lagrange', 1)

    # --- Define Boundaries ---
    # We need to identify the inner and outer surfaces
    inner_boundary = f'on_boundary && sqrt(x[0]*x[0] + x[1]*x[1]) < {R_outer - 1e-3} && sqrt(x[0]*x[0] + x[1]*x[1] + (abs(x[2])-L/2)*(abs(x[2])-L/2)) > {R_inner - 1e-3}'
    
    # Define a point or small area to fix to prevent rigid body motion
    # We will constrain the vessel at the 'bottom' end in the z-direction.
    fixed_boundary = f'on_boundary && near(x[2], -{L/2 + R_outer})'
    
    # --- Boundary Conditions ---
    # 1. Internal Pressure
    # The pressure is a normal force. We define it as a vector.
    pressure = Constant((0.0, 0.0, 0.0)) # Placeholder, will be defined in the variational form
    
    # 2. Fixed support to prevent rigid body motion
    bc = DirichletBC(V.sub(2), Constant(0), fixed_boundary, method='pointwise')

    # --- Variational Problem ---
    u = TrialFunction(V)
    v = TestFunction(V)
    
    def epsilon(u):
        return 0.5 * (nabla_grad(u) + nabla_grad(u).T)
    
    def sigma(u):
        return lmbda * tr(epsilon(u)) * Identity(len(u)) + 2 * mu * epsilon(u)

    # Weak form
    a = inner(sigma(u), epsilon(v)) * dx
    
    # The pressure term is applied as a surface integral on the inner boundary
    # The normal 'n' points outwards from the domain, so inwards for the hollow part
    n = FacetNormal(mesh)
    L_form = inner(-pressure_value * n, v) * ds(sub_domain_data=CompiledSubDomain(inner_boundary))

    # --- Solve ---
    u_solution = Function(V, name="Displacement")
    solve(a == L_form, u_solution, bc)
    
    return u_solution

# --- Calculate von Mises Stress --- (Helper function)
def get_von_mises_stress(u):
    # Define sigma again as it's in the local scope of the solver function
    E = Constant(200e9)
    nu = Constant(0.3)
    mu = E / (2 * (1 + nu))
    lmbda = E * nu / ((1 + nu) * (1 - 2 * nu))
    
    def epsilon(u_func):
        return 0.5 * (nabla_grad(u_func) + nabla_grad(u_func).T)

    def sigma(u_func):
        return lmbda * tr(epsilon(u_func)) * Identity(len(u_func)) + 2 * mu * epsilon(u_func)
    
    s = sigma(u) - (1./3)*tr(sigma(u))*Identity(len(u)) # Deviatoric stress
    von_mises = sqrt(3./2*inner(s, s))
    # Project to a scalar function space for plotting and export
    W = FunctionSpace(mesh, 'P', 1)
    von_mises = project(von_mises, W)
    return von_mises

# --- Test the solver with a sample pressure ---
sample_pressure = 10e6 # 10 MPa
print("Running FEniCS solver for a sample pressure...")
u_truth = solve_pressure_vessel(mesh, Constant(sample_pressure))
print("Solver finished.")

# Plotting the magnitude of the displacement
plt.figure(figsize=(10, 8))
c = plot(u_truth, mode='displacement', title="Displacement Field from FEniCS Solver")
plt.colorbar(c)
plt.show()

### Step 3: POD + NN Surrogate Model (PyTorch)

Now we build the AI surrogate. This involves an "offline" phase (data generation and training) and an "online" phase (fast prediction).

#### 3.1 Offline Phase: Data Generation and Training

In [None]:
# --- 1. Generate Snapshot Data ---
n_snapshots = 50
pressure_range = np.linspace(1e6, 25e6, n_snapshots) # Pressures from 1 to 25 MPa
pressure_range_normalized = pressure_range / np.max(pressure_range)

# Store the results (displacement vectors) in a "snapshot matrix"
snapshots = []
print(f"Generating {n_snapshots} snapshots...")
for p in tqdm(pressure_range):
    u_sol = solve_pressure_vessel(mesh, Constant(p))
    snapshots.append(u_sol.vector().get_local())

snapshot_matrix = np.array(snapshots).T
print(f"Snapshot matrix shape: {snapshot_matrix.shape}") # (Degrees of Freedom, Num Snapshots)

# --- 2. Perform POD (using SVD) ---
# SVD: U * S * Vh = snapshot_matrix
# U contains the POD modes (our new basis)
U, S, Vh = np.linalg.svd(snapshot_matrix, full_matrices=False)

# Choose number of modes to keep. We can look at the "energy" captured.
cumulative_energy = np.cumsum(S**2) / np.sum(S**2)
n_modes = np.argmax(cumulative_energy >= 0.99999) + 1 # Keep modes that capture 99.999% of energy

print(f"Number of POD modes to capture 99.999% energy: {n_modes}")

# Truncate the basis
pod_basis = U[:, :n_modes]
print(f"Shape of truncated POD basis: {pod_basis.shape}")

# --- 3. Project Snapshots onto POD Basis to get Coefficients ---
pod_coefficients = snapshot_matrix.T @ pod_basis
print(f"Shape of POD coefficients: {pod_coefficients.shape}")

# --- 4. Train a Neural Network with PyTorch ---
# The NN will learn the map: pressure -> pod_coefficients

# Prepare the data
X = pressure_range_normalized.reshape(-1, 1)
y = pod_coefficients

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Convert to PyTorch Tensors
X_train_t = torch.tensor(X_train, dtype=torch.float32)
y_train_t = torch.tensor(y_train, dtype=torch.float32)
X_test_t = torch.tensor(X_test, dtype=torch.float32)
y_test_t = torch.tensor(y_test, dtype=torch.float32)

# Define the PyTorch model
class SurrogateNet(nn.Module):
    def __init__(self, input_size, output_size):
        super(SurrogateNet, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, output_size)
        )
    
    def forward(self, x):
        return self.layers(x)

model = SurrogateNet(input_size=1, output_size=n_modes)
print(model)

# Define loss function and optimizer
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
epochs = 500
train_losses = []
test_losses = []

print("\nTraining the PyTorch surrogate...")
for epoch in range(epochs):
    model.train()
    
    # Forward pass
    y_pred = model(X_train_t)
    loss = loss_fn(y_pred, y_train_t)
    
    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    # Evaluate on test data
    model.eval()
    with torch.no_grad():
        y_test_pred = model(X_test_t)
        test_loss = loss_fn(y_test_pred, y_test_t)
    
    train_losses.append(loss.item())
    test_losses.append(test_loss.item())
    
    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {loss.item():.6f}, Test Loss: {test_loss.item():.6f}')
print("Training complete.")

# Plot training history
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(test_losses, label='Validation Loss')
plt.title('NN Model Training History')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.yscale('log')
plt.legend()
plt.grid(True)
plt.show()

#### 3.2 Online Phase: Creating the Surrogate Predictor

In [None]:
def predict_surrogate(pressure_value):
    """
    Uses the trained PyTorch NN and POD basis to predict the displacement field.
    
    Args:
        pressure_value: A single pressure value.
        
    Returns:
        u_pred: A FEniCS function for the predicted displacement.
    """
    model.eval() # Set the model to evaluation mode
    
    # Normalize input pressure and convert to tensor
    p_norm = pressure_value / np.max(pressure_range)
    p_tensor = torch.tensor([[p_norm]], dtype=torch.float32)
    
    with torch.no_grad():
        # 1. Use NN to predict POD coefficients
        predicted_coeffs = model(p_tensor)
    
    # Convert back to numpy
    predicted_coeffs_np = predicted_coeffs.numpy()
    
    # 2. Reconstruct the full displacement vector from the basis
    u_vector_pred = pod_basis @ predicted_coeffs_np.T
    
    # 3. Create a FEniCS function to hold the result
    V = VectorFunctionSpace(mesh, 'Lagrange', 1)
    u_pred = Function(V, name="Predicted Displacement")
    u_pred.vector().set_local(u_vector_pred.flatten())
    
    return u_pred

# --- Test the surrogate predictor ---
test_pressure = 18e6 # A pressure not in the original linspace
print(f"Using surrogate to predict displacement for P = {test_pressure/1e6:.1f} MPa...")
u_pred = predict_surrogate(test_pressure)
print("Prediction complete.")

### Step 4: Error Analysis

Let's quantify the surrogate's accuracy by comparing its prediction to the full FEniCS solution for the same test pressure.

In [None]:
# --- Get the ground truth solution for the test pressure ---
print("Running FEniCS solver for the test pressure to get ground truth...")
u_true_test = solve_pressure_vessel(mesh, Constant(test_pressure))
print("Done.")

# --- Calculate the error field ---
# The error is the difference between the true and predicted displacement vectors
error_vector = u_true_test.vector() - u_pred.vector()
error_func = Function(u_true_test.function_space(), name="Error Field")
error_func.vector().set_local(error_vector.get_local())

# --- Calculate Global Relative Error ---
error_norm = norm(error_vector, 'l2')
true_norm = norm(u_true_test.vector(), 'l2')
relative_error = error_norm / true_norm
print(f"\nGlobal Relative L2 Error: {relative_error:.2%}")

# --- Plot a Histogram of Nodal Errors ---
error_magnitudes = np.linalg.norm(error_func.compute_vertex_values(mesh).T, axis=1)

plt.figure(figsize=(10, 5))
plt.hist(error_magnitudes, bins=50, color='orangered', alpha=0.7)
plt.title('Histogram of Nodal Error Magnitudes')
plt.xlabel('Magnitude of Displacement Error')
plt.ylabel('Frequency')
plt.grid(True)
plt.show()

# --- Create a 3D Scatter Plot of the Error ---
fig = plt.figure(figsize=(12, 9))
ax = fig.add_subplot(projection='3d')

coords = mesh.coordinates()
p = ax.scatter(coords[:, 0], coords[:, 1], coords[:, 2], c=error_magnitudes, cmap='viridis', s=1)
ax.set_title('3D Scatter Plot of Error Field')
fig.colorbar(p, ax=ax, label='Error Magnitude')
ax.view_init(elev=20., azim=-65)
plt.show()

### Step 5: ParaView Exports

Finally, we export the key results—ground truth, surrogate prediction, and the error field—to files that can be opened in ParaView for high-quality, interactive 3D visualization.

In [None]:
# --- Get stress for truth and prediction ---
stress_truth = get_von_mises_stress(u_true_test)
stress_truth.rename("von Mises Stress", "stress")

stress_pred = get_von_mises_stress(u_pred)
stress_pred.rename("Predicted von Mises Stress", "stress")

# --- Export Ground Truth Solution ---
print("Exporting ground truth solution to ground_truth.xdmf...")
with XDMFFile("ground_truth.xdmf") as xdmf:
    xdmf.parameters["flush_output"] = True
    xdmf.parameters["functions_share_mesh"] = True
    xdmf.write(u_true_test, 0.0)
    xdmf.write(stress_truth, 0.0)

# --- Export Surrogate Prediction ---
print("Exporting surrogate prediction to surrogate_pred.xdmf...")
with XDMFFile("surrogate_pred.xdmf") as xdmf:
    xdmf.parameters["flush_output"] = True
    xdmf.parameters["functions_share_mesh"] = True
    xdmf.write(u_pred, 0.0)
    xdmf.write(stress_pred, 0.0)

# --- Export Error Field ---
print("Exporting error field to error_field.xdmf...")
with XDMFFile("error_field.xdmf") as xdmf:
    xdmf.parameters["flush_output"] = True
    xdmf.parameters["functions_share_mesh"] = True
    xdmf.write(error_func, 0.0)

print("\nExports complete.")

### How to Use the Exported Files in ParaView

1.  Open ParaView.
2.  Go to **File > Open...** and select the `ground_truth.xdmf`, `surrogate_pred.xdmf`, and `error_field.xdmf` files. ParaView will group them; make sure all are checked and click **OK**.
3.  Click the **Apply** button in the Properties panel.
4.  In the toolbar, you can switch between viewing `Displacement`, `Predicted Displacement`, `von Mises Stress`, `Error Field`, etc.
5.  To view the deformed shape, add a **Warp By Vector** filter from the `Filters` menu. Select the appropriate displacement vector.
6.  You can open multiple views (using the `Split` icons in the top bar) to compare the ground truth and the prediction side-by-side.