### 2.3.7  Mixed-Precision Training (Megatron-style)

---

#### Why mix precisions at all?

1. **Memory & speed**

   * Model weights, activations, and gradients are stored in **fp16** (half-precision) to cut memory and bandwidth in half.
   * A second copy of the weights is kept in **fp32** (“master weights”) so that the **Adam** update remains numerically stable.

2. **Gradient underflow protection**

   * Roughly two-thirds of language-model gradients have magnitudes < 2⁻²⁴ (the smallest positive fp16), so a pure-fp16 run would zero-out many updates and diverge.
   * Mixed precision keeps the math in range: forward/backward in fp16, critical updates in fp32.

---

#### End-to-end workflow

| Step                       | What happens                                                                                                   | Precision               |
| -------------------------- | -------------------------------------------------------------------------------------------------------------- | ----------------------- |
| **1. Copy & cast**         | At launch, duplicate each fp32 parameter into fp16. Adam’s **momentum** & **variance** states remain fp32.     | fp32 → fp16 copy        |
| **2. Forward pass (FWD)**  | Run the model with **fp16 weights & activations**.                                                             | fp16                    |
| **3. Loss scaling**        | Multiply the fp32 loss by $2^{\text{loss\_scale}}$ to avoid tiny gradients.                                    | fp32                    |
| **4. Backward pass (BWD)** | Compute **scaled gradients**; store them in fp16 to save memory.                                               | fp16                    |
| **5. Unscale gradients**   | Convert to fp32 and divide by $2^{\text{loss\_scale}}$.                                                        | fp32                    |
| **6. Clip / regularise**   | Apply gradient-norm clipping, weight-decay, etc., on the fp32 gradients.                                       | fp32                    |
| **7. Parameter update**    | `optimizer.step()` updates **master fp32 weights**, then casts the result back to fp16 for the next iteration. | fp32 update → fp16 copy |

> **Diagram:** The attached flowchart (green = fp32, purple = fp16) follows exactly these seven stages.

---

#### Memory breakdown (per parameter “ϕ”)

| Category                   | Precision           | Footprint |
| -------------------------- | ------------------- | --------- |
| **Must-have**              | master param (fp32) | 4 ϕ       |
|                            | momentum (fp32)     | 4 ϕ       |
|                            | variance (fp32)     | 4 ϕ       |
| **Transient**              | param copy (fp16)   | 2 ϕ       |
|                            | gradients (fp16)    | 2 ϕ       |
| **Total (no activations)** |                     | **16 ϕ**  |

*If you cast the fp16 weights on-the-fly instead of storing them persistently, the extra 2 ϕ vanish, but most frameworks keep the copy for speed.*

---

#### Loss-scale mechanics

1. **Static loss scale (cheap but brittle)**

   * Pick a constant power-of-2 scale (e.g. 2¹⁶).
   * *Scale-up*: multiply the loss before BWD.
   * *Scale-down*: divide gradients after BWD.
   * If any **inf / NaN** appears, skip the update for that step.

2. **Dynamic loss scale (AMP default)**

   * Begin with a large scale (e.g. 2²⁴).
   * After each step, check for **inf / NaN** in the fp16 grads.
   * * If clean → optionally **increase** the scale every *N* steps.
   * * If overflow → **halve** the scale and redo the step with the previous weights.

This keeps gradients as large as possible without blowing up.

---

#### Gradient clipping

Use L2-norm clipping on the **unscaled fp32 gradients**:

$$
g_1 = \frac{\partial J}{\partial w_1},\;
g_2 = \frac{\partial J}{\partial w_2},\;
\|g\|_2 = \sqrt{g_1^2 + g_2^2}.
$$

If $\|g\|_2 > c$ (threshold), rescale: $g \leftarrow \tfrac{c}{\|g\|_2} \, g$.

---

#### Key take-aways

* **One copy, two precisions:** fp16 for speed, fp32 for accuracy.
* **Loss scaling is mandatory** to stop gradients from under-flowing.
* **Dynamic scaling + gradient clipping** makes training far more crash-resistant.
* Memory overhead is modest—about **1.5×** pure-fp16—yet avoids the accuracy pitfalls of full half-precision training.
