Skip to content

[Feature][QDP] Add Google TPU support via JAX/Pallas backend #1156

@400Ping

Description

@400Ping

Summary

Add a TPU backend for QDP using JAX (and Pallas for custom kernels), so encoding workflows can run on Google TPU instead of being tied to CUDA-only infrastructure.

Primary Motivation

We want to avoid NVIDIA vendor lock-in.
Today, core execution paths are CUDA-centric, which limits deployment flexibility, cost control, and hardware strategy.

Scope

  • Add a backend abstraction layer for device-specific implementations.
  • Implement a jax_tpu backend for:
    • amplitude encoding
    • angle encoding
    • basis encoding
  • Add backend selection/configuration in API (cuda, jax_tpu, ...).
  • Add correctness parity tests against current CUDA outputs (with tolerances).
  • Add docs for setup and backend selection.

Metadata

Metadata

Assignees

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions