Skip to content

Edmond1Cheng/MBDPO

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MBDPO: Scaling World-Model Reinforcement Learning Through Diffusion Policy Optimization

Official implementation of "Scaling World-Model Reinforcement Learning Through Diffusion Policy Optimization" by

Xiaoyuan Cheng* (ucesxc4@ucl.ac.uk), Wenxuan Yuan* (YUAN0186@e.ntu.edu.sg), Zhancun Mu, Yuanzhao Zhang, Yiming Yang, Hai Wang, Zhuo Sun, Che Liu

arXiv Hugging Face License: MIT Stars

Overview

MBDPO is a model-based reinforcement learning framework that unifies search and policy optimization through a diffusion policy representation inside a learned latent world model. Instead of building an explicit planner (e.g. MPPI) on top of the world model, MBDPO reformulates policy optimization as a diffusion process over imagined trajectories, where the score field is corrected by model-based returns and anchored to the behavior distribution via an implicit energy function. This eliminates the structural misalignment between search and value learning that limits prior world-model approaches, and yields monotonic scaling of performance with model capacity.

Overview

The repository contains code for training and evaluating MBDPO across 121 continuous control tasks in three settings: online from scratch, multi-task offline pretraining, and offline-to-online (O2O) fine-tuning.

Getting started

Environment

We provide ready-to-use Conda environment files for different experiment suites.

# Example: create environment for MT80 experiments
conda env create -f conda_envs/mbdpo-mt80.yml
conda activate mbdpo-mt80

# Optional: other provided environments
# conda env create -f conda_envs/mbdpo-ms2.yml
# conda env create -f conda_envs/mbdpo-myo.yml

See notes for each environment in this link

Offline Pretraining Dataset

For multi-task offline pretraining, we use the replay buffer results from open-sourced TD-MPC2 dataset (mt80 & mt30).

To download (remember to adjust the dataset path accordingly in configuration yaml files):

  • mt30:
mkdir -p ./offline_dataset/mt30

seq 0 3 | xargs -I {} -P 4 wget -c \
  -O ./offline_dataset/mt30/chunk_{}.pt \
  "https://huggingface.co/datasets/nicklashansen/tdmpc2/resolve/main/mt80/chunk_{}.pt?download=true"
  • mt80:
mkdir -p ./offline_dataset/mt80

seq 0 19 | xargs -I {} -P 4 wget -c \
  -O ./offline_dataset/mt80/chunk_{}.pt \
  "https://huggingface.co/datasets/nicklashansen/tdmpc2/resolve/main/mt80/chunk_{}.pt?download=true"

Supported tasks

This codebase provides support for all 121 continuous control tasks from DMControl (39 tasks), MetaWorld (50 tasks), ManiSkill2 (5 tasks), MyoSuite (10 tasks), Locomotion (7 tasks), and Visual RL (10 tasks) used in our technical report. In the DMControl domain, we use the 11 custom tasks followed the setting from TD-MPC2.

See this link for more detailed tasks and notes in each domain.

Example usage

1) Single-task online from scratch

python scripts/train.py task=dog-run seed=1 steps=4000000

or in the parallel launcher

python scripts/online_parallel_train.py --config cfgs/online_parallel_config.yaml

2) Multi-task offline pretraining

python scripts/train.py task=mt80 multitask=true
# or
python scripts/train.py task=mt30 multitask=true

3) Offline-to-online (O2O) fine-tuning

python scripts/offline_to_online.py \
  checkpoint=/path/to/checkpoint.pt \
  save_path=/path/to/output_dir \
  off2on_task="walker-run" \
  steps=40000

4) Evaluation

python scripts/evaluate.py \
  task=mt80 \
  checkpoint=/path/to/checkpoint.pt \
  eval_episodes=10

About parameter usage, please refer to this description

Citation

@misc{cheng2026scalingworldmodelreinforcementlearning,
      title={Scaling World-Model Reinforcement Learning Through Diffusion Policy Optimization}, 
      author={Xiaoyuan Cheng and Wenxuan Yuan and Zhancun Mu and Yuanzhao Zhang and Yiming Yang and Hai Wang and Zhuo Sun and Che Liu},
      year={2026},
      eprint={2605.26282},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2605.26282}, 
}

Contributing

Contributions are welcome — bug reports, questions, feature requests, and pull requests all help. To get started, please open an issue or submit a pull request.

For details on reporting bugs, the pull request process, and code style, see CONTRIBUTING.md. For questions about the paper itself, feel free to contact Xiaoyuan Cheng: ucesxc4@ucl.ac.uk and Wenxuan Yuan: YUAN0186@e.ntu.edu.sg.

License

This project is released under the MIT License.

Note that this repository depends on third-party code and simulators (DMControl, Meta-World, ManiSkill2, MyoSuite, etc.), which are subject to their own respective licenses.

About

Scaling World-Model Reinforcement Learning Through Diffusion Policy Optimization

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages