In [1]:
import jax
import inspeqtor.experimental as sq
from helper import get_data_model, custom_feature_map
from ray import tune

In [2]:
key = jax.random.key(0)
key, data_key, model_key, train_key, gate_optim_key = jax.random.split(key, 5)
sample_size = 1000
shots = 3000
data_model = get_data_model()
qubit_info = sq.predefined.get_mock_qubit_information()
whitebox = sq.predefined.get_single_qubit_whitebox(
    hamiltonian=data_model.ideal_hamiltonian,
    control_sequence=data_model.control_sequence,
    qubit_info=qubit_info,
    dt=data_model.dt,
)

# NOTE pick the pulse sequence you want to use
# def get_control_sequence_fn():
#         return sq.predefined.get_drag_control_sequence(qubit_info)

get_control_sequence_fn = sq.predefined.get_multi_drag_control_sequence_v3

# NOTE: Simulate the experiment with some detuning noise
exp_data, control_sequence, unitaries, noisy_simulator = (
    sq.predefined.generate_experimental_data(
        key=data_key,
        hamiltonian=data_model.total_hamiltonian,
        sample_size=sample_size,
        shots=shots,
        strategy=sq.predefined.SimulationStrategy.SHOT,
        get_qubit_information_fn=lambda: data_model.qubit_information,
        get_control_sequence_fn=lambda: data_model.control_sequence,
    )
)

# Prepare the data for training
loaded_data = sq.utils.prepare_data(
    exp_data=exp_data, control_sequence=control_sequence, whitebox=whitebox
)

model_constructor = sq.model.make_basic_blackbox_model(
    # unitary_activation_fn=lambda x: 2 * jnp.pi * (jnp.cos(x) + 1) / 2,
    # diagonal_activation_fn=lambda x: jnp.cos(x),
    # unitary_activation_fn = lambda x: (2 * jnp.pi * nn.hard_sigmoid(x)) + 1e-3,
    # diagonal_activation_fn = lambda x: ((2 * nn.hard_sigmoid(x)) - 1) + 1e-3,
)

# Choose the loss metric
metric = sq.model.LossMetric.WAEE
# Define trainanle function for hyperparameter tuning
trainable = sq.optimize.default_trainable_v4(
    control_sequence=loaded_data.control_sequence,
    metric=metric,
    experiment_identifier="test",
    hamiltonian=sq.predefined.rotating_transmon_hamiltonian,
    construct_model_fn=lambda x: sq.model.construct_wo_model_from_config(
        x, model_constructor
    ),
    calculate_metrics_fn=sq.model.calculate_metrics,
)

key = jax.random.key(0)
key, random_split_key_1, random_split_key_2, train_key, prediction_key = (
    jax.random.split(key, 5)
)
(
    train_p,
    train_u,
    train_e,
    eval_p,
    eval_u,
    eval_ex,
) = sq.utils.random_split(
    random_split_key_1,
    20,  # Test size
    loaded_data.control_parameters,
    loaded_data.unitaries,
    loaded_data.expectation_values,
)

(val_p, val_u, val_ex, test_p, test_u, test_ex) = sq.utils.random_split(
    random_split_key_2, 10, eval_p, eval_u, eval_ex
)

train_data = sq.optimize.DataBundled(custom_feature_map(train_p), train_u, train_e)
val_data = sq.optimize.DataBundled(custom_feature_map(val_p), val_u, val_ex)
test_data = sq.optimize.DataBundled(custom_feature_map(test_p), test_u, test_ex)

# Hyperparameter tuning
results = sq.optimize.hypertuner(
    trainable=trainable,
    train_data=train_data,
    test_data=test_data,
    val_data=val_data,
    train_key=train_key,
    num_samples=10,
    search_algo=sq.optimize.SearchAlgo.OPTUNA,
    metric=metric,
    search_space={
        "hidden_layer_1_1": tune.randint(0, 1),  # (0, 1) means no hidden layer
        "hidden_layer_1_2": tune.randint(0, 1),  # (0, 1) means no hidden layer
        "hidden_layer_2_1": tune.randint(0, 1),  # (0, 1) means no hidden layer
        "hidden_layer_2_2": tune.randint(4, 5),
    },
)

# Get the best hyperparameters
model_state, train_hist, data_config = sq.optimize.get_best_hypertuner_results(
    results, metric=metric
)

0,1
Current time:,2025-06-26 01:32:49
Running for:,00:10:23.94
Memory:,28.5/64.0 GiB

Trial name,status,loc,hidden_layer_1_1,hidden_layer_1_2,hidden_layer_2_1,hidden_layer_2_2,iter,total time (s),train/MSE[E],train/AE[F],train/WAE[E]
trainable_acf4d6fb,TERMINATED,127.0.0.1:89395,0,0,0,4,1001,595.416,0.000238776,0.00161791,0.0458556
trainable_c7eb8b3e,TERMINATED,127.0.0.1:89413,0,0,0,4,1001,593.861,0.000238776,0.00161791,0.0458556
trainable_82cfb7b1,TERMINATED,127.0.0.1:89425,0,0,0,4,1001,594.262,0.000238776,0.00161791,0.0458556
trainable_a1da021a,TERMINATED,127.0.0.1:89440,0,0,0,4,1001,599.626,0.000238776,0.00161791,0.0458556
trainable_c625f3ed,TERMINATED,127.0.0.1:89452,0,0,0,4,1001,594.15,0.000238776,0.00161791,0.0458556
trainable_291098d4,TERMINATED,127.0.0.1:89468,0,0,0,4,1001,590.62,0.000238776,0.00161791,0.0458556
trainable_0126af9c,TERMINATED,127.0.0.1:89496,0,0,0,4,1001,595.898,0.000238776,0.00161791,0.0458556
trainable_cd68519b,TERMINATED,127.0.0.1:89506,0,0,0,4,1001,596.411,0.000238776,0.00161791,0.0458556
trainable_cb648ad8,TERMINATED,127.0.0.1:89520,0,0,0,4,1001,600.154,0.000238776,0.00161791,0.0458556
trainable_ebf2447b,TERMINATED,127.0.0.1:89536,0,0,0,4,1001,593.284,0.000238776,0.00161791,0.0458556


[36m(trainable pid=89395)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/Users/porametpathumsoot/ray_results/tune_experiment/trainable_acf4d6fb_1_hidden_layer_1_1=0,hidden_layer_1_2=0,hidden_layer_2_1=0,hidden_layer_2_2=4_2025-06-26_01-22-25/checkpoint_000000)
[36m(trainable pid=89425)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/Users/porametpathumsoot/ray_results/tune_experiment/trainable_82cfb7b1_3_hidden_layer_1_1=0,hidden_layer_1_2=0,hidden_layer_2_1=0,hidden_layer_2_2=4_2025-06-26_01-22-29/checkpoint_000000)[32m [repeated 2x across cluster][0m
[36m(trainable pid=89395)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/Users/porametpathumsoot/ray_results/tune_experiment/trainable_acf4d6fb_1_hidden_layer_1_1=0,hidden_layer_1_2=0,hidden_layer_2_1=0,hidden_layer_2_2=4_2025-06-26_01-22-25/checkpoint_000001)[32m [repeated 2x across cluster][0m
[36m(trainable pid=89413)[0m Checkpoint successfully

In [3]:
# Save model and load model
save_path = sq.model.save_model(
    path="ckpt",
    experiment_identifier="test",
    control_sequence=loaded_data.control_sequence,
    hamiltonian=data_config.hamiltonian,
    model_config=model_state.model_config,
    model_params=model_state.model_params,
    history=train_hist,
)

loaded_model = sq.model.load_model(save_path)