LoPT is a lightweight training recipe for localized post-training of decoder-only language models.
- ✅ 💻 Release training code
- ✅ 📄 Release paper to arXiv
- SFT with E2E and LoPT training.
- GRPO with Hugging Face TRL on-policy rollouts.
- Default two-block LoPT training with
--num_blocks 2. - Four-block LoPT training with
--num_blocks 4. - Checkpoints are saved as standard Hugging Face causal LM models.
Standard E2E post-training backpropagates the task loss through all layers. LoPT
splits the transformer into k contiguous blocks. The final block receives the
task loss, while earlier blocks receive local feature reconstruction losses.
Hidden states are detached at block boundaries, so task-loss gradients do not
propagate through the entire model.
The auxiliary reconstruction heads are used only during training. After
training, the saved model can be loaded with AutoModelForCausalLM as usual.
git clone https://github.com/HumyuShi/LoPT.git
cd LoPT
conda create -n lopt python=3.11 -y
conda activate lopt
pip install -r requirements.txtIf your CUDA driver does not support the default PyTorch wheel resolved by pip,
install a matching PyTorch wheel first, then run pip install -r requirements.txt.
For example:
pip install --index-url https://download.pytorch.org/whl/cu118 torch==2.6.0+cu118
pip install -r requirements.txtFor multi-GPU GRPO, configure Accelerate once:
accelerate configYou can pass any local path or Hugging Face model id for Llama/Qwen/Mistral-like
decoder-only models exposed as model.layers in AutoModelForCausalLM. The
following aliases are built in:
| Alias | Hugging Face model |
|---|---|
qwen3-4b |
Qwen/Qwen3-4B |
qwen2.5-7b |
Qwen/Qwen2.5-7B-Instruct |
qwen2.5-32b |
Qwen/Qwen2.5-32B-Instruct |
llama3.1-8b |
meta-llama/Llama-3.1-8B-Instruct |
| Preset | Use | Official source |
|---|---|---|
metamathqa |
SFT | https://huggingface.co/datasets/meta-math/MetaMathQA |
alpaca |
SFT | https://huggingface.co/datasets/tatsu-lab/alpaca |
tulu3 |
SFT | https://huggingface.co/datasets/allenai/tulu-3-sft-mixture |
gsm8k |
GRPO | https://huggingface.co/datasets/openai/gsm8k |
numinamath |
GRPO | https://huggingface.co/datasets/AI-MO/NuminaMath-CoT |
Check the original model and dataset licenses before training. Some sources, including Llama-family models and Alpaca, have access or usage restrictions.
The training scripts can load directly from Hugging Face. You can also download and optionally export local parquet files:
python scripts/download_datasets.py --dataset all
python scripts/download_datasets.py --dataset metamathqa --write_parquet --output_dir dataLocal parquet or JSON files can be passed with --train_file.
For benchmark reporting, we use the official dataset releases below. The table lists dataset sources only.
| Benchmark | Area | Config / split | Official source |
|---|---|---|---|
| MMLU | multitask knowledge | official subject splits | https://huggingface.co/datasets/cais/mmlu |
| IFEval | instruction following | default release | https://huggingface.co/datasets/google/IFEval |
| ARC-Challenge | grade-school science reasoning | ARC-Challenge |
https://huggingface.co/datasets/allenai/ai2_arc |
| GSM8K | grade-school math reasoning | main, test |
https://huggingface.co/datasets/openai/gsm8k |
| HellaSwag | commonsense completion | default release | https://huggingface.co/datasets/allenai/hellaswag |
| TruthfulQA | truthfulness | official multiple-choice/generation configs | https://huggingface.co/datasets/truthfulqa/truthful_qa |
| WinoGrande | commonsense coreference | official validation/test release | https://huggingface.co/datasets/allenai/winogrande |
MetaMathQA, LoPT k=2:
torchrun --nproc_per_node 8 train_sft.py \
--model_name_or_path Qwen/Qwen2.5-7B-Instruct \
--dataset_preset metamathqa \
--method ll \
--num_blocks 2 \
--max_seq_length 1024 \
--per_device_train_batch_size 4 \
--learning_rate 2e-5 \
--lr_k1 2e-5 \
--lambda_aux 10 \
--gradient_checkpointing \
--output_dir outputs/qwen25-7b-metamathqa-lopt-k2MetaMathQA, LoPT k=4:
torchrun --nproc_per_node 8 train_sft.py \
--model_name_or_path Qwen/Qwen2.5-7B-Instruct \
--dataset_preset metamathqa \
--method ll \
--num_blocks 4 \
--max_seq_length 1024 \
--per_device_train_batch_size 4 \
--learning_rate 2e-5 \
--lr_k1 2e-5 \
--lambda_aux 10 \
--gradient_checkpointing \
--output_dir outputs/qwen25-7b-metamathqa-lopt-k4MetaMathQA, E2E baseline:
torchrun --nproc_per_node 8 train_sft.py \
--model_name_or_path Qwen/Qwen2.5-7B-Instruct \
--dataset_preset metamathqa \
--method e2e \
--max_seq_length 1024 \
--per_device_train_batch_size 4 \
--learning_rate 2e-5 \
--gradient_checkpointing \
--output_dir outputs/qwen25-7b-metamathqa-e2eAlpaca SFT:
torchrun --nproc_per_node 8 train_sft.py \
--model_name_or_path Qwen/Qwen2.5-7B-Instruct \
--dataset_preset alpaca \
--method ll \
--num_blocks 2 \
--max_seq_length 512 \
--per_device_train_batch_size 4 \
--learning_rate 2e-5 \
--lambda_aux 10 \
--gradient_checkpointing \
--output_dir outputs/qwen25-7b-alpaca-lopt-k2Tulu-3 SFT:
torchrun --nproc_per_node 8 train_sft.py \
--model_name_or_path meta-llama/Llama-3.1-8B-Instruct \
--dataset_preset tulu3 \
--method ll \
--num_blocks 2 \
--max_seq_length 2048 \
--per_device_train_batch_size 2 \
--learning_rate 2e-5 \
--lambda_aux 10 \
--gradient_checkpointing \
--output_dir outputs/llama31-8b-tulu3-lopt-k2GRPO uses TRL's GRPOTrainer, so completions are sampled from the current actor
during training.
GSM8K, LoPT k=2:
accelerate launch --num_processes 8 train_grpo.py \
--model_name_or_path Qwen/Qwen2.5-7B-Instruct \
--task gsm8k \
--method ll \
--num_blocks 2 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 8 \
--num_generations 4 \
--steps_per_generation 1 \
--num_iterations 1 \
--loss_type grpo \
--learning_rate 1e-6 \
--lr_k1 1e-6 \
--lambda_aux 10 \
--gradient_checkpointing \
--output_dir outputs/qwen25-7b-gsm8k-grpo-lopt-k2GSM8K, LoPT k=4:
accelerate launch --num_processes 8 train_grpo.py \
--model_name_or_path Qwen/Qwen2.5-7B-Instruct \
--task gsm8k \
--method ll \
--num_blocks 4 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 8 \
--num_generations 4 \
--steps_per_generation 1 \
--num_iterations 1 \
--loss_type grpo \
--learning_rate 1e-6 \
--lr_k1 1e-6 \
--lambda_aux 10 \
--gradient_checkpointing \
--output_dir outputs/qwen25-7b-gsm8k-grpo-lopt-k4GSM8K, E2E baseline:
accelerate launch --num_processes 8 train_grpo.py \
--model_name_or_path Qwen/Qwen2.5-7B-Instruct \
--task gsm8k \
--method e2e \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 8 \
--num_generations 4 \
--steps_per_generation 1 \
--num_iterations 1 \
--loss_type grpo \
--learning_rate 1e-6 \
--gradient_checkpointing \
--output_dir outputs/qwen25-7b-gsm8k-grpo-e2eNuminaMath GRPO:
accelerate launch --num_processes 8 train_grpo.py \
--model_name_or_path Qwen/Qwen2.5-7B-Instruct \
--task numinamath \
--method ll \
--num_blocks 2 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 8 \
--num_generations 4 \
--max_completion_length 512 \
--steps_per_generation 1 \
--num_iterations 1 \
--loss_type grpo \
--learning_rate 1e-6 \
--lambda_aux 10 \
--gradient_checkpointing \
--output_dir outputs/qwen25-7b-numinamath-grpo-lopt-k2Use small samples and a small model before full training:
CUDA_VISIBLE_DEVICES=0 python train_sft.py \
--model_name_or_path HuggingFaceTB/SmolLM2-135M-Instruct \
--dataset_preset alpaca \
--method ll \
--num_blocks 2 \
--max_seq_length 64 \
--max_samples 16 \
--max_steps 2 \
--per_device_train_batch_size 1 \
--output_dir outputs/smoke-sft-lopt
CUDA_VISIBLE_DEVICES=0 python train_grpo.py \
--model_name_or_path HuggingFaceTB/SmolLM2-135M-Instruct \
--task gsm8k \
--method ll \
--num_blocks 2 \
--max_samples 8 \
--max_steps 1 \
--num_generations 2 \
--gradient_accumulation_steps 2 \
--max_prompt_length 128 \
--max_completion_length 64 \
--steps_per_generation 2 \
--num_iterations 1 \
--loss_type grpo \
--per_device_train_batch_size 1 \
--output_dir outputs/smoke-grpo-lopt