In [7]:
# ruff: noqa
import sys, os
sys.path.append(os.path.abspath("./../feedback-grape"))
sys.path.append(os.path.abspath("./../"))

# ruff: noqa
from feedback_grape.fgrape import optimize_pulse, evaluate_on_longer_time
from helpers import (
    init_fgrape_protocol,
    test_implementations,
    generate_superposition_state,
    experiment_param_formats,
)
from tqdm import tqdm
from library.utils.FgResult_to_dict import FgResult_to_dict
import json, jax

test_implementations()

In [8]:
# Training and evaluation parameters
dir_ = "./01_results"

training_params = {
    "N_samples": 1, # Number of random initial states to sample
    "samples_offset": 0, # Offset for sample indexing
    "N_training_iterations": 10000, # Number of training iterations
    "learning_rate": 0.01, # Learning rate
    "convergence_threshold": 1e-6,
    "batch_size": 8,
    "eval_batch_size": 128,
    "evaluation_time_steps": 20,
}
N_parallel_threads = 1 # Number of parallel threads for training

# Physical Parameters to test

#num_time_steps : Number of time steps in the control pulse
#lut_memory : Amount of measurements the LUT remembers (depth of LUT)
#reward_weights : Weights for the reward at each time step. Default only weights last timestep [0, 0, ... 0, 1]
#N_qubits : Number of parallel chains to simulate
experiments = [ # (timesteps, reward_weights, N_qubits, N_meas, gamma_z, gamma_m)
    (3, [1,1,1], 2, 1, 0.1, 0.1),
    #(2, [1,1], 3, 2, 0.1, 0.1),
    #(2, [1,1], 4, 3, 0.1, 0.1),
    #(2, [1,1], 5, 4, 0.1, 0.1),
    #(2, [1,1], 6, 5, 0.1, 0.1),
]

# Check if directory structure exists, create if not
for dir__ in [dir_, f"{dir_}/models", f"{dir_}/training_params", f"{dir_}/eval"]:
    os.makedirs(dir__, exist_ok=True)

# Generate parameter-based file paths
def generate_param_paths(dir_, params, s):
    parts = []
    for (name, param_type), param in zip(experiment_param_formats, params):
        assert type(param) == param_type, f"Parameter {name} has type {type(param)}, expected {param_type}."
        if type(param) == list:
            param_str = "".join(str(x) for x in param)
        else:
            param_str = str(param)
        parts.append(f"{name}={param_str}")

    pstr = "_".join(parts)
    pstr += f"_s={s}"
    
    return (
        f"{dir_}/models/{pstr}.json",
        f"{dir_}/training_params/{pstr}.json",
        f"{dir_}/eval/{pstr}.json",
    )

In [9]:
from concurrent.futures import ThreadPoolExecutor

def run_experiment(params):
    (
        num_time_steps,
        reward_weights,
        N_qubits,
        N_meas,
        gamma_z,
        gamma_m,
    ) = params
    
    print(f"Evaluating {params}")
    rho = generate_superposition_state(N_qubits)

    for s in tqdm(range(training_params["samples_offset"], training_params["N_samples"] + training_params["samples_offset"])):
        path1, path2, path3 = generate_param_paths(dir_, params, s)
        #if os.path.exists(path1):
        #    print(f"Skipping existing experiment for params {params} and seed {s}.")
        #    continue

        with open(path2, "w") as f:
            json.dump(training_params, f)

        system_params = init_fgrape_protocol(
            jax.random.PRNGKey(s),
            N_qubits,
            N_meas,
            gamma_z,
            gamma_m,
        )

        # Train LUT
        result = optimize_pulse(
            U_0=rho,
            C_target=rho,
            system_params=system_params,
            num_time_steps=num_time_steps,
            reward_weights=reward_weights,
            mode="nn",
            goal="fidelity",
            max_iter=training_params["N_training_iterations"],
            convergence_threshold=training_params["convergence_threshold"],
            learning_rate=training_params["learning_rate"],
            evo_type="density",
            batch_size=training_params["batch_size"],
            eval_batch_size=1, # This evaluation is discarded
            progress=True,
        )

        with open(path1, "w") as f:
            json.dump(FgResult_to_dict(result), f)

        eval_result = evaluate_on_longer_time( # Evaluate on longer time
            U_0 = rho,
            C_target = rho,
            system_params = system_params,
            optimized_trainable_parameters = result.optimized_trainable_parameters,
            num_time_steps = training_params["evaluation_time_steps"],
            evo_type = "density",
            goal = "fidelity",
            eval_batch_size = training_params["eval_batch_size"],
            mode = "nn",
        )

        with open(path3, "w") as f:
            json.dump(FgResult_to_dict(eval_result), f)
            del eval_result  # Free memory
        del result  # Free memory
        del system_params  # Free memory

# Run experiments in parallel
if N_parallel_threads > 1:
    with ThreadPoolExecutor(max_workers=N_parallel_threads) as executor:
        executor.map(run_experiment, experiments)
else:
    for params in experiments:
        run_experiment(params)

# Play a sound when done
os.system('say "done."')

Evaluating (3, [1, 1, 1], 2, 1, 0.1, 0.1)


  0%|          | 0/1 [00:00<?, ?it/s]

Iteration 10, Loss: -0.406589, T=0s, eta=26s
Iteration 20, Loss: -0.509377, T=0s, eta=27s
Iteration 30, Loss: -0.517850, T=0s, eta=28s
Iteration 40, Loss: -0.523290, T=0s, eta=30s
Iteration 50, Loss: -0.557825, T=0s, eta=33s
Iteration 60, Loss: -0.626103, T=0s, eta=33s
Iteration 70, Loss: -0.512271, T=0s, eta=35s
Iteration 80, Loss: -0.567076, T=0s, eta=38s
Iteration 90, Loss: -0.614312, T=0s, eta=38s
Iteration 100, Loss: -0.639389, T=0s, eta=38s
Iteration 110, Loss: -0.585400, T=0s, eta=37s
Iteration 120, Loss: -0.636840, T=0s, eta=37s
Iteration 130, Loss: -0.572076, T=0s, eta=37s
Iteration 140, Loss: -0.706890, T=0s, eta=36s
Iteration 150, Loss: -0.738864, T=0s, eta=36s
Iteration 160, Loss: -0.784133, T=0s, eta=36s
Iteration 170, Loss: -0.835147, T=0s, eta=35s
Iteration 180, Loss: -0.637041, T=0s, eta=35s
Iteration 190, Loss: -0.432075, T=0s, eta=35s
Iteration 200, Loss: -0.494759, T=0s, eta=35s
Iteration 210, Loss: -0.243678, T=0s, eta=35s
Iteration 220, Loss: -0.691832, T=0s, eta=3

100%|██████████| 1/1 [00:32<00:00, 32.94s/it]


0