In [1]:
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
from pathlib import Path
root_path = Path.cwd().parent.absolute()
import sys
sys.path.append(str(root_path))
import os

In [2]:
from architectures.architectures import MLP
from flax import nnx
import jax
from jax._src.tree_util import tree_structure
import jax.numpy as jnp
from jaxtyping import Array

In [3]:
device = jax.devices()[0] if jax.devices() else jax.devices('cpu')[0]
print(f"Using device: {device}")

Using device: cuda:0


# Import model

In [4]:
model = MLP(din=1, num_layers=3, width_layers=64, dout=1, activation_fn="SinTu", rngs=nnx.Rngs(0))

In [5]:
y = model(x = jnp.ones((10,1),device= device))

# NODE setup

In [6]:
from diffrax import diffeqsolve, ODETerm,Euler,Heun,DirectAdjoint
from ODE_solvers.solvers import string_2_solver

from jaxtyping import PyTree
from typing import Tuple,Optional

In [7]:
# Neural ODE class
class NeuralODE(nnx.Module):
    def __init__(self, 
                 dynamics_model = nnx.Module,
                 time_dependent: bool = False,
                 solver: str = "euler",
                 dt0=0.1,
                 rtol=1e-4,
                 atol=1e-6):
        self.dynamics = dynamics_model
        self.solver = string_2_solver(solver)
        self.dt0 = dt0
        self.rtol = rtol
        self.atol = atol
        self.time_dependent = time_dependent

    # Define the vector field function
    def vector_field(self,t, y, args):
        data = y
        if self.time_dependent:
            data = jnp.concatenate([t[:,None], y], axis=-1)  # Add time as a feature
        return self.dynamics(data)

    # @nnx.jit
    def __call__(self, y0: Array, t_span: Tuple[float,float], params: Optional[PyTree] = None) -> Array:
        """
        Solve the ODE from t_span[0] to t_span[1] with initial condition y0
        
        Args:
            y0: Initial condition, shape (batch_size, feature_dim) or (feature_dim,)
            t_span: Tuple of (t0, t1) for integration bounds
            
        Returns:
            Final state at time t1
        """
        
        if params is None:
            model = self.dynamics
        else:
            graphdef,_ = nnx.split(self.dynamics)
            model = nnx.merge(graphdef, params)
        # Use defined model for the vector field
        def vector_field(t: float, y: Array, args: Optional[dict] = None):
            data = y
            if self.time_dependent:
                data = jnp.concatenate([t[:,None], y], axis=-1)  # Add time as a feature
            return model(data)
        
        t_list = jnp.arange(t_span[0], t_span[1], self.dt0)

        y = self.solver(vector_field,t_list,y0,history=False)
       
        
        return y.reshape(-1,y0.shape[-1])  # Return final state

In [8]:
key = jax.random.PRNGKey(0)
rngs = nnx.Rngs(key)
model = MLP(din=2, num_layers=3, width_layers=64, dout=2, activation_fn="SinTu", rngs=nnx.Rngs(0))
node = NeuralODE(
    dynamics_model=model,
    time_dependent=False,
    dt0=0.1,
    rtol=1e-4,
    atol=1e-6
)

# Initial condition
y0 = jnp.array([1.0, 0.5])

# Solve from t=0 to t=1
y_final = node(y0, (0.0, 1.0))
print("Final state:", y_final)

Final state: [[1.3219943  0.41655946]]


In [9]:
graphdef,state = nnx.split(node)

In [10]:
a = jax.tree.map(lambda x: x ,state)

In [20]:
from typing import Callable

def minres(A_func: Callable, b: PyTree, tol: float = 3e-4, x0: Optional[PyTree] = None, maxiter: int = 100) -> PyTree:
    """
    Simplified MINRES for your G matrix system with PyTree support.
    """
    @jax.jit
    def dot_tree(x: PyTree, y: PyTree) -> Array:
        return sum(jax.tree.leaves(jax.tree.map(lambda a, b: jnp.sum(a * b), x, y)))

    @jax.jit
    def norm_tree(x: PyTree) -> Array:
        return jnp.sqrt(dot_tree(x, x))

    @jax.jit
    def clone_tree(x: PyTree) -> PyTree:
        return jax.tree.map(lambda a: jnp.array(a), x)
    @jax.jit
    def xpay_tree(x: PyTree, y: PyTree, alpha: float) -> PyTree:
        return jax.tree.map(lambda a, b: a + alpha * b, x, y)
    
    # def step():

    if x0 is None:
        x = jax.tree.map(jnp.zeros_like, b)
    else:
        x = clone_tree(x0)

    Ax = A_func(x)
    r = xpay_tree(b, Ax, -1.0)
    p0 = clone_tree(r)
    s0 = A_func(p0)
    p1 = clone_tree(p0)
    s1 = clone_tree(s0)

    for i in range(maxiter):
        
        p2 = clone_tree(p1)
        p1 = clone_tree(p0)
        s2 = clone_tree(s1)
        s1 = clone_tree(s0)

        alpha = dot_tree(r, s1) / dot_tree(s1, s1)

        x = xpay_tree(x, p1, alpha)
        r = xpay_tree(r, s1, -alpha)

        if norm_tree(r) < tol**2:
            print(f"Converged in {i} iterations")
            info = {"success": True, "iterations": i, "norm_res": norm_tree(r)}
            break

        p0 = clone_tree(s1)
        s0 = A_func(s1)
        beta1 = dot_tree(s0, s1) / dot_tree(s1, s1)
        p0 = xpay_tree(p0, p1, -beta1)
        s0 = xpay_tree(s0, s1, -beta1)

        if i > 1:
            beta2 = dot_tree(s0,s2)/dot_tree(s2,s2)
            p0 = xpay_tree(p0, p2, -beta2)
            s0 = xpay_tree(s0, s2, -beta2)
    if i == maxiter - 1:
        info = {"success": False, "iterations": maxiter, "norm_res": norm_tree(r)}
    print(info)
    return x, info

    

In [21]:
from functools import partial
from jax import jit,grad,vmap
from jax.scipy.sparse.linalg import cg,gmres
class G_matrix:
    '''
    Computation of G matrix
    '''

    def __init__(self, node: nnx.Module ):
        
        '''
        Initialize G matrix computation 

        Args:
            node: Neural ODE model nnx.Module instance     
        '''

        self.node = node
        

    @partial(jit,static_argnums = (0,))
    def mvp(self,z_samples: Array, eta: PyTree, params: Optional[PyTree] = None)-> PyTree:
        '''
        Computation of G eta
        Args:
            z_samples: (Bs,d) Samples from reference density
            eta: PyTree with same GraphDef as node
            parms: PyTree where the G matrix is computed at
        Return:
            G(theta) eta : PyTree
        '''

        if params is None:
            
            _,params = nnx.split(self.node)
        
        def single_sample_contribution(z: Array)-> PyTree:

            # Define the flow map

            def flow_map(p):

                return self.node(z.reshape(1,-1), (0.0, 1.0), params=p)

            # Step 1: Compute \partial_{theta}T @ eta using Jvp

            jvp_result = jax.jvp(flow_map,(params,),(eta,))[1]

            # Step 2: Compute \partial_{\theta}T @ jvp_result

            _,vjp_fn = jax.vjp(flow_map,params)

            result = vjp_fn(jvp_result)[0]

            return result
    
        # Vectorize over all samples

        contributions = vmap(single_sample_contribution)(z_samples)

        return jax.tree.map(lambda x: jnp.mean(x, axis=0), contributions)
    
    # @partial(jit,static_argnums = (0,6))
    def solve_system(self,z_samples: Array, b: PyTree, params: Optional[PyTree] = None, tol: float = 1e-5, maxiter: int = 10, method: str = "cg", x0: Optional[PyTree] = None) -> PyTree:

        '''
        Solve G(theta) x = b using conjugate gradient method

        Args:
            z_samples: (Bs,d) Samples from reference density
            b: PyTree with same GraphDef as node
            parms: PyTree where the G matrix is computed at
            tol: Tolerance for CG solver
            maxiter: Maximum number of iterations for CG solver
            method: Method to use for solving the linear system ("cg" or "gmres")
            x0: Initial guess for the solution

        Returns:
            x: PyTree solution to G(theta)x = b
        '''
        if method not in ["cg","gmres","minres"]:
            raise ValueError(f"Unknown method: {method}")
        if method == "cg":
            solver = cg
        elif method == "gmres":
            solver = gmres
        elif method == "minres":
            solver = minres
        if params is None:
            _,params = nnx.split(self.node)
        # Define the linear operator for G(theta)
        matvec = lambda eta: self.mvp(z_samples, eta, params)
        # Use Jax inbuilts methods cg or gmres. 
        x,info = solver(matvec,b,tol = tol, maxiter = maxiter,x0=x0)
        
        # x,info = minres(matvec, b, tol=tol, maxiter=maxiter,x0 = x0)
        return x


In [22]:
G_mat = G_matrix(node)

In [23]:
# Test

key = jax.random.PRNGKey(0)
rngs = nnx.Rngs(key)

n_samples = 100_000
d = 2
z_samples = jax.random.normal(key,(n_samples,d))


In [50]:
_,params = nnx.split(node)
key,subkey = jax.random.split(key)
eta = jax.tree.map(lambda p:jax.random.normal(subkey, p.shape)*10, params) #   jax.numpy.ones(p.shape)
key,subkey = jax.random.split(key)
epsilon = jax.tree.map(lambda p: jax.random.normal(subkey, p.shape)*0.1, params)
eta_pert = jax.tree.map(lambda e,ep: e+ep, eta, epsilon)

In [51]:
# Compute G_hat@eta
result = G_mat.mvp(z_samples, eta)

In [52]:
# Test solver 
result_solver = G_mat.solve_system(z_samples, result, tol=1e-6, maxiter=50,x0 = eta_pert, method="minres")

{'success': False, 'iterations': 50, 'norm_res': Array(0.00059873, dtype=float32)}


In [53]:
jax.tree.map(lambda x,y: jnp.linalg.norm(x-y)/jnp.linalg.norm(y), result_solver, eta)

State({
  'dynamics': {
    'layers': {
      0: {
        'bias': VariableState( # 1 (4 B)
          type=Param,
          value=Array(0.00914392, dtype=float32)
        ),
        'kernel': VariableState( # 1 (4 B)
          type=Param,
          value=Array(0.00967326, dtype=float32)
        )
      },
      2: {
        'bias': VariableState( # 1 (4 B)
          type=Param,
          value=Array(0.00933145, dtype=float32)
        ),
        'kernel': VariableState( # 1 (4 B)
          type=Param,
          value=Array(0.0096993, dtype=float32)
        )
      },
      4: {
        'bias': VariableState( # 1 (4 B)
          type=Param,
          value=Array(0.00953458, dtype=float32)
        ),
        'kernel': VariableState( # 1 (4 B)
          type=Param,
          value=Array(0.00981776, dtype=float32)
        )
      },
      6: {
        'bias': VariableState( # 1 (4 B)
          type=Param,
          value=Array(0.0084059, dtype=float32)
        ),
        'kernel': VariableS

In [54]:
verify_result = G_mat.mvp(z_samples, result_solver)

In [55]:

jax.tree.leaves(jax.tree.map(lambda x,y: jnp.linalg.norm(x-y)/jnp.linalg.norm(y), result, verify_result))

[Array(0.00116854, dtype=float32),
 Array(0.00106134, dtype=float32),
 Array(0.00177689, dtype=float32),
 Array(0.00133497, dtype=float32),
 Array(0.0026245, dtype=float32),
 Array(0.00197169, dtype=float32),
 Array(0.00475153, dtype=float32),
 Array(0.00286135, dtype=float32)]