ISSUE: In this code (which is based upon example_I), NaN values are produced in the fgrape update gradients in optimization step 234. Interestingly, this was only observed for LUT method in example_I.

CAUSE: Apparently, the optimizer has problems if parametrized eigenvalues of POVM elements become zero. I interprete that this has to do with probabilities becoming zeros which could cause zero-division errors. In the system studied here, the issue was solved by parametrizing POVM elements such that their eigenvalues are bound between 1e-6 and 1 - 1e-6 instead of bounds between 0 and 1 in line 145 of bug_helpers.py.

Other unsuccessfull attemps of solving the issue were:
- Trying out different methods in dynamiqs.mesolve for solving lindblad equation. Only Kvaerno3 was successfull in the example tried out, at the cost of 300x more runtime. There is no guarantee that it always works.
- Trying out different error tolerances (atol=1e-10, rtol=1e-10) in mesolve Tsit5 method.
- Trying out different values for eps and eps_root in frgrape's adam optimizer, which should avoid division by zeros. It solved some examples but caused the same issue in others.
- Changing the clipping of prob in .utils.povm from "jnp.maximum(prob, 1e-10)" to "prob + 1e-6".
- Skipping of updates where nan is detected and re-evaluating batch with different keys -> the re-evaluated batches mostly produced nan values aswell

In [1]:
# ruff: noqa
import sys, os
sys.path.append(os.path.abspath("./../../"))

# ruff: noqa
from feedback_grape.fgrape import optimize_pulse # type: ignore
from bugs.NaN_for_indefinite_povm.bug_helpers import (
    init_fgrape_protocol,
    test_implementations,
    generate_random_state,
)

import jax

test_implementations()

In [2]:
# Physical parameters
# (attention! #elements in density matrix grow as 4^n*N_chains)
n = 2 # number of qubits per chain (>= 3)
N_chains = 2 # Number of parallel chains to simulate
gamma = 0.25 # Decay constant

# Training and evaluation parameters
training_params = {
    "N_training_iterations": 1000, # Number of training iterations
    "learning_rate": 0.02, # Learning rate
    "convergence_threshold": 1e-6,
    "batch_size": 16,
    "eval_batch_size": 16,
}

# Parameters to test

#num_time_steps : Number of time steps in the control pulse
#lut_depth : Depth of the lookup table for feedback
#reward_weights: Weights for the reward at each time step. Default only weights last timestep [0, 0, ... 0, 1]

num_time_steps, lut_depth, reward_weights = 2, 1, [1, 1]

In [3]:

state_callable = lambda key: generate_random_state(key, N_chains=N_chains)

system_params = init_fgrape_protocol(jax.random.PRNGKey(0), n, N_chains, gamma)

result = optimize_pulse(
    U_0=state_callable,
    C_target=state_callable,
    system_params=system_params,
    num_time_steps=num_time_steps,
    lut_depth=lut_depth,
    reward_weights=reward_weights,
    mode="lookup",
    goal="fidelity",
    max_iter=training_params["N_training_iterations"],
    convergence_threshold=training_params["convergence_threshold"],
    learning_rate=training_params["learning_rate"],
    evo_type="density",
    batch_size=training_params["batch_size"],
    eval_batch_size=training_params["eval_batch_size"],
    progress=True,
)

Iteration 10, Loss: -0.455813, T=0s, eta=52s
Iteration 20, Loss: -0.638924, T=1s, eta=51s
Iteration 30, Loss: -0.660272, T=1s, eta=50s
Iteration 40, Loss: -1.107678, T=2s, eta=50s
Iteration 50, Loss: -1.489174, T=2s, eta=49s
Iteration 60, Loss: -1.324166, T=3s, eta=48s
Iteration 70, Loss: -1.447306, T=3s, eta=47s
Iteration 80, Loss: -1.474641, T=4s, eta=47s
Iteration 90, Loss: -1.191464, T=4s, eta=46s
Iteration 100, Loss: -1.486094, T=5s, eta=46s
Iteration 110, Loss: -0.753582, T=5s, eta=46s
Iteration 120, Loss: -0.662739, T=6s, eta=45s
Iteration 130, Loss: -1.456794, T=6s, eta=44s
Iteration 140, Loss: -1.307933, T=7s, eta=45s
Iteration 150, Loss: -1.363149, T=7s, eta=44s
Iteration 160, Loss: -1.199986, T=8s, eta=43s
Iteration 170, Loss: -1.225327, T=8s, eta=43s
Iteration 180, Loss: -1.737235, T=9s, eta=42s
Iteration 190, Loss: -1.547854, T=9s, eta=42s
Iteration 200, Loss: -0.903245, T=10s, eta=41s
Iteration 210, Loss: -0.983184, T=10s, eta=41s
Iteration 220, Loss: -1.725291, T=11s, et