In [1]:
# ruff: noqa
import sys, os
sys.path.append(os.path.abspath("./../feedback-grape"))
sys.path.append(os.path.abspath("./../"))

# ruff: noqa
from feedback_grape.fgrape import evaluate_on_longer_time
from helpers import (
    init_system_params_for_custom_protocol,
    test_implementations,
    state_types,
    lut_from_protocol,
)
from tqdm import tqdm
import json, jax, re
import numpy as np
from library.utils.FgResult_to_dict import FgResult_to_dict
from example_N.custom_models.do_nothing_Nqubits_3 import protocol as do_nothing_protocol
from custom_models.stabilizer_code_Nqubits_3 import protocol as stabilizer_code_protocol
test_implementations()

jax.config.update("jax_enable_x64", True)

# Physical parameters
gamma_p_global = None # None = Take value from filename
gamma_m_global = None # None = Take value from filename
evaluation_state_type = None # None = Take value from filename

# Evaluation parameters
evaluation_time_steps = 200 # Number of time steps for evaluation
batch_size = 128*16 # Number of random states to evaluate in parallel
N_parallel_threads = 2 # Number of parallel threads for training
dir_ = "./10_results/"
protocols = [
    #do_nothing_protocol,
    stabilizer_code_protocol,
]

# Open all filenames in the directory dir_+"/models"
json_files = [f for f in os.listdir(dir_ + "/models") if f.endswith(".json")]
unique_files = []
processed_files = set()
for filename in tqdm(json_files):
    # Skip physically identical files (differing only in sample number)
    s = re.search(r"_s=(\d+)", filename)
    fn = filename.replace(s.group(0), "")
    if fn in processed_files:
        continue
    processed_files.add(fn)
    unique_files.append(filename)
unique_files = [
    f"lut_t=3_l=2_w=111_Nqubits=3_Nmeas=2_gammap={gamma_p}_gammam={gamma_m}_rhot=bloch_rhoe=bloch_s=0.json"
    for gamma_m in np.logspace(np.log10(0.0001), np.log10(0.1), 12)[:6].tolist()
    for gamma_p in np.logspace(np.log10(0.0001), np.log10(0.1), 12)[:6].tolist()
]
print(f"Found {len(unique_files)} unique model files for evaluation.")

# Generate filestructure
for protocol in protocols:
    os.makedirs(dir_ + f"/custom/{protocol['label']}", exist_ok=True)

    # Write Physical and evaluation parameters to a text file
    with open(f"{dir_}/custom/{protocol['label']}/evaluation_parameters.txt", "w") as f:
        f.write(f"Physical parameters:\n")
        f.write(f"gamma_p_global: {gamma_p_global}\n")
        f.write(f"gamma_m_global: {gamma_m_global}\n")
        f.write(f"evaluation_state_type: {evaluation_state_type}\n")
        f.write(f"\nEvaluation parameters:\n")
        f.write(f"evaluation_time_steps: {evaluation_time_steps}\n")
        f.write(f"batch_size: {batch_size}\n")

100%|██████████| 578/578 [00:00<00:00, 772810.87it/s]

Found 36 unique model files for evaluation.





In [2]:
from concurrent.futures import ThreadPoolExecutor

def evaluate_protocol(filename, protocol):
    N_qubits = int(re.search(r"Nqubits=(\d+)_", filename).group(1))
    N_meas = int(re.search(r"Nmeas=(\d+)_", filename).group(1))
    if gamma_p_global is None:
        gamma_p = float(re.search(r"gammap=([\d.]+)_", filename).group(1))
    else:
        gamma_p = gamma_p_global
    if gamma_m_global is None:
        gamma_m = float(re.search(r"gammam=([\d.]+)_", filename).group(1))
    else:
        gamma_m = gamma_m_global

    p = re.search(r"rhoe=([a-zA-Z0-9.]+)_", filename).group(1)
    if evaluation_state_type is None:
        evaluation_state = lambda key: state_types[p](key, N_qubits)
    else:
        filename = filename.replace("_rhoe="+p+"_", f"_rhoe={evaluation_state_type}_")
        evaluation_state = lambda key: state_types[evaluation_state_type](key, N_qubits)

    model = lut_from_protocol(protocol, N_qubits, N_meas)

    system_params = init_system_params_for_custom_protocol(
            N_qubits,
            N_meas,
            gamma_p,
            gamma_m,
        )

    eval_result = evaluate_on_longer_time( # Evaluate on longer time and choose best LUT accordingly
        U_0 = evaluation_state,
        C_target = evaluation_state,
        system_params = system_params,
        optimized_trainable_parameters = model,
        num_time_steps = evaluation_time_steps,
        evo_type = "density",
        goal = "fidelity",
        eval_batch_size = batch_size,
        mode = "lookup",
    )

    with open(f"{dir_}/custom/{protocol["label"]}/{filename[:-4]}.json", "w") as f:
        json.dump(FgResult_to_dict(eval_result), f)

for protocol in protocols:
    if N_parallel_threads == 1:
        for filename in tqdm(unique_files):
            evaluate_protocol(filename, protocol)
    else:
        with ThreadPoolExecutor(max_workers=N_parallel_threads) as executor:
            executor.map(lambda filename: evaluate_protocol(filename, protocol), unique_files)