Skip to content

als244/flextrain

Repository files navigation

FlexTrain

A high-efficiency training engine for transformer LLMs in tight-GPU-memory and long-context regimes — the operating points where mainstream engines (DeepSpeed ZeRO, FSDP, Megatron-LM, Unsloth) either lose throughput to idle stalls and excessive recomputation or require additional GPUs (more aggregate VRAM) to fit the model in the first place. A working-set planner + DP solver schedules every tensor (parameters, gradients, optimizer state, activations) between GPU and host RAM. You declare GPU and host RAM budgets at model construction; the planner automatically parameterizes a unified data-sizing, offloading, and recomputation policy that fits those budgets — no manual tuning required.

The single-GPU engine is fully implemented today; multi-GPU support is in active development. On a 32 GiB GPU + 192 GiB host: a 9B dense model full-fine-tunes at 80% of an RTX 5090's bf16 peak throughput; 12B dense full-fine-tunes on the same card; 27B-parameter LoRA tuning hits 78% of peak — all on one GPU. See docs/verified_runs.md for the full table.

Supported architectures. Llama 2 / 3 / 3.1+, Mistral, Qwen2, Qwen3 (dense + MoE), Qwen3.5 (dense + MoE) / Qwen3.6, Qwen3-Next, OLMoE, Gemma 2 (dense), Gemma 3 (dense, text-only path) — see docs/architectures.md. Each arch has forward + backward + HF safetensors load. End-to-end smoke runs are recorded in docs/verified_runs.md (Llama, Qwen, OLMoE families) and docs/gemma_runs.md (Gemma family).

Install

conda create -n flextrain python=3.12
conda activate flextrain
pip install torch triton
pip install -e .
flextrain-post-install            # fetch prebuilt flash-attn + causal-conv1d wheels

pip install -e . builds the two in-tree C/CUDA helpers (matmul_dispatcher, transmission_scheduler) and pulls in flash-linear-attention so Qwen3-Next / Qwen3.5 hybrid layers (Gated DeltaNet) work out of the box. Set FLEXTRAIN_SKIP_HELPERS=1 to skip the helper builds when iterating on Python-only code.

flextrain-post-install runs once after the editable install and fetches:

  • flash-attention prebuilt wheels matching your (python, torch, CUDA) tuple from mjun0812/flash-attention-prebuild-wheels. flash-attn 2 always; flash-attn 3 added only on Hopper (compute capability exactly 9.0) — FA3 prebuilt wheels currently ship only sm_90 kernels, so they're skipped on Ada (sm_89), Blackwell (sm_120), and earlier sm_80 / sm_86 to avoid the "no kernel image is available for execution on the device" runtime error. The in-tree FA2 path covers those GPUs. If no matching wheel exists for your combo, that package is skipped silently — the install does not fail. FLEXTRAIN_SKIP_FLASH_ATTN=1 opts out.
  • causal-conv1d prebuilt wheel from Dao-AILab/causal-conv1d GitHub releases (Mamba-style state-space layers; FLA also uses it for some kernels). FLEXTRAIN_SKIP_CAUSAL_CONV1D=1 opts out.

The two-step install (rather than auto-installing during pip install -e .) is forced by modern pip's default build isolation: setup.py cannot reach the user environment from inside the isolated build env, so wheel installs have to run as a separate user-env step.

Optional extras: -e ".[hopper]" (tilelang for FLA on Hopper + Triton ≥ 3.4 — works around a correctness bug in the default Triton chunk_bwd kernel), -e ".[peft]" (HF PEFT for LoRA parity tests), -e ".[flash-attn-build]" / -e ".[causal-conv1d-build]" (build those packages from source instead of fetching a prebuilt wheel — slow; minutes-to-hours).

Quickstart

python train.py \
  --model meta-llama/Llama-3.1-8B \
  --mode lora \
  --max-seq-len 1024 \
  --max-global-batch-tokens 1024

Auto-discovers your GPU + host memory budgets, probes hardware (sustained TFLOPS / PCIe / mem bandwidth), downloads the model into models/ if needed, runs 20 steps on a tiny bundled SFT dataset, and logs to runs/<model>_<mode>_sl<seq_len>.

Common flags:

flag what it does
--mode {full,lora} full fine-tune or LoRA
--use-muon use HybridMuonAdamW for --mode full (Muon on dense projections, AdamW elsewhere)
--data-source {synthetic,json_sft} synthetic tokens or SFT JSONL
--dataset path/or/repo local JSONL, HF repo, or http(s) URL
--truncate-long truncate response of records longer than --max-seq-len instead of dropping them (default: drop)
--max-gpu-mem-gib N / --max-host-mem-gib N override auto-discovered budgets
--force-save-level {0,1,2,3} force activation save-tier (debug)
--save export final.safetensors after training

Run python train.py --help for the full list.

Learning rate schedule

Every run uses linear warmup → constant peak → cosine cooldown. Tune via:

flag default what
--lr 3e-5 (full), 1e-4 (lora), 1e-3 (--use-muon) peak (max) LR
--lr-warmup-pct 0.1 fraction of steps spent ramping 0 → peak
--lr-cooldown-start-pct 0.8 fraction at which cosine cooldown begins
--lr-final-pct 0.1 final LR as a fraction of peak

Example with a tighter warmup and a deeper cooldown:

python train.py --model models/Llama-3.1-8B --mode full \
  --max-seq-len 1024 --max-global-batch-tokens 524288 \
  --lr 5e-5 --lr-warmup-pct 0.03 --lr-final-pct 0.01

Profiling with nsys

Wrap a window of steady-state steps for nsys profile:

nsys profile --capture-range=cudaProfilerApi --capture-range-end=stop \
  python train.py --model models/Llama-3.1-8B --mode full \
                  --max-seq-len 1024 --max-global-batch-tokens 524288 \
                  --steps 10 --profile-start-step 5 --profile-stop-step 7

--profile-start-step calls cudaProfilerStart() right before that step begins; --profile-stop-step calls cudaProfilerStop() after it ends (default = start + 2). nsys' capture range opens/closes on those markers, so warmup and final-step teardown stay out of the report. Each captured step is wrapped in an NVTX range so the timeline groups by step.

Air-gapped compute nodes

If your training nodes have no internet, pre-stage the model + dataset on a login node with download.py, then point train.py at the local paths:

# Login node (has internet):
python download.py model meta-llama/Llama-3.1-8B --target models/Llama-3.1-8B
python download.py dataset HuggingFaceH4/no_robots --target datasets/no_robots.jsonl

# Compute node (no internet):
python train.py --model models/Llama-3.1-8B \
                --data-source json_sft \
                --dataset datasets/no_robots.jsonl \
                --max-seq-len 1024 --max-global-batch-tokens 1024

The dataset path normalizes common SFT schemas (instruction/output, prompt/completion, chat-style messages, ...) into FlexTrain's JSONL format. download.py model --allow-patterns '*.safetensors' '*.json' skips redundant pytorch_model.bin shards.

Python API

from flextrain import from_pretrained
from flextrain.optim.adamw import AdamW, AdamWHyperparams

am = from_pretrained(
    "models/Llama-3.1-8B",
    optimizer=AdamW(AdamWHyperparams(lr=3e-5)),
    max_seq_len=1024,
    max_global_batch_tokens=1024,
    max_gpu_mem_bytes=24 << 30,
    max_host_mem_bytes=110 << 30,
)

for batch in your_dataloader:
    am.fwd_bwd(batch)
    am.step()

For LoRA, pass lora_targets="all" (and optionally lora_rank, lora_alpha). For full fine-tuning with Muon on dense projections, swap the optimizer for flextrain.optim.HybridMuonAdamW.

Hardware probe. from_pretrained runs probe_hardware() on the first call (~10s) when hw_cost isn't provided, so the save-level DP solver gets accurate TFLOPS / PCIe numbers. To avoid re-probing across calls, cache the result and pass it explicitly:

from flextrain.core.hw_probe import probe_hardware
probe = probe_hardware()  # ~10s; sustained TFLOPS / PCIe / mem-bw
am = from_pretrained(..., hw_cost=probe.hw_cost, mem_bw_gbps=probe.mem_bw_gbps)

train.py does this once at startup and reuses the result.

Layout

flextrain/
  core/      Layer/Block protocols, ActivationSchema, save-level DP solver,
             working-set sizer, hardware probe
  engine/    ActiveModel trainer, buffer manager, streams/events
  nn/        blocks/ (attention, FFN, MoE, RoPE, ...) + layers/ (full models)
  optim/     AdamW, Muon, HybridMuonAdamW
  ops/       FlexTrain-owned Triton kernels
  io/        HF weight load/save, per-arch adapters, download helpers
  bench/     parity tests + microbenchmarks
train.py     end-to-end CLI
download.py  pre-stage models/datasets for air-gapped nodes

Documentation

architectures.md supported HF configs
working_set.md how the planner picks chunk size, GPU layer counts, save tiers
sft_vs_pretraining.md targets / loss-mask conventions
dataset.md data format + built-in JSON SFT source
weights.md HF safetensors I/O, custom archs
lora.md LoRA, MoE per-expert LoRA, HF PEFT parity
extending/ adding a new model: tutorial + flow + per-level contracts (block / layer / chunk / model)
dtypes.md compute / master / grad / opt-state dtypes
optimizers.md AdamW / Muon / HybridMuonAdamW
export.md export to vLLM / sGLang / HF (full, LoRA adapter, merged)
verified_runs.md end-to-end smoke runs on the reference workstation
gemma_runs.md Gemma 2 / Gemma 3 verified runs + 5-step parity vs HF + re-verify protocol

Tests

Each test is a standalone script (no test runner yet):

python tests/test_arch_parity.py              # FT-vs-HF loss + logit parity across archs
python tests/test_arch_lora_e2e.py --all      # LoRA-mode parity vs HF PEFT (all archs)
python tests/test_llama32_1b_parity.py        # Llama-3.2-1B end-to-end vs HF transformers
python tests/parity_qwen3_5_9b_35b_5step.py   # Qwen3.5-9B + 35B-A3B 5-step replay

Per-area suites live in subdirs: tests/moe/ (MoE block / backend parity), tests/multi_chunk_dense_parity/ (chunked linear-attn correctness), tests/io/ (HF safetensors round-trip), tests/cross_machine_parity/ (replay across machines).

About

Transformer Training on Tight GPU Memory Budget: Unified Data-Sizing, Offloading, and Recomputation

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages