In [1]:
import jax
import inspeqtor.experimental as sq
from helper import get_data_model, custom_feature_map
from ray import tune

In [None]:
key = jax.random.key(0)
key, data_key, model_key, train_key, gate_optim_key = jax.random.split(key, 5)
sample_size = 1000
shots = 3000
data_model = get_data_model()
qubit_info = sq.predefined.get_mock_qubit_information()
whitebox = sq.predefined.get_single_qubit_whitebox(
    hamiltonian=data_model.ideal_hamiltonian,
    control_sequence=data_model.control_sequence,
    qubit_info=qubit_info,
    dt=data_model.dt,
)

# NOTE pick the pulse sequence you want to use
# def get_control_sequence_fn():
#         return sq.predefined.get_drag_control_sequence(qubit_info)

get_control_sequence_fn = sq.predefined.get_multi_drag_control_sequence_v3

# NOTE: Simulate the experiment with some detuning noise
exp_data, control_sequence, unitaries, noisy_simulator = (
    sq.predefined.generate_experimental_data(
        key=data_key,
        hamiltonian=data_model.total_hamiltonian,
        sample_size=sample_size,
        shots=shots,
        strategy=sq.predefined.SimulationStrategy.SHOT,
        get_qubit_information_fn=lambda: data_model.qubit_information,
        get_control_sequence_fn=lambda: data_model.control_sequence,
    )
)

# Prepare the data for training
loaded_data = sq.utils.prepare_data(
    exp_data=exp_data, control_sequence=control_sequence, whitebox=whitebox
)

model_constructor = sq.model.make_basic_blackbox_model(
    # unitary_activation_fn=lambda x: 2 * jnp.pi * (jnp.cos(x) + 1) / 2,
    # diagonal_activation_fn=lambda x: jnp.cos(x),
    # unitary_activation_fn = lambda x: (2 * jnp.pi * nn.hard_sigmoid(x)) + 1e-3,
    # diagonal_activation_fn = lambda x: ((2 * nn.hard_sigmoid(x)) - 1) + 1e-3,
)

# Choose the loss metric
metric = sq.model.LossMetric.WAEE
# Define trainanle function for hyperparameter tuning
trainable = sq.optimize.default_trainable_v4(
    control_sequence=loaded_data.control_sequence,
    metric=metric,
    experiment_identifier="test",
    hamiltonian=sq.predefined.rotating_transmon_hamiltonian,
    construct_model_fn=lambda x: sq.model.construct_wo_model_from_config(
        x, model_constructor
    ),
    calculate_metrics_fn=sq.model.calculate_metrics,
)

key = jax.random.key(0)
key, random_split_key_1, random_split_key_2, train_key, prediction_key = (
    jax.random.split(key, 5)
)
(
    train_p,
    train_u,
    train_e,
    eval_p,
    eval_u,
    eval_ex,
) = sq.utils.random_split(
    random_split_key_1,
    20,  # Test size
    loaded_data.pulse_parameters,
    loaded_data.unitaries,
    loaded_data.expectation_values,
)

(val_p, val_u, val_ex, test_p, test_u, test_ex) = sq.utils.random_split(
    random_split_key_2, 10, eval_p, eval_u, eval_ex
)

# Hyperparameter tuning
results = sq.optimize.hypertuner(
    trainable=trainable,
    train_pulse_parameters=custom_feature_map(train_p),
    train_unitaries=train_u,
    train_expectation_values=train_e,
    test_pulse_parameters=custom_feature_map(test_p),
    test_unitaries=test_u,
    test_expectation_values=test_ex,
    val_pulse_parameters=custom_feature_map(val_p),
    val_unitaries=val_u,
    val_expectation_values=val_ex,
    train_key=train_key,
    num_samples=10,
    search_algo=sq.optimize.SearchAlgo.OPTUNA,
    metric=metric,
    search_space={
        "hidden_layer_1_1": tune.randint(0, 1),  # (0, 1) means no hidden layer
        "hidden_layer_1_2": tune.randint(0, 1),  # (0, 1) means no hidden layer
        "hidden_layer_2_1": tune.randint(0, 1),  # (0, 1) means no hidden layer
        "hidden_layer_2_2": tune.randint(4, 5),
    },
)

# Get the best hyperparameters
model_state, train_hist, data_config = sq.optimize.get_best_hypertuner_results(
    results, metric=metric
)

0,1
Current time:,2025-04-16 22:14:46
Running for:,00:11:10.20
Memory:,30.5/64.0 GiB

Trial name,status,loc,hidden_layer_1_1,hidden_layer_1_2,hidden_layer_2_1,hidden_layer_2_2,iter,total time (s),train/MSE[E],train/AE[F],train/WAE[E]
trainable_ec2ce321,TERMINATED,127.0.0.1:8800,0,0,0,4,1001,632.21,0.000238892,0.00161811,0.0458585
trainable_c7fa8972,TERMINATED,127.0.0.1:8836,0,0,0,4,1001,637.341,0.000238892,0.00161811,0.0458585
trainable_721fb803,TERMINATED,127.0.0.1:8879,0,0,0,4,1001,639.467,0.000238892,0.00161811,0.0458585
trainable_88f2caa0,TERMINATED,127.0.0.1:8931,0,0,0,4,1001,640.274,0.000238892,0.00161811,0.0458585
trainable_5d3d25ba,TERMINATED,127.0.0.1:8968,0,0,0,4,1001,639.508,0.000238892,0.00161811,0.0458585
trainable_d8d6734f,TERMINATED,127.0.0.1:9008,0,0,0,4,1001,646.261,0.000238892,0.00161811,0.0458585
trainable_690a2b12,TERMINATED,127.0.0.1:9052,0,0,0,4,1001,648.292,0.000238892,0.00161811,0.0458585
trainable_bdc8bf4a,TERMINATED,127.0.0.1:9092,0,0,0,4,1001,646.51,0.000238892,0.00161811,0.0458585
trainable_75472892,TERMINATED,127.0.0.1:9167,0,0,0,4,1001,642.626,0.000238892,0.00161811,0.0458585
trainable_46b1e533,TERMINATED,127.0.0.1:9228,0,0,0,4,1001,634.753,0.000238892,0.00161811,0.0458585


2025-04-16 22:14:46,340	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/Users/porametpathumsoot/ray_results/tune_experiment' in 0.0106s.
2025-04-16 22:14:46,343	INFO tune.py:1041 -- Total run time: 670.21 seconds (670.19 seconds for the tuning loop).
2025-04-16 22:14:46,343	INFO tune.py:1041 -- Total run time: 670.21 seconds (670.19 seconds for the tuning loop).


In [None]:
# Save model and load model
save_path = sq.model.save_model(
    path="ckpt",
    experiment_identifier="test",
    control_sequence=loaded_data.control_sequence,
    hamiltonian=data_config.hamiltonian,
    model_config=model_state.model_config,
    model_params=model_state.model_params,
    history=train_hist,
)

loaded_model = sq.model.load_model(save_path)