### A4.2.3. Custom Calls

> *A custom call is an XLA mechanism that invokes an external function (hand-written kernel, library routine) from within a compiled HLO graph, bridging the gap between compiler-generated and manually-optimized code.*

**Explanation:**

Not every operation can be expressed efficiently as a composition of XLA primitives. **Custom calls** allow the compiled graph to call out to:

- **Vendor libraries** ‚Äî cuBLAS GEMM, cuDNN convolution, MKL routines.
- **Hand-tuned kernels** ‚Äî Triton kernels, inline PTX, hand-written CUDA.
- **External C/C++ functions** ‚Äî FFI (Foreign Function Interface) targets.

**In XLA HLO:**

```
%result = custom-call(%input), custom_call_target="my_kernel",
    backend_config={...}, api_version=API_VERSION_TYPED_FFI
```

**In JAX:**

JAX exposes custom calls via `jax.extend.ffi.ffi_call` (the new FFI API) or the older `jax.lib.xla_client.register_custom_call_target`.

Requirements for a custom call:

1. **Registration** ‚Äî register the function name and pointer with XLA.
2. **Shape inference** ‚Äî tell the compiler the output shape given input shapes.
3. **Differentiation rule** ‚Äî if the op must participate in `jax.grad`, provide a custom VJP.
4. **Batching rule** ‚Äî if it must work with `jax.vmap`, provide a batching rule.

**Use Cases:**

| Scenario | Why custom call? |
|----------|------------------|
| cuBLAS GEMM | XLA's generated GEMM may be slower than cuBLAS for certain shapes |
| Sparse ops | XLA has no native sparse support |
| Hardware-specific intrinsics | TPU intrinsics, GPU tensor cores |
| Third-party libraries | NCCL collectives, CUTLASS |

**Example:**

```python
from jax.extend import ffi

def my_custom_op(x):
    return ffi.ffi_call("my_kernel", result_shape=x)(x)
```

In [None]:
from dataclasses import dataclass


@dataclass
class CustomCallTarget:
    name: str
    platform: str
    has_grad_rule: bool
    has_batch_rule: bool


def infer_output_shape(target_name, input_shapes):
    shape_rules = {
        "cublas_gemm": lambda shapes: (shapes[0][0], shapes[1][1]),
        "custom_relu": lambda shapes: shapes[0],
        "sparse_matmul": lambda shapes: (shapes[0][0], shapes[1][1]),
    }
    return shape_rules[target_name](input_shapes)


registered_targets = [
    CustomCallTarget("cublas_gemm", "gpu", has_grad_rule=True, has_batch_rule=True),
    CustomCallTarget("custom_relu", "gpu", has_grad_rule=True, has_batch_rule=True),
    CustomCallTarget("sparse_matmul", "cpu", has_grad_rule=False, has_batch_rule=False),
    CustomCallTarget("nccl_allreduce", "gpu", has_grad_rule=True, has_batch_rule=False),
]

print("Registered custom call targets:")
for target in registered_targets:
    grad = "‚úì" if target.has_grad_rule else "‚úó"
    batch = "‚úì" if target.has_batch_rule else "‚úó"
    print(f"  {target.name} ({target.platform}) ‚Äî grad: {grad}, vmap: {batch}")

gemm_output = infer_output_shape("cublas_gemm", [(128, 784), (784, 256)])
relu_output = infer_output_shape("custom_relu", [(128, 256)])

print(f"\nShape inference:")
print(f"  cublas_gemm((128,784), (784,256)) ‚Üí {gemm_output}")
print(f"  custom_relu((128,256)) ‚Üí {relu_output}")

hlo_representation = [
    '%p0 = f32[128,784] parameter(0)',
    '%p1 = f32[784,256] parameter(1)',
    '%gemm = f32[128,256] custom-call(%p0, %p1), custom_call_target="cublas_gemm"',
    '%out = f32[128,256] custom-call(%gemm), custom_call_target="custom_relu"',
]

print(f"\nHLO with custom calls:")
for line in hlo_representation:
    print(f"  {line}")

**References:**

[üìò JAX Documentation. *FFI ‚Äî Foreign Function Interface.*](https://jax.readthedocs.io/en/latest/ffi.html)

[üìò XLA Documentation. *Custom Calls.*](https://openxla.org/xla/custom_call)

---

[‚¨ÖÔ∏è Previous: JAX Just-in-Time Compilation](./02_jax_just_in_time_compilation.ipynb) | [Next: Operator Fusion ‚û°Ô∏è](../03_Runtime_Topics/01_operator_fusion.ipynb)