LeWorldModel trains a pixel-based latent world model for continuous-control navigation. It collects trajectories from a PPO policy in swm/TwoRoom-v1, learns latent dynamics from rendered frames and actions, then plans in latent space with CEM toward a visual goal.
The implementation is inspired by LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels.
The world model follows a joint-embedding predictive architecture:
frame_t -> encoder -> z_t
z_t, a_t -> predictor -> z_{t+1}
Training predicts the next latent embedding rather than reconstructing pixels. The model combines:
- a ViT encoder from
torchvision; - a causal transformer predictor conditioned on continuous actions;
- a next-latent prediction loss;
SIGReg, a Gaussian latent regularizer used to reduce representation collapse risk.
At evaluation time, main_wm.py encodes the current frame and a goal frame, uses CEM to optimize candidate action sequences through latent rollouts, and executes the first action of the best sequence in MPC style.
The active environment wrapper is envs/two_room_env.py. It wraps swm/TwoRoom-v1 from stable_worldmodel and adds the project-specific behavior used by training and data collection:
- continuous action space:
Box(-1, 1, (2,), float32); - normalized 10D state observations for PPO;
- rendered RGB frames for world-model data;
- target rendering enabled;
- optional fixed seed or explicit start/goal positions;
- shaped progress reward based on normalized distance-to-target improvement;
- success bonus when the agent reaches the target;
- frame skip for shorter effective control horizons.
main_rl.py PPO training, evaluation, and trajectory collection
main_wm.py world-model training and CEM-based latent planning
two_room_test.py quick manual TwoRoom rollout script
envs/two_room_env.py TwoRoom wrapper used by the current workflow
agents/ppo_agent.py custom PPO implementation
architectures/simple_mlp.py PPO actor/critic MLP embeddings
world_model/leworldmodel.py encoder, predictor, SIGReg, and training loop
layers/transformer.py custom transformer block used by the world model
layers/*.py additional transformer/DiT/Llama/GPT-2 components
runners/runner.py single-environment training loop
runners/parallel_runner.py threaded parallel runner
runners/vectorized_runner.py vectorized runner
assets/ README media
datasets/ saved trajectory datasets, ignored by git
saved/ saved model checkpoints, ignored by git
arrays/ saved training statistics, ignored by git
plot_results.py utility for plotting saved run statistics
Train a state-based PPO agent in the TwoRoom environment:
uv run python main_rl.py --model-name two_room_ppo --fixed-seed 423The policy observes the normalized 10D TwoRoom state and outputs 2D continuous actions in [-1, 1].
Collect rendered trajectories from the trained PPO policy:
uv run python main_rl.py \
--model-name two_room_ppo \
--fixed-seed 423 \
--save-trajectories \
--num-samples-to-save 1000With a fixed seed, this writes:
datasets/dataset_fixed_seed_423.pkl
Each trajectory stores normalized states, rendered RGB frames, continuous actions, and shaped rewards.
Train the latent dynamics model from the collected visual dataset:
uv run python main_wm.py \
--model-name two_room_wm \
--dataset-name datasets/dataset_fixed_seed_423.pkl \
--fixed-seed 423 \
--sequence-length 4 \
--batch-size 32 \
--epochs-number 5When a saved world model exists, main_wm.py can load it and evaluate planning in the same fixed-seed environment:
uv run python main_wm.py \
--model-name two_room_wm \
--dataset-name datasets/dataset_fixed_seed_423.pkl \
--fixed-seed 423The evaluation path selects the final frame from the highest-reward trajectory as the visual goal, encodes it, and runs CEM over latent rollouts. The current defaults use a planning horizon of 50, population size 64, 4 CEM iterations, and an elite fraction of 0.1.
This README was written by Codex and validated by the author.
