Convert LLMs (7B–70B) into efficient Small Language Models (1B–3B) using activation-based structured pruning and knowledge distillation — achieving ~95% of teacher accuracy at 40–80% fewer parameters.
Implements the NVIDIA Minitron pruning strategy combined with response- and feature-based knowledge distillation, with plug-and-play support for DeepSeek, Kimi, LLaMA, Mistral, Qwen, Gemma, and Phi model families.
- What's New
- Overview
- Key Features
- Supported Models
- Installation
- Quick Start
- Detailed Usage
- Algorithm Details
- Benchmark Results
- Hardware Requirements
- Project Structure
- Contributing
- Citation
- License
- [2025-03] Initial release — full prune + distill + eval pipeline
- [2025-03] Shell presets added for 10 model/target combinations (
run.sh) - [2025-03] GQA (Grouped-Query Attention) support for LLaMA-3 and Mistral models
- [2025-03] Multi-GPU support via
accelerate+deepspeedauto-detection
┌──────────────────────────────────────────────────────────────────┐
│ prune_distill Pipeline │
│ │
│ Teacher LLM (7B-70B) │
│ │ │
│ ▼ │
│ ┌─────────────────────────────┐ │
│ │ Step 1: Structured Pruning │ (Minitron-style) │
│ │ • Width: remove heads/MLP │ │
│ │ • Depth: remove layers │ │
│ │ • Importance: activations │ │
│ └──────────────┬──────────────┘ │
│ │ Pruned Model (1B-3B) │
│ ▼ │
│ ┌──────────────────────────────┐ │
│ │ Step 2: Knowledge Distill. │ │
│ │ • Response KL divergence │ │
│ │ • Feature MSE matching │ │
│ │ • L = αCE + (1-α)T²KL + βF │ │
│ └──────────────┬───────────────┘ │
│ │ Distilled Model │
│ ▼ │
│ ┌──────────────────────────────┐ │
│ │ Step 3: Evaluation │ │
│ │ HellaSwag / MMLU / ARC / │ │
│ │ WinoGrande / TruthfulQA / │ │
│ │ GSM8K / HumanEval │ │
│ └──────────────────────────────┘ │
│ │
│ Student SLM (1B-3B) — ~95% accuracy at 40-80% fewer params │
└──────────────────────────────────────────────────────────────────┘
- Minitron-style importance scoring — activation-magnitude scoring for attention heads, MLP neurons, and transformer layers; globally ranked, no per-layer heuristics
- Three pruning strategies —
width(heads + neurons),depth(layers), orcombined— selectable at runtime - GQA-aware pruning — correctly handles Grouped-Query Attention in LLaMA-3, Mistral, Qwen, and Phi
- Three-term distillation loss — α·L_CE + (1−α)·T²·L_KL + β·L_feature with learned hidden-state projectors
- Calibration datasets — WikiText-2, C4, PTB, or OpenWebText for importance scoring
- Training datasets — C4, SlimPajama, OpenWebText, Alpaca, Dolly for distillation
- Seven benchmark tasks — HellaSwag, MMLU, ARC-Easy/Challenge, WinoGrande, TruthfulQA, GSM8K, HumanEval via
lm-eval - Multi-GPU / ZeRO — auto-detects
accelerate+ DeepSpeed ZeRO-2/3 configs - Memory efficiency —
bfloat16, gradient checkpointing, and 8-bit quantization (bitsandbytes) - WandB + TensorBoard — integrated training metrics and eval logging
- Modular design — each script is independently importable and CLI-runnable
Any HuggingFace decoder-only model with LLaMA-style architecture is supported out of the box:
| Model Family | HuggingFace IDs | Architecture |
|---|---|---|
| DeepSeek | deepseek-ai/DeepSeek-R1-Distill-Llama-8B · deepseek-ai/deepseek-coder-7b-base |
LLaMA-3 |
| Kimi (Moonshot) | moonshotai/Moonlight-16B-A3B-Instruct · moonshotai/Kimi-VL-A3B-Instruct |
MoE/LLaMA |
| LLaMA | meta-llama/Llama-3.1-8B · meta-llama/Llama-3.2-3B · meta-llama/Llama-3.1-70B |
LLaMA-3 |
| Mistral | mistralai/Mistral-7B-v0.3 · mistralai/Mixtral-8x7B-v0.1 |
Mistral |
| Qwen | Qwen/Qwen2.5-7B · Qwen/Qwen2.5-14B · Qwen/Qwen2.5-72B |
Qwen2 |
| Gemma | google/gemma-2-9b · google/gemma-7b |
Gemma |
| Phi | microsoft/phi-4 · microsoft/Phi-3-medium-4k-instruct |
Phi |
Requirements: Python 3.10+, CUDA 11.8+ (for GPU training)
# Clone the repository
git clone https://github.com/TiniLLM/prune_distill.git
cd prune_distill
# Install all dependencies
pip install -r requirements.txt
# (Optional) Flash Attention 2 for faster training — requires CUDA 11.6+
pip install flash-attn --no-build-isolationNote: For multi-GPU training, also run:
accelerate config # configure DeepSpeed / FSDP
python pipeline.py \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--output-dir ./output/deepseek-3b \
--target-size 3B \
--prune-strategy combined \
--distill-dataset c4 \
--distill-epochs 3 \
--eval-tasks hellaswag,mmlu,arc_easy,arc_challenge,winogrande# DeepSeek → 3B
bash run.sh deepseek-3b
# LLaMA 3.1 8B → 1B
bash run.sh llama-1b
# Kimi Moonlight → 3B
bash run.sh kimi-3b
# Evaluate an existing model
bash run.sh eval-only --model ./my_model
# See all presets
bash run.sh --helpRuns calibration data through the teacher, scores all attention heads, MLP neurons, and transformer layers by activation importance, then removes the lowest-scoring components to hit the target parameter count.
python prune.py \
--model meta-llama/Llama-3.1-8B \
--output ./pruned_llama \
--target-params 3B \
--prune-strategy combined \
--calibration-dataset wikitext \
--calibration-samples 512 \
--dtype bfloat16Key arguments:
| Argument | Default | Description |
|---|---|---|
--model |
(required) | HuggingFace model ID or local path |
--output |
./pruned_model |
Directory to save the pruned model |
--target-params |
3B |
Target size: 1B, 1.5B, 3B, or exact count |
--prune-strategy |
combined |
width / depth / combined |
--calibration-dataset |
wikitext |
wikitext / c4 / ptb / openwebtext |
--calibration-samples |
512 |
Number of calibration sequences |
--head-prune-ratio |
auto | Override attention head pruning ratio |
--layer-prune-ratio |
auto | Override transformer layer pruning ratio |
--dtype |
bfloat16 |
float32 / float16 / bfloat16 |
--token |
None |
HuggingFace token for gated models |
Trains the pruned student to mimic the full teacher using a combined loss of cross-entropy, KL divergence on logits (with temperature), and MSE on intermediate hidden states.
python distill.py \
--teacher meta-llama/Llama-3.1-8B \
--student ./pruned_llama \
--output ./distilled_llama \
--dataset c4 \
--epochs 3 \
--temperature 2.0 \
--alpha 0.7 \
--beta 0.01 \
--batch-size 4 \
--gradient-accumulation 8Key arguments:
| Argument | Default | Description |
|---|---|---|
--teacher |
(required) | Teacher model ID or local path |
--student |
(required) | Student (pruned) model ID or local path |
--output |
./distilled_model |
Output directory |
--dataset |
c4 |
c4 / wikitext / openwebtext / alpaca / dolly / slim_pajama |
--epochs |
3 |
Training epochs |
--batch-size |
2 |
Per-device batch size |
--gradient-accumulation |
8 |
Gradient accumulation steps |
--temperature |
2.0 |
Distillation temperature T |
--alpha |
0.7 |
Weight for L_CE (1−alpha weights L_KL) |
--beta |
0.01 |
Weight for feature matching loss L_feature |
--learning-rate |
1e-4 |
AdamW learning rate |
--warmup-ratio |
0.05 |
LR warmup fraction |
--gradient-checkpointing |
False |
Enable for lower VRAM usage |
--wandb-project |
None |
WandB project name |
Runs the student (and optionally the teacher as baseline) through the lm-evaluation-harness suite and outputs a comparison table.
# Evaluate a single model
python evaluate.py \
--model ./distilled_llama \
--tasks hellaswag,mmlu,arc_easy,arc_challenge,winogrande,truthfulqa,gsm8k \
--output ./eval_results.json
# Compare student vs teacher
python evaluate.py \
--model ./distilled_llama \
--baseline meta-llama/Llama-3.1-8B \
--tasks all \
--output ./comparison.jsonSupported benchmark tasks:
| Task | Metric | Description |
|---|---|---|
hellaswag |
Accuracy (norm.) | Commonsense sentence completion |
mmlu |
Accuracy | 57-subject massive multitask understanding |
arc_easy |
Accuracy (norm.) | AI2 easy reasoning |
arc_challenge |
Accuracy (norm.) | AI2 challenge reasoning |
winogrande |
Accuracy | Commonsense NLI |
truthfulqa |
MC accuracy | Factual truthfulness |
gsm8k |
Exact match | Grade-school math word problems |
humaneval |
pass@1 | Python code generation |
Orchestrates all three stages in sequence. Checkpoints are saved after each stage so runs can be resumed.
# Full pipeline: Kimi Moonlight 16B → 1B
python pipeline.py \
--model moonshotai/Moonlight-16B-A3B-Instruct \
--output-dir ./output/kimi-1b \
--target-size 1B \
--prune-strategy combined \
--distill-dataset c4 \
--distill-epochs 3
# Skip pruning — distill from an existing pruned model
python pipeline.py \
--model meta-llama/Llama-3.1-8B \
--student-path ./my_pruned_model \
--output-dir ./output/llama-3b \
--skip-prune
# Evaluate only (no pruning, no distillation)
python pipeline.py \
--model ./my_distilled_model \
--baseline meta-llama/Llama-3.1-8B \
--skip-prune \
--skip-distill
# Multi-GPU with accelerate
accelerate launch pipeline.py \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--output-dir ./output/deepseek-3b \
--target-size 3BConvenience wrapper with named presets for the most common workflows:
bash run.sh <preset> [extra args]
# Available presets:
# deepseek-3b deepseek-1b
# kimi-3b kimi-1b
# llama-3b llama-1b
# mistral-3b qwen-3b
# phi-1b eval-only prune-onlyFor each calibration sample, a forward pass is run with hooks that record intermediate activations. Components are scored globally:
| Component | Importance Score |
|---|---|
| Attention head i | mean( ‖head_i(x)‖₂ ) over all tokens and samples |
| FFN neuron j | `mean( |
| Layer l | mean( ‖layer_l(x)‖₂ ) over all tokens and samples |
Global ranking across all components of each type → prune lowest-scored to hit target parameter count.
L_total = α · L_CE + (1−α) · T² · L_KL + β · L_feature
where:
L_CE = cross-entropy against ground-truth labels
L_KL = KL divergence of student vs teacher logits (temperature-scaled)
L_feature = MSE between student and teacher hidden states (via learned projectors)
T = distillation temperature (default 2.0)
α = response/label balance weight (default 0.7)
β = feature matching weight (default 0.01)
Reference results replicating NVIDIA Minitron 3B from a LLaMA-3.1-8B teacher:
| Benchmark | Teacher 8B | Student 3B (pruned + KD) | Ratio |
|---|---|---|---|
| HellaSwag | 82.1 | 78.4 | 95.5% |
| MMLU | 66.6 | 63.7 | 95.6% |
| ARC-Challenge | 55.3 | 52.1 | 94.2% |
| ARC-Easy | 80.1 | 76.8 | 95.9% |
| WinoGrande | 74.7 | 72.3 | 96.8% |
| TruthfulQA | 44.2 | 43.1 | 97.5% |
| Average | 67.2 | 64.4 | 95.8% |
Results vary by calibration data, distillation dataset, and number of training steps. See
output/*/eval_results.jsonfor your run.
| Operation | Model Size | Min VRAM | Recommended |
|---|---|---|---|
| Pruning (calibration) | 7–8B teacher | 16 GB | 24 GB |
| Pruning (calibration) | 14–16B teacher | 28 GB | 40 GB |
| Distillation (training) | 3B student + 8B teacher | 40 GB | 80 GB (2x A100) |
| Distillation (training) | 1B student + 8B teacher | 24 GB | 40 GB |
| Inference (student only) | 1B | 4 GB | 8 GB |
| Inference (student only) | 3B | 8 GB | 16 GB |
Memory reduction tips:
# Reduce VRAM with bfloat16 + gradient checkpointing
python distill.py --dtype bfloat16 --gradient-checkpointing
# Load teacher in 8-bit (bitsandbytes)
python distill.py --teacher-load-8bit
# Multi-GPU (auto-detected via accelerate)
accelerate launch distill.py ...prune_distill/
├── README.md # This file
├── requirements.txt # pip dependencies
├── run.sh # Shell presets for common model/target combos
├── pipeline.py # End-to-end orchestrator (prune → distill → eval)
│
└── prune_distill/ # Python package
├── __init__.py # Public API exports
├── prune.py # Minitron-style structured pruning
├── distill.py # Knowledge distillation trainer
├── evaluate.py # lm-eval harness integration
└── utils.py # Shared model loading, dataset, param-count helpers
The prune_distill package can be imported directly in Python:
from prune_distill import prune_model, distill, evaluate_model
# 1. Prune
pruned_path = prune_model(
model="meta-llama/Llama-3.1-8B",
output="./pruned",
target_params="3B",
prune_strategy="combined",
)
# 2. Distill
distilled_path = distill(
teacher="meta-llama/Llama-3.1-8B",
student=pruned_path,
output="./distilled",
dataset="c4",
epochs=3,
)
# 3. Evaluate
results = evaluate_model(
model=distilled_path,
tasks=["hellaswag", "mmlu", "arc_easy"],
)
print(results)Contributions are welcome. To get started:
- Fork the repository and create a feature branch:
git checkout -b feature/my-feature
- Make your changes and ensure existing functionality is not broken.
- Open a pull request with a clear description of the change and motivation.
Areas open for contribution:
- Support for encoder-decoder architectures (T5, BART)
- Additional pruning methods (SparseGPT, Wanda)
- ONNX / TensorRT export for pruned models
- Automated hyperparameter search for α, β, T
- Additional evaluation tasks (MT-Bench, AlpacaEval)
If you use this codebase in your research, please cite the foundational work:
@article{muralidharan2024compact,
title = {Compact Language Models via Pruning and Knowledge Distillation},
author = {Muralidharan, Saurav and Sreenivas, Sharath Turuvekere and
Joshi, Raviraj and Chochowski, Marcin and Patwary, Mostofa and
Shoeybi, Mohammad and Catanzaro, Bryan and Kautz, Jan and Molchanov, Pavlo},
journal = {arXiv preprint arXiv:2407.14679},
year = {2024}
}
@article{llmpruner2023,
title = {LLM-Pruner: On the Structural Pruning of Large Language Models},
author = {Ma, Xinyin and Fang, Gongfan and Wang, Xinchao},
journal = {arXiv preprint arXiv:2305.11627},
year = {2023}
}
@article{minillm2023,
title = {MiniLLM: Knowledge Distillation of Large Language Models},
author = {Gu, Yuxian and Dong, Li and Wei, Furu and Huang, Minlie},
journal = {arXiv preprint arXiv:2306.08543},
year = {2023}
}This project is licensed under the Apache License 2.0.
The implementations are inspired by and cite:
- NVIDIA Minitron (NVIDIA Research)
- LLM-Pruner (Ma et al., 2023)
- MiniLLM (Microsoft Research)
- lm-evaluation-harness (EleutherAI)
Built by TiniLLM