In [1]:
# ruff: noqa
import sys, os
sys.path.append(os.path.abspath("./../feedback-grape"))
sys.path.append(os.path.abspath("./../"))

# ruff: noqa
from feedback_grape.fgrape import optimize_pulse, Decay, Gate, evaluate_on_longer_time
from feedback_grape.utils.operators import cosm, sinm, identity
from feedback_grape.utils.states import coherent
import jax.numpy as jnp
import jax
from feedback_grape.utils.operators import create, destroy
from feedback_grape.utils.fidelity import ket2dm
from library.utils.FgResult_to_dict import FgResult_to_dict
from tqdm import tqdm
import json

jax.config.update("jax_enable_x64", True)

# Physical parameters
# (attention! #elements in density matrix grow as 4^N_chains)
N_chains = 3 # Number of parallel chains to simulate
gamma = 0.05 # Decay constant
evaluation_time_steps = 200 # Number of time steps for evaluation
batch_size = 32 # Number of random states to evaluate in parallel

# Physical parameters
N_cav = 30  # number of cavity modes
N_snap = 15

alpha = 2
psi_target = coherent(N_cav, alpha) + coherent(N_cav, -alpha)

# Normalize psi_target before constructing rho_target
psi_target = psi_target / jnp.linalg.norm(psi_target)

rho_target = ket2dm(psi_target)

filenames = []
models = []
# open all json files from ./optimized_architectures which begin with "lut"
for file in os.listdir("./optimized_architectures"):
    if file.startswith("lut") and file.endswith(".json"):
        filenames.append(file)

        with open(f"./optimized_architectures/{file}", "r") as f:
            model_params = json.load(f)["optimized_trainable_parameters"]
            models.append(model_params)

In [2]:
def displacement_gate(alphas):
    """Displacement operator for a coherent state."""
    alpha_re, alpha_im = alphas
    alpha = alpha_re + 1j * alpha_im
    gate = jax.scipy.linalg.expm(
        alpha * create(N_cav) - alpha.conj() * destroy(N_cav)
    )
    return gate

def initialize_displacement_gate(key):
    return Gate(
        gate=displacement_gate,
        initial_params=jax.random.uniform(
            key,
            shape=(2,),
            minval=-jnp.pi / 2,
            maxval=jnp.pi / 2,
            dtype=jnp.float64,
        ),
        measurement_flag=False,
    )

def displacement_gate_dag(alphas):
    """Displacement operator for a coherent state."""
    return displacement_gate(alphas).conj().T

def initialize_displacement_gate_dag(key):
    return Gate(
        gate=displacement_gate_dag,
        initial_params=jax.random.uniform(
            key,
            shape=(2,),
            minval=-jnp.pi / 2,
            maxval=jnp.pi / 2,
            dtype=jnp.float64,
        ),
        measurement_flag=False,
    )

def snap_gate(phase_list):
    diags = jnp.ones(shape=(N_cav - len(phase_list)))
    exponentiated = jnp.exp(1j * jnp.array(phase_list))
    diags = jnp.concatenate((exponentiated, diags))
    return jnp.diag(diags)

def initialize_snap_gate(key):
    return Gate(
        gate=snap_gate,
        initial_params=jax.random.uniform(
            key,
            shape=(N_snap,),
            minval=-jnp.pi / 2,
            maxval=jnp.pi / 2,
            dtype=jnp.float64,
        ),
        measurement_flag=False,
    )

def povm_measure_operator(measurement_outcome, params):
    """
    POVM for the measurement of the cavity state.
    returns Mm ( NOT the POVM element Em = Mm_dag @ Mm ), given measurement_outcome m, gamma and delta
    """
    gamma, delta = params
    cav_operator = gamma * create(N_cav) @ destroy(N_cav) + delta * identity(N_cav) / 2
    angle = cav_operator
    meas_op = jnp.where(
        measurement_outcome == 1,
        cosm(angle),
        sinm(angle),
    )
    return meas_op

def initialize_povm_gate(key):
    return Gate(
        gate=povm_measure_operator,
        initial_params=jax.random.uniform(
            key,
            shape=(2,),  # 2 for gamma and delta
            minval=-jnp.pi / 2,
            maxval=jnp.pi / 2,
            dtype=jnp.float64,
        ),
        measurement_flag=True,
    )

decay_gate = Decay(c_ops=[jnp.sqrt(0.005) * destroy(N_cav)])

def initialize_system_params(key):
    keys = jax.random.split(key, 4)
    return [
        decay_gate,
        initialize_povm_gate(keys[0]),
        decay_gate,
        initialize_displacement_gate(keys[1]),
        initialize_snap_gate(keys[2]),
        initialize_displacement_gate_dag(keys[3])
    ]

In [3]:
for filename,model in zip(filenames, models):
    system_params = initialize_system_params(jax.random.PRNGKey(0))

    result = evaluate_on_longer_time(
            U_0 = rho_target,
            C_target = rho_target,
            system_params = system_params,
            optimized_trainable_parameters = model,
            num_time_steps = 1000,
            evo_type = "density",
            goal = "fidelity",
            eval_batch_size = 16, # Default is 10
            mode = "lookup",
            rnn = None,
            rnn_hidden_size = 30,
        )

    fidelities = result.fidelity_each_timestep

    jnp.savez(f"./evaluation_results/{filename[:-5]}.npz", fidelities_lut=fidelities)