In [4]:
# ruff: noqa
import sys, os
sys.path.append(os.path.abspath("./../feedback-grape"))
sys.path.append(os.path.abspath("./../"))

# ruff: noqa
from feedback_grape.fgrape import optimize_pulse, evaluate_on_longer_time
from helpers import (
    init_grape_protocol,
    init_fgrape_protocol,
    test_implementations,
    generate_random_discrete_state,
    generate_random_bloch_state
)
from tqdm import tqdm
import jax
import jax.numpy as jnp
import numpy as np
from library.utils.FgResult_to_dict import FgResult_to_dict
import json

test_implementations()

jax.config.update("jax_enable_x64", True)

In [5]:
# Physical parameters
# (attention! #elements in density matrix grow as 4^N_chains)
N_chains = 3 # Number of parallel chains to simulate
gamma = 0.05 # Decay constant

# Training and evaluation parameters
training_params = {
    "N_samples": 10, # Number of random initial states to sample
    "N_training_iterations": 10000, # Number of training iterations
    "learning_rate": 0.01, # Learning rate
    "convergence_threshold": 1e-6,
    "batch_size": 8,
    "eval_batch_size": 16,
    "evaluation_time_steps": 100,
}
N_parallel_threads = 4 # Number of parallel threads for training
generate_state = generate_random_bloch_state

# Parameters to test

#num_time_steps : Number of time steps in the control pulse
#lut_depth : Depth of the lookup table for feedback
#reward_weights: Weights for the reward at each time step. Default only weights last timestep [0, 0, ... 0, 1]

experiments = [ # (timesteps, lut_depth, reward_weights, noise_level)
    (t, 2, [0]*(t-1)+[1], noise)  for t in [9,10,11,12,13,14,15] for noise in [0.0, 0.5, 1.0, float('inf')]
]

experiments_format = [ # Format of variables as tuples (name, type, required by grape, required by lut, required by rnn)
    ("t", int, True, True, True), # number timesteps (or segments)
    ("l", int, False, True, False), # LUT depth
    ("w", list, True, True, True), # reward weights
    ("noise", float, True, True, True), # noise level
]

# Check if experiment parameters are in specified format
for params in experiments:
    assert len(params) == len(experiments_format), "Experiment parameters length mismatch."
    for i, (param, (name, expected_type, req_grape, req_lut, req_rnn)) in enumerate(zip(params, experiments_format)):
        assert type(name) == str, "Parameter name must be a string."
        assert type(req_grape) == bool, "Requirement flags must be boolean."
        assert type(req_lut) == bool, "Requirement flags must be boolean."
        assert type(req_rnn) == bool, "Requirement flags must be boolean."
        assert type(param) == expected_type, f"Parameter '{name}' at index {i} must be of type {expected_type.__name__}."
        if type(param) == list:
            for j, item in enumerate(param):
                assert type(item) == int and 0 <= item <= 9, f"Item at index {j} in parameter '{name}' must be of type int and between 0 and 9."

def generate_param_paths(params, model_type):
    assert model_type in ["grape", "lut", "rnn"], "Invalid model type."

    parts = []
    for (name, expected_type, req_grape, req_lut, req_rnn), param in zip(experiments_format, params):
        required = (model_type == "grape" and req_grape) or (model_type == "lut" and req_lut) or (model_type == "rnn" and req_rnn)
        if required:
            if expected_type == list:
                param_str = "".join(str(x) for x in param)
            else:
                param_str = str(param)
            parts.append(f"{name}={param_str}")

    pstr = f"{model_type}_" + "_".join(parts)
    return (
        "./optimized_architectures/" + pstr + ".json",
        "./optimized_architectures/" + pstr + "_training_data.json",
        "./evaluation_results/" + pstr + ".npz"
    )

In [6]:
from concurrent.futures import ThreadPoolExecutor

def run_experiment(params):
    num_time_steps, lut_depth, reward_weights, noise_level = params
    print(f"Evaluating {params}")
    noisy_state = lambda key: generate_state(key, N_chains=N_chains, noise_level=noise_level)
    pure_state = lambda key: generate_state(key, N_chains=N_chains, noise_level=0.0)

    # Only test LUTs in this example
    path1, path2, path3 = generate_param_paths(params, "lut")
    if os.path.exists(path1):
        print(f"LUT already trained, skipping training.")
    else:
        print("Training LUT")

        best_result = None
        best_result_fidelities = None
        fidelities_each = []
        for s in tqdm(range(training_params["N_samples"])):
            system_params = init_fgrape_protocol(jax.random.PRNGKey(s), N_chains, gamma)

            # Train LUT
            result = optimize_pulse(
                U_0=noisy_state,
                C_target=pure_state,
                system_params=system_params,
                num_time_steps=num_time_steps,
                lut_depth=lut_depth,
                reward_weights=reward_weights,
                mode="lookup",
                goal="fidelity",
                max_iter=training_params["N_training_iterations"],
                convergence_threshold=training_params["convergence_threshold"],
                learning_rate=training_params["learning_rate"],
                evo_type="density",
                batch_size=training_params["batch_size"],
                eval_batch_size=1, # This evaluation is discarded
            )

            eval_result = evaluate_on_longer_time( # Evaluate on longer time and choose best LUT accordingly
                U_0 = pure_state,
                C_target = pure_state,
                system_params = system_params,
                optimized_trainable_parameters = result.optimized_trainable_parameters,
                num_time_steps = training_params["evaluation_time_steps"],
                evo_type = "density",
                goal = "fidelity",
                eval_batch_size = training_params["eval_batch_size"],
                mode = "lookup",
            )

            fidelities = eval_result.fidelity_each_timestep

            if best_result is None or np.mean(fidelities) > np.mean(best_result_fidelities):
                best_result = result
                best_result_fidelities = fidelities

            fidelities_each.append(float(np.mean(fidelities)))

        np.savez(path3, fidelities_lut=np.array(best_result_fidelities))

        with open(path1, "w") as f:
            json.dump(FgResult_to_dict(best_result), f)

        with open(path2, "w") as f:
            training_data = {
                "fidelities_each_sample": fidelities_each,
                "average_fidelity": float(np.mean(fidelities_each)),
                "training_params": training_params,
            }
            json.dump(training_data, f)

# Run experiments in parallel
if N_parallel_threads > 1:
    with ThreadPoolExecutor(max_workers=N_parallel_threads) as executor:
        executor.map(run_experiment, experiments)
else:
    for params in experiments:
        run_experiment(params)

# Play a sound when done
os.system('say "done."')

Evaluating (9, 2, [0, 0, 0, 0, 0, 0, 0, 0, 1], 0.0)Evaluating (9, 2, [0, 0, 0, 0, 0, 0, 0, 0, 1], 0.5)

Training LUT
Training LUT
Evaluating (9, 2, [0, 0, 0, 0, 0, 0, 0, 0, 1], 1.0)
Training LUT
Evaluating (9, 2, [0, 0, 0, 0, 0, 0, 0, 0, 1], inf)
Training LUT


  0%|          | 0/10 [00:00<?, ?it/s]



[A[A
 10%|█         | 1/10 [05:08<46:17, 308.63s/it]

[A[A
 20%|██        | 2/10 [10:27<41:58, 314.83s/it]

[A[A
 30%|███       | 3/10 [15:50<37:09, 318.46s/it]

[A[A
 40%|████      | 4/10 [21:14<32:04, 320.74s/it]

[A[A
 50%|█████     | 5/10 [26:33<26:39, 319.93s/it]

[A[A
 60%|██████    | 6/10 [31:52<21:19, 319.83s/it]

[A[A
 70%|███████   | 7/10 [37:12<15:58, 319.65s/it]

[A[A
 80%|████████  | 8/10 [42:31<10:38, 319.37s/it]

[A[A
 90%|█████████ | 9/10 [47:53<05:20, 320.29s/it]

[A[A
100%|██████████| 10/10 [53:17<00:00, 319.71s/it]


Evaluating (10, 2, [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 0.0)
Training LUT


  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [53:45<00:00, 322.56s/it]


Evaluating (10, 2, [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 0.5)
Training LUT




100%|██████████| 10/10 [54:11<00:00, 325.16s/it]


Evaluating (10, 2, [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 1.0)
Training LUT



100%|██████████| 10/10 [54:37<00:00, 327.74s/it]


Evaluating (10, 2, [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], inf)
Training LUT



 10%|█         | 1/10 [06:01<54:10, 361.21s/it]

[A[A
 20%|██        | 2/10 [11:57<47:45, 358.14s/it]

[A[A
 30%|███       | 3/10 [17:56<41:49, 358.57s/it]

[A[A
 40%|████      | 4/10 [23:52<35:45, 357.66s/it]

[A[A
 50%|█████     | 5/10 [29:54<29:56, 359.30s/it]

[A[A
 60%|██████    | 6/10 [35:55<23:59, 359.81s/it]

[A[A
 70%|███████   | 7/10 [41:59<18:03, 361.13s/it]

[A[A
 80%|████████  | 8/10 [48:01<12:02, 361.30s/it]

[A[A
 90%|█████████ | 9/10 [54:05<06:02, 362.15s/it]

[A[A
100%|██████████| 10/10 [1:00:06<00:00, 360.66s/it]


Evaluating (11, 2, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 0.0)
Training LUT


  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [1:00:18<00:00, 361.84s/it]


Evaluating (11, 2, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 0.5)
Training LUT




100%|██████████| 10/10 [1:00:25<00:00, 362.58s/it]


Evaluating (11, 2, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 1.0)
Training LUT



100%|██████████| 10/10 [1:01:13<00:00, 367.32s/it]


Evaluating (11, 2, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], inf)
Training LUT



 10%|█         | 1/10 [06:42<1:00:21, 402.37s/it]

[A[A
 20%|██        | 2/10 [13:23<53:33, 401.73s/it]  

[A[A
 30%|███       | 3/10 [20:04<46:50, 401.54s/it]

[A[A
 40%|████      | 4/10 [26:47<40:10, 401.79s/it]

[A[A
 50%|█████     | 5/10 [33:28<33:28, 401.71s/it]

[A[A
 60%|██████    | 6/10 [40:13<26:50, 402.65s/it]

[A[A
 70%|███████   | 7/10 [46:57<20:09, 403.04s/it]

[A[A
 80%|████████  | 8/10 [53:40<13:26, 403.17s/it]

[A[A
 90%|█████████ | 9/10 [1:00:23<06:43, 403.21s/it]

[A[A
100%|██████████| 10/10 [1:07:07<00:00, 402.79s/it]


Evaluating (12, 2, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 0.0)
Training LUT


  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [1:07:10<00:00, 403.09s/it]


Evaluating (12, 2, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 0.5)
Training LUT




100%|██████████| 10/10 [1:07:15<00:00, 403.60s/it]


Evaluating (12, 2, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 1.0)
Training LUT



100%|██████████| 10/10 [1:07:26<00:00, 404.65s/it]


Evaluating (12, 2, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], inf)
Training LUT



 10%|█         | 1/10 [07:23<1:06:29, 443.23s/it]

[A[A
 20%|██        | 2/10 [14:44<58:56, 442.04s/it]  

[A[A
 30%|███       | 3/10 [22:04<51:27, 441.09s/it]

[A[A
 40%|████      | 4/10 [29:27<44:12, 442.02s/it]

[A[A
 50%|█████     | 5/10 [36:51<36:52, 442.59s/it]

[A[A
 60%|██████    | 6/10 [44:16<29:33, 443.47s/it]

[A[A
 70%|███████   | 7/10 [51:39<22:10, 443.42s/it]

[A[A
 80%|████████  | 8/10 [58:57<14:42, 441.48s/it]

[A[A
 90%|█████████ | 9/10 [1:06:18<07:21, 441.47s/it]

[A[A
100%|██████████| 10/10 [1:13:40<00:00, 442.10s/it]


Evaluating (13, 2, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 0.0)
Training LUT


  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [1:14:19<00:00, 445.90s/it]


Evaluating (13, 2, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 0.5)
Training LUT




100%|██████████| 10/10 [1:14:12<00:00, 445.23s/it]


Evaluating (13, 2, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 1.0)
Training LUT



100%|██████████| 10/10 [1:14:19<00:00, 445.99s/it]


Evaluating (13, 2, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], inf)
Training LUT



 10%|█         | 1/10 [08:03<1:12:29, 483.29s/it]

[A[A
 20%|██        | 2/10 [16:05<1:04:23, 482.90s/it]

[A[A
 30%|███       | 3/10 [24:09<56:21, 483.12s/it]  

[A[A
 40%|████      | 4/10 [32:13<48:20, 483.42s/it]

[A[A
 50%|█████     | 5/10 [40:18<40:20, 484.07s/it]

[A[A
 60%|██████    | 6/10 [48:23<32:17, 484.47s/it]

[A[A
 70%|███████   | 7/10 [56:29<24:14, 484.76s/it]

[A[A
 80%|████████  | 8/10 [1:04:31<16:07, 483.92s/it]

[A[A
 90%|█████████ | 9/10 [1:12:33<08:03, 483.54s/it]

[A[A
100%|██████████| 10/10 [1:20:37<00:00, 483.70s/it]


Evaluating (14, 2, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 0.0)
Training LUT


100%|██████████| 10/10 [1:19:25<00:00, 476.55s/it]


Evaluating (14, 2, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 0.5)
Training LUT




100%|██████████| 10/10 [1:20:47<00:00, 484.72s/it]


Evaluating (14, 2, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 1.0)
Training LUT




[A[A
100%|██████████| 10/10 [1:21:01<00:00, 486.13s/it]


Evaluating (14, 2, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], inf)
Training LUT



 10%|█         | 1/10 [08:42<1:18:24, 522.70s/it]

[A[A
 20%|██        | 2/10 [17:24<1:09:37, 522.22s/it]

[A[A
 30%|███       | 3/10 [26:11<1:01:11, 524.54s/it]

[A[A
 40%|████      | 4/10 [34:57<52:30, 525.15s/it]  

[A[A
 50%|█████     | 5/10 [43:47<43:53, 526.65s/it]

[A[A
 60%|██████    | 6/10 [52:36<35:09, 527.48s/it]

[A[A
 70%|███████   | 7/10 [1:01:24<26:22, 527.57s/it]

[A[A
 80%|████████  | 8/10 [1:10:09<17:33, 526.85s/it]

[A[A
 90%|█████████ | 9/10 [1:18:55<08:46, 526.68s/it]

[A[A
100%|██████████| 10/10 [1:27:38<00:00, 525.87s/it]


Evaluating (15, 2, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 0.0)
Training LUT


100%|██████████| 10/10 [1:27:37<00:00, 525.78s/it]


Evaluating (15, 2, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 0.5)
Training LUT




100%|██████████| 10/10 [1:28:37<00:00, 531.75s/it]


Evaluating (15, 2, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 1.0)
Training LUT




[A[A
100%|██████████| 10/10 [1:27:57<00:00, 527.75s/it]


Evaluating (15, 2, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], inf)
Training LUT



 10%|█         | 1/10 [09:25<1:24:48, 565.44s/it]

[A[A
[A
 20%|██        | 2/10 [18:54<1:15:40, 567.57s/it]

 30%|███       | 3/10 [24:21<53:23, 457.61s/it]  
[A

 40%|████      | 4/10 [33:51<50:13, 502.18s/it]
[A

 50%|█████     | 5/10 [43:27<44:02, 528.53s/it]
[A

 70%|███████   | 7/10 [55:24<20:40, 413.45s/it]
[A

 80%|████████  | 8/10 [1:04:53<15:25, 462.77s/it]
[A

 90%|█████████ | 9/10 [1:14:23<08:16, 496.47s/it]
[A

100%|██████████| 10/10 [1:23:53<00:00, 503.32s/it]

[A

[A[A
100%|██████████| 10/10 [1:29:35<00:00, 537.52s/it]
100%|██████████| 10/10 [1:33:28<00:00, 560.86s/it]


100%|██████████| 10/10 [1:32:55<00:00, 557.54s/it]


0