Independent, faithful reimplementation of GRAM by Baek, Jo, Kim, Ren, Bengio & Ahn (arXiv:2605.19376v1, 19 May 2026).
The original paper has no public code as of intake date (2026-05-21). This repository fills that gap with a single goal: anyone with one GPU should be able to reproduce the paper's numbers from a clean machine.
⚠️ This is a community reimplementation, not the authors' code. Discrepancies vs the paper may reflect either the underspecified parts of the method section (seeGAP_DEFAULTS.md) or genuine bugs in this repo. Open an issue if you find one.
- Reproducibility — every numerical result in this README is produced by one of the scripts in this repo. No magic, no hidden state.
bash scripts/reproduce.shwalks from a clean machine to our exact numbers. - Transparency — every place where the paper underspecified a choice (
GAP_DEFAULTS.md) and every place where our reproduction deviates from the paper's full recipe (DEVIATIONS.md) is documented explicitly with a one-line justification.
| Benchmark | Paper metric | Paper (full compute) | This repo (partial compute) | Status |
|---|---|---|---|---|
| N-Queens 8×8 — test accuracy_valid_solution | satisfies all constraints | 99.7 ± 0.3 % | 89.52 % (seed=42, ~8% paper compute, RTX 4090, $1.60) | partial — within 10pp of paper at 1/12 the budget |
| N-Queens 8×8 — test coverage (N=20) | unique valid completions | 90.3 ± 1.9 % | 73.98 % | partial — climbing toward paper |
| N-Queens 8×8 — hostile slice (n_sol ≥ 2) accuracy | constraint satisfaction on multi-solution puzzles | not separately reported | 92.28 % | ✓ above our pre-registered pass threshold (≥90 %) — direct evidence multi-solution recovery works |
Latest headline numbers from the b=768 paper-recipe run on RTX 4090 (
v7_b768_fp32). Previous run at b=256 on RTX 3080 Ti is preserved inresults/(87.58 % test / 86.87 % hostile) — kept as a smaller-batch reference. Both runs use the same code, seed, and 250-epoch budget; the b=768 run is the paper-faithful single-GPU config.Earlier in this README we reported
accuracy_target_exact_match = 0.8383as headline. That metric counted the model wrong when it produced a different-but-valid completion on multi-solution puzzles. Both metrics now in every eval JSON; the headline switched to match paper Table 1's "output satisfies all constraints" definition. See REPRODUCTION.md for the audit trail.
We ran an optimization sprint to see how cheap a single-GPU reproduction can be. Headline: ~2.5× throughput vs our RTX 3080 Ti baseline (48 → 122 examples/sec) on a single RTX 4090, with a tradeoff on convergence at non-paper batch sizes. Full measured matrix + what worked / what didn't / what we wouldn't bother with on a model this size: OPTIMIZATION.md.
| Sudoku-Extreme | 97.0 % | — | not yet attempted |
| ARC-AGI-1 / 2 | 52.0 / 11.1 % | — | not yet attempted |
| Graph Coloring 8 / 10 v (conflict ↓) | 2.7 / 3.3 | — | not yet attempted |
| Binarized MNIST (256 steps) | IS 2.04 / FID 73.34 | — | not yet attempted |
See REPRODUCTION.md for the full N-Queens reproduction report (training curve, eval breakdown by slice, hardware, cost, wall time).
- 1× NVIDIA GPU with ≥ 24 GB VRAM for the paper-faithful
b=768config (we used RTX 4090 at $0.34/h on RunPod community cloud). 12 GB cards work but requireb=256(worse convergence per epoch — seeOPTIMIZATION.md). - Python ≥ 3.11
- PyTorch ≥ 2.0 (we used 2.4.0 + CUDA 12.4 via
runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04) - ~5 hours of compute for a single seed at our 250-epoch budget; ~$1.60 on RunPod 4090
git clone https://github.com/ad3002/gram
cd gram
pip install -e ".[dev]"
pytest -q # 47/50 unit tests pass; 3 skip on real-dataset path
# Build the dataset (deterministic; MD5s verified against scripts/MANIFEST.expected.md5)
python scripts/build_nqueens_dataset.py --n 8 --remove-k 5,6,7 --split 0.85 \
--seed 42 --hostile-threshold 2 --out data/nqueens_8x8/
# Train (paper-recipe batch=768, fp32; produces results/eval_*_v7_b768.json)
python -m gram.train --config configs/nqueens_8x8.yaml \
--dataset data/nqueens_8x8/ --output-dir runs/v7_b768_fp32 \
--device cuda --seed 42 --amp-dtype fp32 \
--override train.epochs=250 train.global_batch_size=768
# Evaluate (writes both accuracy metrics + seed + used_ema + used_act to JSON)
for split in test slice_baseline slice_hostile; do
python -m gram.eval --checkpoint runs/v7_b768_fp32/final.pt \
--split data/nqueens_8x8/${split}.tsv --n-samples 20 \
--device cuda --batch-size 480 --seed 0 \
--output runs/v7_b768_fp32/eval_${split/slice_/}.json
doneThe smaller-batch archival config from our earlier 3080 Ti run is at the bottom of scripts/reproduce.sh if you want a fits-in-12-GB version.
Or run the notebooks/reproduce.ipynb for a step-by-step Colab-friendly version with side-by-side comparisons against results/.
GRAM (paper §2) wraps a deterministic recursive reasoning model (HRM/TRM-style) in a probabilistic shell:
For each of N_sup supervision steps (the "outer loop"):
For each of T transitions (the "inner-outer loop"):
# K low-level updates with high-level state frozen
for k in 1..K:
l_k = fL(h_prev + l_{k-1} + e_x)
# Deterministic high-level proposal
u_t = fH(h_prev + l_K)
# Stochastic guidance (the GRAM innovation)
ε_t ~ N(μ_θ(u_t), σ_θ²(u_t) · I) at inference
ε_t ~ N(μ_φ(u_t, y), σ_φ²(u_t, y)) at training (variational posterior)
h_t = u_t + ε_t
# Decode + apply training loss; detach states for next supervision step
logits = decoder(h_T)
L = CE(logits, y) + β · KL(q_φ ‖ p_θ)
We follow paper §B.1 / §B.2 for everything that is pinned and document our defaults for everything that isn't.
What we deviated on (full table in DEVIATIONS.md)
| Knob | Paper | This run | Reason |
|---|---|---|---|
| Hardware | 8× RTX 4090 | 1× RTX 3080 Ti | cost / availability |
| Global batch | 768 | 256 | single-GPU memory |
| Epochs | 3000 | 250 | wall-time budget (~6h) |
| Seeds | mean ± std | 42 only | wall-time budget |
| EMA / ACT inference | enabled / adaptive | disabled / final-step-only | simplification for first pass |
Effective compute is ~11 % of the paper's epochs × batch_size product. The gap between our 83.8% test accuracy and paper's 99.7% is consistent with this deficit, not with an architectural defect — see REPRODUCTION.md §"Compute budget accounting" for the math.
Underspecified choices we had to make (GAP_DEFAULTS.md)
Ten places where the paper's §2 + appendices A–B don't pin a choice. We picked a reasonable default for each and documented it so you can swap and test.
Examples:
- Posterior
q_φ(ε | u, y)architecture: SwiGLU MLP overconcat(u, mean-pool(y_embed)). - KL balancing: Dreamer-V2 variant with α = 0.8.
- LPRM target: per-example exact-match accuracy.
z_0: one frozenN(0, I)sample saved in the checkpoint.
If you change one of these and our numbers change a lot, that gap is load-bearing for the paper's result. Worth knowing.
gram/
├── README.md ← you are here
├── REPRODUCTION.md ← our actual run (hardware, cost, numbers, MD5s)
├── OPTIMIZATION.md ← throughput sprint: what worked, what didn't, $ projection
├── DEVIATIONS.md ← config we used vs paper config
├── GAP_DEFAULTS.md ← 11 underspecified items + our choices
├── CITATION.cff ← cite Baek et al. (the paper) + this repo (the reimpl)
├── LICENSE ← MIT
├── pyproject.toml ← pinned deps
├── gram/ ← package
│ ├── blocks.py ← RoPE, MHSA, SwiGLU, RMSNorm
│ ├── heads.py ← guidance prior + posterior + halt + LPRM
│ ├── recursive.py ← K low-level + 1 high-level + truncated BPTT
│ ├── model.py ← top-level GRAMModel (encoder + core + decoder)
│ ├── losses.py ← truncated ELBO + ACT BCE + LPRM MSE
│ ├── ema.py ← parameter EMA
│ ├── train.py ← CLI
│ ├── eval.py ← CLI
│ └── datasets/nqueens.py
├── configs/nqueens_8x8.yaml ← paper hyperparams + our gap-defaults
├── scripts/
│ ├── build_nqueens_dataset.py ← deterministic, seed=42
│ └── reproduce.sh ← one-command full pipeline
├── notebooks/reproduce.ipynb ← step-by-step
├── tests/ ← 50 unit tests (47 + 3 skip)
└── results/ ← our run's eval JSONs + history + train.log + MANIFEST
MIT. Cite the paper, optionally cite this repo (see CITATION.cff).
If you use this code, please cite both the original paper and this reimplementation:
@article{baek2026gram,
title={Generative Recursive Reasoning Models},
author={Baek, Junyeob and Jo, Mingyu and Kim, Minsu and Ren, Mengye and Bengio, Yoshua and Ahn, Sungjin},
journal={arXiv preprint arXiv:2605.19376},
year={2026}
}
@software{komissarov2026gram_reimpl,
title={GRAM — community reimplementation},
author={Komissarov, Aleksey},
year={2026},
url={https://github.com/ad3002/gram}
}The deterministic recursive scaffolding was inspired by:
- HRM (Apache-2.0) by Sapient
- TRM (MIT) by Samsung SAIL Montreal
We did not fork either; this is a clean reimplementation from the GRAM paper text alone, with the above as reference reading.