In [1]:
import jax
import inspeqtor.experimental as sq

In [2]:
key = jax.random.key(0)
key, data_key, model_key, train_key, gate_optim_key = jax.random.split(key, 5)
sample_size = 1000

qubit_info = sq.predefined.get_mock_qubit_information()


# NOTE pick the pulse sequence you want to use
# def get_pulse_sequence_fn():
#         return sq.predefined.get_drag_pulse_sequence(qubit_info)

get_pulse_sequence_fn = sq.predefined.get_multi_drag_pulse_sequence_v3

# NOTE: Simulate the experiment with some detuning noise
(
    exp_data,
    pulse_sequence,
    noisy_unitaries,
    signal_params_list,
    noisy_simulator,
    whitebox,
) = sq.predefined.generate_mock_experiment_data(
    key=data_key,
    sample_size=sample_size,
    shots=3000,
    strategy=sq.predefined.SimulationStrategy.SHOT,
    detune=0.001,
    get_pulse_sequence_fn=get_pulse_sequence_fn,
    get_qubit_information_fn=sq.predefined.get_mock_qubit_information,
)

# Prepare the data for training
loaded_data = sq.utils.prepare_data(
    exp_data=exp_data, pulse_sequence=pulse_sequence, whitebox=whitebox
)

# Choose the loss metric
metric = sq.model.LossMetric.WAEE
# Define trainanle function for hyperparameter tuning
trainable = sq.optimize.default_trainable_v3(
    pulse_sequence=loaded_data.pulse_sequence,
    metric=metric,
    experiment_identifier="test",
    hamiltonian=sq.predefined.rotating_transmon_hamiltonian,
)

# Hyperparameter tuning
results = sq.optimize.hypertuner(
    trainable=trainable,
    pulse_parameters=loaded_data.pulse_parameters,
    unitaries=loaded_data.unitaries,
    expectation_values=loaded_data.expectation_values,
    train_key=train_key,
    num_samples=10,  # NOTE: The number of samples to train
    search_algo=sq.optimize.SearchAlgo.OPTUNA,
    metric=metric,
)

# Get the best hyperparameters
model_state, train_hist, data_config = sq.optimize.get_best_hypertuner_results(
    results, metric=metric
)

0,1
Current time:,2024-12-06 18:16:04
Running for:,00:01:42.54
Memory:,36.2/64.0 GiB

Trial name,status,loc,hidden_layer_1_1,hidden_layer_1_2,hidden_layer_2_1,hidden_layer_2_2,iter,total time (s),val/MSE[E],val/AE[F],val/WAE[E]
trainable_e5d99e81,TERMINATED,127.0.0.1:87271,23,43,44,26,1001,73.5491,0.000278764,0.00110606,0.0442366
trainable_36ae63f4,TERMINATED,127.0.0.1:87305,15,18,30,16,1001,72.4405,0.000245542,0.00110266,0.0396225
trainable_6f9ff8dc,TERMINATED,127.0.0.1:87341,9,14,13,21,1001,74.8515,0.000245345,0.00110394,0.039513
trainable_040c1241,TERMINATED,127.0.0.1:87374,13,20,40,39,1001,74.9713,0.00027257,0.00110171,0.0419242
trainable_2f7f3d3a,TERMINATED,127.0.0.1:87404,36,6,13,29,1001,75.3379,0.000243729,0.00110347,0.0392391
trainable_5bcf9ef2,TERMINATED,127.0.0.1:87441,36,36,13,22,1001,75.0287,0.00024992,0.00111679,0.0400251
trainable_9c475bc8,TERMINATED,127.0.0.1:87489,24,10,34,33,1001,74.6081,0.000246143,0.00110673,0.0402256
trainable_f71aa40a,TERMINATED,127.0.0.1:87518,13,22,7,37,1001,75.4412,0.000239674,0.00110556,0.0393171
trainable_2f1711a5,TERMINATED,127.0.0.1:87547,37,38,9,21,1001,75.5121,0.000244614,0.00110332,0.0393406
trainable_c5b31b1d,TERMINATED,127.0.0.1:87613,14,18,6,29,1001,75.3941,0.000250129,0.00110455,0.0397181


[36m(trainable pid=87271)[0m   Q = jnp.zeros(U.shape[:-1] + (2, 2), dtype=jnp.complexfloating)
[36m(trainable pid=87271)[0m   Diag = jnp.zeros(D.shape[:-1] + (2, 2), dtype=jnp.complexfloating)
[36m(trainable pid=87271)[0m   Q = jnp.zeros(U.shape[:-1] + (2, 2), dtype=jnp.complexfloating)
[36m(trainable pid=87271)[0m   Diag = jnp.zeros(D.shape[:-1] + (2, 2), dtype=jnp.complexfloating)
[36m(trainable pid=87271)[0m   Q = jnp.zeros(U.shape[:-1] + (2, 2), dtype=jnp.complexfloating)
[36m(trainable pid=87271)[0m   Diag = jnp.zeros(D.shape[:-1] + (2, 2), dtype=jnp.complexfloating)
[36m(trainable pid=87271)[0m   Q = jnp.zeros(U.shape[:-1] + (2, 2), dtype=jnp.complexfloating)
[36m(trainable pid=87271)[0m   Diag = jnp.zeros(D.shape[:-1] + (2, 2), dtype=jnp.complexfloating)
[36m(trainable pid=87374)[0m   Q = jnp.zeros(U.shape[:-1] + (2, 2), dtype=jnp.complexfloating)[32m [repeated 12x across cluster][0m
[36m(trainable pid=87374)[0m   Diag = jnp.zeros(D.shape[:-1] + (2, 2), dty

In [None]:
# Save model and load model
save_path = sq.model.save_model(
    path="ckpt",
    experiment_identifier="test",
    pulse_sequence=loaded_data.pulse_sequence,
    hamiltonian=data_config.hamiltonian,
    model_config=model_state.model_config,
    model_params=model_state.model_params,
    history=train_hist,
)

loaded_model = sq.model.load_model(save_path)