# Stabilization of 5 qubits using a RNN control model
In this example, we train a RNN model to stabilize 5 qubits against decay and decoherence using 4 generalized measurements plus a unitary per error correction cycle and compare its performance to the uncontrolled case. Mathematically, the model is powerful enough to express Laflamme code. To demonstrate GPU acceleration, we chose a larger batch size and fewer training iterations. The code takes 16 min on a NVIDIA Quadro RTX 6000 node.

In [1]:
# ruff: noqa
import os

os.sys.path.append("../../../..")

# ruff: noqa
from feedback_grape.fgrape import optimize_pulse
from feedback_grape.fgrape import Gate, Decay # type: ignore
from feedback_grape.utils.states import basis # type: ignore
from feedback_grape.utils.fidelity import ket2dm # type: ignore
from feedback_grape.utils.operators import sigmaz, sigmam # type: ignore
from feedback_grape.utils.modeling import embed # type: ignore

import jax
import jax.numpy as jnp

In [None]:
# Experimental parameters
N_qubits = 5
N_meas = 4
gamma_z = 0.1 # dephasing rate
gamma_m = 0.1 # decay rate
psi_target = (basis(2**N_qubits, 0) + basis(2**N_qubits, 2**N_qubits - 1)) / jnp.sqrt(2) # (|11...1> + |00...0>)/sqrt(2)
rho_target = ket2dm(psi_target)

# Model parameters
rnn_hidden_size = 16

# Training parameters
num_time_steps = 3
reward_weights = [1.0]*num_time_steps
convergence_threshold = None # no early stopping
N_training_iterations = 250 # relatively small number of iterations for demo purposes
batch_size = 128
learning_rate = 3e-3

# Evaluation parameters
eval_time_steps = 10
eval_batch_size = 128

In [None]:
# All the operators we need
def generate_hermitian(params, dim):
    assert len(params) == dim**2, "Number of real parameters must be dim^2 for an NxN Hermitian matrix."
    
    X = params.reshape(dim, dim)

    # Real part: take the symmetric part of X
    Re = 0.5 * (X + X.T)

    # Imag part: take the antisymmetric part of X
    Im = 0.5 * (X - X.T)

    # Build Hermitian matrix: Herm = Re + 1j * Im
    H = Re + 1j * Im

    return H

def generate_unitary(params, dim):
    assert len(params) == dim**2, "Number of real parameters must be dim^2 for an NxN unitary matrix."

    H = generate_hermitian(params, dim)
    return jax.scipy.linalg.expm(-1j * H)

def generate_povm(measurement_outcome, params, dim):
    """ 
        Generate a 2-outcome POVM elements M_0 and M_1 for a system with Hilbert space dimension dim.
        This function should parametrize all such POVMs up to unitary equivalence, i.e., M_i -> U M_i for some unitary U.
        I.e it parametrizes all pairs (M_0, M_1) such that M_0 M_0† + M_1 M_1† = I.

        measurement_outcome: 0 or 1, indicating which POVM element to generate.
        params: list of dim^(dim+1) real parameters.

        when measurement_outcome == 1:
            M_1 = S D S†
        when measurement_outcome == -1:
            M_0 = S (I - D) S†

        where S is a unitary parametrized by dim^2 parameters, and D is a diagonal matrix with eigenvalues parametrized by dim parameters.
    """
    assert len(params) == dim * (dim + 1), "Number of real parameters must be N * (N + 1) for an NxN POVM element."

    S = generate_unitary(params[0:dim*dim], dim=dim) # All parameters for unitary

    d_vec = jnp.astype(jnp.square(jnp.sin( params[dim*dim:dim*(dim+1)] )), jnp.complex128) # Last #dim parameters for eigenvalues
    d_vec = 1e-8 + (1 - 2e-8) * d_vec # Avoid exactly 0 or 1 eigenvalues for numerical stability

    # jnp.multiply is fast way to "matrix @ diagonal matrix" multiplication, especially for large matrices
    return jnp.where(measurement_outcome == 1,
        jnp.multiply(S, d_vec) @ S.conj().T,
        jnp.multiply(S, jnp.sqrt(1 - jnp.square(d_vec))) @ S.conj().T
    )

def jump_ops(N_qubits, gamma_z, gamma_m):
    return [
        gamma_z**0.5 * embed(sigmaz(), 1, (2**j, 2, 2**(N_qubits - j - 1)))
        for j in range(N_qubits)
    ] + [
        gamma_m**0.5 * embed(sigmam(), 1, (2**j, 2, 2**(N_qubits - j - 1)))
        for j in range(N_qubits)
    ]

# All the gates we need
key = jax.random.PRNGKey(2)
dim = 2**N_qubits # Hilbert space dimension

decay_gate = Decay(
    c_ops=jump_ops(N_qubits, gamma_z, gamma_m),
)

N_unitary_params = dim**2
U_gate = Gate(
    gate=lambda params: generate_unitary(params, dim=dim),
    initial_params = jax.random.uniform(key, (N_unitary_params,), minval=0.0, maxval=1.0),
    measurement_flag = False
)

N_msmt_params = dim * (dim + 1)
msmt_gate = Gate(
    gate=lambda msmt, params: generate_povm(msmt, params, dim=dim),
    initial_params = jax.random.uniform(key, (N_msmt_params,), minval=0.0, maxval=2*jnp.pi),
    measurement_flag = True
)

system_params = [decay_gate] + [msmt_gate]*N_meas + [U_gate]

In [None]:
# Train RNN
result = optimize_pulse(
    U_0=rho_target,
    C_target=rho_target,
    system_params=system_params,
    num_time_steps=num_time_steps,
    reward_weights=reward_weights,
    mode="nn",
    goal="fidelity",
    max_iter=N_training_iterations,
    convergence_threshold=convergence_threshold,
    learning_rate=learning_rate,
    evo_type="density",
    batch_size=batch_size,
    eval_batch_size=eval_batch_size,
    eval_time_steps=eval_time_steps,
    progress=True,
    rnn_hidden_size=rnn_hidden_size,
)

print(f"Iterations: {result.iterations}")
print(f"Average fidelity over 5 timesteps: {jnp.mean(jnp.mean(jnp.array(result.fidelity_each_timestep), axis=1)[1:6]):.2f}")
print(f"Fidelity across {eval_time_steps} timesteps: \n{jnp.mean(jnp.array(result.fidelity_each_timestep), axis=1)}")

# Expected output:
# Iterations: 250
# Average fidelity over 5 timesteps: 0.73
# Fidelity across 10 timesteps:
# [1.         0.75307206 0.78865803 0.74860818 0.69591505 0.6625737 0.63545622 0.61818136 0.58743217 0.54756492 0.51893844]

Iteration 10, Loss: -0.414569, T=33s, eta=744s
Iteration 20, Loss: -0.417313, T=68s, eta=749s
Iteration 30, Loss: -0.443282, T=102s, eta=730s
Iteration 40, Loss: -0.386187, T=136s, eta=702s
Iteration 50, Loss: -0.016632, T=170s, eta=672s
Iteration 60, Loss: 0.313386, T=204s, eta=641s
Iteration 70, Loss: 0.723101, T=238s, eta=609s
Iteration 80, Loss: 1.055035, T=272s, eta=576s
Iteration 90, Loss: 1.007991, T=306s, eta=542s
Iteration 100, Loss: 1.147771, T=340s, eta=509s
Iteration 110, Loss: 1.695535, T=375s, eta=476s
Iteration 120, Loss: 1.546634, T=409s, eta=443s
Iteration 130, Loss: 1.500587, T=443s, eta=409s
Iteration 140, Loss: 1.857865, T=477s, eta=375s
Iteration 150, Loss: 1.472231, T=511s, eta=342s
Iteration 160, Loss: 2.462461, T=545s, eta=308s
Iteration 170, Loss: 2.093630, T=579s, eta=274s
Iteration 180, Loss: 1.545639, T=613s, eta=240s
Iteration 190, Loss: 2.547595, T=647s, eta=206s
Iteration 200, Loss: 2.031090, T=682s, eta=173s
Iteration 210, Loss: 2.501371, T=716s, eta=139

In [None]:
# Simulation of the uncontrolled dynamics
import dynamiqs as dq
from feedback_grape.utils.fidelity import fidelity

dq_result = dq.mesolve(
    H=jnp.zeros((dim, dim)), # No Hamiltonian
    jump_ops=jump_ops(N_qubits, gamma_z, gamma_m),
    rho0=rho_target,
    tsave=jnp.linspace(0, eval_time_steps, eval_time_steps + 1),
)

rho_t = dq_result.states.to_jax()
fidelities = jnp.array([fidelity(C_target=rho_target, U_final=rho_ti, evo_type="density") for rho_ti in rho_t])
print(f"Fidelity of uncontrolled dynamics across {eval_time_steps} timesteps: \n{fidelities}")
print(f"Avergage fidelity over first 5 timesteps: {jnp.mean(fidelities[1:6]):.2f}")

# Expected output: Fidelity of uncontrolled dynamics across 10 timesteps:
# [1.         0.54488266 0.38306071 0.31783434 0.28817987 0.27384715 0.26740115 0.26570994 0.26725884 0.27118467 0.2769193 ]
# Avergage fidelity over first 5 timesteps: 0.36

Fidelity of uncontrolled dynamics across 10 timesteps: [1.         0.54488266 0.38306071 0.31783434 0.28817987 0.27384715
 0.26740115 0.26570994 0.26725884 0.27118467 0.2769193 ]
Avergage fidelity over first 5 timesteps: 0.36
