In [4]:
# ruff: noqa
import sys, os
sys.path.append(os.path.abspath("./../feedback-grape"))
sys.path.append(os.path.abspath("./../"))

# ruff: noqa
from feedback_grape.fgrape import optimize_pulse, evaluate_on_longer_time
from helpers import (
    init_grape_protocol,
    init_fgrape_protocol,
    test_implementations,
    generate_random_state,
)
from tqdm import tqdm
import jax
import jax.numpy as jnp
from library.utils.FgResult_to_dict import FgResult_to_dict
import json

test_implementations()

jax.config.update("jax_enable_x64", True)

In [5]:
# Physical parameters
# (attention! #elements in density matrix grow as 4^n*N_chains)
n = 2 # number of qubits per chain (>= 2)
N_chains = 2 # Number of parallel chains to simulate
gamma = 0.25 # Decay constant

assert n >= 2, "Chain lengths must be at least 2."

# Training and evaluation parameters
training_params = {
    "N_samples": 3, # Number of random initial states to sample
    "N_training_iterations": 1000, # Number of training iterations
    "learning_rate": 0.02, # Learning rate
    "convergence_threshold": 1e-6,
    "batch_size": 16,
    "eval_batch_size": 16,
    "evaluation_time_steps": 200,
}

# Parameters to test

#num_time_steps : Number of time steps in the control pulse
#lut_depth : Depth of the lookup table for feedback
#reward_weights: Weights for the reward at each time step. Default only weights last timestep [0, 0, ... 0, 1]

experiments = [ # (timesteps, lut_depth, reward_weights)
    (1,1,[1]),
    (2,1,[0,1]),
    (2,1,[1,1]),
    (3,1,[0,0,1]),
    (3,1,[0,1,1]),
    (3,1,[1,1,1]),

    (1,1,[1]),
    (2,2,[0,1]),
    (2,2,[1,1]),
    (3,3,[0,0,1]),
    (3,3,[0,1,1]),
    (3,3,[1,1,1]),
]

for t, l, weights in experiments:
    assert type(t) == int and t >= 1, "Number of time steps must be a positive integer."
    assert type(l) == int and l >= 1, "LUT depth must be a positive integer."
    assert len(weights) == t, "Length of weights must equal number of time steps."
    assert l <= t, "LUT depth cannot exceed number of time steps."
    for w in weights:
        assert type(w) == int and 0 <= w <= 9, "Weights must be integers between 0 and 9 to save as single character in filename."



In [6]:
for num_time_steps, lut_depth, reward_weights in experiments:
    print(f"Evaluating num_time_steps={num_time_steps}, lut_depth={lut_depth}, reward_weights={reward_weights}")
    weights_str = "".join([str(w) for w in reward_weights])
    state_callable = lambda key: generate_random_state(key, N_chains=N_chains)
    # Skip Grape optimization since it is slow for num_time_steps > 2
    """
    if os.path.exists(f"./optimized_architectures/grape_t={num_time_steps}_s={training_params["N_samples"]}_w={weights_str}.json"):
        print(f"Grape for t={num_time_steps} and weights={weights_str} already optimized, skipping optimization.")
    else:
        print("Optimizing grape")

        best_result = None
        fidelities_each = []
        for s in tqdm(range(training_params["N_samples"])):
            system_params = init_grape_protocol(jax.random.PRNGKey(s), n, N_chains, gamma)

            # Optimize GRAPE
            try: # Catch any errors during optimization, as we may sometimes encounter numerical issues
                result = optimize_pulse(
                    U_0=state_callable,
                    C_target=state_callable,
                    system_params=system_params,
                    num_time_steps=num_time_steps,
                    reward_weights=reward_weights,
                    mode="no-measurement",
                    goal="fidelity",
                    max_iter=training_params["N_training_iterations"],
                    convergence_threshold=training_params["convergence_threshold"],
                    learning_rate=training_params["learning_rate"],
                    evo_type="density",
                    batch_size=training_params["batch_size"],
                    eval_batch_size=training_params["eval_batch_size"],
                )

                if best_result is None or result.final_fidelity > best_result.final_fidelity:
                    best_result = result

                fidelities_each.append(float(result.final_fidelity))
            except Exception as e:
                if "Argument rho0 is not hermitian." in str(e):
                    print(f'Sample {s} failed with error: "Argument rho0 is not hermitian."')
                else:
                    print(f"Sample {s} failed with error: {e}")
                continue

        if best_result is not None:
            with open(f"./optimized_architectures/grape_t={num_time_steps}_s={training_params["N_samples"]}_w={weights_str}.json", "w") as f:
                json.dump(FgResult_to_dict(best_result), f)
            
            try:
                result = evaluate_on_longer_time(
                    U_0 = state_callable,
                    C_target = state_callable,
                    system_params = system_params,
                    optimized_trainable_parameters = result.optimized_trainable_parameters,
                    num_time_steps = training_params["evaluation_time_steps"],
                    evo_type = "density",
                    goal = "fidelity",
                    eval_batch_size = training_params["eval_batch_size"],
                    mode = "no-measurement",
                )

                fidelities_grape = result.fidelity_each_timestep

                jnp.savez(f"./evaluation_results/grape_t={num_time_steps}_s={training_params["N_samples"]}_w={weights_str}.npz", fidelities_grape=jnp.array(fidelities_grape))
            except Exception as e:
                if "Argument rho0 is not hermitian." in str(e):
                    print(f'Evaluation on longer time failed with error: "Argument rho0 is not hermitian."')
                else:
                    print(f"Evaluation on longer time failed with error: {e}")

        with open(f"./optimized_architectures/grape_t={num_time_steps}_s={training_params["N_samples"]}_w={weights_str}_training_data.json", "w") as f:
            training_data = {
                "fidelities_each_sample": fidelities_each,
                "average_fidelity": float(jnp.mean(jnp.array(fidelities_each))),
                "training_params": training_params,
            }
            json.dump(training_data, f)
    """

    print("Training LUT")

    best_result = None
    fidelities_each = []
    for s in tqdm(range(training_params["N_samples"])):
        system_params = init_fgrape_protocol(jax.random.PRNGKey(s), n, N_chains, gamma)

        # Train LUT
        try: # Catch any errors during optimization, as we may sometimes encounter numerical issues
            result = optimize_pulse(
                U_0=state_callable,
                C_target=state_callable,
                system_params=system_params,
                num_time_steps=num_time_steps,
                lut_depth=lut_depth,
                reward_weights=reward_weights,
                mode="lookup",
                goal="fidelity",
                max_iter=training_params["N_training_iterations"],
                convergence_threshold=training_params["convergence_threshold"],
                learning_rate=training_params["learning_rate"],
                evo_type="density",
                batch_size=training_params["batch_size"],
                eval_batch_size=training_params["eval_batch_size"],
            )

            if best_result is None or result.final_fidelity > best_result.final_fidelity:
                best_result = result

            fidelities_each.append(float(result.final_fidelity))
        except Exception as e:
            if "Argument rho0 is not hermitian." in str(e):
                print(f'Sample {s} failed with error: "Argument rho0 is not hermitian."')
            else:
                print(f"Sample {s} failed with error: {e}")
            continue

    if best_result is not None:
        with open(f"./optimized_architectures/lut_t={num_time_steps}_l={lut_depth}_s={training_params["N_samples"]}_w={weights_str}.json", "w") as f:
            json.dump(FgResult_to_dict(best_result), f)

        try:
            result = evaluate_on_longer_time(
                U_0 = state_callable,
                C_target = state_callable,
                system_params = system_params,
                optimized_trainable_parameters = result.optimized_trainable_parameters,
                num_time_steps = training_params["evaluation_time_steps"],
                evo_type = "density",
                goal = "fidelity",
                eval_batch_size = training_params["eval_batch_size"],
                mode = "lookup",
            )

            fidelities_lut = result.fidelity_each_timestep

            jnp.savez(f"./evaluation_results/lut_t={num_time_steps}_l={lut_depth}_s={training_params["N_samples"]}_w={weights_str}.npz", fidelities_lut=jnp.array(fidelities_lut))
        except Exception as e:
            if "Argument rho0 is not hermitian." in str(e):
                print(f'Evaluation on longer time failed with error: "Argument rho0 is not hermitian."')
            else:
                print(f"Evaluation on longer time failed with error: {e}")

    with open(f"./optimized_architectures/lut_t={num_time_steps}_l={lut_depth}_s={training_params["N_samples"]}_w={weights_str}_training_data.json", "w") as f:
        training_data = {
            "fidelities_each_sample": fidelities_each,
            "average_fidelity": float(jnp.mean(jnp.array(fidelities_each))),
            "training_params": training_params,
        }
        json.dump(training_data, f)
    
    
    # Check if RNN has to be trained (only if num_time_steps and weight are different)
    if os.path.exists(f"./optimized_architectures/rnn_t={num_time_steps}_s={training_params["N_samples"]}_w={weights_str}.json"):
        print(f"RNN for t={num_time_steps} and weights={weights_str} already trained, skipping training.")
    else:
        print("Training RNN")

        best_result = None
        fidelities_each = []
        for s in tqdm(range(training_params["N_samples"])):
            system_params = init_fgrape_protocol(jax.random.PRNGKey(s), n, N_chains, gamma)

            # Train RNN
            try: # Catch any errors during optimization, as we may sometimes encounter numerical issues
                result = optimize_pulse(
                    U_0=state_callable,
                    C_target=state_callable,
                    system_params=system_params,
                    num_time_steps=num_time_steps,
                    lut_depth=lut_depth,
                    reward_weights=reward_weights,
                    mode="nn",
                    goal="fidelity",
                    max_iter=training_params["N_training_iterations"],
                    convergence_threshold=training_params["convergence_threshold"],
                    learning_rate=training_params["learning_rate"],
                    evo_type="density",
                    batch_size=training_params["batch_size"],
                    eval_batch_size=training_params["eval_batch_size"],
                )

                if best_result is None or result.final_fidelity > best_result.final_fidelity:
                    best_result = result

                fidelities_each.append(float(result.final_fidelity))
            except Exception as e:
                if "Argument rho0 is not hermitian." in str(e):
                    print(f'Sample {s} failed with error: "Argument rho0 is not hermitian."')
                else:
                    print(f"Sample {s} failed with error: {e}")
                continue

        if best_result is not None:
            with open(f"./optimized_architectures/rnn_t={num_time_steps}_s={training_params["N_samples"]}_w={weights_str}.json", "w") as f:
                json.dump(FgResult_to_dict(best_result), f)

            try:
                result = evaluate_on_longer_time(
                    U_0 = state_callable,
                    C_target = state_callable,
                    system_params = system_params,
                    optimized_trainable_parameters = result.optimized_trainable_parameters,
                    num_time_steps = training_params["evaluation_time_steps"],
                    evo_type = "density",
                    goal = "fidelity",
                    eval_batch_size = training_params["eval_batch_size"],
                    mode = "nn",
                )

                fidelities_rnn = result.fidelity_each_timestep

                jnp.savez(f"./evaluation_results/rnn_t={num_time_steps}_s={training_params["N_samples"]}_w={weights_str}.npz", fidelities_rnn=jnp.array(fidelities_rnn))
            except Exception as e:
                if "Argument rho0 is not hermitian." in str(e):
                    print(f'Evaluation on longer time failed with error: "Argument rho0 is not hermitian."')
                else:
                    print(f"Evaluation on longer time failed with error: {e}")

        with open(f"./optimized_architectures/rnn_t={num_time_steps}_s={training_params["N_samples"]}_w={weights_str}_training_data.json", "w") as f:
            training_data = {
                "fidelities_each_sample": fidelities_each,
                "average_fidelity": float(jnp.mean(jnp.array(fidelities_each))),
                "training_params": training_params,
            }
            json.dump(training_data, f)
    
# Play a sound when done
os.system('say "done."')

Evaluating num_time_steps=1, lut_depth=1, reward_weights=[1]
Training LUT


100%|██████████| 3/3 [00:33<00:00, 11.21s/it]


Training RNN


100%|██████████| 3/3 [00:40<00:00, 13.65s/it]


Evaluating num_time_steps=2, lut_depth=1, reward_weights=[0, 1]
Training LUT


  0%|          | 0/3 [00:06<?, ?it/s]


KeyboardInterrupt: 