REAP pruning for MLX MoE models on Apple Silicon.
The shortest version:
collect telemetry -> build a pruning plan -> apply the plan
This repo ports the pruning side of Cerebras REAP to local MLX workflows. It is built for people who want to calibrate on real data, inspect the telemetry, and physically write a smaller MLX checkpoint.
- Paper: https://arxiv.org/abs/2510.13999
- Original Cerebras REAP repo: https://github.com/CerebrasResearch/reap
- Vendored upstream submodule in this repo:
external/cerebras-reap
The submodule is there so you can diff local MLX work against the upstream research implementation without leaving the repo.
- Collects per-expert telemetry from an MLX MoE model.
- Builds a pruning plan with REAP or simpler scoring rules.
- Applies that plan to an MLX checkpoint.
- Compares two telemetry files under the same prune config with an exact parity report.
- Supports prompt calibration, Hugging Face datasets, and local dataset files.
- Supports lower-memory collection modes, token chunking, sample mini-batching, packing, and chat-style calibration rows.
- It does not stream checkpoint writeback layer-by-layer during
apply. Collection can run in lower-memory modes, but apply still loads the model normally. - It does not ship a full benchmark harness in the repo. You still need to run before/after evals yourself.
Requirements:
- Apple Silicon Mac
- Node 20+
- pnpm
- Python 3.11+
mlxandmlx_lm
Install and build:
pnpm install
pnpm buildnode dist/cli/index.js collect \
--model ./models/qwen1.5-moe-a2.7b-chat-4bit \
--output ./tmp \
--dataset-file ./calibration/tinystories.jsonl \
--dataset-format jsonl \
--dataset-text-field text \
--max-samples 1024 \
--min-samples 512 \
--max-tokens 512 \
--sample-batch-size 8 \
--pack-samples \
--collect-mode reload_per_layer \
--batch-size 128 \
--lazy-loadnode dist/cli/index.js run \
--model ./tmp/telemetry-*.json \
--output ./tmp/plan \
--ratio 0.5 \
--min-experts 1 \
--no-legacynode dist/cli/index.js apply \
--model ./models/qwen1.5-moe-a2.7b-chat-4bit \
--plan ./tmp/plan/pruning-plan.json \
--output ./tmp/pruned-modelUse --dry-run on apply if you want to validate the plan before writing a new checkpoint.
collect: run prompt or dataset calibration and write telemetry JSON.run: score experts from telemetry and writepruning-plan.json.parity: run the same prune config against two telemetry files and diff the exact prune set.full: runcollect -> run -> applyin one command.apply: remove pruned experts from the MLX checkpoint.observe: summarize an observation log.init: generate synthetic telemetry for local testing.
You can calibrate with exactly one input source per run:
--prompt <text>--dataset <huggingface-name>--dataset-file <path>
Local dataset files support:
jsonjsonlcsvparquet- plain
text
Useful controls:
--max-samples <n>limits how many rows are processed.--min-samples <n>fails the run if too many rows were unusable.--max-tokens <n>limits how much of each sample is processed.--dataset-text-field <field.path>selects a text field.--dataset-messages-field <field.path>renders chat-style rows through the tokenizer chat template when available.
That means you are not stuck with one-off prompts. You can hand the collector a real dataset and bound both sample count and per-sample length.
If you want reap-mlx to track the published REAP paper more closely, use the paper's calibration standards as your baseline instead of inventing larger sample counts by default.
- For models with
<= 110Bparameters, the paper calibrates on 1,024 randomly selected samples packed to 2,048 tokens. - For models with
>= 110Bparameters, the paper uses 12,228 samples with a maximum sequence length of 16,384 tokens and no packing. - The paper shows that domain-specific calibration matters a lot. In particular, coding models calibrated on
c4can degrade badly, whileevol-codealpacapreserves coding quality much better.
Recommended dataset choices from the paper:
- coding:
theblackcat102/evol-codealpaca-v1 - creative writing:
euclaise/WritingPrompts_curated - math:
allenai/tulu-3-sft-personas-math - larger tool-use / agentic mixes: add
Salesforce/xlam-function-calling-60kandSWE-bench/SWE-smith-trajectories
For local Apple Silicon workflows, the most practical paper-aligned baseline is the <= 110B recipe:
node dist/cli/index.js collect \
--model ./models/qwen1.5-moe-a2.7b-chat-4bit \
--output ./tmp/paper-coding \
--dataset theblackcat102/evol-codealpaca-v1 \
--dataset-split train \
--max-samples 1024 \
--min-samples 1024 \
--max-tokens 2048 \
--pack-samples \
--renorm-topk \
--collect-mode reload_per_layer \
--batch-size 128 \
--lazy-loadDo not treat 20k as a paper-backed default. If you want to test 20k, treat it as an experiment and compare it against the paper-backed 1,024 sample baseline on your actual eval suite.
There are three different knobs here, and they are not the same thing.
--sample-batch-size <n> batches multiple independent samples or conversations into one model pass.
Use this when the goal is throughput on a larger calibration set.
--pack-samples packs multiple independent short samples into fuller context windows.
Use this when your dataset has lots of short rows and you do not need one-row-per-window isolation.
--batch-size <n> chunks flattened token activations inside expert scoring.
Use this when collection is running hot on memory.
collect and full support these modes:
single_pass: the original behaviorreplay_per_layer: replay hidden states layer-by-layer in one processreload_per_layer: rerun one selected layer per collector process, then merge telemetry
reload_per_layer is the lowest-memory observation mode in this repo today. It is slower, but it cuts working-set pressure enough to make some runs practical on smaller machines.
--layer-wise remains as shorthand for the layer replay path.
--lazy-load asks MLX to defer parameter materialization during load. Pair it with reload_per_layer when you want the most conservative collection profile.
This project supports:
reapreap_l2frequencyweighted_frequency_sumean_sumean_meanean_caweighted_ean_sumweighted_ean_sum_l2max_activations
Notes:
weighted_frequency_summaps to weighted routing-frequency totals (weightedExpertFrequencySum, withgateValueSumfallback for local collector telemetry).ean_ca,weighted_ean_sum_l2, andreap_l2can require richer telemetry fields when you ingest data produced outside the local collector.
The REAP-style score used here is:
saliency_j = mean( g_j(x) * ||f_j(x)|| )
Where:
g_j(x)is the router softmax weight for expertjf_j(x)is the expert output- the mean is taken over routed tokens for that expert
If you are trying to answer this question:
same exact telemetry + same exact prune config => same exact experts pruned?
Use parity.
node dist/cli/index.js parity \
--left ./left.telemetry.json \
--right ./right.telemetry.json \
--output ./tmp/parity \
--prune-method reap \
--n-experts-to-prune-per-layer 15 \
--min-experts 1 \
--no-legacy \
--require-identical-telemetryIt writes:
parity-report.jsonparity-report.md- left and right pruning outputs under the output directory
The report includes:
- normalized telemetry hashes
- the first differing expert row
- the exact prune-set diff
- per-layer expert deltas
This is the cleanest correctness check in the repo. If telemetry is identical and the prune sets differ, something is wrong.
--model <dir> MLX model directory
--output <dir> Telemetry output directory
--prompt <text> Single calibration text
--dataset <name> HuggingFace dataset name
--dataset-file <path> Local calibration dataset file
--dataset-format <fmt> auto|json|jsonl|csv|parquet|text
--dataset-split <name> Dataset split (default: train)
--dataset-text-field <field> Text field path
--dataset-messages-field <field> Chat messages array field path
--max-samples <n> Max dataset samples to aggregate (default: 100)
--min-samples <n> Require at least n usable samples (default: 1)
--max-tokens <n> Per-sample token cap (default: 256)
--sample-batch-size <n> Batch multiple samples or conversations together
--pack-samples Pack short independent samples into fuller windows
--layers <spec> Example: 0-3,8,10
--renorm-topk Renormalize top-k gate weights to sum to 1
--layer-wise Enable layer-wise collection mode
--collect-mode <name> single_pass|replay_per_layer|reload_per_layer
--batch-size <n> Token chunk size for collection batching
--lazy-load Ask MLX to lazily materialize weights during load
--python <bin> Python binary (default: python3)
--model <dir> MLX model directory
--output <dir> Pipeline output directory
--prompt <text> Single calibration text
--dataset <name> HuggingFace dataset name
--dataset-file <path> Local calibration dataset file
--dataset-format <fmt> auto|json|jsonl|csv|parquet|text
--dataset-split <name> Dataset split (default: train)
--dataset-text-field <field> Text field path
--dataset-messages-field <field> Chat messages array field path
--max-samples <n> Max dataset samples to aggregate (default: 100)
--min-samples <n> Require at least n usable samples (default: 1)
--max-tokens <n> Per-sample token cap (default: 256)
--sample-batch-size <n> Batch multiple samples or conversations together
--pack-samples Pack short independent samples into fuller windows
--layers <spec> Example: 0-3,8,10
--renorm-topk Renormalize top-k gate weights to sum to 1
--layer-wise Enable layer-wise collection mode
--collect-mode <name> single_pass|replay_per_layer|reload_per_layer
--batch-size <n> Token chunk size for collection batching
--lazy-load Ask MLX to lazily materialize weights during load
--ratio <0..0.95> Target prune ratio per layer
--min-experts <n> Minimum experts kept per layer
--dry-run Validate apply step without writing pruned model
--model <file> Telemetry JSON from collect
--output <dir> Plan and observation output directory
--ratio <0..0.95> Target prune ratio per layer
--calibration <n> Calibration rounds (default: 2)
--min-experts <n> Minimum experts kept per layer
--no-legacy Require REAP saliency fields
--json Print plan JSON to stdout
--left <file> Left telemetry JSON
--right <file> Right telemetry JSON
--output <dir> Output directory for left/right plans and parity report
--ratio <0..0.95> Target prune ratio per layer
--n-experts-to-prune-per-layer <n>
--prune-method <name> reap|reap_l2|frequency|weighted_frequency_sum|ean_sum|ean_mean|ean_ca|weighted_ean_sum|weighted_ean_sum_l2|max_activations
--require-identical-telemetry Fail unless normalized telemetry hashes match exactly
--json Print parity report JSON to stdout
--model <dir> Source MLX model
--plan <file> pruning-plan.json
--output <dir> Pruned model output
--dry-run Validate plan only
--file <path> Observation log file
--json JSON output
--output <file> Synthetic telemetry output
--model-name <name> Default: synthetic-moe
--layers <n> Default: 8
--experts <n> Default: 8
--seed <int> RNG seed
- This repo is focused on pruning MLX MoE checkpoints, not on being a full research harness.
- It supports
switch_mlp-style MLX MoE checkpoints, including full-precision and quantized expert weights in the collector. - Lower-memory collection is here now. Lower-memory apply is not.
- The repo has planner tests, collector wiring tests, and an exact parity harness.
- The repo still does not include a built-in benchmark suite for pruned versus unpruned models.
If you want a broader research stack, including full evaluation workflows and other compression paths, use the upstream Cerebras repo. A checked-out copy now lives in this repo at external/cerebras-reap.
pnpm lint
pnpm build
pnpm test
pnpm verify- Paper: https://arxiv.org/abs/2510.13999
- Cerebras implementation: https://github.com/CerebrasResearch/reap
- MLX: https://github.com/ml-explore/mlx