### A4.3.3. Dispatch

> *Dispatch is the runtime mechanism that selects and launches the correct kernel implementation for an operation based on the device, data type, layout, and other properties of its inputs.*

**Explanation:**

When a framework executes an operation like `matmul`, the **dispatcher** must resolve which concrete kernel to invoke. This decision depends on multiple axes:

**Dispatch Axes:**

| Axis | Examples |
|------|----------|
| Device | CPU, CUDA, TPU, XPU |
| Dtype | float32, float16, bfloat16, int8 |
| Layout | dense, sparse_coo, sparse_csr, strided |
| Autograd | needs_grad ‚Üí wrap with gradient tracking |
| Quantization | quantized weight ‚Üí specialized kernel |

**PyTorch Dispatcher:**

PyTorch uses a **dispatch table** indexed by **dispatch keys** (device, autograd, quantized, etc.). Each op is registered with multiple implementations:

```
aten::matmul
  ‚îú‚îÄ‚îÄ CPU ‚Üí mkl_matmul / openblas_matmul
  ‚îú‚îÄ‚îÄ CUDA ‚Üí cublas_matmul / cutlass_matmul
  ‚îú‚îÄ‚îÄ AutogradCPU ‚Üí autograd wrapper ‚Üí CPU kernel
  ‚îî‚îÄ‚îÄ QuantizedCPU ‚Üí quantized_matmul_int8
```

**XLA/JAX Dispatch:**

In JIT-compiled systems (XLA, JAX), dispatch happens at **compile time**: the compiler selects the kernel during lowering/code generation, not at runtime.

**Tradeoffs:**

- **Eager dispatch** (PyTorch) ‚Äî flexible, supports dynamic shapes and control flow, per-op overhead (~Œºs).
- **Compiled dispatch** (XLA) ‚Äî zero per-op overhead at runtime, but compilation cost upfront.

**Example:**

Calling `torch.matmul(a, b)` where `a` is on CUDA and has `requires_grad=True`:
1. Dispatcher checks keys: `AutogradCUDA` ‚Üí `CUDA`.
2. Autograd wrapper records the op on the tape.
3. CUDA backend dispatches to cuBLAS GEMM.

In [None]:
from dataclasses import dataclass, field


@dataclass
class DispatchKey:
    device: str
    dtype: str
    requires_grad: bool = False

    @property
    def key_tuple(self):
        return (self.device, self.dtype, self.requires_grad)


@dataclass
class KernelEntry:
    dispatch_key: tuple
    kernel_name: str


@dataclass
class DispatchTable:
    op_name: str
    entries: dict = field(default_factory=dict)

    def register(self, device, dtype, requires_grad, kernel_name):
        key = (device, dtype, requires_grad)
        self.entries[key] = kernel_name

    def dispatch(self, dispatch_key):
        key = dispatch_key.key_tuple
        if key in self.entries:
            return self.entries[key]
        fallback_key = (dispatch_key.device, dispatch_key.dtype, False)
        return self.entries[fallback_key]


matmul_table = DispatchTable("aten::matmul")
matmul_table.register("cpu", "float32", False, "mkl_sgemm")
matmul_table.register("cpu", "float64", False, "mkl_dgemm")
matmul_table.register("cuda", "float32", False, "cublas_sgemm")
matmul_table.register("cuda", "float16", False, "cublas_hgemm")
matmul_table.register("cuda", "float32", True, "autograd_cuda_sgemm")
matmul_table.register("cpu", "int8", False, "quantized_matmul_int8")

print(f"Dispatch table: {matmul_table.op_name}")
print(f"Registered kernels: {len(matmul_table.entries)}")
for key, kernel in matmul_table.entries.items():
    device, dtype, grad = key
    grad_str = ", grad" if grad else ""
    print(f"  ({device}, {dtype}{grad_str}) ‚Üí {kernel}")

test_cases = [
    DispatchKey("cuda", "float32", requires_grad=True),
    DispatchKey("cuda", "float16", requires_grad=False),
    DispatchKey("cpu", "float32", requires_grad=False),
    DispatchKey("cpu", "int8", requires_grad=False),
]

print("\nDispatch resolution:")
for key in test_cases:
    kernel = matmul_table.dispatch(key)
    grad_str = ", grad" if key.requires_grad else ""
    print(f"  matmul({key.device}, {key.dtype}{grad_str}) ‚Üí {kernel}")

**References:**

[üìò Chanan, G. *PyTorch Dispatcher Internals.*](http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/)

[üìò PyTorch Documentation. *Extending dispatcher for a new backend.*](https://pytorch.org/tutorials/advanced/extend_dispatcher.html)

---

[‚¨ÖÔ∏è Previous: Memory Planning](./02_memory_planning.ipynb) | [Next: Benchmark Design ‚û°Ô∏è](../04_Benchmarking_and_Regressions/01_benchmark_design.ipynb)