Summary
Add a TPU backend for QDP using JAX (and Pallas for custom kernels), so encoding workflows can run on Google TPU instead of being tied to CUDA-only infrastructure.
Primary Motivation
We want to avoid NVIDIA vendor lock-in.
Today, core execution paths are CUDA-centric, which limits deployment flexibility, cost control, and hardware strategy.
Scope
- Add a backend abstraction layer for device-specific implementations.
- Implement a
jax_tpu backend for:
- amplitude encoding
- angle encoding
- basis encoding
- Add backend selection/configuration in API (
cuda, jax_tpu, ...).
- Add correctness parity tests against current CUDA outputs (with tolerances).
- Add docs for setup and backend selection.
Summary
Add a TPU backend for QDP using JAX (and Pallas for custom kernels), so encoding workflows can run on Google TPU instead of being tied to CUDA-only infrastructure.
Primary Motivation
We want to avoid NVIDIA vendor lock-in.
Today, core execution paths are CUDA-centric, which limits deployment flexibility, cost control, and hardware strategy.
Scope
jax_tpubackend for:cuda,jax_tpu, ...).