In [None]:
from pathlib import Path
from xpinn import XPINN
import optax
from type_util import Array
from jax import hessian, jacobian, jit, vmap
import jax.numpy as np
import numpy as onp

from jax import config

config.update("jax_enable_x64", True)

In [4]:
import jax.numpy as np
from jax import grad, vmap, jit, hessian
from typing import Callable
from type_util import Params

LFunc = Callable[[Params, dict[str, np.ndarray]], np.ndarray]

In [None]:
def navier_stokes_functional(model, nu):
    """
    model: The neural network model that takes (x, y, t) inputs.
    nu: The kinematic viscosity.
    
    returns relevant values for loss and simulation.
    """
    # Define auxiliary functions for u and v derivatives
    def psi_component(model, params, xyt):
        psi = model(params, xyt)[:, 0]  # Assuming first component is psi
        return psi
    
    # Compute u and v from the stream function psi
    d_psi = grad(psi_component, argnums=(2,))
    
    def u(model, params, xyt):
        return -d_psi(model, params, xyt)[1]
    
    def v(model, params, xyt):
        return d_psi(model, params, xyt)[0]
    
    # Create functions for pressure gradient components
    def p_component(model, params, xyt):
        p = model(params, xyt)[:, 1]  # Assuming second component is pressure
        return p
    
    d_p = grad(p_component, argnums=(2,))
    
    # Higher order derivatives for u and v
    def u_t(model, params, xyt):
        return grad(u, argnums=(2,))(model, params, xyt)
    
    def u_x(model, params, xyt):
        return grad(u, argnums=(2,))(model, params, xyt)[0]
    
    def u_y(model, params, xyt):
        return grad(u, argnums=(2,))(model, params, xyt)[1]
    
    def u_xx(model, params, xyt):
        return grad(u_x, argnums=(2,))(model, params, xyt)[0]

    def u_yy(model, params, xyt):
        return grad(u_y, argnums=(2,))(model, params, xyt)[1]

    def v_t(model, params, xyt):
        return grad(v, argnums=(2,))(model, params, xyt)
    
    def v_x(model, params, xyt):
        return grad(v, argnums=(2,))(model, params, xyt)[0]

    def v_y(model, params, xyt):
        return grad(v, argnums=(2,))(model, params, xyt)[1]
    
    def v_xx(model, params, xyt):
        return grad(v_x, argnums=(2,))(model, params, xyt)[0]
    
    def v_yy(model, params, xyt):
        return grad(v_y, argnums=(2,))(model, params, xyt)[1]

    # Assemble the Navier-Stokes functional
    def function(params, xyt):
        # Calculate velocity components using the stream function
        u_vel = u(model, params, xyt)
        v_vel = v(model, params, xyt)

        # Calculate pressure gradient components
        p_x = d_p(model, params, xyt)[0]
        p_y = d_p(model, params, xyt)[1]

        # Calculate time and space derivatives of velocity components
        u_t_val = u_t(model, params, xyt)
        u_x_val = u_x(model, params, xyt)
        u_y_val = u_y(model, params, xyt)
        u_xx_val = u_xx(model, params, xyt)
        u_yy_val = u_yy(model, params, xyt)
        
        v_t_val = v_t(model, params, xyt)
        v_x_val = v_x(model, params, xyt)
        v_y_val = v_y(model, params, xyt)
        v_xx_val = v_xx(model, params, xyt)
        v_yy_val = v_yy(model, params, xyt)

        # Compute the residuum of the Navier-Stokes equations
        f = u_t_val + u_vel * u_x_val + v_vel * u_y_val + p_x - nu * (u_xx_val + u_yy_val)
        g = v_t_val + u_vel * v_x_val + v_vel * v_y_val + p_y - nu * (v_xx_val + v_yy_val)

        return u_vel, v_vel, p_x, p_y, f, g

    return jit(function)
