Skip to content

TiniLLM/prune_distill

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

prune_distill — LLM → SLM via Structured Pruning + Knowledge Distillation

License: Apache 2.0 Python 3.10+ PyTorch 2.2+ HuggingFace arXiv: Minitron

Convert LLMs (7B–70B) into efficient Small Language Models (1B–3B) using activation-based structured pruning and knowledge distillation — achieving ~95% of teacher accuracy at 40–80% fewer parameters.

Implements the NVIDIA Minitron pruning strategy combined with response- and feature-based knowledge distillation, with plug-and-play support for DeepSeek, Kimi, LLaMA, Mistral, Qwen, Gemma, and Phi model families.


Table of Contents


What's New

  • [2025-03] Initial release — full prune + distill + eval pipeline
  • [2025-03] Shell presets added for 10 model/target combinations (run.sh)
  • [2025-03] GQA (Grouped-Query Attention) support for LLaMA-3 and Mistral models
  • [2025-03] Multi-GPU support via accelerate + deepspeed auto-detection

Overview

┌──────────────────────────────────────────────────────────────────┐
│                    prune_distill Pipeline                        │
│                                                                  │
│   Teacher LLM (7B-70B)                                           │
│        │                                                         │
│        ▼                                                         │
│   ┌─────────────────────────────┐                                │
│   │  Step 1: Structured Pruning │  (Minitron-style)              │
│   │  • Width: remove heads/MLP  │                                │
│   │  • Depth: remove layers     │                                │
│   │  • Importance: activations  │                                │
│   └──────────────┬──────────────┘                                │
│                  │ Pruned Model (1B-3B)                          │
│                  ▼                                               │
│   ┌──────────────────────────────┐                               │
│   │  Step 2: Knowledge Distill.  │                               │
│   │  • Response KL divergence    │                               │
│   │  • Feature MSE matching      │                               │
│   │  • L = αCE + (1-α)T²KL + βF │                               │
│   └──────────────┬───────────────┘                               │
│                  │ Distilled Model                               │
│                  ▼                                               │
│   ┌──────────────────────────────┐                               │
│   │  Step 3: Evaluation          │                               │
│   │  HellaSwag / MMLU / ARC /    │                               │
│   │  WinoGrande / TruthfulQA /   │                               │
│   │  GSM8K / HumanEval           │                               │
│   └──────────────────────────────┘                               │
│                                                                  │
│   Student SLM (1B-3B) — ~95% accuracy at 40-80% fewer params    │
└──────────────────────────────────────────────────────────────────┘

Key Features

  • Minitron-style importance scoring — activation-magnitude scoring for attention heads, MLP neurons, and transformer layers; globally ranked, no per-layer heuristics
  • Three pruning strategieswidth (heads + neurons), depth (layers), or combined — selectable at runtime
  • GQA-aware pruning — correctly handles Grouped-Query Attention in LLaMA-3, Mistral, Qwen, and Phi
  • Three-term distillation loss — α·L_CE + (1−α)·T²·L_KL + β·L_feature with learned hidden-state projectors
  • Calibration datasets — WikiText-2, C4, PTB, or OpenWebText for importance scoring
  • Training datasets — C4, SlimPajama, OpenWebText, Alpaca, Dolly for distillation
  • Seven benchmark tasks — HellaSwag, MMLU, ARC-Easy/Challenge, WinoGrande, TruthfulQA, GSM8K, HumanEval via lm-eval
  • Multi-GPU / ZeRO — auto-detects accelerate + DeepSpeed ZeRO-2/3 configs
  • Memory efficiencybfloat16, gradient checkpointing, and 8-bit quantization (bitsandbytes)
  • WandB + TensorBoard — integrated training metrics and eval logging
  • Modular design — each script is independently importable and CLI-runnable

Supported Models

Any HuggingFace decoder-only model with LLaMA-style architecture is supported out of the box:

Model Family HuggingFace IDs Architecture
DeepSeek deepseek-ai/DeepSeek-R1-Distill-Llama-8B · deepseek-ai/deepseek-coder-7b-base LLaMA-3
Kimi (Moonshot) moonshotai/Moonlight-16B-A3B-Instruct · moonshotai/Kimi-VL-A3B-Instruct MoE/LLaMA
LLaMA meta-llama/Llama-3.1-8B · meta-llama/Llama-3.2-3B · meta-llama/Llama-3.1-70B LLaMA-3
Mistral mistralai/Mistral-7B-v0.3 · mistralai/Mixtral-8x7B-v0.1 Mistral
Qwen Qwen/Qwen2.5-7B · Qwen/Qwen2.5-14B · Qwen/Qwen2.5-72B Qwen2
Gemma google/gemma-2-9b · google/gemma-7b Gemma
Phi microsoft/phi-4 · microsoft/Phi-3-medium-4k-instruct Phi

Installation

Requirements: Python 3.10+, CUDA 11.8+ (for GPU training)

# Clone the repository
git clone https://github.com/TiniLLM/prune_distill.git
cd prune_distill

# Install all dependencies
pip install -r requirements.txt

# (Optional) Flash Attention 2 for faster training — requires CUDA 11.6+
pip install flash-attn --no-build-isolation

Note: For multi-GPU training, also run:

accelerate config    # configure DeepSpeed / FSDP

Quick Start

Full Pipeline: DeepSeek 8B → 3B Student

python pipeline.py \
  --model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
  --output-dir ./output/deepseek-3b \
  --target-size 3B \
  --prune-strategy combined \
  --distill-dataset c4 \
  --distill-epochs 3 \
  --eval-tasks hellaswag,mmlu,arc_easy,arc_challenge,winogrande

Shell Presets — One Command

# DeepSeek → 3B
bash run.sh deepseek-3b

# LLaMA 3.1 8B → 1B
bash run.sh llama-1b

# Kimi Moonlight → 3B
bash run.sh kimi-3b

# Evaluate an existing model
bash run.sh eval-only --model ./my_model

# See all presets
bash run.sh --help

Detailed Usage

1. Structured Pruning (prune.py)

Runs calibration data through the teacher, scores all attention heads, MLP neurons, and transformer layers by activation importance, then removes the lowest-scoring components to hit the target parameter count.

python prune.py \
  --model meta-llama/Llama-3.1-8B \
  --output ./pruned_llama \
  --target-params 3B \
  --prune-strategy combined \
  --calibration-dataset wikitext \
  --calibration-samples 512 \
  --dtype bfloat16

Key arguments:

Argument Default Description
--model (required) HuggingFace model ID or local path
--output ./pruned_model Directory to save the pruned model
--target-params 3B Target size: 1B, 1.5B, 3B, or exact count
--prune-strategy combined width / depth / combined
--calibration-dataset wikitext wikitext / c4 / ptb / openwebtext
--calibration-samples 512 Number of calibration sequences
--head-prune-ratio auto Override attention head pruning ratio
--layer-prune-ratio auto Override transformer layer pruning ratio
--dtype bfloat16 float32 / float16 / bfloat16
--token None HuggingFace token for gated models

2. Knowledge Distillation (distill.py)

Trains the pruned student to mimic the full teacher using a combined loss of cross-entropy, KL divergence on logits (with temperature), and MSE on intermediate hidden states.

python distill.py \
  --teacher meta-llama/Llama-3.1-8B \
  --student  ./pruned_llama \
  --output   ./distilled_llama \
  --dataset  c4 \
  --epochs   3 \
  --temperature 2.0 \
  --alpha 0.7 \
  --beta 0.01 \
  --batch-size 4 \
  --gradient-accumulation 8

Key arguments:

Argument Default Description
--teacher (required) Teacher model ID or local path
--student (required) Student (pruned) model ID or local path
--output ./distilled_model Output directory
--dataset c4 c4 / wikitext / openwebtext / alpaca / dolly / slim_pajama
--epochs 3 Training epochs
--batch-size 2 Per-device batch size
--gradient-accumulation 8 Gradient accumulation steps
--temperature 2.0 Distillation temperature T
--alpha 0.7 Weight for L_CE (1−alpha weights L_KL)
--beta 0.01 Weight for feature matching loss L_feature
--learning-rate 1e-4 AdamW learning rate
--warmup-ratio 0.05 LR warmup fraction
--gradient-checkpointing False Enable for lower VRAM usage
--wandb-project None WandB project name

3. Evaluation (evaluate.py)

Runs the student (and optionally the teacher as baseline) through the lm-evaluation-harness suite and outputs a comparison table.

# Evaluate a single model
python evaluate.py \
  --model ./distilled_llama \
  --tasks hellaswag,mmlu,arc_easy,arc_challenge,winogrande,truthfulqa,gsm8k \
  --output ./eval_results.json

# Compare student vs teacher
python evaluate.py \
  --model ./distilled_llama \
  --baseline meta-llama/Llama-3.1-8B \
  --tasks all \
  --output ./comparison.json

Supported benchmark tasks:

Task Metric Description
hellaswag Accuracy (norm.) Commonsense sentence completion
mmlu Accuracy 57-subject massive multitask understanding
arc_easy Accuracy (norm.) AI2 easy reasoning
arc_challenge Accuracy (norm.) AI2 challenge reasoning
winogrande Accuracy Commonsense NLI
truthfulqa MC accuracy Factual truthfulness
gsm8k Exact match Grade-school math word problems
humaneval pass@1 Python code generation

4. End-to-End Pipeline (pipeline.py)

Orchestrates all three stages in sequence. Checkpoints are saved after each stage so runs can be resumed.

# Full pipeline: Kimi Moonlight 16B → 1B
python pipeline.py \
  --model moonshotai/Moonlight-16B-A3B-Instruct \
  --output-dir ./output/kimi-1b \
  --target-size 1B \
  --prune-strategy combined \
  --distill-dataset c4 \
  --distill-epochs 3

# Skip pruning — distill from an existing pruned model
python pipeline.py \
  --model meta-llama/Llama-3.1-8B \
  --student-path ./my_pruned_model \
  --output-dir   ./output/llama-3b \
  --skip-prune

# Evaluate only (no pruning, no distillation)
python pipeline.py \
  --model    ./my_distilled_model \
  --baseline meta-llama/Llama-3.1-8B \
  --skip-prune \
  --skip-distill

# Multi-GPU with accelerate
accelerate launch pipeline.py \
  --model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
  --output-dir ./output/deepseek-3b \
  --target-size 3B

5. Shell Presets (run.sh)

Convenience wrapper with named presets for the most common workflows:

bash run.sh <preset> [extra args]

# Available presets:
#   deepseek-3b   deepseek-1b
#   kimi-3b       kimi-1b
#   llama-3b      llama-1b
#   mistral-3b    qwen-3b
#   phi-1b        eval-only    prune-only

Algorithm Details

Importance Scoring (Minitron-style)

For each calibration sample, a forward pass is run with hooks that record intermediate activations. Components are scored globally:

Component Importance Score
Attention head i mean( ‖head_i(x)‖₂ ) over all tokens and samples
FFN neuron j `mean(
Layer l mean( ‖layer_l(x)‖₂ ) over all tokens and samples

Global ranking across all components of each type → prune lowest-scored to hit target parameter count.

Distillation Loss

L_total = α · L_CE  +  (1−α) · T² · L_KL  +  β · L_feature

where:
  L_CE      = cross-entropy against ground-truth labels
  L_KL      = KL divergence of student vs teacher logits (temperature-scaled)
  L_feature = MSE between student and teacher hidden states (via learned projectors)
  T         = distillation temperature (default 2.0)
  α         = response/label balance weight (default 0.7)
  β         = feature matching weight (default 0.01)

Benchmark Results

Reference results replicating NVIDIA Minitron 3B from a LLaMA-3.1-8B teacher:

Benchmark Teacher 8B Student 3B (pruned + KD) Ratio
HellaSwag 82.1 78.4 95.5%
MMLU 66.6 63.7 95.6%
ARC-Challenge 55.3 52.1 94.2%
ARC-Easy 80.1 76.8 95.9%
WinoGrande 74.7 72.3 96.8%
TruthfulQA 44.2 43.1 97.5%
Average 67.2 64.4 95.8%

Results vary by calibration data, distillation dataset, and number of training steps. See output/*/eval_results.json for your run.


Hardware Requirements

Operation Model Size Min VRAM Recommended
Pruning (calibration) 7–8B teacher 16 GB 24 GB
Pruning (calibration) 14–16B teacher 28 GB 40 GB
Distillation (training) 3B student + 8B teacher 40 GB 80 GB (2x A100)
Distillation (training) 1B student + 8B teacher 24 GB 40 GB
Inference (student only) 1B 4 GB 8 GB
Inference (student only) 3B 8 GB 16 GB

Memory reduction tips:

# Reduce VRAM with bfloat16 + gradient checkpointing
python distill.py --dtype bfloat16 --gradient-checkpointing

# Load teacher in 8-bit (bitsandbytes)
python distill.py --teacher-load-8bit

# Multi-GPU (auto-detected via accelerate)
accelerate launch distill.py ...

Project Structure

prune_distill/
├── README.md               # This file
├── requirements.txt        # pip dependencies
├── run.sh                  # Shell presets for common model/target combos
├── pipeline.py             # End-to-end orchestrator (prune → distill → eval)
│
└── prune_distill/          # Python package
    ├── __init__.py         # Public API exports
    ├── prune.py            # Minitron-style structured pruning
    ├── distill.py          # Knowledge distillation trainer
    ├── evaluate.py         # lm-eval harness integration
    └── utils.py            # Shared model loading, dataset, param-count helpers

Module API

The prune_distill package can be imported directly in Python:

from prune_distill import prune_model, distill, evaluate_model

# 1. Prune
pruned_path = prune_model(
    model="meta-llama/Llama-3.1-8B",
    output="./pruned",
    target_params="3B",
    prune_strategy="combined",
)

# 2. Distill
distilled_path = distill(
    teacher="meta-llama/Llama-3.1-8B",
    student=pruned_path,
    output="./distilled",
    dataset="c4",
    epochs=3,
)

# 3. Evaluate
results = evaluate_model(
    model=distilled_path,
    tasks=["hellaswag", "mmlu", "arc_easy"],
)
print(results)

Contributing

Contributions are welcome. To get started:

  1. Fork the repository and create a feature branch:
    git checkout -b feature/my-feature
  2. Make your changes and ensure existing functionality is not broken.
  3. Open a pull request with a clear description of the change and motivation.

Areas open for contribution:

  • Support for encoder-decoder architectures (T5, BART)
  • Additional pruning methods (SparseGPT, Wanda)
  • ONNX / TensorRT export for pruned models
  • Automated hyperparameter search for α, β, T
  • Additional evaluation tasks (MT-Bench, AlpacaEval)

Citation

If you use this codebase in your research, please cite the foundational work:

@article{muralidharan2024compact,
  title   = {Compact Language Models via Pruning and Knowledge Distillation},
  author  = {Muralidharan, Saurav and Sreenivas, Sharath Turuvekere and
             Joshi, Raviraj and Chochowski, Marcin and Patwary, Mostofa and
             Shoeybi, Mohammad and Catanzaro, Bryan and Kautz, Jan and Molchanov, Pavlo},
  journal = {arXiv preprint arXiv:2407.14679},
  year    = {2024}
}

@article{llmpruner2023,
  title   = {LLM-Pruner: On the Structural Pruning of Large Language Models},
  author  = {Ma, Xinyin and Fang, Gongfan and Wang, Xinchao},
  journal = {arXiv preprint arXiv:2305.11627},
  year    = {2023}
}

@article{minillm2023,
  title   = {MiniLLM: Knowledge Distillation of Large Language Models},
  author  = {Gu, Yuxian and Dong, Li and Wei, Furu and Huang, Minlie},
  journal = {arXiv preprint arXiv:2306.08543},
  year    = {2023}
}

License

This project is licensed under the Apache License 2.0.

The implementations are inspired by and cite:


Built by TiniLLM

About

Pruning & Knowledge distillation methods for converting Large Language Models into Small Language Models

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors