Development repository. This is the experimental/development repo. The official, release lives at github.com/sisinflab/WaveDiT; Project page: danesed.github.io/wavedit-page.
PyTorch implementation of "WaveDiT: Distribution-Aware Wavelet Flow Matching for Efficient 3D Brain MRI Synthesis" (MICCAI 2026).
WaveDiT synthesises full resolution, high-fidelity, conditional 3D brain MRIs by performing flow matching in the 3D Haar wavelet domain with a slice-wise HDiT backbone, guided by Morpheus, a state-aware uncertainty scheduler that adaptively weights the loss and sampling across frequency bands.
Links: Project page · Paper (MICCAI 2026, proceedings link forthcoming) · arXiv (preprint)
- Wavelet flow matching: operates on the 8-channel 3D Haar latent (1 LLL + 7 HF bands).
- Morpheus uncertainty scheduler: Bayesian heteroscedastic loss weighting + uncertainty-minimising sampling guidance.
- HDiT backbone: neighbourhood + spatio-depth factorised attention for efficient 3D modelling.
- Multiple flow formulations:
cfm,rectified,ot_fm. - Conditional synthesis: numeric and categorical metadata (e.g. age), with classifier-free guidance.
- Single-file configs: one YAML fully describes a run; checkpoints are self-contained for generation.
Wavelet subbands are not statistically equal: the low-frequency approximation stays close to Gaussian, while the high-frequency bands are sparse and heavy-tailed, and these statistics shift along the flow trajectory. Morpheus is a lightweight network that, at each step, reads the statistical signature of the current noisy state (per-band mean, standard deviation, max amplitude, L2 energy, skewness and kurtosis) and predicts a per-band log-variance. That prediction plays two roles:
- Weighting the loss: it forms a Bayesian heteroscedastic objective
(
0.5 * exp(-s) * ||v - v_target||^2 + 0.5 * s) that down-weights inherently unpredictable high-frequency content, while the0.5 * sterm prevents trivial variance inflation. The result is state-dependent precision instead of a uniform MSE. - Conditioning the backbone: the projected log-variances become a frequency hint, injected alongside the time, slice and age embeddings, so the transformer adapts its prediction to the current reliability of each band, during both training and sampling.
conda create -n wavedit_env python=3.11 && conda activate wavedit_env
pip install -r requirements.txt
# Optional but recommended: fused neighbourhood-attention CUDA kernels (match your build):
pip install natten -f https://whl.natten.org
# Optional, faster global attention:
# pip install -U xformersDeveloped for Python 3.11 and PyTorch 2.6 (CUDA recommended).
NATTEN is optional. It is the fastest, ground-truth implementation of the neighbourhood attention used in the default config, but WaveDiT ships an equivalent built-in pure-PyTorch fallback, so the model runs without NATTEN, including on CPU. The backend is chosen automatically; override with
WAVEDIT_NA_BACKEND=auto|natten|torch.
configs/ One YAML per experiment (cfm, rectified, ot_fm)
train.sh bash train.sh [config.yaml] -> launches training
generate.sh bash generate.sh <ckpt> [outdir] -> generates samples
scripts/
train.py config-driven training entry point
generate.py generation (specific condition sets or linear interpolation)
prepare_metadata.py build the metadata CSV from NIfTI folders
tools/
slim_checkpoint.py strip optimiser state for release/inference
wavedit/
config.py typed config loaded from YAML
data/ unified dataset (CSV / filename), augmentation, collation
wavelets/ differentiable 3D Haar DWT/IDWT
models/ WaveletFlowMatching, DiT3D backbone, Morpheus, sampling, hdit/
training/ Trainer + checkpoint I/O
generation/ sample generation
evaluation/ metrics + W&B visualisation
utils/ logging + seeding
See data/README.md. In short, build a catalog once:
python scripts/prepare_metadata.py --input-dirs /path/to/scans --output-csv ./data/dataset.csvthen point data.metadata_csv in your config at it. Raw scans and catalogs are
git-ignored and must be obtained from the original dataset providers.
Edit a config (data paths, architecture, hyper-parameters) and launch:
bash train.sh configs/cfm.yamlOr run the entry point directly:
PYTHONPATH=. python scripts/train.py configs/cfm.yamlEach run writes to <checkpoint_dir>/<run_name>/: best.pth, last.pth, a copy of the
resolved config.yaml, and logs. Set logging.wandb: true for W&B metrics and
visualisations. Switch the objective with model.flow (cfm | rectified | ot_fm).
Checkpoints are self-contained (they embed the config and condition metadata), so generation needs only the checkpoint and your sampling choices.
# Specific condition sets (N samples each)
PYTHONPATH=. python scripts/generate.py checkpoints/WaveDiT_CFM/best.pth out/ \
specific --conditions "age=45.0" "age=70.5" --num-samples 10 \
--cfg-scale 1.5 --num-flow-steps 10 --sampler heun --save-size 182 218 182
# Linearly interpolate one condition (one sample per step)
PYTHONPATH=. python scripts/generate.py checkpoints/WaveDiT_CFM/best.pth out/ \
linear --condition age --min 6 --max 95 --num 100Or use the launcher: bash generate.sh checkpoints/WaveDiT_CFM/best.pth.
| Argument | Meaning |
|---|---|
--cfg-scale |
Classifier-free guidance scale (1.0 = none). |
--num-flow-steps |
ODE integration steps (overrides the checkpoint default). |
--sampler |
heun (2nd order) or euler. |
--morpheus-scale |
Uncertainty-guidance scale (0 disables it). |
--save-size |
Center-crop saved volumes to D H W (default: full model output). |
% Temporary arXiv preprint.
@article{danese2026wavedit,
title = {WaveDiT: Distribution-Aware Wavelet Flow Matching for Efficient 3D Brain MRI Synthesis},
author = {Danese, Danilo and Lombardi, Angela and Fasano, Giuseppe and Attimonelli, Matteo and Di Noia, Tommaso},
journal = {arXiv preprint arXiv:XXXX.XXXXX},
year = {2026}
}WaveDiT builds on the wavelet-domain analysis and multi-level evaluation protocol of our previous work, FlowLet.
The HDiT backbone is adapted from k-diffusion. See LICENSE.
