Skip to content

HumyuShi/LoPT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

LoPT logo LoPT: Localized Post-Training for LLMs

LoPT is a lightweight training recipe for localized post-training of decoder-only language models.

English 中文

🗺️ Plan

  • ✅ 💻 Release training code
  • ✅ 📄 Release paper to arXiv

✨ Highlights

  • SFT with E2E and LoPT training.
  • GRPO with Hugging Face TRL on-policy rollouts.
  • Default two-block LoPT training with --num_blocks 2.
  • Four-block LoPT training with --num_blocks 4.
  • Checkpoints are saved as standard Hugging Face causal LM models.

🧠 Method

Standard E2E post-training backpropagates the task loss through all layers. LoPT splits the transformer into k contiguous blocks. The final block receives the task loss, while earlier blocks receive local feature reconstruction losses. Hidden states are detached at block boundaries, so task-loss gradients do not propagate through the entire model.

The auxiliary reconstruction heads are used only during training. After training, the saved model can be loaded with AutoModelForCausalLM as usual.

⚙️ Installation

git clone https://github.com/HumyuShi/LoPT.git
cd LoPT

conda create -n lopt python=3.11 -y
conda activate lopt

pip install -r requirements.txt

If your CUDA driver does not support the default PyTorch wheel resolved by pip, install a matching PyTorch wheel first, then run pip install -r requirements.txt. For example:

pip install --index-url https://download.pytorch.org/whl/cu118 torch==2.6.0+cu118
pip install -r requirements.txt

For multi-GPU GRPO, configure Accelerate once:

accelerate config

🤖 Supported Models

You can pass any local path or Hugging Face model id for Llama/Qwen/Mistral-like decoder-only models exposed as model.layers in AutoModelForCausalLM. The following aliases are built in:

Alias Hugging Face model
qwen3-4b Qwen/Qwen3-4B
qwen2.5-7b Qwen/Qwen2.5-7B-Instruct
qwen2.5-32b Qwen/Qwen2.5-32B-Instruct
llama3.1-8b meta-llama/Llama-3.1-8B-Instruct

📚 Training Datasets

Preset Use Official source
metamathqa SFT https://huggingface.co/datasets/meta-math/MetaMathQA
alpaca SFT https://huggingface.co/datasets/tatsu-lab/alpaca
tulu3 SFT https://huggingface.co/datasets/allenai/tulu-3-sft-mixture
gsm8k GRPO https://huggingface.co/datasets/openai/gsm8k
numinamath GRPO https://huggingface.co/datasets/AI-MO/NuminaMath-CoT

Check the original model and dataset licenses before training. Some sources, including Llama-family models and Alpaca, have access or usage restrictions.

The training scripts can load directly from Hugging Face. You can also download and optionally export local parquet files:

python scripts/download_datasets.py --dataset all
python scripts/download_datasets.py --dataset metamathqa --write_parquet --output_dir data

Local parquet or JSON files can be passed with --train_file.

📊 Benchmark Datasets

For benchmark reporting, we use the official dataset releases below. The table lists dataset sources only.

Benchmark Area Config / split Official source
MMLU multitask knowledge official subject splits https://huggingface.co/datasets/cais/mmlu
IFEval instruction following default release https://huggingface.co/datasets/google/IFEval
ARC-Challenge grade-school science reasoning ARC-Challenge https://huggingface.co/datasets/allenai/ai2_arc
GSM8K grade-school math reasoning main, test https://huggingface.co/datasets/openai/gsm8k
HellaSwag commonsense completion default release https://huggingface.co/datasets/allenai/hellaswag
TruthfulQA truthfulness official multiple-choice/generation configs https://huggingface.co/datasets/truthfulqa/truthful_qa
WinoGrande commonsense coreference official validation/test release https://huggingface.co/datasets/allenai/winogrande

🏋️ SFT Training

MetaMathQA, LoPT k=2:

torchrun --nproc_per_node 8 train_sft.py \
  --model_name_or_path Qwen/Qwen2.5-7B-Instruct \
  --dataset_preset metamathqa \
  --method ll \
  --num_blocks 2 \
  --max_seq_length 1024 \
  --per_device_train_batch_size 4 \
  --learning_rate 2e-5 \
  --lr_k1 2e-5 \
  --lambda_aux 10 \
  --gradient_checkpointing \
  --output_dir outputs/qwen25-7b-metamathqa-lopt-k2

MetaMathQA, LoPT k=4:

torchrun --nproc_per_node 8 train_sft.py \
  --model_name_or_path Qwen/Qwen2.5-7B-Instruct \
  --dataset_preset metamathqa \
  --method ll \
  --num_blocks 4 \
  --max_seq_length 1024 \
  --per_device_train_batch_size 4 \
  --learning_rate 2e-5 \
  --lr_k1 2e-5 \
  --lambda_aux 10 \
  --gradient_checkpointing \
  --output_dir outputs/qwen25-7b-metamathqa-lopt-k4

MetaMathQA, E2E baseline:

torchrun --nproc_per_node 8 train_sft.py \
  --model_name_or_path Qwen/Qwen2.5-7B-Instruct \
  --dataset_preset metamathqa \
  --method e2e \
  --max_seq_length 1024 \
  --per_device_train_batch_size 4 \
  --learning_rate 2e-5 \
  --gradient_checkpointing \
  --output_dir outputs/qwen25-7b-metamathqa-e2e

Alpaca SFT:

torchrun --nproc_per_node 8 train_sft.py \
  --model_name_or_path Qwen/Qwen2.5-7B-Instruct \
  --dataset_preset alpaca \
  --method ll \
  --num_blocks 2 \
  --max_seq_length 512 \
  --per_device_train_batch_size 4 \
  --learning_rate 2e-5 \
  --lambda_aux 10 \
  --gradient_checkpointing \
  --output_dir outputs/qwen25-7b-alpaca-lopt-k2

Tulu-3 SFT:

torchrun --nproc_per_node 8 train_sft.py \
  --model_name_or_path meta-llama/Llama-3.1-8B-Instruct \
  --dataset_preset tulu3 \
  --method ll \
  --num_blocks 2 \
  --max_seq_length 2048 \
  --per_device_train_batch_size 2 \
  --learning_rate 2e-5 \
  --lambda_aux 10 \
  --gradient_checkpointing \
  --output_dir outputs/llama31-8b-tulu3-lopt-k2

🎯 GRPO Training

GRPO uses TRL's GRPOTrainer, so completions are sampled from the current actor during training.

GSM8K, LoPT k=2:

accelerate launch --num_processes 8 train_grpo.py \
  --model_name_or_path Qwen/Qwen2.5-7B-Instruct \
  --task gsm8k \
  --method ll \
  --num_blocks 2 \
  --per_device_train_batch_size 1 \
  --gradient_accumulation_steps 8 \
  --num_generations 4 \
  --steps_per_generation 1 \
  --num_iterations 1 \
  --loss_type grpo \
  --learning_rate 1e-6 \
  --lr_k1 1e-6 \
  --lambda_aux 10 \
  --gradient_checkpointing \
  --output_dir outputs/qwen25-7b-gsm8k-grpo-lopt-k2

GSM8K, LoPT k=4:

accelerate launch --num_processes 8 train_grpo.py \
  --model_name_or_path Qwen/Qwen2.5-7B-Instruct \
  --task gsm8k \
  --method ll \
  --num_blocks 4 \
  --per_device_train_batch_size 1 \
  --gradient_accumulation_steps 8 \
  --num_generations 4 \
  --steps_per_generation 1 \
  --num_iterations 1 \
  --loss_type grpo \
  --learning_rate 1e-6 \
  --lr_k1 1e-6 \
  --lambda_aux 10 \
  --gradient_checkpointing \
  --output_dir outputs/qwen25-7b-gsm8k-grpo-lopt-k4

GSM8K, E2E baseline:

accelerate launch --num_processes 8 train_grpo.py \
  --model_name_or_path Qwen/Qwen2.5-7B-Instruct \
  --task gsm8k \
  --method e2e \
  --per_device_train_batch_size 1 \
  --gradient_accumulation_steps 8 \
  --num_generations 4 \
  --steps_per_generation 1 \
  --num_iterations 1 \
  --loss_type grpo \
  --learning_rate 1e-6 \
  --gradient_checkpointing \
  --output_dir outputs/qwen25-7b-gsm8k-grpo-e2e

NuminaMath GRPO:

accelerate launch --num_processes 8 train_grpo.py \
  --model_name_or_path Qwen/Qwen2.5-7B-Instruct \
  --task numinamath \
  --method ll \
  --num_blocks 2 \
  --per_device_train_batch_size 1 \
  --gradient_accumulation_steps 8 \
  --num_generations 4 \
  --max_completion_length 512 \
  --steps_per_generation 1 \
  --num_iterations 1 \
  --loss_type grpo \
  --learning_rate 1e-6 \
  --lambda_aux 10 \
  --gradient_checkpointing \
  --output_dir outputs/qwen25-7b-numinamath-grpo-lopt-k2

🧪 Smoke Tests

Use small samples and a small model before full training:

CUDA_VISIBLE_DEVICES=0 python train_sft.py \
  --model_name_or_path HuggingFaceTB/SmolLM2-135M-Instruct \
  --dataset_preset alpaca \
  --method ll \
  --num_blocks 2 \
  --max_seq_length 64 \
  --max_samples 16 \
  --max_steps 2 \
  --per_device_train_batch_size 1 \
  --output_dir outputs/smoke-sft-lopt

CUDA_VISIBLE_DEVICES=0 python train_grpo.py \
  --model_name_or_path HuggingFaceTB/SmolLM2-135M-Instruct \
  --task gsm8k \
  --method ll \
  --num_blocks 2 \
  --max_samples 8 \
  --max_steps 1 \
  --num_generations 2 \
  --gradient_accumulation_steps 2 \
  --max_prompt_length 128 \
  --max_completion_length 64 \
  --steps_per_generation 2 \
  --num_iterations 1 \
  --loss_type grpo \
  --per_device_train_batch_size 1 \
  --output_dir outputs/smoke-grpo-lopt

About

LoPT is a lightweight training recipe for localized post-training of decoder-only language models. It supports standard end-to-end post-training and localized learning for supervised fine-tuning (SFT) and on-policy GRPO.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors