Here, the performance of the lookup tables optimized in the "example_F_lookup_training_test.ipynb" on state stabilization for longer time durations can be tested.

In [1]:
# ruff: noqa
import sys, os
sys.path.append(os.path.abspath("./../feedback-grape"))
sys.path.append(os.path.abspath("./../"))

# ruff: noqa
from helpers import (
    transport_unitary,
    initialize_chain_of_zeros,
    partial_trace,
    generate_random_discrete_state,
    generate_random_bloch_state,
    test_implementations,
)
import numpy as np
from feedback_grape.utils.fidelity import fidelity
from feedback_grape.utils.operators import sigmam
from library.utils.qubit_chain_1D import embed
from feedback_grape.utils.solver import mesolve
from tqdm import tqdm
import jax
import jax.numpy as jnp

test_implementations()

In [2]:
# Physical parameters
n = 2 # number of qubits per chain (>= 3)
N_chains = 3 # Number of parallel chains to simulate
gamma = 0.25 # Decay constant
evaluation_time_steps = 100 # Number of time steps for evaluation
batch_size = 256 # Number of random states to evaluate in parallel

generate_random_state = generate_random_bloch_state
# Path to save results
results_path = "./evaluation_results/baseline_fidelities.npz"

In [3]:
base_dim = 2**N_chains
T_half = transport_unitary(0.5, n, N_chains)
c_ops = [sum([gamma * embed(sigmam(), idx, N_chains*n) for idx in range(N_chains*n)])]

# Evolve all basis states and compute fidelities
keys = jax.random.split(jax.random.PRNGKey(42), batch_size)
states = jnp.array([generate_random_state(key, N_chains, noise_level = 0.0) for key in keys])
target_states = jnp.array([generate_random_state(key, N_chains, noise_level=0.0) for key in keys])

fidelities_each = np.zeros((len(states), evaluation_time_steps+1))
for i, (state, target_state) in enumerate(zip(states, target_states)):
    fidelities_each[i, 0] = fidelity(
        C_target=target_state,
        U_final=state,
        evo_type="density",
    )

def propagate_single_timestep(rho, rho_target):
    tmp = initialize_chain_of_zeros(rho, n=n, N_chains=N_chains)
    tmp = T_half @ tmp @ T_half.conj().T

    tmp = mesolve(
        jump_ops=c_ops,
        rho0=tmp,
    )

    tmp = T_half @ tmp @ T_half.conj().T
    tmp = partial_trace(tmp, sys_A_dim=base_dim**(n-1), sys_B_dim=base_dim)

    fid = fidelity(C_target=rho_target, U_final=tmp, evo_type="density")

    return tmp, fid

propagate_single_timestep_vmap = jax.vmap(jax.jit(propagate_single_timestep))

for i in tqdm(range(evaluation_time_steps)):
    states, fid = propagate_single_timestep_vmap(states, target_states)
    fidelities_each[:, i+1] = fid

    #for j, rho in enumerate(states):
    #    assert np.all(np.isclose(rho, rho.conj().T)), "State is not Hermitian"
    #    assert np.isclose(np.trace(rho).real, 1.0), "State is not normalized"
    #    assert np.all(np.linalg.eigvalsh(rho) >= -1e-10), "State is not positive semidefinite"

fidelities_each = fidelities_each.mean(axis=0) # Average over all basis states

# Reshape back to matrix form
np.savez(results_path, fidelities=fidelities_each)

100%|██████████| 100/100 [01:35<00:00,  1.04it/s]


In [4]:
fidelities_each

array([1.        , 0.92197055, 0.85573523, 0.79992903, 0.75321514,
       0.71433558, 0.68213931, 0.65559516, 0.63379448, 0.6159473 ,
       0.60137454, 0.58949808, 0.57982993, 0.57196134, 0.56555231,
       0.56032187, 0.55603926, 0.55251606, 0.54959932, 0.54716556,
       0.54511562, 0.54337037, 0.54186696, 0.54055577, 0.53939786,
       0.53836282, 0.53742706, 0.53657238, 0.53578478, 0.53505361,
       0.53437074, 0.53373001, 0.53312676, 0.53255743, 0.53201931,
       0.53151027, 0.53102865, 0.53057306, 0.53014236, 0.52973552,
       0.52935162, 0.52898978, 0.52864915, 0.52832891, 0.52802822,
       0.52774627, 0.52748223, 0.52723528, 0.5270046 , 0.52678937,
       0.52658879, 0.52640207, 0.52622843, 0.52606712, 0.52591742,
       0.5257786 , 0.52564999, 0.52553094, 0.52542081, 0.52531902,
       0.525225  , 0.52513821, 0.52505814, 0.52498432, 0.5249163 ,
       0.52485364, 0.52479596, 0.52474289, 0.52469407, 0.52464918,
       0.52460793, 0.52457002, 0.52453521, 0.52450324, 0.52447