Skip to content

DLYuanGod/MegaTrain

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

17 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MegaTrain

Full Precision Training of 100B+ Parameter LLMs on a Single GPU

Paper GitHub Stars License Python 3.9+ PyTorch 2.0+

A RAM-centric architecture that stores parameters in host memory and treats GPUs as transient compute engines, enabling full-precision training of 100B+ models on a single GPU.

Quick Start | Supported Models | Data Preparation | Performance | Citation


Features

  • Single GPU, Massive Models -- Train 120B+ models on one GPU by leveraging CPU RAM for parameter storage
  • Universal Model Support -- Any HuggingFace decoder-only model works out of the box via AutoModelForCausalLM
  • Hybrid Architecture -- Automatic handling of mixed attention (linear + full) and MoE layers
  • LlamaFactory-style Data -- Flexible dataset_info.json registry with alpaca/sharegpt format support
  • 1.84x Faster -- Outperforms DeepSpeed ZeRO-3 on 14B models through pipelined double-buffered execution
  • YAML Configuration -- Easy model/dataset/hyperparameter setup with 25+ ready-made configs

Quick Start

# Install
git clone https://github.com/DLYuanGod/MegaTrain.git
cd MegaTrain
pip install -e .

# Train with built-in demo data
python examples/train.py --config examples/configs/llama3_8b.yaml

# Train any supported model
python examples/train.py --config examples/configs/qwen3_5_27b.yaml

Supported Models

Model Family Model Sizes Architecture
Qwen2/Qwen2.5 0.5B/1.5B/3B/7B/14B/32B/72B Dense
Qwen3 0.6B/1.7B/4B/8B/14B/32B Dense
Qwen3.5 0.8B/2B/4B/9B/27B Hybrid (linear+full attn)
Qwen3.5 MoE 35B-A3B/122B-A10B/397B-A17B Hybrid + MoE
Qwen3-Next 80B-A3B Hybrid + MoE
Llama 2 7B/13B/70B Dense
Llama 3/3.1/3.2/3.3 1B/3B/8B/70B Dense
Llama 4 Scout-17B-16E/Maverick MoE
Mistral 7B Dense
Mixtral 8x7B/8x22B MoE
DeepSeek (LLM/Code/R1) 7B/16B/67B Dense
Phi-3/Phi-4 3.8B/14B Dense
Gemma 2/3 2B/7B/9B/27B Dense
GLM-4/GLM-4.5 9B/32B Dense
InternLM 2/2.5 7B/20B Dense
Yi 1.5 6B/9B/34B Dense
Baichuan 2 7B/13B Dense
GPT-OSS 20B/120B Dense
Any HF decoder-only model Any size Auto-detected

MegaTrain uses HuggingFace's AutoModelForCausalLM with automatic model structure discovery. Any decoder-only transformer model is supported without code changes.

Data Preparation

MegaTrain supports a LlamaFactory-compatible data system with flexible format support.

Option 1: Dataset Registry (Recommended)

Register datasets in data/dataset_info.json and reference by name:

dataset:
  name: "alpaca_en_demo"    # name from dataset_info.json
  dataset_dir: "data"
  max_seq_len: 1024

Supports alpaca format, sharegpt format, local JSON/JSONL files, and HuggingFace Hub datasets. See data/README.md for details.

Option 2: Direct Path (Legacy)

dataset:
  path: "/path/to/arrow/dataset"
  query_field: "query"
  response_field: "response"

Provided Datasets

Dataset Source Format
alpaca_en_demo Built-in Alpaca
MetaMathQA HuggingFace Hub Alpaca
Open-Platypus HuggingFace Hub Alpaca
MathInstruct HuggingFace Hub Alpaca
CodeAlpaca-20k HuggingFace Hub Alpaca
ShareGPT4 HuggingFace Hub ShareGPT
UltraChat-200k HuggingFace Hub ShareGPT
OpenThoughts-114k HuggingFace Hub ShareGPT
OpenR1-Math-94k HuggingFace Hub ShareGPT

Configuration

Caution

Do NOT guess the batch_size! Use our resource calculator to find the optimal batch size for your hardware. Wrong batch size leads to OOM or wasted GPU utilization.

python scripts/calc_resource.py
model:
  name: "Qwen/Qwen3.5-27B"
  dtype: "bfloat16"
  attn_implementation: "flash_attention_2"

dataset:
  name: "metamath"
  max_seq_len: 1024

training:
  batch_size: 64       # <-- Use calc_resource.py to determine this!
  num_steps: 500
  learning_rate: 1.0e-5

optimizer:
  type: "deepspeed_adam"

See examples/configs/ for ready-made configurations.

Config Model Architecture
qwen_7b.yaml Qwen 2.5 7B Dense
qwen3_8b.yaml Qwen 3 8B Dense
qwen3_5_27b.yaml Qwen 3.5 27B Hybrid (linear+full attn)
qwen3_next_80b.yaml Qwen3-Next 80B-A3B Hybrid + MoE
glm4_flash.yaml GLM-4-Flash 9B Dense
llama3_8b.yaml Llama 3.1 8B Dense
gpt_oss_20b.yaml GPT-OSS 20B Dense

Performance

Model GPU TFLOPS CPU RAM GPU VRAM
Qwen 2.5 32B 1x H100 ~259 ~327 GB ~40 GB
Qwen 3.5 27B 1x H100 ~157 ~275 GB ~38 GB

Key Techniques

  • Double buffering for overlapped weight transfer between CPU and GPU
  • Per-layer structure grouping for hybrid/MoE architectures
  • Gradient checkpointing every K layers to reduce GPU memory
  • Async gradient collection with slab pool
  • Manual gradient computation (no autograd overhead)
  • HuggingFace native Flash Attention integration
  • DeepSpeed CPUAdam for 5-7x faster optimizer steps

Installation

git clone https://github.com/DLYuanGod/MegaTrain.git
cd MegaTrain
pip install -e .

# Optional: faster attention & optimizer
pip install flash-attn
pip install flash-linear-attention causal-conv1d  # for Qwen3.5 linear attention
pip install deepspeed                              # for CPUAdam optimizer

Troubleshooting

Out of Memory?
  • Reduce batch_size in config
  • Increase checkpoint_interval
  • Reduce max_seq_len
Slow Training?
  • Use deepspeed_adam optimizer (5-7x faster than PyTorch AdamW)
  • Install Flash Attention
  • Install flash-linear-attention + causal-conv1d for Qwen3.5 models
  • Increase num_workers for data loading
New Model Not Working?
  • Ensure it's a decoder-only model (not encoder-decoder like T5)
  • Check trust_remote_code: true in config if the model requires it
  • Try attn_implementation: "sdpa" or "eager" if flash attention fails

Citation

If you use MegaTrain in your research, please cite:

@misc{yuan2026megatrainprecisiontraining100b,
      title={MegaTrain: Full Precision Training of 100B+ Parameter Large Language Models on a Single GPU}, 
      author={Zhengqing Yuan and Hanchi Sun and Lichao Sun and Yanfang Ye},
      year={2026},
      eprint={2604.05091},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2604.05091}, 
}

Acknowledgement

This project benefits from the following open-source works:

  • LLaMA-Factory -- Our data loading system (dataset_info.json registry, alpaca/sharegpt format support) is inspired by LlamaFactory's elegant dataset management design. Thanks to @hiyouga and all contributors.
  • HuggingFace Transformers -- Universal model loading and native Flash Attention integration.
  • DeepSpeed -- SIMD-accelerated CPUAdam optimizer.
  • Flash Attention -- Memory-efficient attention and cross-entropy loss.
  • Flash Linear Attention -- Efficient linear attention kernels for hybrid models like Qwen3.5.

License

This repository is licensed under the Apache-2.0 License.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors