In [1]:
# ruff: noqa
import sys, os
sys.path.append(os.path.abspath("./../feedback-grape"))
sys.path.append(os.path.abspath("./../"))

# ruff: noqa
from feedback_grape.fgrape import optimize_pulse, evaluate_on_longer_time
from helpers import (
    init_fgrape_protocol,
    test_implementations,
    generate_random_discrete_state,
    generate_random_bloch_state,
    generate_excited_state,
    generate_ground_state,
)
from tqdm import tqdm
import jax
import jax.numpy as jnp
import numpy as np
from library.utils.FgResult_to_dict import FgResult_to_dict
import json

test_implementations()

In [2]:
# Training and evaluation parameters
dir_ = "./00_results"

training_params = {
    "N_samples": 3, # Number of random initial states to sample
    "samples_offset": 0, # Offset for sample indexing
    "N_training_iterations": 10000, # Number of training iterations
    "learning_rate": 0.01, # Learning rate
    "convergence_threshold": 1e-6,
    "batch_size": 16,
    "eval_batch_size": 16,
    "evaluation_time_steps": 50,
}
N_parallel_threads = 1 # Number of parallel threads for training

# Physical Parameters to test

#num_time_steps : Number of time steps in the control pulse
#lut_memory : Amount of measurements the LUT remembers (depth of LUT)
#reward_weights : Weights for the reward at each time step. Default only weights last timestep [0, 0, ... 0, 1]
#N_qubits : Number of parallel chains to simulate

experiments = [ # (timesteps, lut_memory, reward_weights, N_qubits, N_meas, gamma_p, gamma_m, training_state_type, evaluation_state_type)
    (3, 2, [1,1,1], 3, 1, 0.05, 0.05, "bloch", "bloch"),
]


# Formatting specification for experiment parameters
experiments_format = [ # Format of variables as tuples (name, type, required by grape, required by lut, required by rnn)
    ("t", int, True, True, True), # number timesteps (or segments)
    ("l", int, False, True, False), # LUT memory
    ("w", list, True, True, True), # reward weights
    ("N_qubits", int, True, True, True), # number of chains
    ("N_meas", int, True, True, True), # number of measurements per timestep
    ("gamma_p", float, True, True, True), # sig^+ rate
    ("gamma_m", float, True, True, True), # sig^- rate
    ("rho_t", str, True, True, True), # training state type
    ("rho_e", str, True, True, True), # evaluation state type
]

state_types = {
    "discrete": generate_random_discrete_state,
    "bloch": generate_random_bloch_state,
    "excited": generate_excited_state,
    "ground": generate_ground_state,
}


# Check if experiment parameters are in specified format
for params in experiments:
    assert len(params) == len(experiments_format), "Experiment parameters length mismatch."
    for i, (param, (name, expected_type, req_grape, req_lut, req_rnn)) in enumerate(zip(params, experiments_format)):
        assert type(name) == str, "Parameter name must be a string."
        assert type(req_grape) == bool, "Requirement flags must be boolean."
        assert type(req_lut) == bool, "Requirement flags must be boolean."
        assert type(req_rnn) == bool, "Requirement flags must be boolean."
        assert type(param) == expected_type, f"Parameter '{name}' at index {i} must be of type {expected_type.__name__}."
        if type(param) == list:
            for j, item in enumerate(param):
                assert type(item) == int and 0 <= item <= 9, f"Item at index {j} in parameter '{name}' must be of type int and between 0 and 9."

# Check if directory structure exists, create if not
for dir__ in [dir_, f"{dir_}/models", f"{dir_}/training_params", f"{dir_}/eval"]:
    os.makedirs(dir__, exist_ok=True)

# Generate parameter-based file paths
def generate_param_paths(dir_, params, model_type, s):
    assert model_type in ["grape", "lut", "rnn"], "Invalid model type."

    parts = []
    for (name, expected_type, req_grape, req_lut, req_rnn), param in zip(experiments_format, params):
        required = (model_type == "grape" and req_grape) or (model_type == "lut" and req_lut) or (model_type == "rnn" and req_rnn)
        if required:
            if expected_type == list:
                param_str = "".join(str(x) for x in param)
            else:
                param_str = str(param)
            parts.append(f"{name}={param_str}")

    pstr = f"{model_type}_" + "_".join(parts)
    pstr += f"_s={s}"
    
    return (
        f"{dir_}/models/{pstr}.json",
        f"{dir_}/training_params/{pstr}.json",
        f"{dir_}/eval/{pstr}.json",
    )

In [3]:
from concurrent.futures import ThreadPoolExecutor

def run_experiment(params):
    (
        num_time_steps,
        lut_memory,
        reward_weights,
        N_qubits,
        N_meas,
        gamma_p,
        gamma_m,
        training_state_type,
        evaluation_state_type
    ) = params
    
    print(f"Evaluating {params}")
    training_state = lambda key: state_types[training_state_type](key, N_qubits)
    evaluation_state = lambda key: state_types[evaluation_state_type](key, N_qubits)

    for s in tqdm(range(training_params["N_samples"])):
        system_params = init_fgrape_protocol(
            jax.random.PRNGKey(s + training_params["samples_offset"]),
            N_qubits,
            N_meas,
            gamma_p,
            gamma_m,
        )

        # Train LUT
        result = optimize_pulse(
            U_0=training_state,
            C_target=training_state,
            system_params=system_params,
            num_time_steps=num_time_steps,
            lut_depth=lut_memory,
            reward_weights=reward_weights,
            mode="lookup",
            goal="fidelity",
            max_iter=training_params["N_training_iterations"],
            convergence_threshold=training_params["convergence_threshold"],
            learning_rate=training_params["learning_rate"],
            evo_type="density",
            batch_size=training_params["batch_size"],
            eval_batch_size=1, # This evaluation is discarded
            progress = True,
        )
        eval_result = evaluate_on_longer_time( # Evaluate on longer time and choose best LUT accordingly
            U_0 = evaluation_state,
            C_target = evaluation_state,
            system_params = system_params,
            optimized_trainable_parameters = result.optimized_trainable_parameters,
            num_time_steps = training_params["evaluation_time_steps"],
            evo_type = "density",
            goal = "fidelity",
            eval_batch_size = training_params["eval_batch_size"],
            mode = "lookup",
        )

        path1, path2, path3 = generate_param_paths(params, "lut", s)
        with open(path1, "w") as f:
            json.dump(FgResult_to_dict(result), f)

        with open(path2, "w") as f:
            json.dump(training_params, f)

        with open(path3, "w") as f:
            json.dump(FgResult_to_dict(eval_result), f)

# Run experiments in parallel
if N_parallel_threads > 1:
    with ThreadPoolExecutor(max_workers=N_parallel_threads) as executor:
        executor.map(run_experiment, experiments)
else:
    for params in experiments:
        run_experiment(params)

# Play a sound when done
os.system('say "done."')

Evaluating (3, 2, [1, 1, 1], 3, 1, 0.05, 0.05, 'bloch', 'bloch')


  0%|          | 0/3 [00:00<?, ?it/s]

Iteration 10, Loss (avg over 10): 0.249154, T=0s, eta=121s
Iteration 20, Loss (avg over 10): 0.262430, T=0s, eta=117s
Iteration 30, Loss (avg over 10): 0.300980, T=0s, eta=116s
Iteration 40, Loss (avg over 10): 0.279364, T=0s, eta=116s
Iteration 50, Loss (avg over 10): 0.410086, T=0s, eta=116s
Iteration 60, Loss (avg over 10): 0.390650, T=0s, eta=115s
Iteration 70, Loss (avg over 10): 0.327669, T=0s, eta=115s
Iteration 80, Loss (avg over 10): 0.277465, T=0s, eta=114s
Iteration 90, Loss (avg over 10): 0.289765, T=1s, eta=114s
Iteration 100, Loss (avg over 10): 0.203208, T=1s, eta=114s
Iteration 110, Loss (avg over 10): 0.143093, T=1s, eta=114s
Iteration 120, Loss (avg over 10): 0.077756, T=1s, eta=113s
Iteration 130, Loss (avg over 10): 0.191410, T=1s, eta=113s
Iteration 140, Loss (avg over 10): 0.118959, T=1s, eta=113s
Iteration 150, Loss (avg over 10): 0.235809, T=1s, eta=113s
Iteration 160, Loss (avg over 10): 0.184289, T=1s, eta=112s
Iteration 170, Loss (avg over 10): 0.136732, T=1s

100%|██████████| 1/1 [02:06<00:00, 126.82s/it]
  0%|          | 0/3 [02:23<?, ?it/s]


KeyboardInterrupt: 