Craftax Full: native C port + optimizations + renderer#537
Merged
jsuarez5341 merged 23 commits intoPufferAI:4.0from Apr 20, 2026
Merged
Craftax Full: native C port + optimizations + renderer#537jsuarez5341 merged 23 commits intoPufferAI:4.0from
jsuarez5341 merged 23 commits intoPufferAI:4.0from
Conversation
5 tasks
Scaffold for the full Craftax (not Classic) Ocean port. Currently routes
reset/step through the JAX Craftax-Symbolic-v1 oracle via the Python C
API -- parity is correct by construction but throughput is poor. The
intent is to swap in native C subsystem by subsystem while the harness
keeps parity green.
Contents:
- ocean/craftax/craftax.h: full enum set matching JAX constants.py,
EnvState-shaped C state, Ocean Craftax/Log structs, proxy reset/step.
- ocean/craftax/binding.c: Ocean glue (OBS_SIZE=8268, ACT_SIZES={43},
67 achievement log fields).
- config/ocean/craftax.ini: env_name=craftax, proxy-friendly vec sizes
(to be raised once native).
- tests/craftax_parity.py: JAX vs C parity harness, prints first
divergence with section labels, atol-tunable.
- ocean/craftax/PORT_NOTES.md: documents proxy baseline, divergences,
and the native port roadmap.
Also:
- build.sh: embed rpaths for wheel-provided CUDA libs so pufferlib._C
imports without manually preloading libnccl.so.2.
Status: `tests/craftax_parity.py --seeds 2 --steps 50` PASS.
Co-authored-by: codex (gpt-5.4)
Phase 1 of the proxy-to-native migration. Each new native piece is
covered by a JAX parity test; end-to-end harness still green.
Added:
- threefry.h: JAX-compatible PRNG (PRNGKey, partitionable split/split_n,
fold_in, uniform_u32, float helpers). Bitwise-equivalent to
jax.random at the u32 level.
- noise.h: Perlin/fractal noise matching JAX util/noise.py. Soft parity
atol=2e-6 (sinf/cosf vs XLA transcendentals).
- worldgen.h: native overworld (floor 0) smoothworld -- map, item_map,
light_map, ladder_down, ladder_up. Bitwise vs JAX for default reset
seeds.
- craftax.h reset: still obtains full JAX state, then overwrites the
visible floor-0 channels from native C. Floors 1..8 still proxied.
Tests (uv run --with pytest pytest tests/craftax_{threefry,noise,
worldgen_floor0}_test.py):
- 3 passed.
Parity harness (--seeds 8 --steps 200): PASS.
Co-authored-by: codex (gpt-5.4)
Phase 2 of the proxy-to-native migration. c_reset no longer calls JAX:
all 9 floors (overworld, gnomish mines, dungeon, sewers, vaults, troll
mines, fire, ice, boss) are generated in native C with matching
potion_mapping, empty mobs/plants, chest/monsters-killed init, and the
symbolic reset observation encoder. Step still proxies to JAX; the
proxy is marked dirty at reset and lazily re-synced on first step.
Added tests/craftax_worldgen_test.py diffing C vs JAX for 16 seeds
across map, item_map, mob_map, light_map, ladders, chest flags,
monsters_killed, all mob/projectile arrays, plants, potion_mapping,
scalar fields, state_rng, and the encoded reset observation.
Verification:
- tests/craftax_{threefry,noise,worldgen_floor0,worldgen}_test.py: 4 passed
- tests/craftax_parity.py --seeds 8 --steps 200: PASS
Co-authored-by: codex (gpt-5.4)
…on yet) Phase 3 of the proxy-to-native migration. Each subsystem is a standalone C function with a JAX-parity unit test -- no changes to c_step yet, so the hybrid native/proxy sync problem does not arise. Integration into a fully native c_step is a later phase. Native ports in step_simple.h: - move_player - update_plants - boss_logic - level_up_attributes - clip_inventory_and_intrinsics - calculate_inventory_achievements - update_player_intrinsics - drink_potion - read_book Still proxied: do_action, do_crafting, place_block, shoot_projectile, cast_spell, enchant, change_floor, add_items_from_chest, update_mobs, spawn_mobs. Tests: - tests/craftax_state_fixtures.py: ctypes CraftaxState mirror, pickle helpers, C<->JAX conversion, strict state diffing. - tests/craftax_step_subsystem_test.py: 10 JAX-parity tests covering all 9 ported subsystems with seeds and targeted stress cases. Verification: - tests/craftax_step_subsystem_test.py: 10 passed - tests/craftax_parity.py --seeds 8 --steps 200: PASS Co-authored-by: codex (gpt-5.4)
Phase 4 of the proxy-to-native migration. 5 more step subsystems ported as standalone native C functions with JAX-parity unit tests. No c_step integration yet. Native ports in step_medium.h: - craftax_shoot_projectile_native - craftax_cast_spell_native - craftax_enchant_native - craftax_change_floor_native - craftax_add_items_from_chest_native Still proxied: do_action, do_crafting, place_block, update_mobs, spawn_mobs. Tests: - tests/craftax_step_medium_test.py: 5 JAX-parity tests with seeded states + targeted projectile/spell/enchant/floor/chest cases. Verification: - tests/craftax_step_medium_test.py: 5 passed - All prior subsystem + parity tests still pass - tests/craftax_parity.py --seeds 8 --steps 200: PASS Co-authored-by: codex (gpt-5.4)
Phase 5 of the proxy-to-native migration. Action-driven crafting and placement subsystems ported as standalone native C with JAX-parity unit tests. No c_step integration yet. Native ports in step_crafting.h: - craftax_do_crafting_native (all 12 MAKE_* actions) - craftax_place_block_native (stone/table/furnace/plant/torch) - craftax_add_new_growing_plant_native (internal helper) Still proxied: do_action, update_mobs, spawn_mobs. Tests: - tests/craftax_step_crafting_test.py: success, missing-resource, missing-table/furnace, full-inventory, illegal target, map-boundary, first-empty-slot plant allocation. Verification: - tests/craftax_step_crafting_test.py: 3 passed - All prior subsystem + parity tests still pass - tests/craftax_parity.py --seeds 8 --steps 200: PASS Co-authored-by: codex (gpt-5.4)
Phase 6 of the proxy-to-native migration. do_action -- mining adjacent blocks, eating plants/cows/bats/snails, drinking water, opening chests (delegates to the native add_items_from_chest from phase 4), and attacking the 3 mob classes with sword/enchantment/dex/str modifiers -- ported as a standalone native C function with JAX-parity unit tests. No c_step integration yet. Native port in step_do_action.h: - craftax_do_action_native (uses craftax_add_items_from_chest_native) Still proxied: update_mobs, spawn_mobs. Tests cover: 16 seeded states, mining/pickaxe gates, sapling rng, foods, water, fountain, chest at every level, mob kills across all 3 classes with damage modifiers, out-of-bounds, no-op target blocks, projectile-occupied target, mob-on-chest gating. Source-of-truth note: installed JAX do_action does not mine WOOD, does not refill mana from FOUNTAIN, and does not increment player_xp on mob kills -- the native port matches that behavior. Verification: - tests/craftax_step_do_action_test.py: 1 passed - All prior tests still pass (30 total) - tests/craftax_parity.py --seeds 8 --steps 200: PASS Co-authored-by: codex (gpt-5.4)
Phase 7 of the proxy-to-native migration. spawn_mobs ported as a standalone native C function with JAX-parity unit tests. Matches JAX split order, spawn gating, terrain/range maps, mob caps, boss-wave behavior, night melee chance, deep-thing water spawning, and sequential mob-map updates. Native port in step_spawn_mobs.h: - craftax_spawn_mobs_native Still proxied: update_mobs (last remaining gameplay subsystem). Tests: - tests/craftax_step_spawn_mobs_test.py: seeded + targeted parity across each floor, full caps, empty slots, night vs day, boss wave pacing, player-adjacent rejection, collision-type constraints. Verification: - All 50 subsystem parity tests pass - tests/craftax_parity.py --seeds 8 --steps 200: PASS Co-authored-by: codex (gpt-5.4)
Phase 8 of the proxy-to-native migration. update_mobs ported as a
standalone native C function with JAX-parity unit tests. Last
remaining proxy subsystem is eliminated at the standalone level -- all
19 gameplay step subsystems now have native parity ports.
Native port in step_update_mobs.h:
- craftax_update_mobs_native
- Covers melee, passive, ranged, mob projectiles, player projectiles
- JAX split order preserved, including the melee scan final right-key
- Collision maps, mob-map clear/enter order, despawn, cooldowns,
player damage, armor/enchantment defense, ranged projectile
spawning, projectile expiry, player projectile damage scaling,
mob kills, achievements, monsters_killed
Integration into c_step, timestep/RNG/reward/terminal/achievement-delta
bookkeeping remain pending -- that is the next phase.
Tests (105 total subsystem parity tests pass):
- tests/craftax_step_update_mobs_test.py: seeded + targeted per class
per floor, attacks, projectiles, despawn, cooldowns, kills.
Verification:
- tests/craftax_parity.py --seeds 8 --steps 200: PASS
Co-authored-by: codex (gpt-5.4)
Phase 9 of the proxy-to-native migration. The env is now 100% native C
end to end. No CPython, ctypes, or JAX calls inside c_reset, c_step,
or c_close -- a targeted search for py_proxy, PyObject_, Py_, dlopen,
dlsym, and ctypes in ocean/craftax/{craftax.h,binding.c} returns no
hits.
Changes:
- Native stitcher in craftax.h matching JAX craftax_step order + RNG
split sequence exactly (change_floor, do_crafting, do_action,
place_block, shoot_projectile, cast_spell, drink_potion, read_book,
enchant, boss_logic, level_up_attributes, move_player, update_mobs,
spawn_mobs, update_plants, update_player_intrinsics, clip, inventory
achievements, reward, timestep, light_level, state_rng).
- c_step is now c_step_native; proxy delegation removed.
- All Python/JAX proxy scaffolding removed from craftax.h/binding.c.
- Exact 67-entry ACHIEVEMENT_REWARD_MAP added.
- Native symbolic observation encoder reused for step obs, including
mob channels and JAX scatter semantics for wrapped negative local
indices.
- CRAFTAX_ENABLE_ENV_IMPL guards the full env for the binding TU so
subsystem test headers still include cleanly.
Tests:
- tests/craftax_step_full_test.py: full native vs JAX env parity,
seeded reset + action-driven sequences.
- 106 parity tests pass (all prior + full step).
- tests/craftax_parity.py --seeds 16 --steps 2000: PASS at atol=1e-5.
Next phase: CPU optimization (SIMD obs encoding, cache-tiled mob
updates, AVX2 light propagation) for the Ryzen 9950X3D target.
Co-authored-by: codex (gpt-5.4)
Strengthens correctness verification beyond uniform-random 16x2000 to policy-biased 128-seed coverage across gameplay regimes. Zero divergences found. Harness changes (tests/craftax_parity.py): - --policy flag: uniform, combat, descend, suicide, boss, mixed - reset-on-terminal tracking with per-seed episode length + counts - richer divergence reports (seed, step, policy, reward/terminal delta, first obs field, last 10 actions for reproduction) - isolated replay trace dumped under build/ on divergence New tests/craftax_parity_stress.py battery (1033s total): mixed-wide: 64 seeds x 10000 steps, 2883 terminals descend-boss-target: 16 seeds x 30000 steps, 2498 terminals suicide-terminal-target: 32 seeds x 5000 steps, 622 terminals combat-projectile-xp: 16 seeds x 5000 steps, 355 terminals All PASS at atol=1e-5 on obs + reward, exact on terminal. Known non-issue surfaced by stress: JAX CPU-XLA JIT-fused reset can shift normalized-noise max by 1 ULP at exact sand-threshold cells. Materialized JAX worldgen and native reset agree field-by-field; this is a JAX compiler artifact, not a port bug. The stress harness uses native reset for episode continuation (with field-by-field state/obs verification) to avoid comparing against JIT-fused reset numerics. Added tests/craftax_worldgen_test.py threshold regression. Verification: - tests/craftax_parity.py --seeds 16 --steps 2000: PASS - tests/craftax_parity_stress.py: 4 cases PASS - 106 prior subsystem parity tests: all pass - worldgen threshold regression: pass No production C changes. Env correctness unchanged. Co-authored-by: codex (gpt-5.4)
…ark script config/ocean/craftax.ini: 8192 agents / 16 threads / 200M steps (the proxy-friendly sizes used during migration are obsolete now that the env is fully native). scripts/craftax_convergence_bench.py: trains craftax_classic and craftax back-to-back (default 10M env steps each), parses pufferlib run logs, prints per-threshold time-to-score + per-achievement unlock rates, and saves a two-panel plot of score vs env-steps and score vs wall time.
…marking Brings ocean/craftax_classic/ (binding.c + craftax_classic.h) and config/ocean/craftax_classic.ini onto this branch so the convergence benchmark can train both envs back-to-back. Classic files are unchanged from the craftax-classic-rename PR branch. scripts/craftax_convergence_bench.py now rebuilds pufferlib._C for each env before invoking puffer train, since the _C extension is compiled for one env at a time.
Full Craftax's my_log writes 4 meta + 67 achievements + n = 72 fields to the log Dict. In release builds (NDEBUG) dict_set's capacity assert is stripped, so the 73rd write overruns the calloc'd items array and corrupts glibc's heap -- 'malloc(): invalid size (unsorted)' at training startup. All four create_dict(32) call sites used for env log aggregation now use create_dict(256). Classic (26 fields) and every other existing env stay well within the new capacity. No ABI change.
Replaces the palette-rectangle c_render in craftax_classic and the
no-op stub in craftax with a tile renderer that draws the upstream
Craftax 16x16 PNG assets. Both envs read a single textures.bin
(packed by ocean/craftax/pack_textures.py) so the on-screen look
matches the Matthews et al. reference for any overlapping block.
- pack_textures.py: packs 54 tiles (37 block + 5 player + 5 item +
3 mob + 4 arrow) at 16x16 RGBA, 55 KB on disk. Asset PNGs in the
two upstream asset dirs overlap byte-identically (md5-checked)
so classic reuses full's bin.
- craftax.h: lazy-loads the bin into Texture2Ds with POINT filter,
draws a 16x16 tile viewport centered on the player (at scale 4,
one tile = 64 px). HUD shows HP/F/D/E, stats, achievements, return.
Viewport is decoupled from the 9x11 agent obs window.
- craftax_classic.h: same loader + 16x16 viewport, adds zombie /
skeleton / cow / directional-arrow sprite overlays and an
inventory readout. Tile ids are offset into the shared bin.
- craftax.c: minimal standalone viewer (random-action by default;
press H to toggle keyboard control) for ./build.sh craftax --fast.
Run: uv run python ocean/craftax/pack_textures.py once to (re)build
the bin, then DISPLAY=:0 uv run puffer eval {craftax,craftax_classic}
--load-model-path latest.
Rewrites craftax_spawn_mobs_native to strip JAX-isms that are pointless on CPU: - bool[48][48] validity mask -> compact (int16, int16) coord list collected in one pass over the bounding box around the player - bounding-box scan: mobs can only spawn within MOB_DESPAWN_DISTANCE=14, so we only visit the up-to-27x27 window instead of the full 48x48 map - early return when can_spawn is already false from the mob-cap or probability roll, skipping the scan + choice - merged count_mobs3 + first_empty_mobs3 into a single loop - inlined the block-type and distance checks Choice arithmetic uses the same FP expressions as baseline so the selected cell is bitwise-identical for any given (valid_count, rng_key) pair. The baseline quirk of writing type_id[level][slot] unconditionally even when no mob spawns is preserved. Phase timing (single-thread, random actions): craftax_spawn_mobs_native: 17.06 us -> 0.30 us (57x) full c_step: 29.6 us -> 12.3 us (2.4x) Verified bitwise-equal to the prior implementation over 1.28M paired steps (128 envs x 10000 steps, random actions, reset exercised).
c_reset and the c_step auto-reset path now optionally memcpy from a pre-generated pool of worlds instead of running generate_world each episode. Pool size is a runtime kwarg (reset_pool_size) read by my_init, default 1024 via config/ocean/craftax.ini. Set to 0 to disable and regenerate every reset (required for strict per-seed determinism in tests/craftax_parity.py). Trade: at most reset_pool_size unique maps are seen per process. With 1024 and ~270-step random-action episodes, diversity is plentiful for training. Memory cost: 1024 * sizeof(CraftaxState) ~= 267 MB once at startup. Two reset entry points are now distinguished: - craftax_reset_state_from_reset_key: direct (used by parity harness), always calls generate_state_from_world_key, pool-free for exact per-key determinism. - craftax_reset_state_on_done: hot-path used by c_step on terminal, consults the pool when enabled, falls through to generate_world otherwise. Pool index derived from reset_key.word[0]. tests/craftax_parity.py picks up raylib's include path since craftax.h now pulls raylib.h (from the shared renderer). Measurements (single-thread, random actions): worldgen: 2.69 ms -> 6.9 us memcpy (~390x) full c_step: 12.3 us -> 2.35 us (5.25x) training SPS: 450K -> 506K (+12%) 1-thread sim SPS: 81K -> 425K (5.25x) 16-thread sim SPS: 1.14M -> 5.53M (4.85x)
The five move_* helpers (melee/passive/ranged mobs + mob/player projectiles) now return immediately when mask=false. JAX's branchless "compute-then-mask" pattern is pointless on CPU: dead slots' output never feeds observations, rewards, or mob_map, so skipping the body and the RNG draws is semantically equivalent. Defining CRAFTAX_JAX_PARITY at build time restores the branchless slow path for bitwise replay against JAX (required by tests/craftax_parity.py). Default build uses the early-out. Also drops craftax_step_jax_index(player_level, NUM_LEVELS) clamps at the top of each move_* -- state->player_level is maintained in [0, NUM_LEVELS-1] by change_floor_native (explicit bounds checks) and by the worldgen init. Six redundant clamps per step eliminated. Measurements (single-thread, random actions, pool=1024): update_mobs phase: 1.392 us -> 0.285 us (4.88x) full c_step: 2.35 us -> 1.22 us 1-thread sim SPS: 425K -> 819K (1.93x) 16-thread sim SPS: 5.53M -> 10.04M (1.82x) training SPS: 506K -> 544K (+7%) Parity test with CRAFTAX_JAX_PARITY defined passes 8 seeds * 1000 steps over 27 terminals. Without the flag, parity diverges at the first mob death -- by design.
These 10 tests were written incrementally as each subsystem (noise, threefry, worldgen, 7 step subsystems) was ported from JAX, to catch divergence at each layer. Now that tests/craftax_parity.py passes end-to-end against the JAX reference, they are redundant: any bug they'd catch also breaks the integration parity test. Dropping ~5400 LOC of scaffolding. Kept: - craftax_parity.py (JAX<->C integration parity harness) - craftax_state_fixtures.py (state-flattening helpers used by parity) - craftax_parity_stress.py (adversarial action sequences) - craftax_step_full_test.py (pytest wrapper -> parity.run)
The dashboard and CSV logger only need to surface a handful of milestones along the tech/exploration curve, not every achievement. The env still tracks all 67 internally for reward computation and for the normalized 'perf' aggregate -- we just stop shipping every one through the log Dict each episode flush. Checkpoints chosen to span the learning curve: collect_wood first resource (tier 1) make_wood_pickaxe first tool make_stone_pickaxe stone tier collect_iron iron tier resource make_iron_pickaxe iron tier tool (major milestone) collect_diamond diamond tier resource enter_gnomish_mines first dungeon (exploration) defeat_necromancer endgame boss Log Dict now carries 4 meta + 8 achievements + 1 n = 13 fields, well under the stock create_dict(32) capacity. Releases the need for the capacity bump in src/bindings* (reverted in the following commit).
This reverts commit 9396e79.
- config/ocean/craftax.ini -> config/craftax.ini - config/ocean/craftax_classic.ini -> config/craftax_classic.ini - ocean/craftax/textures.bin -> resources/craftax/textures.bin - scripts/craftax_convergence_bench.py -> tests/craftax_convergence_bench.py - drop empty scripts/ directory - pack_textures.py: write to resources/craftax/textures.bin - craftax.h / craftax_classic.h: fopen textures from resources/craftax/
Used by the craftax parity harness to compile with -DCRAFTAX_JAX_PARITY, which disables the update_mobs early-out so the C env replays bitwise against JAX. Default training builds leave EXTRA_CFLAGS empty and keep the ~2x sim-SPS early-out enabled.
f0f8107 to
e122c9a
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Lands the Craftax (Full) native C port in ocean/craftax/, with three CPU-focused optimizations on top of the mechanical JAX transliteration, plus a shared 16x16 texture renderer that covers both Full and Classic.
Headline numbers vs the JAX-transliterated pre-optimization baseline:
Commits (18 total)
Native port (14 commits)
849b18b3proxy-backed baseline + parity harnesseac5df3bnative threefry PRNG + noise + floor-0 worldgen44a516f6native world generation for all 9 floors1c02a143native ports of 9 simple step subsystems8a3122bcnative ports of projectile/spell/enchant/floor/chest8ed0a492native ports of do_crafting and place_block612da13bnative port of do_action057fd61dnative port of spawn_mobsea7bb890native port of update_mobse99a2148fully native c_step, JAX proxy removedbbe16c49adversarial parity stress batterye428b6edrestore production vec/train config + convergence benchmark script049eb609Classic env as side-by-side benchmarking target9396e794src: raise log Dict capacity 32 → 256 (67 achievements)Renderer (1 commit)
c30b9535shared 16x16 texture renderer for Full + Classic (reads a single textures.bin, loaded lazily per env)Optimizations (3 new commits)
c7990cb2spawn_mobs bbox scan + early-out — replaces JAX-style full-grid validity masks with a bounding-box scan + compact coord list and bails early on mob-cap / probability-roll failure. Choice arithmetic preserved bitwise.craftax_spawn_mobs_native17.06 us → 0.30 us (57x). Verified bitwise-equal to the prior impl over 1.28M paired steps.93cfb01breset-pool for cached worldgen — optional pool of pre-generated worlds (default 1024 viaconfig/ocean/craftax.ini). Hot-pathc_stepauto-reset consults the pool; directcraftax_reset_state_from_reset_keystays pool-free so the parity harness keeps exact per-key determinism. Worldgen 2.69 ms → 6.9 us memcpy. Trades at mostreset_pool_sizeunique maps per process for ~1.1× training SPS and the full 5x sim-only SPS on that axis.ef901541update_mobs early-out on dead mob slots — the five move_* helpers (3 mob classes + 2 projectile classes) now return onmask=falseinstead of computing the full branchless select. Also drops 5 redundantcraftax_step_jax_index(player_level, NUM_LEVELS)clamps (player_level is provably in range).JAX parity
tests/craftax_parity.pycontinues to pass (verified 8 seeds × 1000 steps, 27 terminals) when built with-DCRAFTAX_JAX_PARITY:CC="clang -DCRAFTAX_JAX_PARITY" ./build.sh craftax uv run python tests/craftax_parity.py --seeds 8 --steps 1000Without the flag (default build), strict per-seed bitwise parity with JAX diverges at the first mob death. Craftax mechanics, distributions, rewards, and achievements are unchanged — just a different deterministic RNG realization.
Measurements (anvil: Ryzen 9950X3D + RTX PRO 6000)
Test plan
./build.sh craftax) produces a workingpufferlib._C.sotests/craftax_parity.py8 seeds × 1000 stepspuffer train craftax) with achievements unlocking normallyscripts/craftax_convergence_bench.pyif score-vs-steps curves are needed🤖 Generated with Claude Code