In [20]:
# ruff: noqa
import sys, os
sys.path.append(os.path.abspath("./../feedback-grape"))
sys.path.append(os.path.abspath("./../"))

# ruff: noqa
from feedback_grape.fgrape import optimize_pulse, Decay, Gate, evaluate_on_longer_time
from feedback_grape.utils.states import basis
from feedback_grape.utils.tensor import tensor
from feedback_grape.utils.operators import sigmap, sigmam
import jax.numpy as jnp
import jax
from feedback_grape.utils.fidelity import ket2dm
from library.utils.qubit_chain_1D import embed
from jax.scipy.linalg import expm
from tqdm import tqdm
import json

jax.config.update("jax_enable_x64", True)

In [21]:
from numpy import identity


def generate_traceless_hermitian(params, dim):
    assert len(params) == dim**2 - 1, "Number of real parameters must be dim^2 - 1 for an NxN traceless Hermitian matrix."
    
    # Read the first (dim**2 - dim) / 2 as the real parts of the upper triangle
    real_parts = jnp.array(params[: (dim**2 - dim) // 2])

    # Read the next (dim**2 - dim) / 2 as the imaginary parts of the upper triangle
    imag_parts = jnp.array(params[(dim**2 - dim) // 2 : - (dim - 1)])

    # Read the last (dim - 1) as the diagonal elements, set the last diagonal element to ensure tracelessness
    trace = sum(params[- (dim - 1):])
    diag_parts = jnp.append(params[- (dim - 1):], jnp.array([-trace]))

    # Construct the Hermitian matrix
    triag_parts = real_parts + 1j * imag_parts

    return jnp.array([
        [
            diag_parts[i] if i == j else
            triag_parts[(i * (i - 1)) // 2 + j - i - 1] if i < j else
            jnp.conj(triag_parts[(j * (j - 1)) // 2 + i - j - 1])
            for j in range(dim)
        ] for i in range(dim)
    ])
generate_traceless_hermitian = jax.jit(generate_traceless_hermitian, static_argnames=['dim'])

def generate_hermitian(params, dim):
    assert len(params) == dim**2, "Number of real parameters must be dim^2 for an NxN Hermitian matrix."
    
    # Generate traceless hermitanian from first dim^2 - 1 parameters and read last parameter as trace
    return generate_traceless_hermitian(params[:-1], dim) + jnp.eye(dim) * params[-1] / dim
generate_hermitian = jax.jit(generate_hermitian, static_argnames=['dim'])

def generate_unitary(params, dim):
    assert len(params) == dim**2, "Number of real parameters must be dim^2 for an NxN Hermitian matrix."

    H = generate_hermitian(params, dim)
    return jax.scipy.linalg.expm(-1j * H)
generate_unitary = jax.jit(generate_unitary, static_argnames=['dim'])

def generate_special_unitary(params, dim):
    assert len(params) == dim**2 - 1, "Number of real parameters must be dim^2 - 1 for an NxN traceless Hermitian matrix."
    
    H = generate_traceless_hermitian(params, dim)
    return jax.scipy.linalg.expm(-1j * H)
generate_special_unitary = jax.jit(generate_special_unitary, static_argnames=['dim'])

def partial_trace(rho, sys_A_dim, sys_B_dim):
    """ Compute the partial trace over system A of a density matrix rho = rho_AB.
        sys_A_dim: Dimension of system A
        sys_B_dim: Dimension of system B
    """
    dim_A = sys_A_dim
    dim_B = sys_B_dim

    #assert rho.shape == (dim_A * dim_B, dim_A * dim_B), "Input density matrix has incorrect dimensions."

    rho_B = jnp.zeros((dim_B, dim_B), dtype=rho.dtype)

    def loop_body(i, rho_B):
        return rho_B + jax.lax.dynamic_slice(rho, (i*dim_B, i*dim_B), (dim_B, dim_B))
    
    rho_B = jax.lax.fori_loop(0, dim_A, loop_body, rho_B) # Compiler friendly loop

    return rho_B
partial_trace = jax.jit(partial_trace, static_argnames=['sys_A_dim', 'sys_B_dim'])

def generate_povm1(measurement_outcome, params):
    """ 
        Generate a 2-outcome POVM elements M_0 and M_1 for a qubit system.
        This function should parametrize all such POVMs up to unitary equivalence, i.e., M_i -> U M_i for some unitary U.
        I.e it parametrizes all pairs (M_0, M_1) such that M_0 M_0† + M_1 M_1† = I.

        measurement_outcome: 0 or 1, indicating which POVM element to generate.
        params: list of 4 real parameters [phi, theta, alpha, beta].

        when measurement_outcome == 1:
            M_1 = S D S†
        when measurement_outcome == -1:
            M_0 = S (I - D) S†

        phi, theta parametrize the unitary S, and alpha, beta parametrize the eigenvalues of M_1.
    """
    phi, theta, alpha, beta = params
    S = jnp.array(
        [[jnp.cos(phi),                   -jnp.sin(phi)*jnp.exp(-1j*theta)],
         [jnp.sin(phi)*jnp.exp(1j*theta),  jnp.cos(phi)                  ]]
    )
    s1 = jnp.sin(alpha)**2
    s2 = jnp.sin(beta)**2
    D_0 = jnp.array(
        [[s1, 0],
         [0,  s2]]
    )
    D_1 = jnp.array(
        [[(1 - s1*s1)**0.5, 0],
         [0, (1 - s2*s2)**0.5]]
    )

    return jnp.where(measurement_outcome == 1,
        tensor(identity(2), S @ D_0 @ S.conj().T),
        tensor(identity(2), S @ D_1 @ S.conj().T)
    )

def generate_povm2(measurement_outcome, params, dim):
    """ 
        Generate a 2-outcome POVM elements M_0 and M_1 for a system with Hilbert space dimension dim.
        This function should parametrize all such POVMs up to unitary equivalence, i.e., M_i -> U M_i for some unitary U.
        I.e it parametrizes all pairs (M_0, M_1) such that M_0 M_0† + M_1 M_1† = I.

        measurement_outcome: 0 or 1, indicating which POVM element to generate.
        params: list of 15 real parameters.

        when measurement_outcome == 1:
            M_1 = S D S†
        when measurement_outcome == -1:
            M_0 = S (I - D) S†

        where S is a unitary parametrized by dim^2 parameters, and D is a diagonal matrix with eigenvalues parametrized by dim parameters.
    """
    S = generate_unitary(params, dim=dim)

    d_vec = jnp.sin( params[dim*(dim-1):dim*dim] ) ** 2

    return jnp.where(measurement_outcome == 1,
        S @ jnp.diag(d_vec) @ S.conj().T,
        S @ jnp.diag(jnp.sqrt(1 - d_vec**2)) @ S.conj().T
    )
generate_povm2 = jax.jit(generate_povm2, static_argnames=['dim'])

def initialize_chain_of_zeros(rho, n):
    """ Initialize double chain of qubits where the first pair is in state rho and the rest in |0><0|. """
    dim = 4**(n - 1)
    rho_zero = jnp.zeros((dim, dim), dtype=rho.dtype)
    rho_zero = rho_zero.at[-1, -1].set(1.0)

    return tensor(rho, rho_zero)
initialize_chain_of_zeros = jax.jit(initialize_chain_of_zeros, static_argnames=['n'])

def gram_schmidt(matrix):
    """ Perform Gram-Schmidt orthonormalization on the rows of the complex input matrix. """
    def body_fun(i, val):
        v_i, Q = val
        for j in range(i):
            q_j = Q[j]
            v_i = v_i - jnp.dot(v_i, q_j.conj()) * q_j
        q_i = v_i / jnp.linalg.norm(v_i)
        Q = Q.at[i].set(q_i)
        return v_i, Q

    num_rows = matrix.shape[0]
    Q = jnp.zeros_like(matrix, dtype=matrix.dtype)

    def gs_loop(i, Q):
        v_i = matrix[i]
        _, Q = body_fun(i, (v_i, Q))
        return Q

    Q = jax.lax.fori_loop(0, num_rows, gs_loop, Q)
    return Q

def transport_unitary(n):
    tau = jnp.pi/2 # fixed time for each transport unitary (cf. paper)

    def theory_J(j, n):
        return (j*(n-j))**0.5

    J_theory = jnp.array([theory_J(j+1, n) for j in range(n-1)])

    H_I = sum([
        g_prime_j*embed(sigmap(), 2*j, 2*n)@embed(sigmam(), 2*(j+1), 2*n) # first qubit chain
        + g_prime_j*embed(sigmap(), 2*j+1, 2*n)@embed(sigmam(), 2*(j+1)+1, 2*n) # second qubit chain
        for j,g_prime_j in enumerate(J_theory)
    ])
    H_I = H_I + H_I.conj().T

    # Need another phase of i^(n-1) for basis vectors with one spin up. (c.f. eq 15 in [1])
    # Hence we add a term with eigenvalue -1 for all those vectors to Homiltonian.
    # When evolving for time tau = pi/2, this adds a exp(-1j*pi/2*(-1)) = i.
    # We multiply by ((n - 1) % 4), to get the desired i^(n-1) term.
    H_phase = - ((n - 1) % 4) * sum(
        [
            embed(jnp.array([[1,0],[0,0]]), 2*j, 2*n) # first qubit chain
            + embed(jnp.array([[1,0],[0,0]]), 2*j + 1, 2*n) # second qubit chain
            for j in range(n)
        ]
    )

    H_I = H_I + H_phase
    
    return expm(-1j*tau*H_I)

def random_state(key):
    """ Generate a pair of up or down states with equal probability. """
    random_value = jax.random.uniform(key, minval=0.0, maxval=1.0)

    psi_one  = basis(2, 0)
    psi_zero = basis(2, 1)
    psi = jnp.where(random_value < 0.5, psi_one, psi_zero)

    return ket2dm(tensor(psi, psi))
random_state = jax.jit(random_state)

In [22]:
# Test unitary and special unitary generators in two dimensions
for i in range(10):
    key = jax.random.PRNGKey(i)
    key, subkey1, subkey2 = jax.random.split(key, 3)

    params = jax.random.uniform(subkey1, (4,), minval=0.0, maxval=2*jnp.pi)

    U = generate_unitary(params, 2)
    SU = generate_special_unitary(params[:-1], 2)

    assert jnp.allclose(U @ U.conj().T, jnp.eye(2)), "Unitary condition failed"
    assert jnp.allclose(SU @ SU.conj().T, jnp.eye(2)), "Special Unitary condition failed"
    assert jnp.isclose(jnp.linalg.det(SU), 1.0), "Determinant condition for Special Unitary failed"

# Test that partial trace works correctly
for i in range(10):
    key = jax.random.PRNGKey(i)
    key, subkey1, subkey2 = jax.random.split(key, 3)

    rho_A = jax.random.normal(subkey1, (2,2)) + 1j * jax.random.normal(subkey1, (2,2))
    rho_A = rho_A @ rho_A.conj().T
    rho_A = rho_A / jnp.trace(rho_A)

    rho_B = jax.random.normal(subkey2, (2,2)) + 1j * jax.random.normal(subkey2, (2,2))
    rho_B = rho_B @ rho_B.conj().T
    rho_B = rho_B / jnp.trace(rho_B)

    rho_AB = tensor(rho_A, rho_B)

    traced_rho_B = partial_trace(rho_AB, 2, 2)

    assert jnp.allclose(traced_rho_B, rho_B), "Partial trace did not return correct result"

# Test povm generator
for i in range(10):
    key = jax.random.PRNGKey(i)
    key, subkey = jax.random.split(key, 2)

    for f,N_params in [(generate_povm1, 4), (lambda msmt, params: generate_povm2(msmt, params, 4), 16)]:
        params = jax.random.uniform(subkey, (N_params,), minval=0.0, maxval=2*jnp.pi)

        M_0 = f(-1, params)
        M_1 = f(1, params)

        assert jnp.allclose(M_0 @ M_0.conj().T + M_1 @ M_1.conj().T, jnp.eye(4)), "POVM elements do not sum to identity"
        assert jnp.allclose(M_0, M_0.conj().T), "POVM element M_0 is not Hermitian"
        assert jnp.allclose(M_1, M_1.conj().T), "POVM element M_1 is not Hermitian"
        assert jnp.all(jnp.linalg.eigvals(M_0) >= 0), "POVM element M_0 is not positive semidefinite"
        assert jnp.all(jnp.linalg.eigvals(M_1) >= 0), "POVM element M_1 is not positive semidefinite"

In [23]:
# Physical parameters
n = 3 # Number of qubit pairs in the chain (attention! elements in density matrix grow as 16^n)
N_samples = 10 # Number of random initial states to sample

# Training parameters
num_time_steps = 2 # Number of time steps
N_training_iterations = 1000 # Number of training iterations
learning_rate = 0.02 # Learning rate
lut_depth = 1 # Length of lookup table for measurement feedback

def initialize_system_params(key):
    subkey1, subkey2 = jax.random.split(key, 2)
    
    initial_gate = Gate(
        gate=lambda rho, _: initialize_chain_of_zeros(rho, n=n),
        initial_params = jnp.array([]),
        measurement_flag = False,
        quantum_channel_flag = True
    )

    T = transport_unitary(n)
    
    T_gate = Gate(
        gate=lambda _: T,
        initial_params = jnp.array([]),
        measurement_flag = False
    )

    ptrace_gate = Gate(
        gate=lambda rho, _: partial_trace(rho, sys_A_dim=4**(n-1), sys_B_dim=4),
        initial_params = jnp.array([]),
        measurement_flag = False,
        quantum_channel_flag = True
    )

    povm_gate = Gate(
        gate=lambda msmt, params: generate_povm2(msmt, params, dim=4),
        initial_params = jax.random.uniform(subkey1, (16,), minval=0.0, maxval=2*jnp.pi),
        measurement_flag = True
    )

    U_gate = Gate(
        gate=lambda params: generate_unitary(params, dim=4),
        initial_params = jax.random.uniform(subkey2, (16,), minval=0.0, maxval=1.0),
        measurement_flag = False
    )

    return [initial_gate, T_gate, ptrace_gate, povm_gate, U_gate]

In [24]:
best_result = None
fidelity_list = []

for s in tqdm(range(N_samples)):
    system_params = initialize_system_params(jax.random.PRNGKey(s))

    result = optimize_pulse(
        U_0=random_state,
        C_target=random_state,
        system_params=system_params,
        num_time_steps=num_time_steps,
        lut_depth=lut_depth,
        mode="lookup",
        goal="fidelity",
        max_iter=N_training_iterations,
        convergence_threshold=1e-6,
        learning_rate=learning_rate,
        evo_type="density",
        batch_size=8,
        eval_batch_size=10
    )

    fidelity_list.append(float(result.final_fidelity))
    if best_result is None or result.final_fidelity > best_result.final_fidelity:
        best_result = result

print(f"Average fidelity over {N_samples} samples: {jnp.mean(jnp.array(fidelity_list)):.2f}, Best fidelity: {best_result.final_fidelity:.2f}")

  x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)
100%|██████████| 10/10 [01:17<00:00,  7.72s/it]

Average fidelity over 10 samples: 0.80, Best fidelity: 1.00





In [25]:
fidelity_list

[0.0,
 0.9999997861590976,
 0.9997231267718165,
 0.9999999107900497,
 0.9996901979740918,
 0.9998569992365101,
 0.9992372907918574,
 0.9999999651646541,
 0.9999989291087288,
 0.0]

In [26]:
best_result.iterations

342