Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion docs/explanations/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,34 @@ python3 -m MaxText.train src/MaxText/configs/base.yml run_name=$YOUR_JOB_NAME ba
```
Note that `use_qwix_quantization` is not set to `True`.

For further reading, please refer to the [Qwix Read the Docs website](https://qwix.readthedocs.io/en/latest/get_started.html#).
For further reading, please refer to the [Qwix Read the Docs website](https://qwix.readthedocs.io/en/latest/get_started.html#).

## DeepSeek V3 Fine-tuning FP8 Recipe
To improve the performance of DeepSeek V3 fine-tuning, we developed a custom recipe optimized for FP8 throughput. The method prioritizes specific compute-intensive and bandwidth-heavy components while preserving training stability through a fine-grained scaling strategy.

### Quantization Scope
To realize these gains, the recipe employs a w8a8g8 (8-bit weights, activations and gradients) strategy targeting three primary areas:

* Megablox Kernels: Specifically the `gmm` and `tgmm` operations.

* Attention Projections: Utilizing convolution fusion.

* Communication: Specifically the weight All-Gathers.

### FP8 Recipe
* Rounding: rounding to nearest even
* Precision
* Activations and weights: e4m3fn
* Gradients: e5m2
* Scaling granularity: per-axis
* Scaling mode:
* static for weights and activations
* dynamic for gradients

### Convergence
To validate this recipe, we utilized MaxText following the MLPerf Training framework by MLCommons to ensure a reproducible and standardized evaluation. Using the C4 dataset (loaded via TFDS) as the reference corpus, we tracked convergence by monitoring validation loss on a held-out split. This aligns with MLPerf’s time-to-quality principle, where the primary metric is the speed at which the model achieves target quality.

For this specific case, we derived our training duration from the MLPerf 405B benchmark, targeting roughly 2–3 billion tokens after resuming from a checkpoint. In our configuration, we executed 300 steps with a sequence length of 4096 and a global batch size of 2048, resulting in a total of approximately 2.5 billion tokens.

### Performance Sensitivity
Please note that the FP8 benefits are highly sensitive to model parameters, the efficiency of the BF16 baseline, and hardware utilization; consequently, results will vary when this recipe is applied to other models. Any variance in these factors shifts the ratio of compute-bound to memory-bound operations, directly altering the potential gains.
Loading