██╗ ██╗██████╗ ███╗ ███╗██╗
██║ ██║██╔══██╗████╗ ████║██║
██║ ██║██████╔╝██╔████╔██║██║
╚██╗ ██╔╝██╔═══╝ ██║╚██╔╝██║██║
╚████╔╝ ██║ ██║ ╚═╝ ██║███████╗
╚═══╝ ╚═╝ ╚═╝ ╚═╝╚══════╝
Fourier-Hermite · Vlasov-Poisson · JAX
JAX Fourier–Hermite Vlasov–Poisson solver with a learned interface closure.
vpml is a 1D1V collisionless-plasma solver that discretises the Vlasov–Poisson
system with Fourier modes in space and orthonormal Hermite functions in
velocity, advanced with an IMEX CNAB2 scheme on top of JAX. Classical closures
for the Hermite truncation boundary (hypercollisions, Hou–Li filtering, nonlocal
closure) are included so that results from Palisso et al.,
arXiv:2412.07073 can be reproduced end-to-end.
At a glance
- Python ≥ 3.10 · JAX /
jax.numpy·float64throughout - CPU by default (including on macOS); CUDA is opt-in
- Three sibling packages:
vpml/(library),benchmarks/(paper figures),model/(learned closure) - CLI entry points:
fh-nonlinear-sim,fh-benchmarks-2412-07073,fh-ml-tail-closure-train,fh-learned-closure-eval
python -m venv venv && source venv/bin/activate
pip install -e .
# Regenerate the linear-Landau benchmark (classical truncation closure)
python -m benchmarks.fh_benchmarks_2412_07073_jax linear_landau --outdir out_benchOutputs land in out_bench/ as linear_landau_*.png.
jax,jaxlib,numpy,matplotlib,scipy
For better eigenvalue / root-finding accuracy, enable 64-bit JAX:
export JAX_ENABLE_X64=Truevpml bootstraps JAX before import and prints the active backend when the
main benchmark or model scripts start.
- On Linux,
VPML_JAX_BACKEND=autoleaves backend selection to JAX. - On macOS,
vpmldefaults to CPU rather thanjax-metal, because this repo relies heavily onfloat64and complex dtypes.
Overrides:
export VPML_JAX_BACKEND=cpu
export VPML_JAX_BACKEND=gpuIf you actually want CUDA, install a CUDA-enabled JAX build:
pip install -U "jax[cuda13]"python -m benchmarks.fh_nonlinear_sim_jax two_stream --outdir out_nl
python -m benchmarks.fh_nonlinear_sim_jax bump_on_tail --system AC --outdir out_nl --vmin -12 --vmax 12python -m benchmarks.fh_benchmarks_2412_07073_jax fig2 --outdir out_bench
python -m benchmarks.fh_benchmarks_2412_07073_jax fig3 --outdir out_bench
python -m benchmarks.fh_benchmarks_2412_07073_jax fig4 --outdir out_bench --Nv 20
python -m benchmarks.fh_benchmarks_2412_07073_jax linear_landau --method truncation --outdir out_bench
./benchmarks/run_all_benchmarks.sh out_benchpython -m benchmarks.fh_benchmarks_2412_07073_jax linear_landau \
--method learned --outdir out_bench \
--learned-checkpoint out_model/interface_closure.npz
python -m benchmarks.fh_benchmarks_2412_07073_jax fig10_learned_comparison \
--outdir out_bench --learned-checkpoint out_model/interface_closure.npz
LEARNED_CHECKPOINT=out_model/interface_closure.npz ./benchmarks/run_all_benchmarks.sh out_benchThe learned closure is intentionally not supported in the fig3
response-function or fig4 eigenvalue benchmarks: it is state-dependent, not a
fixed modified-Hermite matrix.
python -m model.train.train \
--checkpoint out_model/interface_closure.npz \
--dataset-cache out_model/interface_closure_dataset.npzWrites:
out_model/interface_closure.npzout_model/interface_closure.metrics.npzout_model/interface_closure_dataset.npzout_model/interface_closure.loss.png(if--loss-plotis passed)
The main offline lane is q_only; the pure-online lane is kept separate as online_rollout.
python -m model.eval \
--checkpoint out_model/interface_closure.npz \
--outdir out_model/evalWrites summary.json, per-case *.npz rollouts, and *_summary.png plots
under heldout_landau/ and benchmark_rollouts/.
| Wrapper | What it sweeps |
|---|---|
run_nv_sweep_single_qloss.sh |
Offline target-specific q_only |
run_nv_sweep_single_qloss_fixed_ratio.sh |
Offline fixed-ratio q_only |
run_nv_sweep_online_rollout.sh |
Pure online_rollout |
run_nv_sweep_higher_order_hermite_fixed_ratio.sh |
Higher-order-Hermite teacher, fixed ratio |
Example:
./model/train/run_nv_sweep_single_qloss.sh out_bench/nv_sweep_single_qlossEach wrapper trains one checkpoint per deployment N_v and then calls
python -m model.eval_nv_sweep ..., emitting:
summary.jsonnv_sweep_metric1.png,nv_sweep_metric2.pngfig10_learned_vs_nonlocal_nv_sweep_phase_space.pngcases/*.npz
For the offline wrappers, per-N_v dataset caches live under
models/nv*/interface_closure_dataset.npz. The pure-online wrapper writes no
dataset cache.
Repo map & design boundary
vpml/core.py— Fourier–Hermite operators, closures, implicit/CNAB2 solvers, learned-closure runtimevpml/linear_landau.py— shared linear-Landau rollout helpers and dispersion / root-finding utilitiesvpml/nonlinear_landau.py— shared nonlinear-Landau rollout runtime for benchmarks and learned-model evalvpml/physical_grid.py— physical-grid semi-Lagrangian teacher solver and projection helpersvpml/metrics/— reusable rollout metricsvpml/visualization/— reusable plotting helpersbenchmarks/fh_benchmarks_2412_07073_jax.py— paper benchmark regeneration for Palisso et al. (arXiv:2412.07073)benchmarks/run_all_benchmarks.sh— full benchmark shell entrypointbenchmarks/run_linear_landau_suite.sh— linear Landau benchmark shell entrypointbenchmarks/fh_nonlinear_sim_jax.py— standalone nonlinear physical-grid simulationsmodel/model.py— thin learned-model surface built on top ofvpmlmodel/train/train.py— learned interface-closure training entrypointmodel/train/data.py— dataset / cache / reference-building surface for learned-closure workflowsmodel/eval.py— post-train learned-model evaluationmodel/eval_nv_sweep.py— learned-model nonlinearN_vsweep evaluationmodel/train/run_nv_sweep_single_qloss.sh— per-N_vofflineq_onlysweep wrappermodel/train/run_nv_sweep_single_qloss_fixed_ratio.sh— per-N_vfixed-ratio offlineq_onlysweep wrappermodel/train/run_nv_sweep_online_rollout.sh— per-N_vpureonline_rolloutsweep wrapper Default recipe: denser nonlinear amplitude coverage,TEACHER_NX=256,TEACHER_DT=0.005, andTRAIN_ONLINE_V_PROBES=256for the top-endN_vcomparison lane Runtime note: the wrapper prebuilds a sharedonline_reference_dataset.npz, reuses it acrossN_v, and enables bounded per-N_vparallel training on larger CPU boxesmodel/train/run_nv_sweep_higher_order_hermite_fixed_ratio.sh— per-N_vhigher-order-Hermite teacher sweep wrapper