Skip to content

ad3002/gram

Repository files navigation

GRAM — Generative Recursive Reasoning Models

Independent, faithful reimplementation of GRAM by Baek, Jo, Kim, Ren, Bengio & Ahn (arXiv:2605.19376v1, 19 May 2026).

The original paper has no public code as of intake date (2026-05-21). This repository fills that gap with a single goal: anyone with one GPU should be able to reproduce the paper's numbers from a clean machine.

⚠️ This is a community reimplementation, not the authors' code. Discrepancies vs the paper may reflect either the underspecified parts of the method section (see GAP_DEFAULTS.md) or genuine bugs in this repo. Open an issue if you find one.

Two values this repo optimizes for

  1. Reproducibility — every numerical result in this README is produced by one of the scripts in this repo. No magic, no hidden state. bash scripts/reproduce.sh walks from a clean machine to our exact numbers.
  2. Transparency — every place where the paper underspecified a choice (GAP_DEFAULTS.md) and every place where our reproduction deviates from the paper's full recipe (DEVIATIONS.md) is documented explicitly with a one-line justification.

What's reproduced (so far)

Benchmark Paper metric Paper (full compute) This repo (partial compute) Status
N-Queens 8×8 — test accuracy_valid_solution satisfies all constraints 99.7 ± 0.3 % 89.52 % (seed=42, ~8% paper compute, RTX 4090, $1.60) partial — within 10pp of paper at 1/12 the budget
N-Queens 8×8 — test coverage (N=20) unique valid completions 90.3 ± 1.9 % 73.98 % partial — climbing toward paper
N-Queens 8×8 — hostile slice (n_sol ≥ 2) accuracy constraint satisfaction on multi-solution puzzles not separately reported 92.28 % above our pre-registered pass threshold (≥90 %) — direct evidence multi-solution recovery works

Latest headline numbers from the b=768 paper-recipe run on RTX 4090 (v7_b768_fp32). Previous run at b=256 on RTX 3080 Ti is preserved in results/ (87.58 % test / 86.87 % hostile) — kept as a smaller-batch reference. Both runs use the same code, seed, and 250-epoch budget; the b=768 run is the paper-faithful single-GPU config.

Earlier in this README we reported accuracy_target_exact_match = 0.8383 as headline. That metric counted the model wrong when it produced a different-but-valid completion on multi-solution puzzles. Both metrics now in every eval JSON; the headline switched to match paper Table 1's "output satisfies all constraints" definition. See REPRODUCTION.md for the audit trail.

How fast / how cheap can you reproduce this?

We ran an optimization sprint to see how cheap a single-GPU reproduction can be. Headline: ~2.5× throughput vs our RTX 3080 Ti baseline (48 → 122 examples/sec) on a single RTX 4090, with a tradeoff on convergence at non-paper batch sizes. Full measured matrix + what worked / what didn't / what we wouldn't bother with on a model this size: OPTIMIZATION.md. | Sudoku-Extreme | 97.0 % | — | not yet attempted | | ARC-AGI-1 / 2 | 52.0 / 11.1 % | — | not yet attempted | | Graph Coloring 8 / 10 v (conflict ↓) | 2.7 / 3.3 | — | not yet attempted | | Binarized MNIST (256 steps) | IS 2.04 / FID 73.34 | — | not yet attempted |

See REPRODUCTION.md for the full N-Queens reproduction report (training curve, eval breakdown by slice, hardware, cost, wall time).

What you'll need

  • 1× NVIDIA GPU with ≥ 24 GB VRAM for the paper-faithful b=768 config (we used RTX 4090 at $0.34/h on RunPod community cloud). 12 GB cards work but require b=256 (worse convergence per epoch — see OPTIMIZATION.md).
  • Python ≥ 3.11
  • PyTorch ≥ 2.0 (we used 2.4.0 + CUDA 12.4 via runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04)
  • ~5 hours of compute for a single seed at our 250-epoch budget; ~$1.60 on RunPod 4090

Quick start (paper-recipe config that produced the headline 89.52% / 92.28%)

git clone https://github.com/ad3002/gram
cd gram
pip install -e ".[dev]"
pytest -q                                  # 47/50 unit tests pass; 3 skip on real-dataset path

# Build the dataset (deterministic; MD5s verified against scripts/MANIFEST.expected.md5)
python scripts/build_nqueens_dataset.py --n 8 --remove-k 5,6,7 --split 0.85 \
       --seed 42 --hostile-threshold 2 --out data/nqueens_8x8/

# Train (paper-recipe batch=768, fp32; produces results/eval_*_v7_b768.json)
python -m gram.train --config configs/nqueens_8x8.yaml \
       --dataset data/nqueens_8x8/ --output-dir runs/v7_b768_fp32 \
       --device cuda --seed 42 --amp-dtype fp32 \
       --override train.epochs=250 train.global_batch_size=768

# Evaluate (writes both accuracy metrics + seed + used_ema + used_act to JSON)
for split in test slice_baseline slice_hostile; do
    python -m gram.eval --checkpoint runs/v7_b768_fp32/final.pt \
        --split data/nqueens_8x8/${split}.tsv --n-samples 20 \
        --device cuda --batch-size 480 --seed 0 \
        --output runs/v7_b768_fp32/eval_${split/slice_/}.json
done

The smaller-batch archival config from our earlier 3080 Ti run is at the bottom of scripts/reproduce.sh if you want a fits-in-12-GB version.

Or run the notebooks/reproduce.ipynb for a step-by-step Colab-friendly version with side-by-side comparisons against results/.

Architecture summary

GRAM (paper §2) wraps a deterministic recursive reasoning model (HRM/TRM-style) in a probabilistic shell:

For each of N_sup supervision steps (the "outer loop"):
    For each of T transitions (the "inner-outer loop"):
        # K low-level updates with high-level state frozen
        for k in 1..K:
            l_k = fL(h_prev + l_{k-1} + e_x)
        # Deterministic high-level proposal
        u_t = fH(h_prev + l_K)
        # Stochastic guidance (the GRAM innovation)
        ε_t ~ N(μ_θ(u_t), σ_θ²(u_t) · I)    at inference
        ε_t ~ N(μ_φ(u_t, y), σ_φ²(u_t, y))  at training (variational posterior)
        h_t = u_t + ε_t
    # Decode + apply training loss; detach states for next supervision step
    logits = decoder(h_T)
    L = CE(logits, y) + β · KL(q_φ ‖ p_θ)

We follow paper §B.1 / §B.2 for everything that is pinned and document our defaults for everything that isn't.

What we deviated on (full table in DEVIATIONS.md)

Knob Paper This run Reason
Hardware 8× RTX 4090 1× RTX 3080 Ti cost / availability
Global batch 768 256 single-GPU memory
Epochs 3000 250 wall-time budget (~6h)
Seeds mean ± std 42 only wall-time budget
EMA / ACT inference enabled / adaptive disabled / final-step-only simplification for first pass

Effective compute is ~11 % of the paper's epochs × batch_size product. The gap between our 83.8% test accuracy and paper's 99.7% is consistent with this deficit, not with an architectural defect — see REPRODUCTION.md §"Compute budget accounting" for the math.

Underspecified choices we had to make (GAP_DEFAULTS.md)

Ten places where the paper's §2 + appendices A–B don't pin a choice. We picked a reasonable default for each and documented it so you can swap and test.

Examples:

  • Posterior q_φ(ε | u, y) architecture: SwiGLU MLP over concat(u, mean-pool(y_embed)).
  • KL balancing: Dreamer-V2 variant with α = 0.8.
  • LPRM target: per-example exact-match accuracy.
  • z_0: one frozen N(0, I) sample saved in the checkpoint.

If you change one of these and our numbers change a lot, that gap is load-bearing for the paper's result. Worth knowing.

Layout

gram/
├── README.md                ← you are here
├── REPRODUCTION.md          ← our actual run (hardware, cost, numbers, MD5s)
├── OPTIMIZATION.md          ← throughput sprint: what worked, what didn't, $ projection
├── DEVIATIONS.md            ← config we used vs paper config
├── GAP_DEFAULTS.md          ← 11 underspecified items + our choices
├── CITATION.cff             ← cite Baek et al. (the paper) + this repo (the reimpl)
├── LICENSE                  ← MIT
├── pyproject.toml           ← pinned deps
├── gram/                    ← package
│   ├── blocks.py            ← RoPE, MHSA, SwiGLU, RMSNorm
│   ├── heads.py             ← guidance prior + posterior + halt + LPRM
│   ├── recursive.py         ← K low-level + 1 high-level + truncated BPTT
│   ├── model.py             ← top-level GRAMModel (encoder + core + decoder)
│   ├── losses.py            ← truncated ELBO + ACT BCE + LPRM MSE
│   ├── ema.py               ← parameter EMA
│   ├── train.py             ← CLI
│   ├── eval.py              ← CLI
│   └── datasets/nqueens.py
├── configs/nqueens_8x8.yaml ← paper hyperparams + our gap-defaults
├── scripts/
│   ├── build_nqueens_dataset.py   ← deterministic, seed=42
│   └── reproduce.sh               ← one-command full pipeline
├── notebooks/reproduce.ipynb      ← step-by-step
├── tests/                          ← 50 unit tests (47 + 3 skip)
└── results/                        ← our run's eval JSONs + history + train.log + MANIFEST

License

MIT. Cite the paper, optionally cite this repo (see CITATION.cff).

Citation

If you use this code, please cite both the original paper and this reimplementation:

@article{baek2026gram,
  title={Generative Recursive Reasoning Models},
  author={Baek, Junyeob and Jo, Mingyu and Kim, Minsu and Ren, Mengye and Bengio, Yoshua and Ahn, Sungjin},
  journal={arXiv preprint arXiv:2605.19376},
  year={2026}
}

@software{komissarov2026gram_reimpl,
  title={GRAM — community reimplementation},
  author={Komissarov, Aleksey},
  year={2026},
  url={https://github.com/ad3002/gram}
}

Acknowledgements

The deterministic recursive scaffolding was inspired by:

We did not fork either; this is a clean reimplementation from the GRAM paper text alone, with the above as reference reading.

About

Community reimplementation of GRAM (Generative Recursive Reasoning Models, arXiv:2605.19376) — reproducibility + transparency first

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors