Here, the performance of the lookup tables optimized in the "example_F_lookup_training_test.ipynb" on state stabilization for longer time durations can be tested.

In [5]:
# ruff: noqa
import sys, os
sys.path.append(os.path.abspath("./../feedback-grape"))
sys.path.append(os.path.abspath("./../"))

# ruff: noqa
from helpers import (
    generate_random_discrete_state,
    generate_random_bloch_state,
    test_implementations,
    generate_decay_superoperator,
)
import numpy as np
from feedback_grape.utils.fidelity import fidelity
from feedback_grape.utils.operators import sigmam
from library.utils.qubit_chain_1D import embed
from feedback_grape.utils.solver import mesolve
from tqdm import tqdm
import jax
import jax.numpy as jnp

test_implementations()

In [6]:
# Physical parameters
N_chains = 3 # Number of parallel chains to simulate
gamma = 0.05 # Decay constant
evaluation_time_steps = 100 # Number of time steps for evaluation
batch_size = 256 # Number of random states to evaluate in parallel

generate_random_state = generate_random_bloch_state
# Path to save results
results_path = "./evaluation_results/baseline_fidelities.npz"

In [7]:
decay_superoperator = generate_decay_superoperator(N_chains, gamma)

# Evolve all basis states and compute fidelities
keys = jax.random.split(jax.random.PRNGKey(42), batch_size)
states = jnp.array([generate_random_state(key, N_chains, noise_level = 0.0) for key in keys])
target_states = jnp.array([generate_random_state(key, N_chains, noise_level=0.0) for key in keys])

fidelities_each = np.zeros((len(states), evaluation_time_steps+1))
for i, (state, target_state) in enumerate(zip(states, target_states)):
    fidelities_each[i, 0] = fidelity(
        C_target=target_state,
        U_final=state,
        evo_type="density",
    )

def propagate_single_timestep(rho, rho_target):
    tmp = decay_superoperator(rho)

    fid = fidelity(C_target=rho_target, U_final=tmp, evo_type="density")

    return tmp, fid

propagate_single_timestep_vmap = jax.vmap(jax.jit(propagate_single_timestep))

for i in tqdm(range(evaluation_time_steps)):
    states, fid = propagate_single_timestep_vmap(states, target_states)
    fidelities_each[:, i+1] = fid

    #for j, rho in enumerate(states):
    #    assert np.all(np.isclose(rho, rho.conj().T)), "State is not Hermitian"
    #    assert np.isclose(np.trace(rho).real, 1.0), "State is not normalized"
    #    assert np.all(np.linalg.eigvalsh(rho) >= -1e-10), "State is not positive semidefinite"

fidelities_each = fidelities_each.mean(axis=0) # Average over all basis states

# Reshape back to matrix form
np.savez(results_path, fidelities=fidelities_each)

100%|██████████| 100/100 [00:00<00:00, 268.58it/s]


In [8]:
fidelities_each

array([1.        , 0.92570502, 0.86063175, 0.80365187, 0.75377759,
       0.71014377, 0.67199223, 0.63865821, 0.60955859, 0.58418154,
       0.56207762, 0.54285197, 0.52615751, 0.51168902, 0.49917798,
       0.48838801, 0.479111  , 0.4711636 , 0.46438424, 0.4586305 ,
       0.45377678, 0.44971232, 0.44633935, 0.44357165, 0.44133309,
       0.43955647, 0.43818249, 0.43715879, 0.43643916, 0.43598282,
       0.43575377, 0.43572028, 0.43585436, 0.43613133, 0.43652947,
       0.43702964, 0.43761504, 0.43827088, 0.43898418, 0.4397436 ,
       0.44053916, 0.4413622 , 0.44220513, 0.44306137, 0.44392521,
       0.4447917 , 0.4456566 , 0.44651626, 0.44736757, 0.44820789,
       0.44903501, 0.44984707, 0.45064254, 0.45142019, 0.45217903,
       0.45291829, 0.45363739, 0.45433593, 0.45501364, 0.45567039,
       0.45630615, 0.456921  , 0.45751509, 0.45808864, 0.45864194,
       0.45917531, 0.45968913, 0.46018382, 0.46065979, 0.46111752,
       0.46155747, 0.46198013, 0.46238601, 0.46277559, 0.46314