In [1]:
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
from pathlib import Path
root_path = Path.cwd().parent.absolute()
import sys
sys.path.append(str(root_path))
import os

In [None]:
from flax import nnx
import jax
from jax._src.tree_util import tree_structure
import jax.numpy as jnp
from jaxtyping import Array

In [5]:
from architectures.node import NeuralODE
from architectures.architectures import MLP
from geometry.G_matrix import G_matrix
from functionals.functions import create_potentials
from functionals.linear_funcitonal_class import LinearPotential
from flows.gradient_flow import run_gradient_flow
from flows.visualization import visualize_gradient_flow_results

In [6]:
# Create potentials
potentials = create_potentials()

In [7]:
# Gradient flow parameters
h = 1e-2 # Time step size
max_iterations = 100
tolerance = 1e-6
n_samples = 4_000  # Monte Carlo sample size


key = jax.random.PRNGKey(0)
rngs = nnx.Rngs(key)

# Create NODE model
gradient_flow_model = MLP(din=2, num_layers=1, width_layers=32, dout=2, 
                         activation_fn="tanh", rngs=rngs)
gradient_flow_node = NeuralODE(
    dynamics_model=gradient_flow_model,
    time_dependent=False,
    dt0=0.1,
)

# Set parameters to be very small (near zero)
_, initial_params = nnx.split(gradient_flow_node)
scale_factor = 1e-2  # Small scaling to keep near identity

# Scale all parameters to be small
small_params = jax.tree.map(lambda p: p * scale_factor, initial_params)

# Merge back into the node
graphdef, _ = nnx.split(gradient_flow_node)
gradient_flow_node = nnx.merge(graphdef, small_params)

# Generate reference samples from λ = N(0, I)
key, subkey = jax.random.split(key)
z_samples = jax.random.normal(subkey, (n_samples, 2))

G_mat_flow = G_matrix(gradient_flow_node)

potential = potentials['strong_double_well'] # Options: quadratic, double_well, strong_double_well,quartic
solver= 'minres'


In [None]:

results = run_gradient_flow(
    gradient_flow_node, z_samples, G_mat_flow,
    potential,solver=solver,
    h=h, max_iterations=max_iterations, tolerance=tolerance
)

Starting gradient flow with LinearPotential...
Potential function: double_well_potential_fn
Potential parameters: {'alpha': 5.0}
Selected device: cuda:0
Initial energy: 8.136963
Target: minimize energy functional


2025-08-27 14:01:49.745975: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.


In [None]:
# Visualize results
visualize_gradient_flow_results(results)