In [1]:
import jax
from jax import numpy as jnp
from jax.scipy.linalg import qr

# TLDR
Good idea, accurate, but implementation EXTREMELY memory intensive, using jax.normal works well enough

In [2]:
def get_haar_random_input_state(num_qubits, seed):
    dim = 2**num_qubits  # Dimension of the Hilbert space
    key = jax.random.PRNGKey(seed)

    # 1. Generate a complex matrix where real and imaginary parts are N(0,1)
    # This samples from the Ginibre ensemble.
    complex_matrix_key, qr_key = jax.random.split(key)
    Z = jax.random.normal(complex_matrix_key, (dim, dim)) + \
        1j * jax.random.normal(complex_matrix_key, (dim, dim))

    # 2. Compute a QR decomposition.
    Q, R = qr(Z)

    # 3. Compute the diagonal matrix Lambda from R's diagonal.
    diag_R = jnp.diag(R)
    Lambda = jnp.diag(diag_R / jnp.abs(diag_R))

    # 4. Compute Q' = Q * Lambda, which will be Haar-random.
    U = Q @ Lambda

    # 5. Apply the Haar-random unitary to a fiducial state (e.g., |0...0>).
    # For a pure state, this is usually represented as a column vector.
    initial_state = jnp.zeros(dim, dtype=jnp.complex64)
    initial_state = initial_state.at[0].set(1.0) # |0...0>

    # Apply the unitary to the initial state
    random_state = U @ initial_state

    return random_state

def get_haar_random_input_data(num_qubits, num_vals, seed=0):
    """Generate Haar-random input data for a Pennylane circuit."""
    key = jax.random.PRNGKey(seed)
    return jax.vmap(get_haar_random_input_state, in_axes=(None, 0))(num_qubits, jax.random.split(key, num_vals).flatten())


In [3]:
NUM_QUBITS = 5
NUM_VALS = 10

input_data = get_haar_random_input_data(NUM_QUBITS, NUM_VALS)
input_data

Array([[-8.75854492e-02-8.75854269e-02j,  3.60251800e-03+3.60251823e-03j,
        -3.80960107e-02-3.80960107e-02j,  1.27542362e-01+1.27542347e-01j,
        -2.94704437e-01-2.94704437e-01j,  2.25514341e-02+2.25514323e-02j,
         1.20910294e-01+1.20910302e-01j,  3.65596890e-01+3.65596920e-01j,
         1.18467575e-02+1.18467584e-02j,  1.43082382e-03+1.43082382e-03j,
         1.17756777e-01+1.17756769e-01j,  5.60371242e-02+5.60371205e-02j,
         9.00723157e-04+9.00723157e-04j, -2.44235969e-03-2.44235946e-03j,
        -9.29441899e-02-9.29441899e-02j, -2.21627474e-01-2.21627474e-01j,
         9.59815010e-02+9.59814936e-02j, -1.48696993e-02-1.48696983e-02j,
         1.14258923e-01+1.14258923e-01j,  4.14645858e-02+4.14645895e-02j,
        -2.38429010e-02-2.38429010e-02j,  8.98373201e-02+8.98373201e-02j,
         6.64911047e-02+6.64911121e-02j, -2.84422822e-02-2.84422822e-02j,
        -1.26324713e-01-1.26324713e-01j, -1.00990152e-02-1.00990152e-02j,
         9.51910838e-02+9.51910764e-02