In [19]:
"""
Implementation of OASIS fidelity estimation.
"""

import itertools
import numpy as np
from scipy.optimize import linprog


def haar_random_state(dim, rng=None):
    if rng is None:
        rng = np.random.default_rng()
    v = rng.normal(size=(dim,)) + 1j * rng.normal(size=(dim,))
    v /= np.linalg.norm(v)
    return v.astype(np.complex128)


def target_and_rho(n, fidelity=0.9, rng=None):
    """
    ρ = fidelity * O + (1 - fidelity) * I / 2^n.
    """
    if rng is None:
        rng = np.random.default_rng()
    dim = 2 ** n
    psi = haar_random_state(dim, rng)
    O = np.outer(psi, psi.conj())
    rho = fidelity * O + (1.0 - fidelity) * np.eye(dim, dtype=np.complex128) / dim
    return psi, O, rho


def get_single_qubit_measurement_unitaries():
    I = np.eye(2, dtype=np.complex128)
    H = (1 / np.sqrt(2)) * np.array([[1, 1],
                                     [1, -1]], dtype=np.complex128)
    Sdg = np.array([[1, 0],[0, -1j]], dtype=np.complex128)
    return {'Z': I, 'X': H, 'Y': H @ Sdg}


_SINGLE_QUBIT_U = get_single_qubit_measurement_unitaries()


def build_n_qubit_U(pauli_axes):
    """
    Build n-qubit unitary U for a tensor-product Pauli measurement setting.

    pauli_axes: iterable of length n with entries in {'X', 'Y', 'Z'}
    returns: 2^n x 2^n unitary
    """
    U = np.array([[1.0 + 0j]], dtype=np.complex128)
    for ax in pauli_axes:
        U = np.kron(U, _SINGLE_QUBIT_U[ax])
    return U


def build_povm(n):
    dim = 2 ** n
    pauli_labels = ['X', 'Y', 'Z']
    settings = list(itertools.product(pauli_labels, repeat=n))  # 3^n settings
    num_U = len(settings)
    pU = 1.0 / num_U  # p(U) = 1 / 3^n

    # Precompute computational basis projectors |b><b|
    proj_b = []
    for b in range(dim):
        e = np.zeros((dim,), dtype=np.complex128)
        e[b] = 1.0
        proj_b.append(np.outer(e, e.conj()))

    U_list = []
    Pi_list = []
    for axes in settings:
        U = build_n_qubit_U(axes)
        Udag = U.conj().T
        U_list.append(U)
        for Pb in proj_b:
            Pi = pU * (Udag @ Pb @ U)
            Pi_list.append(Pi)

    U_list = np.array(U_list)  # shape (num_U, dim, dim)
    return settings, U_list, pU, Pi_list


def solve_oasis_gt_lp(n, O, Pi_list, pU, num_U):
    """
    Solve:

      minimize   ∑_U p(U) t_U
      subject to -t_U ≤ ω_{U,b} ≤ t_U  for all U,b
                 ∑_{U,b} ω_{U,b} Π_{U,b} = O
    """
    dim = 2 ** n
    num_b = dim
    num_meas = num_U * num_b

    Nw = num_meas
    Nt = num_U
    Nvar = Nw + Nt

    # Equality constraints A_eq x = b_eq (real + imag parts)
    num_eq = 2 * dim * dim
    A_eq = np.zeros((num_eq, Nvar), dtype=float)
    b_eq = np.zeros(num_eq, dtype=float)

    row = 0
    for a in range(dim):
        for c in range(dim):
            # Real part constraint
            for j, Pi in enumerate(Pi_list):
                A_eq[row, j] = Pi[a, c].real
            b_eq[row] = O[a, c].real
            row += 1

            # Imaginary part constraint
            for j, Pi in enumerate(Pi_list):
                A_eq[row, j] = Pi[a, c].imag
            b_eq[row] = O[a, c].imag
            row += 1

    # Inequality constraints A_ub x <= b_ub for |ω_{U,b}| <= t_U
    num_ub = 2 * Nw
    A_ub = np.zeros((num_ub, Nvar), dtype=float)
    b_ub = np.zeros(num_ub, dtype=float)

    row = 0
    for U_index in range(num_U):
        t_col = Nw + U_index  # column index for t_U
        for b_idx in range(num_b):
            j = U_index * num_b + b_idx  # index for ω_{U,b}

            # ω_{U,b} - t_U <= 0
            A_ub[row, j] = 1.0
            A_ub[row, t_col] = -1.0
            row += 1

            # -ω_{U,b} - t_U <= 0   (i.e., t_U + ω_{U,b} >= 0)
            A_ub[row, j] = -1.0
            A_ub[row, t_col] = -1.0
            row += 1

    # Objective: minimize ∑_U p(U) t_U
    c = np.zeros(Nvar, dtype=float)
    for U_index in range(num_U):
        c[Nw + U_index] = pU  # same pU for all U (uniform)

    # Variable bounds
    bounds = []
    for _ in range(Nw):
        bounds.append((None, None))   # ω_{U,b} unrestricted (real)
    for _ in range(Nt):
        bounds.append((0.0, None))    # t_U ≥ 0

    res = linprog(c,
                  A_ub=A_ub, b_ub=b_ub,
                  A_eq=A_eq, b_eq=b_eq,
                  bounds=bounds,
                  method='highs')

    if not res.success:
        raise RuntimeError("LP failed: " + res.message)

    x = res.x
    omega_flat = x[:Nw]
    t = x[Nw:]

    omega = omega_flat.reshape((num_U, num_b))
    return omega, t


def compute_q_distribution(omega, pU_scalar):
    """
    Given ω_{U,b} and default distribution p(U), compute:

        max_abs_U = max_b |ω_{U,b}|
        q(U) = p(U) max_abs_U / Σ_{U'} p(U') max_abs_{U'}
    """
    max_abs = np.max(np.abs(omega), axis=1)  # shape (num_U,)
    weights = pU_scalar * max_abs
    total = weights.sum()
    if total <= 0:
        raise ValueError("All max |ω_{U,b}| are zero; q(U) undefined.")
    q = weights / total
    return q, max_abs


def oasis_gt_estimator(rho, omega, U_list, pU_scalar, num_shots, rng=None):
    """
    Monte Carlo estimator of tr(ρ O) using importance sampling.
    """
    if rng is None:
        rng = np.random.default_rng()

    num_U, dim, _ = U_list.shape
    assert omega.shape == (num_U, dim)

    q, max_abs = compute_q_distribution(omega, pU_scalar)

    # We only sample from U with q(U) > 0
    support = np.nonzero(q > 1e-15)[0]
    if len(support) == 0:
        raise ValueError("No U with positive q(U).")

    q_support = q[support].copy()
    q_support /= q_support.sum()

    est_sum = 0.0
    for _ in range(num_shots):
        # Sample U index from support
        idx = rng.choice(len(support), p=q_support)
        U_index = support[idx]
        U = U_list[U_index]

        # Rotate state
        rho_prime = U @ rho @ U.conj().T

        # Computational-basis measurement probabilities (diag of ρ')
        probs = np.real_if_close(np.diag(rho_prime))
        probs = np.clip(probs.real, 0.0, None)
        total_p = probs.sum()
        if total_p <= 0:
            raise ValueError("Non-positive total probability in measurement.")
        probs /= total_p

        # Sample outcome b
        b = rng.choice(dim, p=probs)

        # Importance-sampling score
        S = omega[U_index, b] * pU_scalar / q[U_index]
        est_sum += S

    return est_sum / num_shots


def main():
    # CONFIG
    n = 3
    fidelity = 0.9
    num_shots = 4425
    seed = 0

    rng = np.random.default_rng(seed)

    psi, O, rho = target_and_rho(n, fidelity=fidelity, rng=rng)
    dim = 2 ** n

    settings, U_list, pU, Pi_list = build_povm(n)
    num_U = len(settings)

    print(f"n = {n}, dim = {dim}, number of Pauli settings = {num_U}")
    print(f"Number of POVM elements Π_(U,b) = {len(Pi_list)}")

    print("Solving LP for ω_{U,b} ...")
    omega, t = solve_oasis_gt_lp(n, O, Pi_list, pU, num_U=num_U)
    print("LP solved.")

    # Optional: check reconstruction accuracy of O ≈ sum ω Π
    O_rec = np.zeros_like(O, dtype=np.complex128)
    num_b = dim
    for U_index in range(num_U):
        for b_idx in range(num_b):
            Pi = Pi_list[U_index * num_b + b_idx]
            O_rec += omega[U_index, b_idx] * Pi
    rec_err = np.linalg.norm(O_rec - O)
    print(f"Reconstruction ‖Σ ω Π - O‖_F = {rec_err:.3e}")

    true_fid = float(np.real(np.trace(rho @ O)))
    print(f"True fidelity  tr(ρ O)       = {true_fid}")

    print(f"Running estimator with N = {num_shots} shots ...")
    est = oasis_gt_estimator(rho, omega, U_list, pU, num_shots=num_shots, rng=rng)
    print(f"Estimated fidelity           = {est}")
    print(f"Squared error                = {(true_fid-est)**2:.3e}")


if __name__ == "__main__":
    main()


n = 3, dim = 8, number of Pauli settings = 27
Number of POVM elements Π_(U,b) = 216
Solving LP for ω_{U,b} ...
LP solved.
Reconstruction ‖Σ ω Π - O‖_F = 2.924e-15
True fidelity  tr(ρ O)       = 0.9125000000000002
Running estimator with N = 4425 shots ...
Estimated fidelity           = 0.9039289986037831
Squared error                = 7.346e-05
