In [21]:
import jax
import jax.numpy as jnp
import equinox as eqx
from jax import grad, vmap, random
import optax
import matplotlib.pyplot as plt


# Parameters
sigma_1 = 0.5
v_max = 2.0
rho_max = 1.0
a = 0.9
b = 0.975
L = 10.0
N = 100
dx = L / (N-1)
x = jnp.linspace(0, L, N)
T = 1.0  # final time
dt = 0.01
time_steps = int(T/dt)

# Initial condition
rho_init = jnp.zeros(N)


In [36]:
import jax
import jax.numpy as jnp
import equinox as eqx

class RhoNN(eqx.Module):
    layers: list
    final_layer: eqx.nn.Linear

    def __init__(self, key):
        keys = jax.random.split(key, 4)
        self.layers = [
            eqx.nn.Linear(2, 64, key=keys[0]),
            eqx.nn.Linear(64, 64, key=keys[1]),
            eqx.nn.Linear(64, 64, key=keys[2]),
        ]
        self.final_layer = eqx.nn.Linear(64, 1, key=keys[3])

    def __call__(self, x, t):
        xt = jnp.concatenate([x, t], axis=-1)
        for layer in self.layers:
            xt = jax.nn.relu(layer(xt))
        return self.final_layer(xt)


In [37]:
def pde_loss(rho_n, x, t, sigma_1, v_max, rho_max):
    grad_rho_x = jax.grad(lambda x, t: rho_n(x, t), argnums=0)(x, t)
    grad_rho_t = jax.grad(lambda x, t: rho_n(x, t), argnums=1)(x, t)
    laplacian_rho = jax.grad(jax.grad(rho_n, argnums=0), argnums=0)(x, t)
    
    f_rho = v_max * (1 - 2 * rho_n(x, t) / rho_max)
    pde_residual = grad_rho_t - sigma_1 * laplacian_rho + grad_rho_x * f_rho
    return jnp.mean(pde_residual ** 2)

def boundary_loss(rho_n, x_in, x_out, t, a, b, rho_max):
    loss_in = jnp.mean((rho_n(x_in, t) - a * (rho_max - rho_n(x_in, t))) ** 2)
    loss_out = jnp.mean((rho_n(x_out, t) - b * rho_n(x_out, t)) ** 2)
    return loss_in + loss_out

def total_loss(rho_n, x, t, x_in, x_out, sigma_1, v_max, rho_max, a, b):
    return pde_loss(rho_n, x, t, sigma_1, v_max, rho_max) + boundary_loss(rho_n, x_in, x_out, t, a, b, rho_max)


In [38]:
import optax

key = jax.random.PRNGKey(0)
rho_0 = RhoNN(key)
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(rho_0)

def step(rho_n, opt_state, x, t, x_in, x_out, sigma_1, v_max, rho_max, a, b):
    loss, grads = jax.value_and_grad(total_loss)(rho_n, x, t, x_in, x_out, sigma_1, v_max, rho_max, a, b)
    updates, opt_state = optimizer.update(grads, opt_state, rho_n)
    rho_n = eqx.apply_updates(rho_n, updates)
    return rho_n, opt_state, loss

x = jnp.linspace(0, L, N).reshape(-1, 1)
t = jnp.linspace(0, T, time_steps).reshape(-1, 1)
x_in = jnp.array([[0]])
x_out = jnp.array([[L]])

for epoch in range(1000):
    rho_0, opt_state, loss_value = step(rho_0, opt_state, x, t, x_in, x_out, sigma_1, v_max, rho_max, a, b)
    if epoch % 100 == 0:
        print(f'Epoch {epoch}, Loss: {loss_value}')


TypeError: Cannot concatenate arrays with shapes that differ in dimensions other than the one being concatenated: concatenating along dimension 1 for shapes (100, 1), (10000, 1).

: 

In [14]:
sigma = 0.5**2


L = 3
T = 20

a = 0.9
b = 0.975


v_max = 2
p_max = 1 # rho_max cannot be learned Susana's paper

In [18]:
import jax
import jax.numpy as jnp
import equinox as eqx

# MLP class definition provided by the user
class MLP(eqx.Module):
    layers: list

    def __init__(self, key, input_dim, hidden_dims, n_layers, output_dim):
        keys = jax.random.split(key, n_layers + 1)
        dims = [input_dim] + [hidden_dims] * n_layers + [output_dim]
        self.layers = [eqx.nn.Linear(dims[i], dims[i + 1], key=keys[i]) for i in range(len(dims) - 1)]

    def __call__(self, x):
        for layer in self.layers[:-1]:
            x = jax.nn.sigmoid(layer(x))
        x = self.layers[-1](x)
        return x

# Define the PDE loss function
def pde_loss(model, input, sigma, v_max, rho_max):
    rho, grad = jax.vmap(jax.value_and_grad(model))(input)
    hess = jax.vmap(jax.hessian(model))(input)

    dp_dx = grad[:, 0]
    dp_dt = grad[:, 1]
    d2p_dx2 = hess[:, 0, 0]

    f_rho = v_max * (1 - rho / rho_max)
    df_rho = -v_max / rho_max

    return jnp.mean((dp_dt - sigma * d2p_dx2 + dp_dx * (f_rho + rho * df_rho))**2)

# Define the boundary condition loss function at x = 0
def bc_loss_left(model, t_vals, a, sigma, v_max, rho_max):
    input = jnp.stack([jnp.zeros_like(t_vals), t_vals], axis=1)
    rho, grad = jax.vmap(jax.value_and_grad(model))(input)

    dp_dx = grad[:, 0]
    f_rho = v_max * (1 - rho / rho_max)

    return jnp.mean((dp_dx * sigma - f_rho * rho - a * (rho_max - rho))**2)

# Define the boundary condition loss function at x = L
def bc_loss_right(model, t_vals, L, b, sigma, v_max, rho_max):
    input = jnp.stack([jnp.ones_like(t_vals) * L, t_vals], axis=1)
    rho, grad = jax.vmap(jax.value_and_grad(model))(input)

    dp_dx = grad[:, 0]
    f_rho = v_max * (1 - rho / rho_max)

    return jnp.mean((dp_dx * sigma - f_rho * rho - b * rho)**2)

# Example usage
key = jax.random.PRNGKey(0)
model = MLP(key, input_dim=2, hidden_dims=64, n_layers=3, output_dim='scalar')

# Define the inputs
x = jnp.linspace(0, L, 100)
t = jnp.linspace(0, T, 100)
xv, tv = jnp.meshgrid(x, t)
input = jnp.stack([xv.flatten(), tv.flatten()], axis=1)

# Calculate losses
pde_loss_val = pde_loss(model, input, sigma, v_max, rho_max)
bc_loss_left_val = bc_loss_left(model, t, a, sigma, v_max, rho_max)
bc_loss_right_val = bc_loss_right(model, t, L, b, sigma, v_max, rho_max)

print(f"PDE Loss: {pde_loss_val}")
print(f"Boundary Condition Loss (Left): {bc_loss_left_val}")
print(f"Boundary Condition Loss (Right): {bc_loss_right_val}")


PDE Loss: 6.453939249695395e-08
Boundary Condition Loss (Left): 1.0498206615447998
Boundary Condition Loss (Right): 0.17961753904819489


In [22]:
import optax

# Total loss function
def total_loss(model, x_vals, t_vals, sigma, v_max, rho_max, a, b, L):
    input = jnp.stack([x_vals, t_vals], axis=1)
    pde_loss_val = pde_loss(model, input, sigma, v_max, rho_max)
    bc_loss_left_val = bc_loss_left(model, t_vals, a, sigma, v_max, rho_max)
    bc_loss_right_val = bc_loss_right(model, t_vals, L, b, sigma, v_max, rho_max)
    return pde_loss_val + bc_loss_left_val + bc_loss_right_val

# Training procedure
def train(model, x_vals, t_vals, sigma, v_max, rho_max, a, b, L, epochs, learning_rate):
    optimizer = optax.adam(learning_rate)
    opt_state = optimizer.init(model)

    @jax.jit
    def step(model, opt_state, x_vals, t_vals):
        loss, grads = jax.value_and_grad(total_loss)(model, x_vals, t_vals, sigma, v_max, rho_max, a, b, L)
        updates, opt_state = optimizer.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss

    for epoch in range(epochs):
        model, opt_state, loss = step(model, opt_state, x_vals, t_vals)
        if epoch % 100 == 0:
            print(f"Epoch {epoch}, Loss: {loss}")

    return model

# Example usage
key = jax.random.PRNGKey(0)
model = MLP(key, input_dim=2, hidden_dims=64, n_layers=3, output_dim='scalar')

# Define the inputs
x = jnp.linspace(0, L, 100)
t = jnp.linspace(0, T, 100)
xv, tv = jnp.meshgrid(x, t)
x_vals = xv.flatten()
t_vals = tv.flatten()

# Train the model
trained_model = train(model, x_vals, t_vals, sigma, v_max, rho_max, a, b, L, epochs=1000, learning_rate=1e-3)

# Print final loss
final_loss = total_loss(trained_model, x_vals, t_vals, sigma, v_max, rho_max, a, b, L)
print(f"Final Loss: {final_loss}")


Epoch 0, Loss: 0.7285966873168945
Epoch 100, Loss: 0.15413591265678406
Epoch 200, Loss: 0.15089228749275208
Epoch 300, Loss: 0.14911772310733795


KeyboardInterrupt: 

In [23]:
# Calculate losses
pde_loss_val = pde_loss(model, input, sigma, v_max, rho_max)
bc_loss_left_val = bc_loss_left(model, t, a, sigma, v_max, rho_max)
bc_loss_right_val = bc_loss_right(model, t, L, b, sigma, v_max, rho_max)

print(f"PDE Loss: {pde_loss_val}")
print(f"Boundary Condition Loss (Left): {bc_loss_left_val}")
print(f"Boundary Condition Loss (Right): {bc_loss_right_val}")


PDE Loss: 0.0009129050886258483
Boundary Condition Loss (Left): 0.5765161514282227
Boundary Condition Loss (Right): 0.15116766095161438


: 