### A4.1.2. TensorFlow XLA

> *XLA (Accelerated Linear Algebra) is TensorFlow's domain-specific compiler that fuses operations and generates optimized code for CPUs, GPUs, and TPUs from TensorFlow graphs.*

**Explanation:**

In TensorFlow, XLA compiles subgraphs of the computation graph into fused kernels. Without XLA, each TF op dispatches a separate precompiled kernel; with XLA, multiple ops are fused into one kernel, eliminating intermediate memory traffic.

**How TF Invokes XLA:**

| Mechanism | Description |
|-----------|------------|
| `tf.function(jit_compile=True)` | Explicitly compile the traced function with XLA |
| Auto-clustering | TF runtime identifies fusible subgraphs and compiles them |
| `TF_XLA_FLAGS=--tf_xla_auto_jit=2` | Environment variable to enable aggressive auto-clustering |

**XLA Compilation Flow in TF:**

```
tf.function ‚Üí TF Graph ‚Üí XLA HLO ‚Üí XLA optimizations ‚Üí LLVM IR ‚Üí machine code
```

**Key Optimizations:**

- **Op fusion** ‚Äî elementwise chains become one kernel (e.g., `relu(matmul(x, w) + b)`).
- **Buffer assignment** ‚Äî aliases output buffers to inputs when safe, avoiding copies.
- **Layout optimization** ‚Äî transposes data into hardware-preferred layout (e.g., NHWC ‚Üí NCHW for GPU).
- **Constant folding** ‚Äî evaluates ops with known inputs at compile time.

**Limitations:**

- Dynamic shapes require recompilation for each distinct shape.
- Not all TF ops have XLA lowerings (e.g., `tf.py_function`).
- Compilation adds latency on first call.

**Example:**

```python
@tf.function(jit_compile=True)
def fused_layer(x, w, b):
    return tf.nn.relu(tf.matmul(x, w) + b)
```

Without XLA: 3 kernel launches (matmul, add, relu). With XLA: 1 fused kernel.

In [None]:
from dataclasses import dataclass, field


@dataclass
class TFOp:
    name: str
    inputs: list[str]
    output: str


@dataclass
class XLACluster:
    ops: list[TFOp] = field(default_factory=list)

    @property
    def kernel_count_unfused(self):
        return len(self.ops)

    @property
    def kernel_count_fused(self):
        return 1

    @property
    def memory_roundtrips_saved(self):
        return len(self.ops) - 1


graph_ops = [
    TFOp("MatMul", ["%x", "%w"], "%mm"),
    TFOp("BiasAdd", ["%mm", "%b"], "%add"),
    TFOp("Relu", ["%add"], "%out"),
]

cluster = XLACluster(ops=graph_ops)

print("TF Graph (unfused):")
for op in graph_ops:
    print(f"  {op.output} = {op.name}({', '.join(op.inputs)})")

print(f"\nWithout XLA: {cluster.kernel_count_unfused} kernel launches")
print(f"With XLA:    {cluster.kernel_count_fused} fused kernel")
print(f"Memory roundtrips eliminated: {cluster.memory_roundtrips_saved}")

compilation_mechanisms = [
    ("tf.function(jit_compile=True)", "explicit, per-function"),
    ("auto-clustering", "runtime identifies fusible subgraphs"),
    ("TF_XLA_FLAGS=--tf_xla_auto_jit=2", "aggressive global auto-clustering"),
]

print("\nXLA activation methods:")
for method, description in compilation_mechanisms:
    print(f"  {method}: {description}")

**References:**

[üìò TensorFlow. *XLA: Optimizing Compiler for Machine Learning.*](https://www.tensorflow.org/xla)

[üìò OpenXLA Project. *XLA Architecture.*](https://openxla.org/xla/architecture)

---

[‚¨ÖÔ∏è Previous: OpenXLA](./01_openxla.ipynb) | [Next: JAX Compilation ‚û°Ô∏è](./03_jax_compilation.ipynb)