### A4.3.1. Operator Fusion

$$
\text{Bytes}_{\text{saved}} = \sum_{i=1}^{n-1} \text{size}(\text{intermediate}_i) \times 2
$$

where $n$ is the number of fused ops and the factor 2 accounts for one write and one read of each eliminated intermediate.

**Explanation:**

**Operator fusion** merges multiple operations into a single kernel so that intermediate results stay in registers or shared memory instead of being written to and read from global memory (DRAM/HBM).

**Why Fusion Matters:**

ML workloads are often **memory-bandwidth bound** for elementwise and reduction ops. A chain like `relu(matmul(x, w) + b)` without fusion requires:
- Write `matmul` result to memory ‚Üí read it for `add` ‚Üí write `add` result ‚Üí read it for `relu`.

With fusion: one kernel computes all three, intermediate values never leave registers.

**XLA Fusion Categories:**

| Type | Description | Example |
|------|-------------|--------|
| Element-wise | All ops are pointwise on same shape | `add ‚Üí relu ‚Üí mul` |
| Input fusion | Consumer reads producer's output element-by-element | `broadcast ‚Üí add` |
| Output fusion | Producer's output is consumed by a reduction | `matmul ‚Üí bias_add` |
| Loop fusion | Ops share a common iteration space | `transpose ‚Üí elementwise` |

**Fusion Decisions:**

The compiler uses a **cost model** considering:
- Register pressure ‚Äî too many fused ops may spill to memory.
- Shared memory limits ‚Äî fused kernel must fit tile data.
- Recomputation vs. materialization ‚Äî sometimes recomputing is cheaper than storing.

**Example:**

```
Unfused: matmul(128√ó784, 784√ó256) ‚Üí [128√ó256 write] ‚Üí add(bias) ‚Üí [128√ó256 write] ‚Üí relu ‚Üí [128√ó256 write]
Fused:   matmul+add+relu ‚Üí [128√ó256 write]   (2 intermediate writes eliminated)
```

In [None]:
from dataclasses import dataclass


@dataclass
class Operator:
    name: str
    output_shape: tuple[int, ...]
    element_bytes: int = 4

    @property
    def output_bytes(self):
        total_elements = 1
        for dim in self.output_shape:
            total_elements *= dim
        return total_elements * self.element_bytes


def analyze_fusion(ops):
    unfused_memory_traffic = sum(
        op.output_bytes * 2
        for op in ops
    )
    fused_memory_traffic = ops[-1].output_bytes

    intermediate_bytes_saved = sum(
        op.output_bytes * 2
        for op in ops[:-1]
    )

    return unfused_memory_traffic, fused_memory_traffic, intermediate_bytes_saved


ops = [
    Operator("matmul", (128, 256)),
    Operator("bias_add", (128, 256)),
    Operator("relu", (128, 256)),
]

unfused_traffic, fused_traffic, saved = analyze_fusion(ops)

print("Operator chain: " + " ‚Üí ".join(op.name for op in ops))
print(f"Output shape: {ops[-1].output_shape}")
print(f"\nUnfused memory traffic: {unfused_traffic:,} bytes ({unfused_traffic / 1024:.1f} KB)")
print(f"Fused memory traffic:   {fused_traffic:,} bytes ({fused_traffic / 1024:.1f} KB)")
print(f"Bytes saved by fusion:  {saved:,} bytes ({saved / 1024:.1f} KB)")
print(f"Traffic reduction:      {saved / unfused_traffic:.0%}")

larger_ops = [
    Operator("matmul", (1024, 1024)),
    Operator("bias_add", (1024, 1024)),
    Operator("relu", (1024, 1024)),
    Operator("dropout", (1024, 1024)),
]

unfused_large, fused_large, saved_large = analyze_fusion(larger_ops)
print(f"\nLarger example ({len(larger_ops)} ops, 1024√ó1024):")
print(f"  Unfused: {unfused_large / (1024**2):.1f} MB")
print(f"  Fused:   {fused_large / (1024**2):.1f} MB")
print(f"  Saved:   {saved_large / (1024**2):.1f} MB ({saved_large / unfused_large:.0%})")

**References:**

[üìò OpenXLA Project. *XLA Architecture ‚Äî Fusion.*](https://openxla.org/xla/architecture)

[üìò Chen, T. et al. (2018). *TVM: An Automated End-to-End Optimizing Compiler for Deep Learning.* OSDI.](https://www.usenix.org/conference/osdi18/presentation/chen)

---

[‚¨ÖÔ∏è Previous: Custom Calls](../02_Framework_Integration/03_custom_calls.ipynb) | [Next: Memory Planning ‚û°Ô∏è](./02_memory_planning.ipynb)