DynaMix: True Zero-Shot Inference of Dynamical Systems Preserving Long-Term Statistics (NeurIPS 2025)
This repository provides a Python implementation of DynaMix, a foundation model architecture for zero-shot inference of dynamical systems that preserves long-term statistics. Note there is also a julia version available at https://github.com/DurstewitzLab/DynaMix-julia (which was used for training the pretrained models).
Dynamical systems are ubiquitous in nature and society, from weather patterns and climate to financial markets and biological systems. While traditional forecasting methods require extensive training data for each specific system, DynaMix takes a fundamentally different approach by learning universal dynamical representations that generalize across systems.
DynaMix serves as a foundation model for dynamical systems. By training on a diverse collection of dynamical systems, DynaMix learns to extract and represent the underlying patterns and principles that govern temporal evolution across different domains. Unlike traditional methods that require retraining for each new system, DynaMix achieves zero-shot inference of a new system by only providing a short context window of observations.
- Accurate Zero-Shot DSR: DynaMix generalizes across diverse dynamical systems without fine-tuning, accurately capturing attractor geometry and long-term statistics.
- Multivariate Dynamics Modeling: The multivariate architecture captures dependencies across system dimensions and adapts flexibly to different dimensionalities and context lengths via embeddings.
- Efficient and Lightweight: Delivers high performance with a compact design, enabling orders-of-magnitude faster inference than traditional foundation models.
- Interpretable Dynamics Composition: Provides insights into the dynamical composition of reconstructed systems, revealing similarities across different dynamical systems.
- General Time Series Forecasting: Extends beyond DSR to general time series forecasting using adaptable embedding techniques.
Try DynaMix instantly through our interactive Huggingface Space at https://huggingface.co/spaces/DurstewitzLab/DynaMix. This demo lets you test the model's capabilities without installing any code or dependencies. Simply upload your time series data in the required format and generate long-term forecasts with zero-shot inference.
The project is implemented in Python using PyTorch.
Clone the repository and install the required packages:
cd zero-shot-DSR-python
pip install -r requirements.txtDynaMix is based on a sparse mixture of experts (MoE) architecture operating in latent space:
- Expert Networks: Each expert is a specialized dynamical model (AL-RNN), given through
- Gating Network: Selects experts based on the provided context and current latent representation of the dynamics
By aggregating the expert weighting
Model implementations:
- DynaMix →
src/model/dynamix.py, individual specifications for the architecture can be modified in thesettingsfile.
DynaMix enables true zero-shot forecasting - the ability to generate accurate predictions for previously unseen dynamical systems without retraining or fine-tuning:
- A short context window from the new system is provided
- DynaMix infers the underlying dynamics from this context
- The model generates long-term forecasts that preserve both trajectory accuracy and statistical properties
Pre-trained DynaMix models can be accessed via Huggingface. The model has been trained on a diverse collection of dynamical systems and serves as a foundation model for zero-shot forecasting tasks.
To use a pretrained model, load it via:
from src.utilities.utilities import load_hf_model
# Load the pre-trained model
model = load_hf_model("dynamix-3d-alrnn-v1.0")
# Set model to evaluation mode
model.eval()To successfully load a model, the name should match the model name in the Huggingface repository.
Example forecasting evaluations of dynamical systems and time series can be found in the notebooks folder.
Given context data from the target system with shape (DynaMixForecaster class:
import torch
from src.model.forecaster import DynaMixForecaster
# Create a forecaster with the trained model
forecaster = DynaMixForecaster(model)
# Make prediction
with torch.no_grad(): # No gradient tracking needed for inference
reconstruction = forecaster.forecast(
context=context_tensor,
horizon=prediction_length,
preprocessing_method="delay_embedding",
standardize=True,
fit_nonstationary=False,
initial_x=None
)The forecasting method requires the following inputs:
-
context: Context data in the form of a tensor with shape (
$T_C$ ,$S$ ,$N$ ) or ($T_C$ ,$N$ ) - horizon: Forecast horizon, i.e. an integer specifying how many future steps to forecast
Optional arguments:
-
preprocessing_method: for time series forecasting, choose between
pos_embedding,delay_embedding,delay_embedding_randomandzero_embeddingas preprocessing method (default:pos_embedding) -
standardize: standardize data?
True/False(default:True) -
fit_nonstationary: fit a non-stationary time series?
True/False(default:False) -
initial_x: Optional initial condition for the model as tensor of shape (
$S$ ,$N$ ), else last context value is used. Does not wirk withfit_nonstationary=True(default:None)
The package includes visualization functions for different types of systems requiring ground truth (
from src.utilities.plotting_eval import plot_3D_attractor, plot_2D_attractor, plot_TS_forecast
# For 3D systems
fig = plot_3D_attractor(ground_truth, context, reconstruction)
# For 2D systems
fig = plot_2D_attractor(context, reconstruction)
# For time series data
fig = plot_TS_forecast(ground_truth, context, reconstruction, prediction_length)To evaluate the performance of the model, several metrics are implemented (see metrics):
- Geometrical Misalignment: Measures how well the model captures the geometry of the attractor
- Temporal Misalignment: Measures how well the model captures the temporal patterns in the data
- MASE: Mean absolute scaled error, measures the accuracy of short-term predictions
from src.metrics.metrics import geometrical_misalignment, temporal_misalignment, MASE
dstsp = geometrical_misalignment(reconstruction_tensor, ground_truth_tensor, n_bins=30)
dh = temporal_misalignment(reconstruction_tensor, ground_truth_tensor, smoothing=20)
pe = MASE(ground_truth_tensor, reconstruction_tensor, steps=10)DynaMix is trained using backpropagation through time with sparse teacher forcing (STF), a technique that balances stability and learning:
- The model makes predictions over a sequence of time steps
- At regular intervals (specified by
n_interleave), the model's state is reset to match ground truth - This prevents error accumulation while still allowing the model to learn long-term dependencies
To train the model, see training_setup script for more details. Appropriate arguments can be parsed via the command line (or via changing the ones from the defaults in the settings):
python -m src.training.training_setup
--latent_dim 30 --experts 10 --pwl_units 2 \
--epochs 2000 --batch_size 256 --device cudaDuring training, the model is saved every ssi (scalar_saving_interval) epochs. The saved model state dictionary can be found in the specified save_path directory under the "checkpoints" subfolder. Additionally, training metrics are saved separately in a metrics.pt file.
If you use DynaMix in your research, please cite our paper:
@misc{hemmer2025truezeroshotinferencedynamical,
title={True Zero-Shot Inference of Dynamical Systems Preserving Long-Term Statistics},
author={Christoph Jürgen Hemmer and Daniel Durstewitz},
year={2025},
eprint={2505.13192},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2505.13192},
}

