In [1]:
# ruff: noqa
import sys, os
sys.path.append(os.path.abspath("./../feedback-grape"))
sys.path.append(os.path.abspath("./../"))

# ruff: noqa
from feedback_grape.fgrape import optimize_pulse, evaluate_on_longer_time
from helpers import (
    init_grape_protocol,
    init_fgrape_protocol,
    test_implementations,
    generate_random_discrete_state,
    generate_random_bloch_state
)
from tqdm import tqdm
import jax
import jax.numpy as jnp
from library.utils.FgResult_to_dict import FgResult_to_dict
import json

test_implementations()

jax.config.update("jax_enable_x64", True)

In [2]:
# Physical parameters
# (attention! #elements in density matrix grow as 4^n*N_chains)
n = 2 # number of qubits per chain (>= 2)
N_chains = 3 # Number of parallel chains to simulate
gamma = 0.25 # Decay constant

assert n >= 2, "Chain lengths must be at least 2."

# Training and evaluation parameters
training_params = {
    "N_samples": 3, # Number of random initial states to sample
    "N_training_iterations": 1000, # Number of training iterations
    "learning_rate": 0.01, # Learning rate
    "convergence_threshold": 1e-6,
    "batch_size": 8,
    "eval_batch_size": 16,
    "evaluation_time_steps": 200,
}
generate_state = generate_random_discrete_state

# Architectures to test
do_test_grape = False
do_test_lut = True
do_test_rnn = False

# Parameters to test

#num_time_steps : Number of time steps in the control pulse
#lut_depth : Depth of the lookup table for feedback
#reward_weights: Weights for the reward at each time step. Default only weights last timestep [0, 0, ... 0, 1]

experiments = [ # (timesteps, lut_depth, reward_weights, noise_level)
    (1,1,[1], 0.0),
]

experiments_format = [ # Format of variables as tuples (name, type, required by grape, required by lut, required by rnn)
    ("t", int, True, True, True), # number timesteps (or segments)
    ("l", int, False, True, False), # LUT depth
    ("w", list, True, True, True), # reward weights
    ("noise", float, True, True, True), # noise level
]

# Check if experiment parameters are in specified format
for params in experiments:
    assert len(params) == len(experiments_format), "Experiment parameters length mismatch."
    for i, (param, (name, expected_type, req_grape, req_lut, req_rnn)) in enumerate(zip(params, experiments_format)):
        assert type(name) == str, "Parameter name must be a string."
        assert type(req_grape) == bool, "Requirement flags must be boolean."
        assert type(req_lut) == bool, "Requirement flags must be boolean."
        assert type(req_rnn) == bool, "Requirement flags must be boolean."
        assert type(param) == expected_type, f"Parameter '{name}' at index {i} must be of type {expected_type.__name__}."
        if type(param) == list:
            for j, item in enumerate(param):
                assert type(item) == int and 0 <= item <= 9, f"Item at index {j} in parameter '{name}' must be of type int and between 0 and 9."

def generate_param_paths(params, model_type):
    assert model_type in ["grape", "lut", "rnn"], "Invalid model type."

    parts = []
    for (name, expected_type, req_grape, req_lut, req_rnn), param in zip(experiments_format, params):
        required = (model_type == "grape" and req_grape) or (model_type == "lut" and req_lut) or (model_type == "rnn" and req_rnn)
        if required:
            if expected_type == list:
                param_str = "".join(str(x) for x in param)
            else:
                param_str = str(param)
            parts.append(f"{name}={param_str}")

    pstr = f"{model_type}_" + "_".join(parts)
    return (
        "./optimized_architectures/" + pstr + ".json",
        "./optimized_architectures/" + pstr + "_training_data.json",
        "./evaluation_results/" + pstr + ".npz"
    )

In [3]:
for params in experiments:
    num_time_steps, lut_depth, reward_weights, noise_level = params
    print(f"Evaluating {params}")
    weights_str = "".join([str(w) for w in reward_weights])
    noisy_state = lambda key: generate_state(key, N_chains=N_chains, noise_level=noise_level)
    pure_state = lambda key: generate_state(key, N_chains=N_chains, noise_level=0.0)
    
    if do_test_grape:
        path1, path2, path3 = generate_param_paths(params, "grape")
        if os.path.exists(path1):
            print(f"Grape already optimized, skipping optimization.")
        else:
            print("Optimizing grape")

            best_result = None
            fidelities_each = []
            for s in tqdm(range(training_params["N_samples"])):
                system_params = init_grape_protocol(jax.random.PRNGKey(s), n, N_chains, gamma)
                
                # Optimize GRAPE
                result = optimize_pulse(
                    U_0=noisy_state,
                    C_target=pure_state,
                    system_params=system_params,
                    num_time_steps=num_time_steps,
                    reward_weights=reward_weights,
                    mode="no-measurement",
                    goal="fidelity",
                    max_iter=training_params["N_training_iterations"],
                    convergence_threshold=training_params["convergence_threshold"],
                    learning_rate=training_params["learning_rate"],
                    evo_type="density",
                    batch_size=training_params["batch_size"],
                    eval_batch_size=training_params["eval_batch_size"],
                )

                if best_result is None or result.final_fidelity > best_result.final_fidelity:
                    best_result = result

                fidelities_each.append(float(result.final_fidelity))

            if best_result is not None:
                with open(path1, "w") as f:
                    json.dump(FgResult_to_dict(best_result), f)
                
                result = evaluate_on_longer_time(
                    U_0 = pure_state,
                    C_target = pure_state,
                    system_params = system_params,
                    optimized_trainable_parameters = result.optimized_trainable_parameters,
                    num_time_steps = training_params["evaluation_time_steps"],
                    evo_type = "density",
                    goal = "fidelity",
                    eval_batch_size = training_params["eval_batch_size"],
                    mode = "no-measurement",
                )

                fidelities_grape = result.fidelity_each_timestep

                jnp.savez(path3, fidelities_grape=jnp.array(fidelities_grape))

            with open(path2, "w") as f:
                training_data = {
                    "fidelities_each_sample": fidelities_each,
                    "average_fidelity": float(jnp.mean(jnp.array(fidelities_each))),
                    "training_params": training_params,
                }
                json.dump(training_data, f)
    
    if do_test_lut:
        path1, path2, path3 = generate_param_paths(params, "lut")
        if os.path.exists(path1):
            print(f"LUT already trained, skipping training.")
        else:
            print("Training LUT")

            best_result = None
            fidelities_each = []
            for s in tqdm(range(training_params["N_samples"])):
                system_params = init_fgrape_protocol(jax.random.PRNGKey(s), n, N_chains, gamma)

                # Train LUT
                result = optimize_pulse(
                    U_0=noisy_state,
                    C_target=pure_state,
                    system_params=system_params,
                    num_time_steps=num_time_steps,
                    lut_depth=lut_depth,
                    reward_weights=reward_weights,
                    mode="lookup",
                    goal="fidelity",
                    max_iter=training_params["N_training_iterations"],
                    convergence_threshold=training_params["convergence_threshold"],
                    learning_rate=training_params["learning_rate"],
                    evo_type="density",
                    batch_size=training_params["batch_size"],
                    eval_batch_size=training_params["eval_batch_size"],
                )

                if best_result is None or result.final_fidelity > best_result.final_fidelity:
                    best_result = result

                fidelities_each.append(float(result.final_fidelity))

            if best_result is not None:
                with open(path1, "w") as f:
                    json.dump(FgResult_to_dict(best_result), f)

                result = evaluate_on_longer_time(
                    U_0 = pure_state,
                    C_target = pure_state,
                    system_params = system_params,
                    optimized_trainable_parameters = result.optimized_trainable_parameters,
                    num_time_steps = training_params["evaluation_time_steps"],
                    evo_type = "density",
                    goal = "fidelity",
                    eval_batch_size = training_params["eval_batch_size"],
                    mode = "lookup",
                )

                fidelities_lut = result.fidelity_each_timestep

                jnp.savez(path3, fidelities_lut=jnp.array(fidelities_lut))

            with open(path2, "w") as f:
                training_data = {
                    "fidelities_each_sample": fidelities_each,
                    "average_fidelity": float(jnp.mean(jnp.array(fidelities_each))),
                    "training_params": training_params,
                }
                json.dump(training_data, f)
        
    if do_test_rnn:
        path1, path2, path3 = generate_param_paths(params, "rnn")
        if os.path.exists(path1):
            print(f"RNN already trained, skipping training.")
        else:
            print("Training RNN")

            best_result = None
            fidelities_each = []
            for s in tqdm(range(training_params["N_samples"])):
                system_params = init_fgrape_protocol(jax.random.PRNGKey(s), n, N_chains, gamma)

                # Train RNN
                result = optimize_pulse(
                    U_0=noisy_state,
                    C_target=pure_state,
                    system_params=system_params,
                    num_time_steps=num_time_steps,
                    lut_depth=lut_depth,
                    reward_weights=reward_weights,
                    mode="nn",
                    goal="fidelity",
                    max_iter=training_params["N_training_iterations"],
                    convergence_threshold=training_params["convergence_threshold"],
                    learning_rate=training_params["learning_rate"],
                    evo_type="density",
                    batch_size=training_params["batch_size"],
                    eval_batch_size=training_params["eval_batch_size"],
                )

                if best_result is None or result.final_fidelity > best_result.final_fidelity:
                    best_result = result

                fidelities_each.append(float(result.final_fidelity))

            if best_result is not None:
                with open(path1, "w") as f:
                    json.dump(FgResult_to_dict(best_result), f)

                result = evaluate_on_longer_time(
                    U_0 = pure_state,
                    C_target = pure_state,
                    system_params = system_params,
                    optimized_trainable_parameters = result.optimized_trainable_parameters,
                    num_time_steps = training_params["evaluation_time_steps"],
                    evo_type = "density",
                    goal = "fidelity",
                    eval_batch_size = training_params["eval_batch_size"],
                    mode = "nn",
                )

                fidelities_rnn = result.fidelity_each_timestep

                jnp.savez(path3, fidelities_rnn=jnp.array(fidelities_rnn))

            with open(path2, "w") as f:
                training_data = {
                    "fidelities_each_sample": fidelities_each,
                    "average_fidelity": float(jnp.mean(jnp.array(fidelities_each))),
                    "training_params": training_params,
                }
                json.dump(training_data, f)
    
    
# Play a sound when done
os.system('say "done."')

Evaluating (1, 1, [1], 0.0)
Training LUT


100%|██████████| 3/3 [02:31<00:00, 50.66s/it]


0