In [1]:
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
from pathlib import Path
root_path = Path.cwd().parent.absolute()
import sys
sys.path.append(str(root_path))
import os

In [2]:
from architectures.architectures import MLP

from flax import nnx
import jax
from jax._src.tree_util import tree_structure
import jax.numpy as jnp
from jaxtyping import Array

In [3]:
device = jax.devices()[0] if jax.devices() else jax.devices('cpu')[0]
print(f"Using device: {device}")

Using device: cuda:0


# Import model

In [4]:
model = MLP(din=1, num_layers=3, width_layers=64, dout=1, activation_fn="SinTu", rngs=nnx.Rngs(0))

In [5]:
y = model(x = jnp.ones((10,1),device= device))

# NODE setup

In [6]:
from architectures.node import NeuralODE



In [7]:
key = jax.random.PRNGKey(0)
rngs = nnx.Rngs(key)
model = MLP(din=2, num_layers=3, width_layers=64, dout=2, activation_fn="tanh", rngs=nnx.Rngs(0))
node = NeuralODE(
    dynamics_model=model,
    time_dependent=False,
    dt0=0.1,
    rtol=1e-4,
    atol=1e-6
)

# Initial condition
y0 = jnp.array([1.0, 0.5])

# Solve from t=0 to t=1
y_final = node(y0, (0.0, 1.0))
print("Final state:", y_final)

Final state: [[1.2042602  0.14924549]]


# G Mat

In [8]:
from geometry.G_matrix import G_matrix

In [9]:
G_mat = G_matrix(node)

In [10]:
# Test

key = jax.random.PRNGKey(0)
rngs = nnx.Rngs(key)

n_samples = 100_000
d = 2
z_samples = jax.random.normal(key,(n_samples,d))


In [11]:
_,params = nnx.split(node)
key,subkey = jax.random.split(key)
eta = jax.tree.map(lambda p:jax.random.normal(subkey, p.shape)*10, params) #   jax.numpy.ones(p.shape)
key,subkey = jax.random.split(key)
epsilon = jax.tree.map(lambda p: jax.random.normal(subkey, p.shape)*0.1, params)
eta_pert = jax.tree.map(lambda e,ep: e+ep, eta, epsilon)

In [12]:
# Compute G_hat@eta
result = G_mat.mvp(z_samples, eta)

In [13]:
# Test solver 
result_solver = G_mat.solve_system(z_samples, result, tol=1e-6, maxiter=50,x0 = None, method="minres")

{'success': False, 'iterations': 50, 'norm_res': Array(0.00699871, dtype=float32)}


In [14]:
jax.tree.map(lambda x,y: jnp.linalg.norm(x-y)/jnp.linalg.norm(y), result_solver, eta)

State({
  'dynamics': {
    'layers': {
      0: {
        'bias': VariableState( # 1 (4 B)
          type=Param,
          value=Array(0.97402227, dtype=float32)
        ),
        'kernel': VariableState( # 1 (4 B)
          type=Param,
          value=Array(1.0005058, dtype=float32)
        )
      },
      2: {
        'bias': VariableState( # 1 (4 B)
          type=Param,
          value=Array(0.97868973, dtype=float32)
        ),
        'kernel': VariableState( # 1 (4 B)
          type=Param,
          value=Array(1.0002952, dtype=float32)
        )
      },
      4: {
        'bias': VariableState( # 1 (4 B)
          type=Param,
          value=Array(0.94813395, dtype=float32)
        ),
        'kernel': VariableState( # 1 (4 B)
          type=Param,
          value=Array(0.99911475, dtype=float32)
        )
      },
      6: {
        'bias': VariableState( # 1 (4 B)
          type=Param,
          value=Array(0.19923553, dtype=float32)
        ),
        'kernel': VariableS

In [17]:
verify_result = G_mat.mvp(z_samples, result_solver)

In [18]:

jax.tree.leaves(jax.tree.map(lambda x,y: jnp.linalg.norm(x-y)/jnp.linalg.norm(y), result, verify_result))

[Array(6.983893e-05, dtype=float32),
 Array(0.00012929, dtype=float32),
 Array(7.586988e-05, dtype=float32),
 Array(7.900696e-05, dtype=float32),
 Array(5.2768468e-05, dtype=float32),
 Array(7.3272946e-05, dtype=float32),
 Array(5.998748e-06, dtype=float32),
 Array(5.5182903e-05, dtype=float32)]