diff --git a/docs/explanations/quantization.md b/docs/explanations/quantization.md index a665a28c0..e54124c14 100644 --- a/docs/explanations/quantization.md +++ b/docs/explanations/quantization.md @@ -141,4 +141,34 @@ python3 -m MaxText.train src/MaxText/configs/base.yml run_name=$YOUR_JOB_NAME ba ``` Note that `use_qwix_quantization` is not set to `True`. -For further reading, please refer to the [Qwix Read the Docs website](https://qwix.readthedocs.io/en/latest/get_started.html#). \ No newline at end of file +For further reading, please refer to the [Qwix Read the Docs website](https://qwix.readthedocs.io/en/latest/get_started.html#). + +## DeepSeek V3 Fine-tuning FP8 Recipe +To improve the performance of DeepSeek V3 fine-tuning, we developed a custom recipe optimized for FP8 throughput. The method prioritizes specific compute-intensive and bandwidth-heavy components while preserving training stability through a fine-grained scaling strategy. + +### Quantization Scope +To realize these gains, the recipe employs a w8a8g8 (8-bit weights, activations and gradients) strategy targeting three primary areas: + +* Megablox Kernels: Specifically the `gmm` and `tgmm` operations. + +* Attention Projections: Utilizing convolution fusion. + +* Communication: Specifically the weight All-Gathers. + +### FP8 Recipe +* Rounding: rounding to nearest even +* Precision + * Activations and weights: e4m3fn + * Gradients: e5m2 +* Scaling granularity: per-axis +* Scaling mode: + * static for weights and activations + * dynamic for gradients + +### Convergence +To validate this recipe, we utilized MaxText following the MLPerf Training framework by MLCommons to ensure a reproducible and standardized evaluation. Using the C4 dataset (loaded via TFDS) as the reference corpus, we tracked convergence by monitoring validation loss on a held-out split. This aligns with MLPerf’s time-to-quality principle, where the primary metric is the speed at which the model achieves target quality. + +For this specific case, we derived our training duration from the MLPerf 405B benchmark, targeting roughly 2–3 billion tokens after resuming from a checkpoint. In our configuration, we executed 300 steps with a sequence length of 4096 and a global batch size of 2048, resulting in a total of approximately 2.5 billion tokens. + +### Performance Sensitivity +Please note that the FP8 benefits are highly sensitive to model parameters, the efficiency of the BF16 baseline, and hardware utilization; consequently, results will vary when this recipe is applied to other models. Any variance in these factors shifts the ratio of compute-bound to memory-bound operations, directly altering the potential gains.