In [None]:
# ruff: noqa
import sys, os
sys.path.append(os.path.abspath("./../feedback-grape"))
sys.path.append(os.path.abspath("./../"))

# ruff: noqa
from helpers import (
    test_implementations,
    state_types,
    calculate_baseline
)
from tqdm import tqdm
import jax, re
import numpy as np

test_implementations()

jax.config.update("jax_enable_x64", True)

# Physical parameters
gamma_p_global = None # None = Take value from filename
gamma_m_global = None # None = Take value from filename
evaluation_state_type = None # None = Take value from filename

# Evaluation parameters
evaluation_time_steps = 20 # Number of time steps for evaluation
batch_size = 128 # Number of random states to evaluate in parallel
dir_ = "./12_results/"


# Open all filenames in the directory dir_+"/models"
json_files = [f for f in os.listdir(dir_ + "/models") if f.endswith(".json")]
unique_files = []
processed_files = set()
for filename in tqdm(json_files):
    # Skip physically identical files (differing only in sample number)
    s = re.search(r"_s=(\d+)", filename)
    fn = filename.replace(s.group(0), "")
    if fn in processed_files:
        continue
    processed_files.add(fn)
    unique_files.append(filename)
print(f"Found {len(unique_files)} unique model files for evaluation.")
os.makedirs(dir_ + "/baseline", exist_ok=True)

# Write Physical and evaluation parameters to a text file
with open(f"{dir_}/baseline/evaluation_parameters.txt", "w") as f:
    f.write(f"Physical parameters:\n")
    f.write(f"gamma_p_global: {gamma_p_global}\n")
    f.write(f"gamma_m_global: {gamma_m_global}\n")
    f.write(f"evaluation_state_type: {evaluation_state_type}\n")
    f.write(f"\nEvaluation parameters:\n")
    f.write(f"evaluation_time_steps: {evaluation_time_steps}\n")
    f.write(f"batch_size: {batch_size}\n")

100%|██████████| 19/19 [00:00<00:00, 91389.65it/s]

Found 19 unique model files for evaluation.





In [2]:
for filename in tqdm(unique_files):
    N_qubits = int(re.search(r"Nqubits=(\d+)_", filename).group(1))
    N_meas = int(re.search(r"Nmeas=(\d+)_", filename).group(1))
    if gamma_p_global is None:
        gamma_p = float(re.search(r"gammap=([\d.]+)_", filename).group(1))
    else:
        gamma_p = gamma_p_global
    if gamma_m_global is None:
        gamma_m = float(re.search(r"gammam=([\d.]+)_", filename).group(1))
    else:
        gamma_m = gamma_m_global

    if evaluation_state_type is None:
        state_type_str = re.search(r"rhoe=([a-zA-Z]+)_", filename).group(1)
        generate_state = state_types[state_type_str]
    else:
        generate_state = state_types[evaluation_state_type]

    fidelities_each, states_each = calculate_baseline(
        N_qubits=N_qubits,
        gamma_p=gamma_p,
        gamma_m=gamma_m,
        evaluation_time_steps=evaluation_time_steps,
        batch_size=batch_size,
        generate_state=generate_state,
        key=jax.random.PRNGKey(0),
    )
    fidelities_each = fidelities_each.mean(axis=0) # Average over all basis states

    np.savez(f"{dir_}/baseline/{filename[:-4]}.npz", fidelities=fidelities_each)

100%|██████████| 2/2 [00:02<00:00,  1.44s/it]
