In [1]:
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import (
    Predictive,
    SVI,
    TraceMeanField_ELBO,
    autoguide,
    init_to_feasible,
)
from numpyro.contrib.module import random_flax_module
import inspeqtor.experimental as sq

jax.config.update("jax_enable_x64", True)

## Simulate some data and load it to LoadedData object

In [2]:
key = jax.random.key(0)
key, data_key, model_key, train_key, gate_optim_key = jax.random.split(key, 5)
sample_size = 1000

qubit_info = sq.predefined.get_mock_qubit_information()


# NOTE pick the pulse sequence you want to use
# def get_pulse_sequence_fn():
#         return sq.predefined.get_drag_pulse_sequence(qubit_info)

get_pulse_sequence_fn = sq.predefined.get_multi_drag_pulse_sequence_v3

# NOTE: Simulate the experiment with some detuning noise
(
    exp_data,
    pulse_sequence,
    noisy_unitaries,
    signal_params_list,
    noisy_simulator,
    whitebox,
) = sq.predefined.generate_mock_experiment_data(
    key=data_key,
    sample_size=sample_size,
    shots=3000,
    strategy=sq.predefined.SimulationStrategy.SHOT,
    detune=0.001,
    get_pulse_sequence_fn=get_pulse_sequence_fn,
    get_qubit_information_fn=sq.predefined.get_mock_qubit_information,
)

# Prepare the data for training
loaded_data = sq.utils.prepare_data(
    exp_data=exp_data, pulse_sequence=pulse_sequence, whitebox=whitebox
)

  ) = sq.predefined.generate_mock_experiment_data(
  hamiltonian = detune_hamiltonian(ideal_hamiltonian, detune)
  out = fun(*args, **kwargs)
  out = fun(*args, **kwargs)


## Dataset metric

-   `var` is the lowest test MSE of expectation value that model could hope to get.
-   `mse_ideal2exp` is the MSE of expectation value between ideal evolution and noisy observation. This show how much data is noisy. This point could be used as a reference to compare with the test MSE of expectation value.


In [3]:
sq.utils.get_dataset_metrics(loaded_data)

DatasetMetrics(var=0.00022213154722644025, mse_ideal2exp=0.0004172154452236824, total_iterations=9000, step_for_optimizer=8000, warmup_steps=800, cool_down_steps=1000)

## Define Bayesian model

In [4]:
def make_bmlp_model(SHOTS: int):
    def model_v2b(
        pulse_parameters: jnp.ndarray,
        unitaries: jnp.ndarray,
        expectation_values: jnp.ndarray | None = None,
    ):
        # NOTE: Seem correct?, but not fast.

        # NOTE: Perform auto expanding on unitaries
        samples_shape = pulse_parameters.shape[:-2]
        unitaries = jnp.broadcast_to(
            unitaries, samples_shape + unitaries.shape[-3:]
        )
        # unitaries = helper.lexpand(unitaries, *samples_shape)

        # Define the base module
        base_module = sq.model.BasicBlackBoxV2(
            hidden_sizes_1=[9, 29, 13],
            hidden_sizes_2=[46, 32],
        )
        # Initialize BMLP model
        net = random_flax_module(
            "nn",
            base_module,
            dist.Normal(0, 1.0),
            input_shape=pulse_parameters.shape,
        )

        # Predict Wo from pulse parameters
        Wo = net(pulse_parameters)

        # With unitary and Wo, calculate expectation values
        expvals = sq.model.get_predict_expectation_value(
            Wos_params=Wo,
            unitaries=unitaries,
            evaluate_expectation_values=sq.constant.default_expectation_values_order,
        )

        # Calculate the variance of expectation value from the expectation value
        expected_scale = jnp.sqrt(
            sq.utils.variance_of_observable(
                expval=expvals,
                shots=SHOTS,
            )
        )

        for idx, exp in enumerate(sq.constant.default_expectation_values_order):
            numpyro.sample(
                f"obs/{exp.initial_state}/{exp.observable}",
                dist.TruncatedNormal(
                    loc=expvals[..., idx], # type: ignore
                    scale=expected_scale[..., idx], # type: ignore
                    low=-1.0,
                    high=1.0,
                ),
                obs=(
                    expectation_values[..., idx]
                    if expectation_values is not None
                    else None
                ),
            )

    return model_v2b

In [5]:
key = jax.random.key(0)
key, random_split_key, train_key, prediction_key = jax.random.split(key, 4)
(
    train_pulse_parameters,
    train_unitaries,
    train_expectation_values,
    test_pulse_parameters,
    test_unitaries,
    test_expectation_values,
) = sq.utils.random_split(
    random_split_key,
    100,  # Test size
    loaded_data.pulse_parameters,
    loaded_data.unitaries,
    loaded_data.expectation_values,
)

In [6]:
model = make_bmlp_model(loaded_data.experiment_data.experiment_config.shots)
guide = autoguide.AutoDiagonalNormal(model, init_loc_fn=init_to_feasible)
optimizer = sq.optimize.get_default_optimizer(8000)

svi = SVI(
    model=model,
    guide=guide,
    optim=numpyro.optim.optax_to_numpyro(optimizer),
    loss=TraceMeanField_ELBO(),
)

svi_result = svi.run(
    train_key,
    num_steps=10_000,
    pulse_parameters=train_pulse_parameters,
    unitaries=train_unitaries,
    expectation_values=train_expectation_values,
    progress_bar=False,
)

In [7]:
params, losses = svi_result.params, svi_result.losses
predictive = Predictive(model, guide=guide, params=params, num_samples=1_000)


y_pred = predictive(
    prediction_key,
    pulse_parameters=test_pulse_parameters,
    unitaries=test_unitaries,
)

In [8]:
predicted_expvals = [
    y_pred[f"obs/{exp.initial_state}/{exp.observable}"]
    for exp in sq.constant.default_expectation_values_order
]

predicted_expvals = jnp.array(predicted_expvals)
# Swapaxes
predicted_expvals = jnp.swapaxes(predicted_expvals, 0, 1)
predicted_expvals = jnp.swapaxes(predicted_expvals, 1, 2)
print(predicted_expvals.shape)

optimal_mse_losses = sq.utils.variance_of_observable(
    expval=predicted_expvals,
    shots=loaded_data.experiment_data.experiment_config.shots,
)
print(optimal_mse_losses.shape, optimal_mse_losses.mean(), optimal_mse_losses.std())

mse_losses = jax.vmap(jax.vmap(sq.model.mse, in_axes=(0, 0)), in_axes=(0, None))(
    predicted_expvals,
    test_expectation_values,
)
print(mse_losses.shape, mse_losses.mean(), mse_losses.std())

(1000, 100, 18)
(1000, 100, 18) 0.00022232106868100968 0.00013795110879701736
(1000, 100) 0.0005340922778215452 0.0002168604855420022
