Craftax Classic: optional reset-pool kwarg#538
Open
Infatoshi wants to merge 24 commits intoPufferAI:4.0from
Open
Craftax Classic: optional reset-pool kwarg#538Infatoshi wants to merge 24 commits intoPufferAI:4.0from
Infatoshi wants to merge 24 commits intoPufferAI:4.0from
Conversation
Scaffold for the full Craftax (not Classic) Ocean port. Currently routes
reset/step through the JAX Craftax-Symbolic-v1 oracle via the Python C
API -- parity is correct by construction but throughput is poor. The
intent is to swap in native C subsystem by subsystem while the harness
keeps parity green.
Contents:
- ocean/craftax/craftax.h: full enum set matching JAX constants.py,
EnvState-shaped C state, Ocean Craftax/Log structs, proxy reset/step.
- ocean/craftax/binding.c: Ocean glue (OBS_SIZE=8268, ACT_SIZES={43},
67 achievement log fields).
- config/ocean/craftax.ini: env_name=craftax, proxy-friendly vec sizes
(to be raised once native).
- tests/craftax_parity.py: JAX vs C parity harness, prints first
divergence with section labels, atol-tunable.
- ocean/craftax/PORT_NOTES.md: documents proxy baseline, divergences,
and the native port roadmap.
Also:
- build.sh: embed rpaths for wheel-provided CUDA libs so pufferlib._C
imports without manually preloading libnccl.so.2.
Status: `tests/craftax_parity.py --seeds 2 --steps 50` PASS.
Co-authored-by: codex (gpt-5.4)
Phase 1 of the proxy-to-native migration. Each new native piece is
covered by a JAX parity test; end-to-end harness still green.
Added:
- threefry.h: JAX-compatible PRNG (PRNGKey, partitionable split/split_n,
fold_in, uniform_u32, float helpers). Bitwise-equivalent to
jax.random at the u32 level.
- noise.h: Perlin/fractal noise matching JAX util/noise.py. Soft parity
atol=2e-6 (sinf/cosf vs XLA transcendentals).
- worldgen.h: native overworld (floor 0) smoothworld -- map, item_map,
light_map, ladder_down, ladder_up. Bitwise vs JAX for default reset
seeds.
- craftax.h reset: still obtains full JAX state, then overwrites the
visible floor-0 channels from native C. Floors 1..8 still proxied.
Tests (uv run --with pytest pytest tests/craftax_{threefry,noise,
worldgen_floor0}_test.py):
- 3 passed.
Parity harness (--seeds 8 --steps 200): PASS.
Co-authored-by: codex (gpt-5.4)
Phase 2 of the proxy-to-native migration. c_reset no longer calls JAX:
all 9 floors (overworld, gnomish mines, dungeon, sewers, vaults, troll
mines, fire, ice, boss) are generated in native C with matching
potion_mapping, empty mobs/plants, chest/monsters-killed init, and the
symbolic reset observation encoder. Step still proxies to JAX; the
proxy is marked dirty at reset and lazily re-synced on first step.
Added tests/craftax_worldgen_test.py diffing C vs JAX for 16 seeds
across map, item_map, mob_map, light_map, ladders, chest flags,
monsters_killed, all mob/projectile arrays, plants, potion_mapping,
scalar fields, state_rng, and the encoded reset observation.
Verification:
- tests/craftax_{threefry,noise,worldgen_floor0,worldgen}_test.py: 4 passed
- tests/craftax_parity.py --seeds 8 --steps 200: PASS
Co-authored-by: codex (gpt-5.4)
…on yet) Phase 3 of the proxy-to-native migration. Each subsystem is a standalone C function with a JAX-parity unit test -- no changes to c_step yet, so the hybrid native/proxy sync problem does not arise. Integration into a fully native c_step is a later phase. Native ports in step_simple.h: - move_player - update_plants - boss_logic - level_up_attributes - clip_inventory_and_intrinsics - calculate_inventory_achievements - update_player_intrinsics - drink_potion - read_book Still proxied: do_action, do_crafting, place_block, shoot_projectile, cast_spell, enchant, change_floor, add_items_from_chest, update_mobs, spawn_mobs. Tests: - tests/craftax_state_fixtures.py: ctypes CraftaxState mirror, pickle helpers, C<->JAX conversion, strict state diffing. - tests/craftax_step_subsystem_test.py: 10 JAX-parity tests covering all 9 ported subsystems with seeds and targeted stress cases. Verification: - tests/craftax_step_subsystem_test.py: 10 passed - tests/craftax_parity.py --seeds 8 --steps 200: PASS Co-authored-by: codex (gpt-5.4)
Phase 4 of the proxy-to-native migration. 5 more step subsystems ported as standalone native C functions with JAX-parity unit tests. No c_step integration yet. Native ports in step_medium.h: - craftax_shoot_projectile_native - craftax_cast_spell_native - craftax_enchant_native - craftax_change_floor_native - craftax_add_items_from_chest_native Still proxied: do_action, do_crafting, place_block, update_mobs, spawn_mobs. Tests: - tests/craftax_step_medium_test.py: 5 JAX-parity tests with seeded states + targeted projectile/spell/enchant/floor/chest cases. Verification: - tests/craftax_step_medium_test.py: 5 passed - All prior subsystem + parity tests still pass - tests/craftax_parity.py --seeds 8 --steps 200: PASS Co-authored-by: codex (gpt-5.4)
Phase 5 of the proxy-to-native migration. Action-driven crafting and placement subsystems ported as standalone native C with JAX-parity unit tests. No c_step integration yet. Native ports in step_crafting.h: - craftax_do_crafting_native (all 12 MAKE_* actions) - craftax_place_block_native (stone/table/furnace/plant/torch) - craftax_add_new_growing_plant_native (internal helper) Still proxied: do_action, update_mobs, spawn_mobs. Tests: - tests/craftax_step_crafting_test.py: success, missing-resource, missing-table/furnace, full-inventory, illegal target, map-boundary, first-empty-slot plant allocation. Verification: - tests/craftax_step_crafting_test.py: 3 passed - All prior subsystem + parity tests still pass - tests/craftax_parity.py --seeds 8 --steps 200: PASS Co-authored-by: codex (gpt-5.4)
Phase 6 of the proxy-to-native migration. do_action -- mining adjacent blocks, eating plants/cows/bats/snails, drinking water, opening chests (delegates to the native add_items_from_chest from phase 4), and attacking the 3 mob classes with sword/enchantment/dex/str modifiers -- ported as a standalone native C function with JAX-parity unit tests. No c_step integration yet. Native port in step_do_action.h: - craftax_do_action_native (uses craftax_add_items_from_chest_native) Still proxied: update_mobs, spawn_mobs. Tests cover: 16 seeded states, mining/pickaxe gates, sapling rng, foods, water, fountain, chest at every level, mob kills across all 3 classes with damage modifiers, out-of-bounds, no-op target blocks, projectile-occupied target, mob-on-chest gating. Source-of-truth note: installed JAX do_action does not mine WOOD, does not refill mana from FOUNTAIN, and does not increment player_xp on mob kills -- the native port matches that behavior. Verification: - tests/craftax_step_do_action_test.py: 1 passed - All prior tests still pass (30 total) - tests/craftax_parity.py --seeds 8 --steps 200: PASS Co-authored-by: codex (gpt-5.4)
Phase 7 of the proxy-to-native migration. spawn_mobs ported as a standalone native C function with JAX-parity unit tests. Matches JAX split order, spawn gating, terrain/range maps, mob caps, boss-wave behavior, night melee chance, deep-thing water spawning, and sequential mob-map updates. Native port in step_spawn_mobs.h: - craftax_spawn_mobs_native Still proxied: update_mobs (last remaining gameplay subsystem). Tests: - tests/craftax_step_spawn_mobs_test.py: seeded + targeted parity across each floor, full caps, empty slots, night vs day, boss wave pacing, player-adjacent rejection, collision-type constraints. Verification: - All 50 subsystem parity tests pass - tests/craftax_parity.py --seeds 8 --steps 200: PASS Co-authored-by: codex (gpt-5.4)
Phase 8 of the proxy-to-native migration. update_mobs ported as a
standalone native C function with JAX-parity unit tests. Last
remaining proxy subsystem is eliminated at the standalone level -- all
19 gameplay step subsystems now have native parity ports.
Native port in step_update_mobs.h:
- craftax_update_mobs_native
- Covers melee, passive, ranged, mob projectiles, player projectiles
- JAX split order preserved, including the melee scan final right-key
- Collision maps, mob-map clear/enter order, despawn, cooldowns,
player damage, armor/enchantment defense, ranged projectile
spawning, projectile expiry, player projectile damage scaling,
mob kills, achievements, monsters_killed
Integration into c_step, timestep/RNG/reward/terminal/achievement-delta
bookkeeping remain pending -- that is the next phase.
Tests (105 total subsystem parity tests pass):
- tests/craftax_step_update_mobs_test.py: seeded + targeted per class
per floor, attacks, projectiles, despawn, cooldowns, kills.
Verification:
- tests/craftax_parity.py --seeds 8 --steps 200: PASS
Co-authored-by: codex (gpt-5.4)
Phase 9 of the proxy-to-native migration. The env is now 100% native C
end to end. No CPython, ctypes, or JAX calls inside c_reset, c_step,
or c_close -- a targeted search for py_proxy, PyObject_, Py_, dlopen,
dlsym, and ctypes in ocean/craftax/{craftax.h,binding.c} returns no
hits.
Changes:
- Native stitcher in craftax.h matching JAX craftax_step order + RNG
split sequence exactly (change_floor, do_crafting, do_action,
place_block, shoot_projectile, cast_spell, drink_potion, read_book,
enchant, boss_logic, level_up_attributes, move_player, update_mobs,
spawn_mobs, update_plants, update_player_intrinsics, clip, inventory
achievements, reward, timestep, light_level, state_rng).
- c_step is now c_step_native; proxy delegation removed.
- All Python/JAX proxy scaffolding removed from craftax.h/binding.c.
- Exact 67-entry ACHIEVEMENT_REWARD_MAP added.
- Native symbolic observation encoder reused for step obs, including
mob channels and JAX scatter semantics for wrapped negative local
indices.
- CRAFTAX_ENABLE_ENV_IMPL guards the full env for the binding TU so
subsystem test headers still include cleanly.
Tests:
- tests/craftax_step_full_test.py: full native vs JAX env parity,
seeded reset + action-driven sequences.
- 106 parity tests pass (all prior + full step).
- tests/craftax_parity.py --seeds 16 --steps 2000: PASS at atol=1e-5.
Next phase: CPU optimization (SIMD obs encoding, cache-tiled mob
updates, AVX2 light propagation) for the Ryzen 9950X3D target.
Co-authored-by: codex (gpt-5.4)
Strengthens correctness verification beyond uniform-random 16x2000 to policy-biased 128-seed coverage across gameplay regimes. Zero divergences found. Harness changes (tests/craftax_parity.py): - --policy flag: uniform, combat, descend, suicide, boss, mixed - reset-on-terminal tracking with per-seed episode length + counts - richer divergence reports (seed, step, policy, reward/terminal delta, first obs field, last 10 actions for reproduction) - isolated replay trace dumped under build/ on divergence New tests/craftax_parity_stress.py battery (1033s total): mixed-wide: 64 seeds x 10000 steps, 2883 terminals descend-boss-target: 16 seeds x 30000 steps, 2498 terminals suicide-terminal-target: 32 seeds x 5000 steps, 622 terminals combat-projectile-xp: 16 seeds x 5000 steps, 355 terminals All PASS at atol=1e-5 on obs + reward, exact on terminal. Known non-issue surfaced by stress: JAX CPU-XLA JIT-fused reset can shift normalized-noise max by 1 ULP at exact sand-threshold cells. Materialized JAX worldgen and native reset agree field-by-field; this is a JAX compiler artifact, not a port bug. The stress harness uses native reset for episode continuation (with field-by-field state/obs verification) to avoid comparing against JIT-fused reset numerics. Added tests/craftax_worldgen_test.py threshold regression. Verification: - tests/craftax_parity.py --seeds 16 --steps 2000: PASS - tests/craftax_parity_stress.py: 4 cases PASS - 106 prior subsystem parity tests: all pass - worldgen threshold regression: pass No production C changes. Env correctness unchanged. Co-authored-by: codex (gpt-5.4)
…ark script config/ocean/craftax.ini: 8192 agents / 16 threads / 200M steps (the proxy-friendly sizes used during migration are obsolete now that the env is fully native). scripts/craftax_convergence_bench.py: trains craftax_classic and craftax back-to-back (default 10M env steps each), parses pufferlib run logs, prints per-threshold time-to-score + per-achievement unlock rates, and saves a two-panel plot of score vs env-steps and score vs wall time.
…marking Brings ocean/craftax_classic/ (binding.c + craftax_classic.h) and config/ocean/craftax_classic.ini onto this branch so the convergence benchmark can train both envs back-to-back. Classic files are unchanged from the craftax-classic-rename PR branch. scripts/craftax_convergence_bench.py now rebuilds pufferlib._C for each env before invoking puffer train, since the _C extension is compiled for one env at a time.
Full Craftax's my_log writes 4 meta + 67 achievements + n = 72 fields to the log Dict. In release builds (NDEBUG) dict_set's capacity assert is stripped, so the 73rd write overruns the calloc'd items array and corrupts glibc's heap -- 'malloc(): invalid size (unsorted)' at training startup. All four create_dict(32) call sites used for env log aggregation now use create_dict(256). Classic (26 fields) and every other existing env stay well within the new capacity. No ABI change.
Replaces the palette-rectangle c_render in craftax_classic and the
no-op stub in craftax with a tile renderer that draws the upstream
Craftax 16x16 PNG assets. Both envs read a single textures.bin
(packed by ocean/craftax/pack_textures.py) so the on-screen look
matches the Matthews et al. reference for any overlapping block.
- pack_textures.py: packs 54 tiles (37 block + 5 player + 5 item +
3 mob + 4 arrow) at 16x16 RGBA, 55 KB on disk. Asset PNGs in the
two upstream asset dirs overlap byte-identically (md5-checked)
so classic reuses full's bin.
- craftax.h: lazy-loads the bin into Texture2Ds with POINT filter,
draws a 16x16 tile viewport centered on the player (at scale 4,
one tile = 64 px). HUD shows HP/F/D/E, stats, achievements, return.
Viewport is decoupled from the 9x11 agent obs window.
- craftax_classic.h: same loader + 16x16 viewport, adds zombie /
skeleton / cow / directional-arrow sprite overlays and an
inventory readout. Tile ids are offset into the shared bin.
- craftax.c: minimal standalone viewer (random-action by default;
press H to toggle keyboard control) for ./build.sh craftax --fast.
Run: uv run python ocean/craftax/pack_textures.py once to (re)build
the bin, then DISPLAY=:0 uv run puffer eval {craftax,craftax_classic}
--load-model-path latest.
Rewrites craftax_spawn_mobs_native to strip JAX-isms that are pointless on CPU: - bool[48][48] validity mask -> compact (int16, int16) coord list collected in one pass over the bounding box around the player - bounding-box scan: mobs can only spawn within MOB_DESPAWN_DISTANCE=14, so we only visit the up-to-27x27 window instead of the full 48x48 map - early return when can_spawn is already false from the mob-cap or probability roll, skipping the scan + choice - merged count_mobs3 + first_empty_mobs3 into a single loop - inlined the block-type and distance checks Choice arithmetic uses the same FP expressions as baseline so the selected cell is bitwise-identical for any given (valid_count, rng_key) pair. The baseline quirk of writing type_id[level][slot] unconditionally even when no mob spawns is preserved. Phase timing (single-thread, random actions): craftax_spawn_mobs_native: 17.06 us -> 0.30 us (57x) full c_step: 29.6 us -> 12.3 us (2.4x) Verified bitwise-equal to the prior implementation over 1.28M paired steps (128 envs x 10000 steps, random actions, reset exercised).
c_reset and the c_step auto-reset path now optionally memcpy from a pre-generated pool of worlds instead of running generate_world each episode. Pool size is a runtime kwarg (reset_pool_size) read by my_init, default 1024 via config/ocean/craftax.ini. Set to 0 to disable and regenerate every reset (required for strict per-seed determinism in tests/craftax_parity.py). Trade: at most reset_pool_size unique maps are seen per process. With 1024 and ~270-step random-action episodes, diversity is plentiful for training. Memory cost: 1024 * sizeof(CraftaxState) ~= 267 MB once at startup. Two reset entry points are now distinguished: - craftax_reset_state_from_reset_key: direct (used by parity harness), always calls generate_state_from_world_key, pool-free for exact per-key determinism. - craftax_reset_state_on_done: hot-path used by c_step on terminal, consults the pool when enabled, falls through to generate_world otherwise. Pool index derived from reset_key.word[0]. tests/craftax_parity.py picks up raylib's include path since craftax.h now pulls raylib.h (from the shared renderer). Measurements (single-thread, random actions): worldgen: 2.69 ms -> 6.9 us memcpy (~390x) full c_step: 12.3 us -> 2.35 us (5.25x) training SPS: 450K -> 506K (+12%) 1-thread sim SPS: 81K -> 425K (5.25x) 16-thread sim SPS: 1.14M -> 5.53M (4.85x)
The five move_* helpers (melee/passive/ranged mobs + mob/player projectiles) now return immediately when mask=false. JAX's branchless "compute-then-mask" pattern is pointless on CPU: dead slots' output never feeds observations, rewards, or mob_map, so skipping the body and the RNG draws is semantically equivalent. Defining CRAFTAX_JAX_PARITY at build time restores the branchless slow path for bitwise replay against JAX (required by tests/craftax_parity.py). Default build uses the early-out. Also drops craftax_step_jax_index(player_level, NUM_LEVELS) clamps at the top of each move_* -- state->player_level is maintained in [0, NUM_LEVELS-1] by change_floor_native (explicit bounds checks) and by the worldgen init. Six redundant clamps per step eliminated. Measurements (single-thread, random actions, pool=1024): update_mobs phase: 1.392 us -> 0.285 us (4.88x) full c_step: 2.35 us -> 1.22 us 1-thread sim SPS: 425K -> 819K (1.93x) 16-thread sim SPS: 5.53M -> 10.04M (1.82x) training SPS: 506K -> 544K (+7%) Parity test with CRAFTAX_JAX_PARITY defined passes 8 seeds * 1000 steps over 27 terminals. Without the flag, parity diverges at the first mob death -- by design.
These 10 tests were written incrementally as each subsystem (noise, threefry, worldgen, 7 step subsystems) was ported from JAX, to catch divergence at each layer. Now that tests/craftax_parity.py passes end-to-end against the JAX reference, they are redundant: any bug they'd catch also breaks the integration parity test. Dropping ~5400 LOC of scaffolding. Kept: - craftax_parity.py (JAX<->C integration parity harness) - craftax_state_fixtures.py (state-flattening helpers used by parity) - craftax_parity_stress.py (adversarial action sequences) - craftax_step_full_test.py (pytest wrapper -> parity.run)
The dashboard and CSV logger only need to surface a handful of milestones along the tech/exploration curve, not every achievement. The env still tracks all 67 internally for reward computation and for the normalized 'perf' aggregate -- we just stop shipping every one through the log Dict each episode flush. Checkpoints chosen to span the learning curve: collect_wood first resource (tier 1) make_wood_pickaxe first tool make_stone_pickaxe stone tier collect_iron iron tier resource make_iron_pickaxe iron tier tool (major milestone) collect_diamond diamond tier resource enter_gnomish_mines first dungeon (exploration) defeat_necromancer endgame boss Log Dict now carries 4 meta + 8 achievements + 1 n = 13 fields, well under the stock create_dict(32) capacity. Releases the need for the capacity bump in src/bindings* (reverted in the following commit).
This reverts commit 9396e79.
Adds craftax_classic_set_reset_pool_size(N) + cached c_reset path. When N>0, c_reset memcpys a pre-generated world from a fixed-size pool of size N instead of running generate_world each episode (drops ~30 us worldgen to ~0.5 us 5KB memcpy). Pool size is a runtime kwarg (reset_pool_size) read by my_init from config/ocean/craftax_classic.ini. Default is 0 (disabled): Classic's env is already faster than the PPO trainer (GPU + backward dominate the loop), so caching does not move training SPS. Users running sim-only workloads -- data generation, evaluation rollouts, offline RL replay -- can set reset_pool_size > 0 to get ~2x sim speedup (2.6M -> 5.5M SPS single-thread, verified bitwise-equal to fresh generate_world output). First caller wins; the setter is idempotent and thread-safe so every env's my_init can call it without racing.
51113c0 to
1354e61
Compare
- config/ocean/craftax.ini -> config/craftax.ini - config/ocean/craftax_classic.ini -> config/craftax_classic.ini - ocean/craftax/textures.bin -> resources/craftax/textures.bin - scripts/craftax_convergence_bench.py -> tests/craftax_convergence_bench.py - drop empty scripts/ directory - pack_textures.py: write to resources/craftax/textures.bin - craftax.h / craftax_classic.h: fopen textures from resources/craftax/
Used by the craftax parity harness to compile with -DCRAFTAX_JAX_PARITY, which disables the update_mobs early-out so the C env replays bitwise against JAX. Default training builds leave EXTRA_CFLAGS empty and keep the ~2x sim-SPS early-out enabled.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds an optional runtime reset-pool to the Classic C env (parallel to the Full opt in the dependency PR #537). On
c_reset, ifreset_pool_size > 0, memcpy a pre-generated world from a pool of size N instead of callinggenerate_world. Drops ~30 us worldgen to a ~0.5 us 5 KB memcpy.Default is 0 (disabled). Classic's env is already fast enough that it isn't the training bottleneck — GPU backward + optimizer dominate the PPO loop. Caching doesn't move training SPS for Classic (verified:
reset_pool_size=1024and=0both hit ~2.9M SPS inpuffer train craftax_classic). The knob exists for sim-only workloads (data generation, evaluation rollouts, offline RL replay) wherec_stepthroughput matters directly.Behavior
reset_pool_size=0→ baseline: freshgenerate_worldevery reset.reset_pool_size=N>0→ N worlds pre-generated once at firstmy_initusing a deterministic seed sequence; every reset memcpys fromcache[idx]where idx is drawn from the env's PCG. Cache entries are bitwise-equal to freshgenerate_world()output for their seeds.craftax_classic_set_reset_pool_sizeis idempotent and thread-safe (acquire/release atomics), so every env'smy_initcan call it without racing on first-init.Measurements
reset_pool_size=0(default)reset_pool_size=1024Training doesn't move because env is only 10-13% of wall-clock; the remaining 87% is GPU + train + forward, which the env can't speed up.
Dependency
Based on #537 (Full port + optimizations + renderer). Touches the
ocean/craftax_classic/craftax_classic.hheader that #537 re-adds as a side-by-side benchmarking target.Test plan
./build.sh craftax_classic)puffer train craftax_classicruns cleanly with default (disabled) — 2.9M SPS, achievements unlockinggenerate_worldoutput for their seedspuffer train craftax_classic --env.reset-pool-size 1024to confirm no regressions at enabled setting🤖 Generated with Claude Code