A PyTorch implementation of looped transformer architectures for language modeling. The project trains and evaluates several model variants — FLT (Fully Looped Transformer), LT (Looped Transformer).
This project is built on top of nanochat, an open-source language model framework. The core infrastructure — including the training loop, tokenizer, dataloader, optimizer implementations (Muon, AdamW), evaluation harness, and web UI — originates from that codebase. The novel contributions of this project are the looped transformer architectures (FLT, LT, LT_i, LT_ia), the multi-step loop training objectives, and the associated experiments.
External resources inherited from the base project include:
- Pretraining dataset:
karpathy/fineweb-edu-100b-shufflehosted on HuggingFace - Evaluation bundle: hosted at
karpathy-public.s3.us-west-2.amazonaws.com - Optimizer code: Muon optimizer attributed to Keller et al. / modded-nanogpt contributors
These references appear throughout the codebase and are part of the upstream open-source project, not specific to this work.
- Python 3.10+
- uv package manager
- CUDA-capable GPU(s) (the default scripts target 8× A100 for training and 4× A100 for evaluation)
- A Weights & Biases account for experiment tracking
Install all dependencies with:
uv sync --extra gpuAll three scripts log metrics to Weights & Biases. Before running any script, open the file and uncomment the WANDB_API_KEY line, then fill in your own key:
# In run.sh / eval.sh / check.sh, find this line and uncomment it:
export WANDB_API_KEY="<YOUR_WANDB_API_KEY>"You can find your API key at https://wandb.ai/authorize.
run.sh downloads the pretraining dataset, then trains all four model variants (FLT, LT_ia, LT_i, LT) sequentially using distributed training across 8 GPUs.
bash run.shWhat it does:
- Creates the artifact cache directory at
$HOME/.cache/nanochat. - Downloads ~250 shards of the pretraining text corpus (≈25 GB).
- Launches
scripts.base_trainviatorchrunfor each config inCONFIGS:config/FLT.yamlconfig/LT_ia.yamlconfig/LT_i.yamlconfig/LT.yaml
- Each run uses
--loss_type END,K=12loop iterations, andL=6layers. - Checkpoints are saved to
$HOME/.cache/nanochat/base_checkpoints/<run_name>/.
Key parameters (editable at the top of run.sh):
| Variable | Default | Description |
|---|---|---|
NPROC_PER_NODE |
8 |
Number of GPUs per node |
DBS |
8 |
Per-device batch size |
CONFIGS |
four variants | Model configs to train |
eval.sh evaluates one or more trained checkpoints on the CORE benchmark, bits-per-byte (BPB), and perplexity metrics.
bash eval.shWhat it does:
- Iterates over the checkpoint directories listed in
CHECKPOINT_DIRS. - For each checkpoint, runs
scripts.base_budgetwith evaluation modescore,bpb,sample. - Results are logged to the wandb project
FullyLoopedTransformer.
Before running, update CHECKPOINT_DIRS in eval.sh to point to the checkpoints produced by run.sh:
CHECKPOINT_DIRS=(
"$NANOCHAT_BASE_DIR/base_checkpoints/<your_run_name>"
...
)Key parameters (editable in eval.sh):
| Variable | Default | Description |
|---|---|---|
NPROC_PER_NODE |
4 |
Number of GPUs |
DBS |
4 |
Per-device batch size |
EVAL_MODES |
core,bpb,sample |
Evaluation modes to run |
MAX_PER_TASK |
-1 |
Max examples per task (-1 = all) |
check.sh runs short diagnostic training runs (2000 steps each) and logs per-layer gradient norms and hidden-state norms to wandb. No checkpoints are saved. This is useful for inspecting the internal training dynamics of the looped models.
bash check.shWhat it does:
- Iterates over all combinations of configs (
LT_ia,LT_i,FLT_res) and loop countsK ∈ {6, 9, 12}. - Each run trains for 2000 steps with
--loss_type ENDandL=6layers. - Per-layer gradient norms and hidden-state norms are logged every step to the wandb project
FLT_check.
Key parameters (editable in check.sh):
| Variable | Default | Description |
|---|---|---|
NPROC_PER_NODE |
4 |
Number of GPUs |
DBS |
8 |
Per-device batch size |
CONFIGS |
three variants | Model configs to diagnose |
Ks |
6, 9, 12 |
Loop counts to sweep |
.
├── config/ # Model configuration YAML files
│ ├── FLT.yaml
│ ├── FLT_res.yaml
│ ├── LT.yaml
│ ├── LT_i.yaml
│ └── LT_ia.yaml
├── nanochat/ # Core library (model, tokenizer, dataloader, engine, ...)
├── scripts/ # Training and evaluation entry points
│ ├── base_train.py # Pretraining
│ ├── base_eval.py # CORE / BPB evaluation
│ ├── base_budget.py # Budget-controlled evaluation
│ ├── base_check.py # Diagnostic gradient/norm logging
│ └── base_rl.py # Reinforcement learning fine-tuning
├── tasks/ # Benchmark task implementations (ARC, GSM8K, MMLU, ...)
├── tests/ # Unit tests
├── run.sh # Step 1: pretraining
├── eval.sh # Step 2: evaluation
├── check.sh # Step 3: diagnostic checks
└── pyproject.toml
1. Fill in WANDB_API_KEY in run.sh / eval.sh / check.sh
2. bash run.sh # pretrain all model variants
3. bash eval.sh # evaluate checkpoints on benchmarks
4. bash check.sh # inspect gradient / hidden-state dynamics