In [7]:
# ruff: noqa
import sys, os
sys.path.append(os.path.abspath("./../../feedback-grape"))
sys.path.append(os.path.abspath("./../../"))
sys.path.append(os.path.abspath("./../"))

# ruff: noqa
from feedback_grape.fgrape import evaluate_on_longer_time
from helpers import (
    init_grape_protocol,
    init_fgrape_protocol,
    test_implementations,
    generate_random_discrete_state,
    generate_random_bloch_state
)
from tqdm import tqdm
import jax
import jax.numpy as jnp
import numpy as np
import json

test_implementations()

jax.config.update("jax_enable_x64", True)

# Physical parameters
# (attention! #elements in density matrix grow as 4^N_chains)
N_chains = 3 # Number of parallel chains to simulate
gamma = 0.05 # Decay constant
generate_state = generate_random_bloch_state
evaluation_time_steps = 200 # Number of time steps for evaluation
batch_size = 256 # Number of random states to evaluate in parallel

filenames = [
    #"lut_t=2_l=2_w=11_noise=0.0_s=0_trained_on_bloch_sphere.json",
    #"lut_t=3_l=2_w=111_noise=0.0_s=1_trained_on_bloch_sphere.json",
    #"lut_t=3_l=2_w=111_noise=0.0_s=2_trained_on_discrete_states.json",
    "lut_t=3_l=2_w=111_noise=0.0_s=1_trained_on_bloch_sphere_high_decay.json",
    "lut_t=3_l=2_w=111_noise=0.0_s=2_trained_on_discrete_states_high_decay.json",
]

In [8]:


for filename in filenames:
    with open(f"./optimized/{filename}", "r") as f:
        lut = json.load(f)["optimized_trainable_parameters"]

    system_params = init_fgrape_protocol(jax.random.PRNGKey(0), N_chains, gamma)

    pure_state = lambda key: generate_state(key, N_chains=N_chains, noise_level=0.0)

    eval_result = evaluate_on_longer_time( # Evaluate on longer time and choose best LUT accordingly
        U_0 = pure_state,
        C_target = pure_state,
        system_params = system_params,
        optimized_trainable_parameters = lut,
        num_time_steps = evaluation_time_steps,
        evo_type = "density",
        goal = "fidelity",
        eval_batch_size = batch_size,
        mode = "lookup",
    )
    fidelities = eval_result.fidelity_each_timestep

    np.savez(f"./results_bloch_sphere/{filename[:-5]}.npz", fidelities=fidelities)