Skip to content

SestoAle/LeWorldModel

Repository files navigation

LeWorldModel

Trained agent rollout in the TwoRoom environment

LeWorldModel trains a pixel-based latent world model for continuous-control navigation. It collects trajectories from a PPO policy in swm/TwoRoom-v1, learns latent dynamics from rendered frames and actions, then plans in latent space with CEM toward a visual goal.

The implementation is inspired by LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels.

Main Ideas

The world model follows a joint-embedding predictive architecture:

frame_t  -> encoder   -> z_t
z_t, a_t -> predictor -> z_{t+1}

Training predicts the next latent embedding rather than reconstructing pixels. The model combines:

  • a ViT encoder from torchvision;
  • a causal transformer predictor conditioned on continuous actions;
  • a next-latent prediction loss;
  • SIGReg, a Gaussian latent regularizer used to reduce representation collapse risk.

At evaluation time, main_wm.py encodes the current frame and a goal frame, uses CEM to optimize candidate action sequences through latent rollouts, and executes the first action of the best sequence in MPC style.

Environment

The active environment wrapper is envs/two_room_env.py. It wraps swm/TwoRoom-v1 from stable_worldmodel and adds the project-specific behavior used by training and data collection:

  • continuous action space: Box(-1, 1, (2,), float32);
  • normalized 10D state observations for PPO;
  • rendered RGB frames for world-model data;
  • target rendering enabled;
  • optional fixed seed or explicit start/goal positions;
  • shaped progress reward based on normalized distance-to-target improvement;
  • success bonus when the agent reaches the target;
  • frame skip for shorter effective control horizons.

Repository Layout

main_rl.py                         PPO training, evaluation, and trajectory collection
main_wm.py                         world-model training and CEM-based latent planning
two_room_test.py                   quick manual TwoRoom rollout script

envs/two_room_env.py               TwoRoom wrapper used by the current workflow
agents/ppo_agent.py                custom PPO implementation
architectures/simple_mlp.py        PPO actor/critic MLP embeddings

world_model/leworldmodel.py        encoder, predictor, SIGReg, and training loop
layers/transformer.py              custom transformer block used by the world model
layers/*.py                        additional transformer/DiT/Llama/GPT-2 components

runners/runner.py                  single-environment training loop
runners/parallel_runner.py         threaded parallel runner
runners/vectorized_runner.py       vectorized runner

assets/                            README media
datasets/                          saved trajectory datasets, ignored by git
saved/                             saved model checkpoints, ignored by git
arrays/                            saved training statistics, ignored by git
plot_results.py                    utility for plotting saved run statistics

Workflow

1. Train PPO

Train a state-based PPO agent in the TwoRoom environment:

uv run python main_rl.py --model-name two_room_ppo --fixed-seed 423

The policy observes the normalized 10D TwoRoom state and outputs 2D continuous actions in [-1, 1].

2. Collect World-Model Data

Collect rendered trajectories from the trained PPO policy:

uv run python main_rl.py \
  --model-name two_room_ppo \
  --fixed-seed 423 \
  --save-trajectories \
  --num-samples-to-save 1000

With a fixed seed, this writes:

datasets/dataset_fixed_seed_423.pkl

Each trajectory stores normalized states, rendered RGB frames, continuous actions, and shaped rewards.

3. Train the World Model

Train the latent dynamics model from the collected visual dataset:

uv run python main_wm.py \
  --model-name two_room_wm \
  --dataset-name datasets/dataset_fixed_seed_423.pkl \
  --fixed-seed 423 \
  --sequence-length 4 \
  --batch-size 32 \
  --epochs-number 5

4. Evaluate with CEM Planning

When a saved world model exists, main_wm.py can load it and evaluate planning in the same fixed-seed environment:

uv run python main_wm.py \
  --model-name two_room_wm \
  --dataset-name datasets/dataset_fixed_seed_423.pkl \
  --fixed-seed 423

The evaluation path selects the final frame from the highest-reward trajectory as the visual goal, encodes it, and runs CEM over latent rollouts. The current defaults use a planning horizon of 50, population size 64, 4 CEM iterations, and an elite fraction of 0.1.

Note

This README was written by Codex and validated by the author.

About

Simple implementation of the paper LeWorldModel. This paper (and JEPA-style papers) are interesting for high-level low-hertz decision planning. For game-AI, we could use something similar for high-level planning that requires multiple frames to complete. E.g., with latent planning we can do path-finding and RL can do path-following.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages