In [1]:
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('png')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks");

In [5]:
import jax
import jax.numpy as jnp
import equinox as eqx
import optax
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from jax import grad, vmap, random as jr
import jax.lax as lax

In [6]:
import jax
jax.config.update("jax_enable_x64", True)

In [7]:
# ---------------------------
# Custom Kaiming Init
# ---------------------------
def kaiming_init(in_size, out_size, key):
    std = jnp.sqrt(2.0 / in_size)
    return std * jr.normal(key, (out_size, in_size)), jnp.zeros(out_size)

# Swish activation
class Swish(eqx.Module):
    def __call__(self, x):
        return x * jax.nn.sigmoid(x)

class MLP(eqx.Module):
    layers: list
    activation: callable = eqx.static_field()

    def __init__(self, in_size, out_size, width, depth, key):
        keys = jr.split(key, depth + 1)
        self.layers = []
        self.activation = Swish()
        sizes = [in_size] + [width] * depth + [out_size]
        for i in range(depth + 1):
            layer = eqx.nn.Linear(sizes[i], sizes[i + 1], key=keys[i])
            weight, bias = kaiming_init(sizes[i], sizes[i + 1], keys[i])
            layer = eqx.tree_at(lambda l: (l.weight, l.bias), layer, (weight, bias))
            self.layers.append(layer)

    def __call__(self, x):
        for layer in self.layers[:-1]:
            x = self.activation(layer(x))
        return self.layers[-1](x)

In [21]:
def stenotic_geometry_grid(Nx=512, Ny=128, A=0.005, sigma=0.1, mu=0.5, r_inlet=0.05):
    x_vals = jnp.linspace(0.0, 1.0, Nx)
    grid_x, grid_y, grid_s = [], [], []
    for x in x_vals:
        R = r_inlet - A * (1. / jnp.sqrt(2 * jnp.pi * sigma**2)) * jnp.exp(-(x - mu)**2 / (2 * sigma**2))
        y_vals = jnp.linspace(-R, R, Ny)
        grid_x.extend([x] * Ny)
        grid_y.extend(y_vals)
        grid_s.extend([A] * Ny)
    return jnp.array(grid_x), jnp.array(grid_y), jnp.array(grid_s)

In [9]:
def velocity_ansatz(x, y, a, nu, net_u, net_v, p):
    sigma, mu, rInlet, dP, L, xStart, xEnd = p
    R = rInlet - a * (1. / jnp.sqrt(2 * jnp.pi * sigma**2)) * jnp.exp(-(x - mu)**2 / (2 * sigma**2))
    inp = jnp.stack([x, y, a, nu], axis=-1)
    u_hat = net_u(inp).squeeze()
    v_hat = net_v(inp).squeeze()
    phi = R**2 - y**2
    return phi * u_hat, phi * v_hat

In [10]:
def pressure_ansatz(x, y, s, nu, net_p, physics):
    sigma, mu, rInlet, dP, L, xStart, xEnd = physics
    input = jnp.stack([x, y, s, nu])
    p_out = net_p(input).squeeze()
    return dP * (xEnd - x) / L + (xStart - x) * (xEnd - x) * p_out

In [11]:
def loss_fn(model_tuple, x_batch, y_batch, s_batch, nu_batch, physics):
    net_u, net_v, net_p = model_tuple
    sigma, mu, rInlet, dP, L, xStart, xEnd = physics
    rho = 1.0

    def residual(x, y, s, nu):
        x_nd, y_nd, s_nd, nu_nd = x, y, s, nu
        inp = jnp.stack([x_nd, y_nd, s_nd, nu_nd])

        rho = 1.0
        dP = physics[3]
        sigma = physics[0]
        mu = physics[1]
        rInlet = physics[2]
        L = physics[4]
        xStart = physics[5]
        xEnd = physics[6]

        # Define full u and v functions with phi(x,y) computed *inside*
        def u_full(z):
            x_physical = z[0]
            y_nd = z[1]
            s = z[2]
            R = rInlet - s * (1. / jnp.sqrt(2 * jnp.pi * sigma**2)) * jnp.exp(-(x_physical - mu)**2 / (2 * sigma**2))
            phi = R**2 - y_nd**2
            return phi * net_u(z).squeeze()

        def v_full(z):
            x_physical = z[0]
            y_nd = z[1]
            s = z[2]
            R = rInlet - s * (1. / jnp.sqrt(2 * jnp.pi * sigma**2)) * jnp.exp(-(x_physical - mu)**2 / (2 * sigma**2))
            phi = R**2 - y_nd**2
            return phi * net_v(z).squeeze()

        # Velocity derivatives
        u_x = grad(u_full)(inp)[0]
        u_y = grad(u_full)(inp)[1]
        u_xx = grad(lambda z: grad(u_full)(z)[0])(inp)[0]
        u_yy = grad(lambda z: grad(u_full)(z)[1])(inp)[1]

        v_x = grad(v_full)(inp)[0]
        v_y = grad(v_full)(inp)[1]
        v_xx = grad(lambda z: grad(v_full)(z)[0])(inp)[0]
        v_yy = grad(lambda z: grad(v_full)(z)[1])(inp)[1]

        # Pressure field (includes physical pressure drop + residual)
        def p_full(z):
            x_physical = z[0]
            return dP * (xEnd - x_physical) / L + (xStart - x_physical) * (xEnd - x_physical) * net_p(z).squeeze()

        p_x = grad(p_full)(inp)[0]
        p_y = grad(p_full)(inp)[1]

        # Compute field values (needed for convective terms)
        u_val = u_full(inp)
        v_val = v_full(inp)

        # Residuals of steady-state Navier-Stokes
        res1 = u_val * u_x + v_val * u_y - nu * (u_xx + u_yy) + (1 / rho) * p_x
        res2 = u_val * v_x + v_val * v_y - nu * (v_xx + v_yy) + (1 / rho) * p_y
        res3 = u_x + v_y  # incompressibility

        return res1**2, res2**2, res3**2

    r1, r2, r3 = vmap(residual)(x_batch, y_batch, s_batch, nu_batch)
    return jnp.mean(r1) + jnp.mean(r2) + jnp.mean(r3)

In [12]:
def train_pinn(model_tuple, key, optimizer, filter_spec, physics, num_iter=1000, freq=100, batch_size=4096):
    def new_loss(diff_model, static_model, x, y, s, nu):
        model_comb = eqx.combine(diff_model, static_model)
        return loss_fn(model_comb, x, y, s, nu, physics)

    @eqx.filter_jit
    def step(opt_state, model, x, y, s, nu):
        diff_model, static_model = eqx.partition(model, filter_spec)
        loss, grads = eqx.filter_value_and_grad(new_loss)(diff_model, static_model, x, y, s, nu)
        updates, opt_state = optimizer.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss

    sigma, mu, rInlet, dP, L, xStart, xEnd = physics
    fixed_nu = 1e-3
    A_vals = jnp.linspace(0.0, 1e-2, 20)

    # Build flat dataset
    x_list, y_list, s_list = [], [], []
    for A in A_vals:
        X, Y, S = stenotic_geometry_grid(Nx=256, Ny=64, A=A, sigma=sigma, mu=mu, r_inlet=rInlet)
        x_list.append(X)
        y_list.append(Y)
        s_list.append(S)

    x_all = jnp.concatenate(x_list)        # shape (20*N,)
    y_all = jnp.concatenate(y_list)
    s_all = jnp.concatenate(s_list)
    nu_all = jnp.ones_like(s_all) * fixed_nu

    dataset = jnp.stack([x_all, y_all, s_all, nu_all], axis=1)  # shape (20*N, 4)
    num_points = dataset.shape[0]
    num_batches = num_points // batch_size

    opt_state = optimizer.init(eqx.filter(model_tuple, eqx.is_inexact_array))
    losses = []
    perm_key = key

    for step_idx in range(num_iter):
        perm_key, subkey = jr.split(perm_key)
        perm = jr.permutation(subkey, num_points)
        shuffled = dataset[perm]

        for batch_idx in range(num_batches):
            batch = shuffled[batch_idx * batch_size:(batch_idx + 1) * batch_size]
            x_batch, y_batch, s_batch, nu_batch = batch[:, 0], batch[:, 1], batch[:, 2], batch[:, 3]

            model_tuple, opt_state, loss_val = step(opt_state, model_tuple, x_batch, y_batch, s_batch, nu_batch)

        if step_idx % freq == 0:
            print(f"Step {step_idx}, Loss = {loss_val:.5e}")
            losses.append(loss_val)

    return model_tuple, losses

In [22]:
import os
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np
from jax import vmap
import jax.numpy as jnp

def plot_stenotic_fig5_dnn(model_tuple, physics, A_values, Nx=512, Ny=128, save_dir='/content/drive/MyDrive/stenotic_Vary_A_outputs'):
    net_u, net_v, net_p = model_tuple
    sigma, mu, rInlet, dP, L, xStart, xEnd = physics
    nu_fixed = 1e-3  # Fixed viscosity

    # Create directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)

    x_plot = jnp.linspace(xStart, xEnd, Nx)

    # Expand output view slightly beyond original axis limits
    x_margin = 0.03 * (xEnd - xStart)
    y_margin = 0.2 * rInlet
    xlim = (xStart - x_margin, xEnd + x_margin)
    ylim = (-rInlet - y_margin, rInlet + y_margin)

    for idx, A_val in enumerate(A_values):
        # Geometry-aware grid at this A
        Xg, Yg, S = stenotic_geometry_grid(Nx=Nx, Ny=Ny, A=A_val, sigma=sigma, mu=mu, r_inlet=rInlet)
        Nu = jnp.ones_like(Xg) * nu_fixed

        u_pred = vmap(lambda x_, y_, s_, nu_: velocity_ansatz(x_, y_, s_, nu_, net_u, net_v, physics)[0])(Xg, Yg, S, Nu)
        v_pred = vmap(lambda x_, y_, s_, nu_: velocity_ansatz(x_, y_, s_, nu_, net_u, net_v, physics)[1])(Xg, Yg, S, Nu)
        p_center = vmap(lambda x_: pressure_ansatz(x_, 0.0, A_val, nu_fixed, net_p, physics))(x_plot)

        # --- Plot u ---
        fig_u, ax_u = plt.subplots(figsize=(8, 3))
        divider_u = make_axes_locatable(ax_u)
        cax_u = divider_u.append_axes("right", size="2.5%", pad=0.05)
        sc_u = ax_u.scatter(np.array(Xg), np.array(Yg), c=np.array(u_pred), cmap='viridis', s=1)
        ax_u.set_xlim(*xlim)
        ax_u.set_ylim(*ylim)
        ax_u.set_title(f"$u(x,y)$, $A$={A_val:.2e}")
        ax_u.set_xlabel("x")
        ax_u.set_ylabel("y")
        fig_u.colorbar(sc_u, cax=cax_u)
        fig_u.tight_layout()
        fig_u.savefig(f"{save_dir}/u_A_{A_val:.2e}.png")
        plt.close(fig_u)

        # --- Plot v ---
        fig_v, ax_v = plt.subplots(figsize=(8, 3))
        divider_v = make_axes_locatable(ax_v)
        cax_v = divider_v.append_axes("right", size="2.5%", pad=0.05)
        sc_v = ax_v.scatter(np.array(Xg), np.array(Yg), c=np.array(v_pred), cmap='viridis', s=1)
        ax_v.set_xlim(*xlim)
        ax_v.set_ylim(*ylim)
        ax_v.set_title(f"$v(x,y)$, $A$={A_val:.2e}")
        ax_v.set_xlabel("x")
        ax_v.set_ylabel("y")
        fig_v.colorbar(sc_v, cax=cax_v)
        fig_v.tight_layout()
        fig_v.savefig(f"{save_dir}/v_A_{A_val:.2e}.png")
        plt.close(fig_v)

        # --- Plot pressure centerline ---
        fig_p, ax_p = plt.subplots(figsize=(8, 6))
        ax_p.plot(np.array(x_plot), np.array(p_center), 'r--', label='DNN')
        ax_p.set_xlabel("x")
        ax_p.set_ylabel("$p_c$")
        ax_p.set_title(f"$p_c(x)$, $A$={A_val:.2e}")
        ax_p.legend()
        fig_p.tight_layout()
        fig_p.savefig(f"{save_dir}/pcenter_A_{A_val:.2e}.png")
        plt.close(fig_p)

In [23]:
def initialize_models(key):
    keys = jr.split(key, 3)
    net_u = MLP(4, 1, 32, 3, keys[0])  # <== fix input size to 4
    net_v = MLP(4, 1, 32, 3, keys[1])
    net_p = MLP(4, 1, 32, 3, keys[2])
    return (net_u, net_v, net_p)

In [None]:
# Check GPU
!nvidia-smi

# Check memory
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print(f"Available RAM: {ram_gb:.1f} GB")

Wed May  7 18:15:43 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   33C    P0             50W /  400W |       0MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
# ---------------------------
# Run Full Workflow
# ---------------------------
key = jr.PRNGKey(0)
model = initialize_models(key)
physics = (0.1, 0.5, 0.05, 0.1, 1.0, 0.0, 1.0)
filter_spec = jax.tree_util.tree_map(lambda _: True, model)
optimizer = optax.adam(1e-3)

model_trained, _ = train_pinn(
    model, key, optimizer, filter_spec, physics,
    num_iter=2000, freq=100
)

Step 0, Loss = 8.77046e-03
Step 100, Loss = 5.08949e-03
Step 200, Loss = 1.80369e-03
Step 300, Loss = 1.18246e-03
Step 400, Loss = 9.35826e-04
Step 500, Loss = 8.15926e-04
Step 600, Loss = 7.24500e-04
Step 700, Loss = 6.09859e-04
Step 800, Loss = 4.95492e-04
Step 900, Loss = 5.00034e-04
Step 1000, Loss = 4.35050e-04
Step 1100, Loss = 3.89954e-04
Step 1200, Loss = 3.84355e-04
Step 1300, Loss = 3.80076e-04
Step 1400, Loss = 3.38069e-04
Step 1500, Loss = 3.41262e-04
Step 1600, Loss = 3.51822e-04
Step 1700, Loss = 3.09667e-04
Step 1800, Loss = 2.94005e-04
Step 1900, Loss = 2.89358e-04


FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/model_trained_vary_A.eqx'

In [15]:
# Path to save model
save_path = "/content/drive/MyDrive/model_trained_vary_A.eqx"

# Save model
eqx.tree_serialise_leaves(save_path, model_trained)

NameError: name 'model_trained' is not defined

In [16]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [24]:
# Recreate the model structure (same as during training)
key = jr.PRNGKey(0)
model_template = initialize_models(key)

# Load the saved weights
load_path = "/content/drive/MyDrive/model_trained_vary_A.eqx"
model_loaded = eqx.tree_deserialise_leaves(load_path, model_template)
print("Model loaded from Drive.")

Model loaded from Drive.


In [25]:
physics = (0.1, 0.5, 0.05, 0.1, 1.0, 0.0, 1.0)
plot_stenotic_fig5_dnn(model_loaded, physics, A_values=[2e-3, 4e-3, 7e-3])

In [26]:
import matplotlib.pyplot as plt
import seaborn as sns
from jax import random, vmap

def plot_geometry_uncertainty(model_tuple, physics, num_samples=500, A_mean=5e-3, A_std=1e-3, seed=0, save_path='/content/drive/MyDrive/stenotic_outputs/A_uncertainty.png'):
    net_u, net_v, net_p = model_tuple
    sigma, mu, rInlet, *_ = physics
    nu_fixed = 1e-3
    x_center, y_center = 0.5, 0.0

    key = random.PRNGKey(seed + 1)
    As = random.normal(key, shape=(num_samples,)) * A_std + A_mean

    # Compute centerline velocity at (x=0.5, y=0.0)
    u_center = vmap(lambda a_: velocity_ansatz(x_center, y_center, a_, nu_fixed, net_u, net_v, physics)[0])(As)

    # --- Plotting ---
    plt.figure(figsize=(6, 4))
    sns.kdeplot(np.array(u_center), label="DNN", color='crimson', linestyle='--', linewidth=2)
    plt.xlabel(r"$u_c$")
    plt.ylabel("PDF")
    plt.title("Geometry Uncertainty Propagation")
    plt.legend()
    #plt.grid(True)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300)
    plt.close()


In [27]:
plot_geometry_uncertainty(model_loaded, physics, num_samples=500, A_mean=5e-3, A_std=1e-3, seed=0)