In [52]:
import jax
import jax.numpy as jnp
from typing import Any
from flax import linen as nn
from jax import random, grad, vmap
from pikan.model_utils import gradf

# create 2 output NN
# compute vorticy w from the outputs derivatives

In [32]:
class VelocityNet(nn.Module):
    features: Any

    @nn.compact
    def __call__(self, inputs):
        x = inputs
        for feature in self.features:
            x = nn.Dense(features=feature)(x)  # Fully connected layer
            x = nn.tanh(x)  # Activation function (tanh)
        x = nn.Dense(features=2)(x)  # Final output layer with 2 outputs (u, v)
        return x

features = [32, 32, 32]    
model = VelocityNet(features)

dp = jnp.ones(3)
params = model.init(random.PRNGKey(0), dp)

In [53]:
class NS_torus():
    def __init__(self, model):
        self.model = model
        self.RE = 10
        
        self.r_net_fn = vmap(self.r_net, (None, 0,0,0))

    def neural_net(self, params, x, y, t):
        collocs = jnp.stack([x,y,t])
        outputs = self.model.apply(params, collocs)

        u = outputs[0]
        v = outputs[1]
        return u, v

    def u_net(self, params, x, y, t):
        u, _ = self.neural_net(params, x,y,t)
        return u
    
    def v_net(self, params, x, y, t):
        _, v = self.neural_net(params, x,y,t)
        return v

    def w_net(self, params, x, y, t):
        v_x = grad(self.v_net, argnums=1)(params,x,y,t)
        u_y = grad(self.u_net, argnums=2)(params,x,y,t)
        
        return v_x + u_y
    
    def r_net(self, params, x, y, t):
        # w_t + U * grad w - 1/Re * Laplacian w
        w_t = grad(self.w_net, argnums=3)(params,x,y,t)
        w_x = grad(self.w_net, argnums=1)(params,x,y,t)
        w_y = grad(self.w_net, argnums=2)(params,x,y,t)
        
        u, v = self.neural_net(params, x, y, t)
        
        w_xx = grad(grad(self.w_net, argnums=1), argnums=1)(params,x,y,t)
        w_yy = grad(grad(self.w_net, argnums=2), argnums=2)(params,x,y,t)
        eq1 = w_t + (u*w_x + v*w_y) - 1/self.RE * (w_xx+w_yy)
        
        # u_x + v_y = 0 eq2
        u_x = grad(self.u_net, argnums=1)(params,x,y,t)
        v_y = grad(self.v_net, argnums=2)(params,x,y,t)
        eq2 = u_x + v_y
        
        return eq1, eq2

In [56]:
ns_torus = NS_torus(model)
x = 1.
y = 2.
t = 3.

ns_torus.u_net(params, x, y, t)
ns_torus.v_net(params, x, y, t)
ns_torus.w_net(params, x, y, t)

BS = 32
x_star = jnp.ones((BS)) 
y_star = jnp.ones((BS))
t_star = jnp.ones((BS))
ns_torus.r_net_fn(params, x_star, y_star, t_star)

(Array([-0.20340623, -0.20340623, -0.20340623, -0.20340623, -0.20340623,
        -0.20340623, -0.20340623, -0.20340623, -0.20340623, -0.20340623,
        -0.20340623, -0.20340623, -0.20340623, -0.20340623, -0.20340623,
        -0.20340623, -0.20340623, -0.20340623, -0.20340623, -0.20340623,
        -0.20340623, -0.20340623, -0.20340623, -0.20340623, -0.20340623,
        -0.20340623, -0.20340623, -0.20340623, -0.20340623, -0.20340623,
        -0.20340623, -0.20340623], dtype=float32),
 Array([-0.17909324, -0.17909324, -0.17909324, -0.17909324, -0.17909324,
        -0.17909324, -0.17909324, -0.17909324, -0.17909324, -0.17909324,
        -0.17909324, -0.17909324, -0.17909324, -0.17909324, -0.17909324,
        -0.17909324, -0.17909324, -0.17909324, -0.17909324, -0.17909324,
        -0.17909324, -0.17909324, -0.17909324, -0.17909324, -0.17909324,
        -0.17909324, -0.17909324, -0.17909324, -0.17909324, -0.17909324,
        -0.17909324, -0.17909324], dtype=float32))