In [3]:
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
from pathlib import Path
root_path = Path.cwd().parent.absolute()
import sys
sys.path.append(str(root_path))
import os

In [4]:
from architectures.architectures import MLP
from flax import nnx
import jax
import jax.numpy as jnp
from jaxtyping import Array

In [5]:
device = jax.devices()[0] if jax.devices() else jax.devices('cpu')[0]
print(f"Using device: {device}")

Using device: cuda:0


# Import model

In [6]:
model = MLP(din=1, num_layers=3, width_layers=64, dout=1, activation_fn="SinTu", rngs=nnx.Rngs(0))

In [7]:
y = model(x = jnp.ones((10,1),device= device))

# NODE setup

In [8]:
from diffrax import diffeqsolve, ODETerm,Euler,Heun

In [10]:
# Neural ODE class
class NeuralODE(nnx.Module):
    def __init__(self, 
                 dynamics_model = nnx.Module,
                 time_dependent: bool = False,
                 solver=Euler(),
                 dt0=0.1,
                 rtol=1e-4,
                 atol=1e-6):
        self.dynamics = dynamics_model
        self.solver = solver
        self.dt0 = dt0
        self.rtol = rtol
        self.atol = atol
        self.time_dependent = time_dependent

    # Define the vector field function
    def vector_field(self,t, y, args):
        data = y
        if self.time_dependent:
            data = jnp.concatenate([t[:,None], y], axis=-1)  # Add time as a feature
        return self.dynamics(data)

    @nnx.jit
    def __call__(self, y0: Array, t_span: tuple) -> Array:
        """
        Solve the ODE from t_span[0] to t_span[1] with initial condition y0
        
        Args:
            y0: Initial condition, shape (batch_size, feature_dim) or (feature_dim,)
            t_span: Tuple of (t0, t1) for integration bounds
            
        Returns:
            Final state at time t1
        """
        
        
        
        # Create the ODE term - this is what diffrax expects
        term = ODETerm(self.vector_field)
    
        
        # Solve the ODE
        solution = diffeqsolve(
            terms=term,
            solver=self.solver,
            t0=t_span[0],
            t1=t_span[1], 
            dt0=self.dt0,
            y0=y0,
        )
        
        return solution.ys[-1]  # Return final state

In [12]:
key = jax.random.PRNGKey(42)
rngs = nnx.Rngs(key)
model = MLP(din=2, num_layers=3, width_layers=64, dout=2, activation_fn="SinTu", rngs=nnx.Rngs(0))
node = NeuralODE(
    dynamics_model=model,
    time_dependent=False,
    solver=Euler(),
    dt0=0.1,
    rtol=1e-4,
    atol=1e-6
)

# Initial condition
y0 = jnp.array([1.0, 0.5])

# Solve from t=0 to t=1
y_final = node(y0, (0.0, 1.0))
print("Final state:", y_final)

Final state: [1.3593439 0.4089642]
