In [1]:
# ruff: noqa
import sys, os
sys.path.append(os.path.abspath("./../feedback-grape"))
sys.path.append(os.path.abspath("./../"))

# ruff: noqa
from feedback_grape.fgrape import optimize_pulse, evaluate_on_longer_time
from helpers import (
    init_grape_protocol,
    init_fgrape_protocol,
    test_implementations,
    generate_random_discrete_state,
    generate_random_bloch_state,
    generate_excited_state
)
from tqdm import tqdm
import jax
import jax.numpy as jnp
import numpy as np
from library.utils.FgResult_to_dict import FgResult_to_dict
import json

test_implementations()

jax.config.update("jax_enable_x64", True)

In [2]:
# Physical parameters
# (attention! #elements in density matrix grow as 4^N_chains)
N_chains = 3 # Number of parallel chains to simulate
gamma = 0.05 # Decay constant
generate_state = generate_random_bloch_state

# Training and evaluation parameters
model_type = "lut" # "lut" or "rnn"
training_params = {
    "N_samples": 1, # Number of random initial states to sample
    "samples_offset": 1,
    "N_training_iterations": 110000, # Number of training iterations
    "learning_rate": 0.01, # Learning rate
    "convergence_threshold": 1e-6,
    "batch_size": 16,
    "eval_batch_size": 16,
    "evaluation_time_steps": 50,
    "train_on_final_states_start": 10000, # EXPERIMENTAL
    "train_on_final_states_every_a": 10000, # EXPERIMENTAL
    "train_on_final_states_every_b": 10000, # EXPERIMENTAL
    #"train_eval_num_time_steps": 36, # EXPERIMENTAL (set in experiment parameters)
    "train_eval_batch_size": 32, # EXPERIMENTAL
}
N_parallel_threads = 1 # Number of parallel threads for training
save_all_samples = True # wether to evaluate and save all training samples or only the best one

# Parameters to test

#num_time_steps : Number of time steps in the control pulse
#lut_depth : Depth of the lookup table for feedback
#reward_weights: Weights for the reward at each time step. Default only weights last timestep [0, 0, ... 0, 1]

experiments = [ # (timesteps, lut_depth, reward_weights, noise_level)
    #(33, 2, [1]*33, 0.0),
    (3, 2, [1,1,1], 0.0, tet) for tet in [30,31,32,33,10,20,29,30,34,35,36]
]

experiments_format = [ # Format of variables as tuples (name, type, required by grape, required by lut, required by rnn)
    ("t", int, True, True, True), # number timesteps (or segments)
    ("l", int, False, True, False), # LUT depth
    ("w", list, True, True, True), # reward weights
    ("noise", float, True, True, True), # noise level
    ("tet", int, True, True, True), # evaluation time steps
]

# Check if experiment parameters are in specified format
for params in experiments:
    assert len(params) == len(experiments_format), "Experiment parameters length mismatch."
    for i, (param, (name, expected_type, req_grape, req_lut, req_rnn)) in enumerate(zip(params, experiments_format)):
        assert type(name) == str, "Parameter name must be a string."
        assert type(req_grape) == bool, "Requirement flags must be boolean."
        assert type(req_lut) == bool, "Requirement flags must be boolean."
        assert type(req_rnn) == bool, "Requirement flags must be boolean."
        assert type(param) == expected_type, f"Parameter '{name}' at index {i} must be of type {expected_type.__name__}."
        if type(param) == list:
            for j, item in enumerate(param):
                assert type(item) == int and 0 <= item <= 9, f"Item at index {j} in parameter '{name}' must be of type int and between 0 and 9."

def generate_param_paths(params, model_type, save_all_samples, s):
    assert model_type in ["grape", "lut", "rnn"], "Invalid model type."

    parts = []
    for (name, expected_type, req_grape, req_lut, req_rnn), param in zip(experiments_format, params):
        required = (model_type == "grape" and req_grape) or (model_type == "lut" and req_lut) or (model_type == "rnn" and req_rnn)
        if required:
            if expected_type == list:
                param_str = "".join(str(x) for x in param)
            else:
                param_str = str(param)
            parts.append(f"{name}={param_str}")

    pstr = f"{model_type}_" + "_".join(parts)
    
    if save_all_samples:
        pstr += f"_s={s}"
    
    return (
        "./optimized_architectures/" + pstr + ".json",
        "./optimized_architectures/" + pstr + "_training_data.json",
        "./evaluation_results/" + pstr + ".npz"
    )

if not os.path.exists("./evaluation_results"):
    os.makedirs("./evaluation_results")
if not os.path.exists("./optimized_architectures"):
    os.makedirs("./optimized_architectures")

In [3]:
from concurrent.futures import ThreadPoolExecutor
def run_experiment(params):
    num_time_steps, lut_depth, reward_weights, noise_level, tet = params
    print(f"Evaluating {params}")
    noisy_state = lambda key: generate_state(key, N_chains=N_chains, noise_level=noise_level)
    pure_state = lambda key: generate_state(key, N_chains=N_chains, noise_level=0.0)

    best_result = None
    best_result_fidelities = None
    fidelities_each = []
    for s in range(training_params["samples_offset"], training_params["N_samples"] + training_params["samples_offset"]):
        print(f" Sample {s}/{training_params['N_samples']}")
        system_params = init_fgrape_protocol(jax.random.PRNGKey(s), N_chains, gamma)
        # Train LUT
        result = optimize_pulse(
            U_0=noisy_state,
            C_target=pure_state,
            system_params=system_params,
            num_time_steps=num_time_steps,
            lut_depth=lut_depth,
            reward_weights=reward_weights,
            mode="lookup" if model_type == "lut" else "nn",
            goal="fidelity",
            max_iter=training_params["N_training_iterations"],
            convergence_threshold=training_params["convergence_threshold"],
            learning_rate=training_params["learning_rate"],
            evo_type="density",
            batch_size=training_params["batch_size"],
            eval_batch_size=1, # This evaluation is discarded
            train_on_final_states_start = training_params["train_on_final_states_start"],
            train_on_final_states_every_a = training_params["train_on_final_states_every_a"],
            train_on_final_states_every_b = training_params["train_on_final_states_every_b"],
            train_eval_num_time_steps = tet,#training_params["train_eval_num_time_steps"],
            train_eval_batch_size = training_params["train_eval_batch_size"],
        )
        eval_result = evaluate_on_longer_time( # Evaluate on longer time and choose best LUT accordingly
            U_0 = pure_state,
            C_target = pure_state,
            system_params = system_params,
            optimized_trainable_parameters = result.optimized_trainable_parameters,
            num_time_steps = training_params["evaluation_time_steps"],
            evo_type = "density",
            goal = "fidelity",
            eval_batch_size = training_params["eval_batch_size"],
            mode = "lookup" if model_type == "lut" else "nn",
        )
        fidelities = eval_result.fidelity_each_timestep

        if best_result is None or np.mean(fidelities) > np.mean(best_result_fidelities):
            best_result = result
            best_result_fidelities = fidelities

        fidelities_each.append(float(np.mean(fidelities)))

        if save_all_samples:
            path1, path2, path3 = generate_param_paths(params, model_type, save_all_samples, s)
            with open(path1, "w") as f:
                json.dump(FgResult_to_dict(result), f)

            with open(path2, "w") as f:
                training_data = {
                    "average_fidelity": float(np.mean(fidelities)),
                    "training_params": training_params,
                }
                json.dump(training_data, f)

            np.savez(path3, fidelities_lut=np.array(fidelities))

    path1, path2, path3 = generate_param_paths(params, model_type, False, s)
    np.savez(path3, fidelities_lut=np.array(best_result_fidelities))

    with open(path1, "w") as f:
        json.dump(FgResult_to_dict(best_result), f)

    with open(path2, "w") as f:
        training_data = {
            "fidelities_each_sample": fidelities_each,
            "average_fidelity": float(np.mean(fidelities_each)),
            "training_params": training_params,
        }
        json.dump(training_data, f)

# Run experiments in parallel
if N_parallel_threads > 1:
    with ThreadPoolExecutor(max_workers=N_parallel_threads) as executor:
        executor.map(run_experiment, experiments)
else:
    for params in experiments:
        run_experiment(params)

# Play a sound when done
os.system('say "done."')

Evaluating (3, 2, [1, 1, 1], 0.0, 30)
 Sample 1/1


  0%|          | 0/11 [00:00<?, ?it/s]

Mean loss -1.7349043059523763


  9%|▉         | 1/11 [02:49<28:13, 169.34s/it]

Mean loss -0.8240096078479968


 18%|█▊        | 2/11 [05:52<26:35, 177.27s/it]

Mean loss -1.6476823115072545


 27%|██▋       | 3/11 [08:38<22:59, 172.44s/it]

Mean loss -0.8259304293394749


 36%|███▋      | 4/11 [11:24<19:49, 169.91s/it]

Mean loss -1.640488583477001


 45%|████▌     | 5/11 [14:10<16:51, 168.54s/it]

Mean loss -1.2270153510796051


 55%|█████▍    | 6/11 [16:56<13:56, 167.34s/it]

Mean loss -1.675865005284402


 64%|██████▎   | 7/11 [19:41<11:06, 166.71s/it]

Mean loss -1.2099841928945936


 73%|███████▎  | 8/11 [22:28<08:20, 166.75s/it]

Mean loss -1.6714049302995722


 82%|████████▏ | 9/11 [25:14<05:33, 166.64s/it]

Mean loss -1.4341891587516145


 91%|█████████ | 10/11 [28:00<02:46, 166.44s/it]

Mean loss -1.6856029763945182


100%|██████████| 11/11 [30:47<00:00, 167.98s/it]


Evaluating (3, 2, [1, 1, 1], 0.0, 31)
 Sample 1/1


  0%|          | 0/11 [00:00<?, ?it/s]

Mean loss -1.7349043059523763


  9%|▉         | 1/11 [02:50<28:21, 170.15s/it]

Mean loss -0.8121287163972944


 18%|█▊        | 2/11 [05:38<25:19, 168.88s/it]

Mean loss -1.6489039386466022


 27%|██▋       | 3/11 [08:25<22:26, 168.30s/it]

Mean loss -0.836129561864984


 36%|███▋      | 4/11 [11:12<19:34, 167.80s/it]

Mean loss -1.6389177013944412


 45%|████▌     | 5/11 [13:58<16:42, 167.12s/it]

Mean loss -1.2280736547353202


 55%|█████▍    | 6/11 [16:46<13:57, 167.41s/it]

Mean loss -1.6760938169024462


 64%|██████▎   | 7/11 [19:33<11:08, 167.13s/it]

Mean loss -1.2264671007418624


 73%|███████▎  | 8/11 [22:20<08:21, 167.02s/it]

Mean loss -1.6688909209216123


 82%|████████▏ | 9/11 [25:09<05:35, 167.83s/it]

Mean loss -1.4164625288459598


 91%|█████████ | 10/11 [27:56<02:47, 167.51s/it]

Mean loss -1.6871187645037193


100%|██████████| 11/11 [30:41<00:00, 167.39s/it]


Evaluating (3, 2, [1, 1, 1], 0.0, 32)
 Sample 1/1


  0%|          | 0/11 [00:00<?, ?it/s]

Mean loss -1.7349043059523763


  9%|▉         | 1/11 [02:48<28:06, 168.64s/it]

Mean loss -0.7998941441691475


 18%|█▊        | 2/11 [05:36<25:14, 168.30s/it]

Mean loss -1.6512829328069951


 27%|██▋       | 3/11 [08:22<22:16, 167.11s/it]

Mean loss -0.7912801820366303


 36%|███▋      | 4/11 [11:10<19:33, 167.66s/it]

Mean loss -1.6395839118870414


 45%|████▌     | 5/11 [13:58<16:45, 167.60s/it]

Mean loss -1.2155817345631068


 55%|█████▍    | 6/11 [16:43<13:54, 166.91s/it]

Mean loss -1.675464633309841


 64%|██████▎   | 7/11 [19:29<11:05, 166.46s/it]

Mean loss -1.233885152681125


 73%|███████▎  | 8/11 [22:14<08:18, 166.14s/it]

Mean loss -1.6701054498744896


 82%|████████▏ | 9/11 [25:00<05:32, 166.11s/it]

Mean loss -1.42641044165569


 91%|█████████ | 10/11 [27:45<02:45, 165.74s/it]

Mean loss -1.688330732596762


100%|██████████| 11/11 [30:34<00:00, 166.79s/it]


Evaluating (3, 2, [1, 1, 1], 0.0, 33)
 Sample 1/1


  0%|          | 0/11 [00:00<?, ?it/s]

Mean loss -1.7349043059523763


  9%|▉         | 1/11 [02:49<28:15, 169.52s/it]

Mean loss 0.14820682239997005


 18%|█▊        | 2/11 [05:39<25:28, 169.80s/it]

Mean loss -1.7313194021329392


 27%|██▋       | 3/11 [08:25<22:24, 168.07s/it]

Mean loss 0.24911627452823595


 36%|███▋      | 4/11 [11:12<19:33, 167.69s/it]

Mean loss -1.6056149665328459


 45%|████▌     | 5/11 [14:01<16:48, 168.06s/it]

Mean loss 0.45547967438637715


 55%|█████▍    | 6/11 [16:47<13:56, 167.39s/it]

Mean loss -1.6732053554496729


 64%|██████▎   | 7/11 [19:34<11:09, 167.26s/it]

Mean loss 0.6625567466910028


 73%|███████▎  | 8/11 [22:20<08:20, 166.92s/it]

Mean loss -1.6324532509296126


 82%|████████▏ | 9/11 [25:09<05:35, 167.62s/it]

Mean loss 0.016016101808490585


 91%|█████████ | 10/11 [27:59<02:48, 168.38s/it]

Mean loss -1.6675857110318928


100%|██████████| 11/11 [30:48<00:00, 168.08s/it]


Evaluating (3, 2, [1, 1, 1], 0.0, 10)
 Sample 1/1


  0%|          | 0/11 [00:00<?, ?it/s]

Mean loss -1.7349043059523763


  9%|▉         | 1/11 [02:51<28:32, 171.28s/it]

Mean loss -1.0233529932403662


 18%|█▊        | 2/11 [05:38<25:22, 169.16s/it]

Mean loss -1.6491605724581533


 27%|██▋       | 3/11 [08:23<22:15, 166.99s/it]

Mean loss -1.090115593266125


 36%|███▋      | 4/11 [11:07<19:22, 166.06s/it]

Mean loss -1.6392273194305857


 45%|████▌     | 5/11 [13:52<16:33, 165.50s/it]

Mean loss -1.3525033941755147


 55%|█████▍    | 6/11 [16:36<13:45, 165.04s/it]

Mean loss -1.6706256989710546


 64%|██████▎   | 7/11 [19:22<11:01, 165.37s/it]

Mean loss -1.383896680745696


 73%|███████▎  | 8/11 [22:09<08:17, 165.72s/it]

Mean loss -1.6678476217601397


 82%|████████▏ | 9/11 [24:58<05:33, 166.78s/it]

Mean loss -1.6370265003023028


 91%|█████████ | 10/11 [27:45<02:47, 167.03s/it]

Mean loss -1.6773601504297302


100%|██████████| 11/11 [30:30<00:00, 166.40s/it]


Evaluating (3, 2, [1, 1, 1], 0.0, 20)
 Sample 1/1


  0%|          | 0/11 [00:00<?, ?it/s]

Mean loss -1.7349043059523763


  9%|▉         | 1/11 [02:50<28:24, 170.42s/it]

Mean loss -0.8981911366173263


 18%|█▊        | 2/11 [05:37<25:15, 168.42s/it]

Mean loss -1.6521633165498737


 27%|██▋       | 3/11 [08:26<22:30, 168.80s/it]

Mean loss -0.8581303994354533


 36%|███▋      | 4/11 [11:11<19:31, 167.42s/it]

Mean loss -1.636066663889438


 45%|████▌     | 5/11 [14:00<16:46, 167.68s/it]

Mean loss -1.3086669673700464


 55%|█████▍    | 6/11 [16:45<13:54, 166.93s/it]

Mean loss -1.667614816245087


 64%|██████▎   | 7/11 [19:32<11:07, 166.79s/it]

Mean loss -1.3314201952436546


 73%|███████▎  | 8/11 [22:17<08:19, 166.44s/it]

Mean loss -1.659450900339355


 82%|████████▏ | 9/11 [25:03<05:32, 166.36s/it]

Mean loss -1.5272313902971086


 91%|█████████ | 10/11 [27:48<02:45, 165.90s/it]

Mean loss -1.6745806576018265


100%|██████████| 11/11 [30:33<00:00, 166.65s/it]


Evaluating (3, 2, [1, 1, 1], 0.0, 29)
 Sample 1/1


  0%|          | 0/11 [00:00<?, ?it/s]

Mean loss -1.7349043059523763


  9%|▉         | 1/11 [02:48<28:06, 168.60s/it]

Mean loss -0.8160919783717226


 18%|█▊        | 2/11 [05:34<25:03, 167.01s/it]

Mean loss -1.6493002249313458


 27%|██▋       | 3/11 [08:22<22:20, 167.57s/it]

Mean loss -0.871246078681682


 36%|███▋      | 4/11 [11:12<19:40, 168.63s/it]

Mean loss -1.6416286836994236


 45%|████▌     | 5/11 [13:56<16:40, 166.82s/it]

Mean loss -1.2357008671715841


 55%|█████▍    | 6/11 [16:40<13:49, 165.93s/it]

Mean loss -1.6744348932821849


 64%|██████▎   | 7/11 [19:24<11:00, 165.07s/it]

Mean loss -1.2216899846299016


 73%|███████▎  | 8/11 [22:08<08:14, 164.97s/it]

Mean loss -1.669588762994088


 82%|████████▏ | 9/11 [24:52<05:28, 164.44s/it]

Mean loss -1.4481381728889287


 91%|█████████ | 10/11 [27:35<02:44, 164.25s/it]

Mean loss -1.683942374352178


100%|██████████| 11/11 [30:25<00:00, 165.97s/it]


Evaluating (3, 2, [1, 1, 1], 0.0, 30)
 Sample 1/1


  0%|          | 0/11 [00:00<?, ?it/s]

Mean loss -1.7349043059523763


  9%|▉         | 1/11 [02:50<28:24, 170.49s/it]

Mean loss -0.8240096078479968


 18%|█▊        | 2/11 [05:35<25:03, 167.08s/it]

Mean loss -1.6476823115072545


 27%|██▋       | 3/11 [08:23<22:20, 167.50s/it]

Mean loss -0.8259304293394749


 36%|███▋      | 4/11 [11:07<19:23, 166.17s/it]

Mean loss -1.640488583477001


 45%|████▌     | 5/11 [13:53<16:36, 166.11s/it]

Mean loss -1.2270153510796051


 55%|█████▍    | 6/11 [16:37<13:47, 165.47s/it]

Mean loss -1.675865005284402


 64%|██████▎   | 7/11 [19:21<10:59, 164.98s/it]

Mean loss -1.2099841928945936


 73%|███████▎  | 8/11 [22:11<08:19, 166.49s/it]

Mean loss -1.6714049302995722


 82%|████████▏ | 9/11 [24:54<05:30, 165.49s/it]

Mean loss -1.4341891587516145


 91%|█████████ | 10/11 [27:42<02:46, 166.35s/it]

Mean loss -1.6856029763945182


100%|██████████| 11/11 [30:30<00:00, 166.45s/it]


Evaluating (3, 2, [1, 1, 1], 0.0, 34)
 Sample 1/1


  0%|          | 0/11 [00:00<?, ?it/s]

Mean loss -1.7349043059523763


  9%|▉         | 1/11 [02:47<27:57, 167.73s/it]

Mean loss 0.6769840973692902


 18%|█▊        | 2/11 [05:37<25:22, 169.20s/it]

Mean loss -1.6602896001894336


 27%|██▋       | 3/11 [08:21<22:12, 166.57s/it]

Mean loss 0.7441310179190801


 36%|███▋      | 4/11 [11:06<19:21, 165.96s/it]

Mean loss -1.6141786690730802


 45%|████▌     | 5/11 [13:49<16:30, 165.01s/it]

Mean loss 0.2108001040219678


 55%|█████▍    | 6/11 [16:36<13:48, 165.74s/it]

Mean loss -1.6787395922448334


 64%|██████▎   | 7/11 [19:21<11:01, 165.45s/it]

Mean loss 0.22490126645890188


 73%|███████▎  | 8/11 [22:06<08:15, 165.26s/it]

Mean loss -1.6857705013039856


 82%|████████▏ | 9/11 [24:57<05:34, 167.12s/it]

Mean loss 0.2872243514257434


 91%|█████████ | 10/11 [27:44<02:46, 166.87s/it]

Mean loss -1.6669307768679134


100%|██████████| 11/11 [30:31<00:00, 166.50s/it]


Evaluating (3, 2, [1, 1, 1], 0.0, 35)
 Sample 1/1


  0%|          | 0/11 [00:00<?, ?it/s]

Mean loss -1.7349043059523763


  9%|▉         | 1/11 [02:51<28:39, 171.98s/it]

Mean loss 0.8192115745965868


 18%|█▊        | 2/11 [05:36<25:08, 167.63s/it]

Mean loss -1.6633647847998603


 27%|██▋       | 3/11 [08:20<22:06, 165.80s/it]

Mean loss 0.9086494126409476


 36%|███▋      | 4/11 [11:05<19:19, 165.68s/it]

Mean loss -1.6376783613428152


 45%|████▌     | 5/11 [13:50<16:31, 165.33s/it]

Mean loss 1.1525035938197508


 55%|█████▍    | 6/11 [16:35<13:46, 165.37s/it]

Mean loss -1.6795172344574192


 64%|██████▎   | 7/11 [19:21<11:01, 165.36s/it]

Mean loss 0.6159419277441595


 73%|███████▎  | 8/11 [22:09<08:18, 166.27s/it]

Mean loss -1.68943191831703


 82%|████████▏ | 9/11 [24:57<05:33, 166.88s/it]

Mean loss 0.7582485488509517


 91%|█████████ | 10/11 [27:46<02:47, 167.35s/it]

Mean loss -1.6958374655544266


100%|██████████| 11/11 [30:33<00:00, 166.68s/it]


Evaluating (3, 2, [1, 1, 1], 0.0, 36)
 Sample 1/1


  0%|          | 0/11 [00:00<?, ?it/s]

Mean loss -1.7349043059523763


  9%|▉         | 1/11 [02:45<27:37, 165.73s/it]

Mean loss 0.8722125246326167


 18%|█▊        | 2/11 [05:31<24:52, 165.82s/it]

Mean loss -1.6584384144019784


 27%|██▋       | 3/11 [08:22<22:25, 168.17s/it]

Mean loss 0.6156441802672769


 36%|███▋      | 4/11 [11:14<19:46, 169.51s/it]

Mean loss -1.6425011043033046


 45%|████▌     | 5/11 [14:02<16:54, 169.05s/it]

Mean loss 0.6963952904725855


 55%|█████▍    | 6/11 [16:50<14:04, 168.88s/it]

Mean loss -1.6826023954884892


 64%|██████▎   | 7/11 [19:35<11:09, 167.36s/it]

Mean loss 0.5723783427199991


 73%|███████▎  | 8/11 [22:26<08:25, 168.50s/it]

Mean loss -1.663534286771281


 82%|████████▏ | 9/11 [25:10<05:34, 167.23s/it]

Mean loss 0.4965691397207697


 91%|█████████ | 10/11 [27:56<02:46, 166.84s/it]

Mean loss -1.6651722014635926


100%|██████████| 11/11 [30:45<00:00, 167.77s/it]


0