### A4.1.1. OpenXLA

> *OpenXLA is an open-source ML compiler ecosystem that provides a shared compiler infrastructure (XLA, StableHLO) across frameworks (TensorFlow, JAX, PyTorch) and hardware targets (CPU, GPU, TPU).*

**Explanation:**

The **OpenXLA** project unifies the compilation stack for machine learning so that frameworks do not need per-hardware backends. It centers on two components:

**1. StableHLO ‚Äî The Portable IR:**

StableHLO (Stable High-Level Operations) is an MLIR dialect that serves as the exchange format between ML frameworks and compilers. It provides:

- A **versioned, backward-compatible** operation set.
- Operations for tensor compute: `stablehlo.dot_general`, `stablehlo.convolution`, `stablehlo.reduce`.
- Framework-agnostic ‚Äî JAX, TensorFlow, and PyTorch/XLA all lower to StableHLO.

**2. XLA ‚Äî The Optimizing Compiler:**

XLA (Accelerated Linear Algebra) consumes HLO/StableHLO and produces optimized machine code:

| Phase | Action |
|-------|--------|
| HLO optimization | Algebraic simplification, CSE, constant folding |
| Fusion | Merge elementwise ops into single kernels |
| Layout assignment | Choose memory layout (row-major, tiled) per target |
| Code generation | Emit LLVM IR (CPU), PTX (GPU), or HLO‚ÜíTPU instructions |

**Architecture:**

```
Framework (JAX / TF / PyTorch)
        ‚Üì
    StableHLO
        ‚Üì
    XLA Compiler
    ‚îú‚îÄ‚îÄ CPU codegen (LLVM)
    ‚îú‚îÄ‚îÄ GPU codegen (LLVM ‚Üí PTX)
    ‚îî‚îÄ‚îÄ TPU codegen
```

**Example:**

A JAX function `jnp.dot(a, b) + c` lowers to StableHLO ops `stablehlo.dot_general` + `stablehlo.add`, which XLA fuses into a single kernel with an optimized GEMM call.

In [None]:
from dataclasses import dataclass, field


@dataclass
class StableHLOOp:
    name: str
    operands: list[str]
    result: str
    result_shape: tuple[int, ...]


@dataclass
class HLOModule:
    name: str
    ops: list[StableHLOOp] = field(default_factory=list)

    def add_op(self, op_name, operands, result, result_shape):
        operation = StableHLOOp(op_name, operands, result, result_shape)
        self.ops.append(operation)
        return operation


module = HLOModule("matmul_add")
module.add_op("stablehlo.dot_general", ["%a", "%b"], "%dot", (128, 64))
module.add_op("stablehlo.add", ["%dot", "%bias"], "%result", (128, 64))

print(f"Module: {module.name}")
print(f"Operations: {len(module.ops)}")
for op in module.ops:
    print(f"  {op.result} = {op.name}({', '.join(op.operands)}) : tensor<{'x'.join(str(d) for d in op.result_shape)}xf32>")

xla_phases = [
    ("HLO optimization", "algebraic simplification, CSE, constant folding"),
    ("Fusion", "merge dot_general + add into fused kernel"),
    ("Layout assignment", "choose row-major for CPU, tiled for GPU"),
    ("Code generation", "emit LLVM IR or PTX"),
]

print("\nXLA compilation pipeline:")
for phase_name, description in xla_phases:
    print(f"  {phase_name}: {description}")

targets = {"CPU": "LLVM IR ‚Üí x86/ARM", "GPU": "LLVM IR ‚Üí PTX ‚Üí SASS", "TPU": "HLO ‚Üí TPU instructions"}
print("\nBackend targets:")
for target, codegen in targets.items():
    print(f"  {target}: {codegen}")

**References:**

[üìò OpenXLA Project. *OpenXLA ‚Äî An Open Ecosystem of ML Compilers.*](https://openxla.org/)

[üìò OpenXLA Project. *StableHLO Specification.*](https://openxla.org/stablehlo)

---

[Next: TensorFlow XLA ‚û°Ô∏è](./02_tensorflow_xla.ipynb)