In [1]:
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('png')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks");

In [4]:
import jax
import jax.numpy as jnp
import equinox as eqx
import optax
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from jax import grad, vmap, random as jr
import jax.lax as lax

In [5]:
def kaiming_init(in_size, out_size, key):
    std = jnp.sqrt(2.0 / in_size)
    return std * jr.normal(key, (out_size, in_size)), jnp.zeros(out_size)

class Swish(eqx.Module):
    def __call__(self, x):
        return x * jax.nn.sigmoid(x)

class MLP(eqx.Module):
    layers: list
    activation: callable = eqx.static_field()

    def __init__(self, in_size, out_size, width, depth, key):
        keys = jr.split(key, depth + 1)
        self.layers = []
        self.activation = Swish()
        sizes = [in_size] + [width] * depth + [out_size]
        for i in range(depth + 1):
            layer = eqx.nn.Linear(sizes[i], sizes[i + 1], key=keys[i])
            weight, bias = kaiming_init(sizes[i], sizes[i + 1], keys[i])
            layer = eqx.tree_at(lambda l: (l.weight, l.bias), layer, (weight, bias))
            self.layers.append(layer)

    def __call__(self, x):
        for layer in self.layers[:-1]:
            x = self.activation(layer(x))
        return self.layers[-1](x)

In [10]:
# ---------------------------
# Hard-encoded Boundary Ansatz
# ---------------------------
def velocity_ansatz(x, y, nu, net_u, net_v, physics):
    d = 0.1  # tube diameter
    phi = (d**2 / 4) - y**2
    inp = jnp.stack([x, y, nu])
    u = phi * net_u(inp).squeeze()
    v = phi * net_v(inp).squeeze()
    return u, v

In [9]:
def pressure_ansatz(x, net_p, physics):
    _, _, _, dP, L, xStart, xEnd = physics
    inp = jnp.stack([x])
    p_out = net_p(inp).squeeze()
    return dP * (xEnd - x) / L + (xStart - x) * (xEnd - x) * p_out

In [8]:
# ---------------------------
# Loss Function
# ---------------------------
def loss_fn(model_tuple, x_batch, y_batch, nu_batch, physics):
    net_u, net_v, net_p = model_tuple
    rho = 1.0

    def residual(x, y, nu):
        u_fn = lambda z: velocity_ansatz(z[0], z[1], z[2], net_u, net_v, physics)[0]
        v_fn = lambda z: velocity_ansatz(z[0], z[1], z[2], net_u, net_v, physics)[1]
        p_fn = lambda z: pressure_ansatz(z[0], net_p, physics)

        input = jnp.stack([x, y, nu])
        u, v = u_fn(input), v_fn(input)
        p = p_fn(jnp.array([x]))

        u_grad = grad(u_fn)(input)
        v_grad = grad(v_fn)(input)
        p_x = grad(p_fn)(jnp.array([x]))[0]

        u_x, u_y = u_grad[0], u_grad[1]
        v_x, v_y = v_grad[0], v_grad[1]

        u_xx = grad(lambda z: grad(u_fn)(z)[0])(input)[0]
        u_yy = grad(lambda z: grad(u_fn)(z)[1])(input)[1]
        v_xx = grad(lambda z: grad(v_fn)(z)[0])(input)[0]
        v_yy = grad(lambda z: grad(v_fn)(z)[1])(input)[1]

        div = u_x + v_y
        mom_x = u * u_x + v * u_y + (1 / rho) * p_x - nu * (u_xx + u_yy)
        mom_y = u * v_x + v * v_y - nu * (v_xx + v_yy)

        return div**2 + mom_x**2 + mom_y**2

    residuals = vmap(residual)(x_batch, y_batch, nu_batch)
    return jnp.mean(residuals)

In [11]:
# ---------------------------
# Training Loop
# ---------------------------
def train_pinn(model_tuple, key, optimizer, filter_spec, physics, num_iter=1000, freq=100, batch_size=4096):
    def new_loss(diff_model, static_model, x, y, nu):
        model_comb = eqx.combine(diff_model, static_model)
        return loss_fn(model_comb, x, y, nu, physics)

    def vmapped_loss(diff_model, static_model, x, y, nu_vec):
        def loss_for_nu(nu_val):
            return new_loss(diff_model, static_model, x, y, jnp.full_like(x, nu_val))
        return jnp.mean(jax.vmap(loss_for_nu)(nu_vec))

    @eqx.filter_jit
    def step(opt_state, model, x, y, nu_vec):
        diff_model, static_model = eqx.partition(model, filter_spec)
        loss, grads = eqx.filter_value_and_grad(vmapped_loss)(diff_model, static_model, x, y, nu_vec)
        updates, opt_state = optimizer.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss

    # Fixed geometry grid
    x_vals = jnp.linspace(0.0, 1.0, 256)
    y_vals = jnp.linspace(-0.05, 0.05, 256)
    X, Y = jnp.meshgrid(x_vals, y_vals)
    X = X.flatten()
    Y = Y.flatten()

    dataset = jnp.stack([X, Y], axis=1)
    num_points = dataset.shape[0]
    num_batches = num_points // batch_size

    # Collocation points in viscosity space
    nu_vec = jnp.linspace(2.0e-4, 2.0e-3, 100)

    opt_state = optimizer.init(eqx.filter(model_tuple, eqx.is_inexact_array))
    losses = []
    perm_key = key

    for step_idx in range(num_iter):
        perm_key, subkey = jr.split(perm_key)
        perm = jr.permutation(subkey, num_points)
        shuffled = dataset[perm]

        for batch_idx in range(num_batches):
            batch = shuffled[batch_idx * batch_size: (batch_idx + 1) * batch_size]
            x_batch, y_batch = batch[:, 0], batch[:, 1]
            model_tuple, opt_state, loss_val = step(opt_state, model_tuple, x_batch, y_batch, nu_vec)

        if step_idx % freq == 0:
            print(f"Step {step_idx}, Loss = {loss_val:.5e}")
            losses.append(loss_val)

    return model_tuple, losses

In [7]:
def mc_pinn_uc_distribution(model_tuple, physics, num_samples=1000, save_path='/content/drive/MyDrive/MCPDF.png'):
    import seaborn as sns
    import matplotlib.pyplot as plt
    from scipy.stats import truncnorm

    net_u, net_v, _ = model_tuple
    d = 0.1
    dP = 0.1
    L = 1.0

    # Truncated Normal Sampling for ν
    mean_nu = 1e-3
    std_nu = 2.67e-4
    a, b = (0 - mean_nu) / std_nu, jnp.inf
    nu_samples = truncnorm.rvs(a, b, loc=mean_nu, scale=std_nu, size=num_samples)
    nu_samples = jnp.array(nu_samples)

    # Predict using vmap
    x = jnp.full((num_samples,), 0.5)
    y = jnp.full((num_samples,), 0.0)

    uc_vals = vmap(lambda x_, y_, nu_: velocity_ansatz(x_, y_, nu_, net_u, net_v, physics)[0])(x, y, nu_samples)

    # Analytical Solution
    uc_analytic = (dP * d**2) / (2 * nu_samples * L* 4)

    # Plot
    fig, ax = plt.subplots(figsize=(5, 4))
    sns.kdeplot(uc_analytic, color='blue', label='Analytical', fill=True, alpha=0.3)
    sns.kdeplot(uc_vals, color='red', linestyle='--', label='DNN', linewidth=2)

    ax.set_title("PDF of Centerline Velocity $u_c$")
    ax.set_xlabel("$u_c$")
    ax.set_ylabel("PDF")
    ax.legend()
    #plt.grid(True)
    fig.tight_layout()
    fig.savefig(save_path, dpi=300)
    plt.close(fig)
    plt.show()

    return uc_vals, uc_analytic, nu_samples

In [12]:
def evaluate_pipe_profiles(model_tuple, physics, nu_vals):
    net_u, net_v, _ = model_tuple
    d = 0.1
    y = jnp.linspace(-0.05, 0.05, 256)
    x = jnp.full_like(y, 0.5)

    plt.figure(figsize=(8, 5))
    for nu_val in nu_vals:
        nu = jnp.full_like(y, nu_val)
        u_pred = vmap(lambda x_, y_, nu_: velocity_ansatz(x_, y_, nu_, net_u, net_v, physics)[0])(x, y, nu)
        u_analytical = 0.1 / (2 * nu_val * 1.0) * ((d**2) / 4 - y**2)
        plt.plot(y, u_pred, 'r--', label=f'DNN ν={nu_val:.5f}')
        plt.plot(y, u_analytical, 'b-', label=f'Analytical ν={nu_val:.5f}')

    plt.xlabel('y')
    plt.ylabel('u(y)')
    plt.title('Cross-section velocity profiles (single model)')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

In [13]:
def initialize_models(key):
    keys = jr.split(key, 3)
    net_u = MLP(3, 1, 32, 3, keys[0])  # 3 inputs: [x, y, nu]
    net_v = MLP(3, 1, 32, 3, keys[1])
    net_p = MLP(1, 1, 32, 3, keys[2])  # only x input for pressure
    return (net_u, net_v, net_p)

In [11]:
key = jr.PRNGKey(0)
model = initialize_models(key)
filter_spec = jax.tree_util.tree_map(lambda _: True, model)
optimizer = optax.adam(1e-3)

# Circular pipe physics
physics = (None, None, None, 0.1, 1.0, 0.0, 1.0)

model_trained, losses = train_pinn(
    model, key, optimizer, filter_spec, physics,
    num_iter=9000, freq=1000
)

# Path to save model
save_path = "/content/drive/MyDrive/model_trainedPos.eqx"

# Save model
eqx.tree_serialise_leaves(save_path, model_trained)

Step 0, Loss = 1.41596e-02
Step 1000, Loss = 3.23707e-05
Step 2000, Loss = 3.81011e-06
Step 3000, Loss = 2.79203e-06
Step 4000, Loss = 1.07705e-05
Step 5000, Loss = 4.96962e-06
Step 6000, Loss = 4.14946e-06
Step 7000, Loss = 1.39188e-06
Step 8000, Loss = 1.09920e-05


In [14]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [15]:
# Recreate the model structure
key = jr.PRNGKey(0)
model_template = initialize_models(key)

# Load the saved weights
load_path = "/content/drive/MyDrive/model_trainedPos.eqx"
model_loaded = eqx.tree_deserialise_leaves(load_path, model_template)
print("Model loaded from Drive.")

Model loaded from Drive.


In [17]:
uc_vals, uc_analytic, nu_samples = mc_pinn_uc_distribution(model_loaded, physics)

In [6]:
import jax.lax as lax

def R_fn(x, A, mu, sigma, r0=0.05):
    stenosis_expr = r0 - A * jnp.exp(-((x - mu) ** 2) / (2 * sigma**2)) / (jnp.sqrt(2 * jnp.pi) * sigma)
    return lax.cond(jnp.abs(A) < 1e-8, lambda _: r0, lambda _: stenosis_expr, operand=None)

In [8]:
def velocity_ansatz(x, y, s, net_u, net_v, physics):
    sigma, mu, rInlet, dP, L, xStart, xEnd, geometry_type, nu = physics
    R = R_fn(x, s, mu, sigma, rInlet)
    input = jnp.stack([x, y, s])
    u_out = net_u(input).squeeze()
    v_out = net_v(input).squeeze()
    u = (R**2 - y**2) * u_out
    v = (R**2 - y**2) * v_out
    return u, v

In [9]:
def pressure_ansatz(x, y, s, net_p, physics):
    sigma, mu, rInlet, dP, L, xStart, xEnd, _, _ = physics
    input = jnp.stack([x, y, s])
    p_out = net_p(input).squeeze()
    return dP * (xEnd - x) / L + (xStart - x) * (xEnd - x) * p_out

In [10]:
def loss_fn(model_tuple, x_batch, y_batch, s_batch, physics):
    net_u, net_v, net_p = model_tuple
    nu = physics[-1]  # fixed scalar

    def residual(x, y, s):
        u_fn = lambda z: velocity_ansatz(z[0], z[1], z[2], net_u, net_v, physics)[0]
        v_fn = lambda z: velocity_ansatz(z[0], z[1], z[2], net_u, net_v, physics)[1]
        p_fn = lambda z: pressure_ansatz(z[0], z[1], z[2], net_p, physics)

        input = jnp.stack([x, y, s])
        u, v, p = u_fn(input), v_fn(input), p_fn(input)

        u_grad = grad(u_fn)(input)
        v_grad = grad(v_fn)(input)
        p_grad = grad(p_fn)(input)

        u_x, u_y = u_grad[0], u_grad[1]
        v_x, v_y = v_grad[0], v_grad[1]
        p_x, p_y = p_grad[0], p_grad[1]

        div = u_x + v_y
        conv_x = u * u_x + v * u_y
        conv_y = u * v_x + v * v_y

        u_xx = grad(lambda z: grad(u_fn)(z)[0])(input)[0]
        u_yy = grad(lambda z: grad(u_fn)(z)[1])(input)[1]
        v_xx = grad(lambda z: grad(v_fn)(z)[0])(input)[0]
        v_yy = grad(lambda z: grad(v_fn)(z)[1])(input)[1]

        mom_x = conv_x + p_x - nu * (u_xx + u_yy)
        mom_y = conv_y + p_y - nu * (v_xx + v_yy)

        return div**2 + mom_x**2 + mom_y**2

    residuals = vmap(residual)(x_batch, y_batch, s_batch)
    return jnp.mean(residuals)

In [11]:
def train_pinn(model_tuple, key, optimizer, filter_spec, physics, num_iter=1000, freq=100, batch_size=4096):
    def new_loss(diff_model, static_model, x, y, s):
        model_comb = eqx.combine(diff_model, static_model)
        return loss_fn(model_comb, x, y, s, physics)

    @eqx.filter_jit
    def step(opt_state, model, x, y, s):
        diff_model, static_model = eqx.partition(model, filter_spec)
        loss, grads = eqx.filter_value_and_grad(new_loss)(diff_model, static_model, x, y, s)
        updates, opt_state = optimizer.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss

    # Generate fixed grid
    x_vals = jnp.linspace(0.0, 1.0, 64)
    y_vals = jnp.linspace(-0.15, 0.15, 64)
    X, Y = jnp.meshgrid(x_vals, y_vals)
    X = X.flatten()
    Y = Y.flatten()

    geometry_type = physics[7]
    if geometry_type == "pipe":
        S = jnp.zeros_like(X)
    else:
        S = jr.uniform(key, shape=X.shape, minval=0.0, maxval=0.01)

    # Shuffle and batch
    dataset = jnp.stack([X, Y, S], axis=1)
    num_points = dataset.shape[0]
    num_batches = num_points // batch_size
    perm_key = key

    opt_state = optimizer.init(eqx.filter(model_tuple, eqx.is_inexact_array))
    losses = []

    for step_idx in range(num_iter):
        perm_key, subkey = jr.split(perm_key)
        perm = jr.permutation(subkey, num_points)
        dataset = dataset[perm]

        for batch_idx in range(num_batches):
            batch = dataset[batch_idx * batch_size: (batch_idx + 1) * batch_size]
            x_batch, y_batch, s_batch = batch[:, 0], batch[:, 1], batch[:, 2]

            model_tuple, opt_state, loss_val = step(opt_state, model_tuple, x_batch, y_batch, s_batch)

        if step_idx % freq == 0:
            #print(f"Step {step_idx}, Loss = {loss_val:.5e}")
            losses.append(loss_val)

    return model_tuple, losses

In [12]:
def initialize_models(key):
    keys = jr.split(key, 3)
    net_u = MLP(3, 1, 20, 3, keys[0])
    net_v = MLP(3, 1, 20, 3, keys[1])
    net_p = MLP(3, 1, 20, 3, keys[2])
    return (net_u, net_v, net_p)

In [24]:
import matplotlib.pyplot as plt
import jax.numpy as jnp
from jax import vmap

def evaluate_all_pipe_profiles(models_by_nu, physics_base, nu_vals,
                               save_path='/content/drive/MyDrive/UProfile.png'):
    d = 0.1
    y = jnp.linspace(-0.05, 0.05, 200)
    x = jnp.full_like(y, 0.5)
    s = jnp.zeros_like(y)

    fig, ax = plt.subplots(figsize=(7, 5))

    for i, nu_val in enumerate(nu_vals):
        model = models_by_nu[nu_val]
        net_u, net_v, _ = model
        physics = physics_base[:-1] + (nu_val,)
        dP, L = physics[3], physics[4]

        u_pred = vmap(lambda x_, y_, s_: velocity_ansatz(x_, y_, s_, net_u, net_v, physics)[0])(x, y, s)
        u_analytical = dP / (2 * nu_val * L) * ((d**2)/4 - y**2)

        ax.plot(y, u_pred, 'r--', linewidth=1.5)       # DNN
        ax.plot(y, u_analytical, 'b-', linewidth=1.5)  # Analytical

        # Label near center-top of curve
        ymax_idx = jnp.argmax(u_analytical)
        y_label = y[ymax_idx]
        u_label = u_analytical[ymax_idx]
        ax.text(float(y_label), float(u_label + 0.010), f"$\\nu$ = {nu_val:.5f}",
                ha='center', fontsize=9)

    # Add legend for line styles only
    ax.plot([], [], 'r--', label="DNN")
    ax.plot([], [], 'b-', label="Analytical")

    ax.set_xlabel(r"$y$")
    ax.set_ylabel(r"$u(y)$")
    ax.set_title(r"(a) Cross-section velocity profiles")
    ax.legend()
    ax.set_xlim(-0.055, 0.055)
    ax.set_ylim(bottom=0)
    #ax.grid(True, which='both', linestyle=':', linewidth=0.5, alpha=0.7)
    fig.tight_layout()
    fig.savefig(save_path, dpi=300)
    plt.close(fig)


In [16]:
key = jr.PRNGKey(0)
filter_spec = jax.tree_util.tree_map(lambda _: True, initialize_models(key))
optimizer = optax.adam(1e-3)

geometry_type = "pipe"  # or "stenosis"
nu_vals = [2.1e-4, 3.2e-4, 6.1e-4, 1.9e-3]
models_by_nu = {}
losses_by_nu = {}

for i, nu_fixed in enumerate(nu_vals):
    print(f"=== Training model for ν = {nu_fixed:.5e} ===")
    subkey = jr.fold_in(key, i)
    model_tuple = initialize_models(subkey)

    physics = (0.1, 0.5, 0.05, 0.1, 1.0, 0.0, 1.0, geometry_type, nu_fixed)

    model_trained, losses = train_pinn(
        model_tuple, subkey, optimizer, filter_spec, physics,
        num_iter=2000, freq=100
    )
    models_by_nu[nu_fixed] = model_trained
    losses_by_nu[nu_fixed] = losses

=== Training model for ν = 2.10000e-04 ===
=== Training model for ν = 3.20000e-04 ===
=== Training model for ν = 6.10000e-04 ===
=== Training model for ν = 1.90000e-03 ===


In [26]:
evaluate_all_pipe_profiles(models_by_nu, physics, nu_vals)