In [5]:
# ruff: noqa
import sys, os
sys.path.append(os.path.abspath("./../../feedback-grape"))
sys.path.append(os.path.abspath("./../../"))
sys.path.append(os.path.abspath("./../"))

# ruff: noqa
from feedback_grape.fgrape import evaluate_on_longer_time
from helpers import (
    init_fgrape_protocol,
    test_implementations,
    generate_random_discrete_state,
    generate_random_bloch_state
)
from tqdm import tqdm
import jax
import jax.numpy as jnp
import numpy as np
import json

test_implementations()

jax.config.update("jax_enable_x64", True)

# Physical parameters
# (attention! #elements in density matrix grow as 4^n*N_chains)
N_chains = 3 # Number of parallel chains to simulate
n = 3 # Number of qubits per chain
gamma = 0.25 # Decay constant
generate_state = generate_random_bloch_state
evaluation_time_steps = 50 # Number of time steps for evaluation
batch_size = 16 # Number of random states to evaluate in parallel

N_parallel_threads = 1 # Number of parallel threads for training

filenames = [
    "lut_t=2_l=2_w=11_noise=0.0_s=0_trained_on_bloch_sphere.json",
    "lut_t=3_l=2_w=111_noise=0.0_s=1_trained_on_bloch_sphere.json",
    "lut_t=3_l=2_w=111_noise=0.0_s=2_trained_on_discrete_states.json",
    "lut_t=3_l=2_w=111_noise=0.0_s=1_trained_on_bloch_sphere_high_decay.json",
    "lut_t=3_l=2_w=111_noise=0.0_s=2_trained_on_discrete_states_high_decay.json",
]
results_dir = "./results_bloch_sphere/"

In [6]:
from concurrent.futures import ThreadPoolExecutor

def run_experiment(filename):
    with open(f"./optimized/{filename}", "r") as f:
        lut = json.load(f)["optimized_trainable_parameters"]
        lut["initial_params"].insert(0, []) # Insert placeholder for non-parametrized gates
        lut["initial_params"].insert(0, [])
        lut["initial_params"].insert(0, [])
    system_params = init_fgrape_protocol(jax.random.PRNGKey(0), n, N_chains, gamma)

    pure_state = lambda key: generate_state(key, N_chains=N_chains, noise_level=0.0)

    eval_result = evaluate_on_longer_time( # Evaluate on longer time and choose best LUT accordingly
        U_0 = pure_state,
        C_target = pure_state,
        system_params = system_params,
        optimized_trainable_parameters = lut,
        num_time_steps = evaluation_time_steps,
        evo_type = "density",
        goal = "fidelity",
        eval_batch_size = batch_size,
        mode = "lookup",
    )
    fidelities = eval_result.fidelity_each_timestep

    np.savez(f"{results_dir}{filename[:-5]}.npz", fidelities=fidelities)

if N_parallel_threads == 1:
    for filename in filenames:
        print("\nStarting new experiment")
        run_experiment(filename)
else:
    with ThreadPoolExecutor(max_workers=N_parallel_threads) as executor:
        executor.map(run_experiment, filenames)


Starting new experiment
starting vmap calculation of trajectories
Calculating time step 1/50
Calculating time step 2/50
Calculating time step 3/50
Calculating time step 4/50
Calculating time step 5/50
Calculating time step 6/50
Calculating time step 7/50
Calculating time step 8/50
Calculating time step 9/50
Calculating time step 10/50
Calculating time step 11/50
Calculating time step 12/50
Calculating time step 13/50
Calculating time step 14/50
Calculating time step 15/50
Calculating time step 16/50
Calculating time step 17/50
Calculating time step 18/50
Calculating time step 19/50
Calculating time step 20/50
Calculating time step 21/50
Calculating time step 22/50
Calculating time step 23/50
Calculating time step 24/50
Calculating time step 25/50
Calculating time step 26/50
Calculating time step 27/50
Calculating time step 28/50
Calculating time step 29/50
Calculating time step 30/50
Calculating time step 31/50
Calculating time step 32/50
Calculating time step 33/50
Calculating time s