In [None]:
import typing
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.typing import VariableDict
import flax.traverse_util as traverse_util
from functools import partial
import optax  # type: ignore

import tempfile
from enum import StrEnum

import inspeqtor.experimental as sq
from ray.tune.search.sample import Domain
from helper import get_data_model, custom_feature_map

from ray import tune

In [None]:
def transform_key(data):
    return {
        # Concanate the key by '/'
        "/".join(key): value
        for key, value in data.items()
    }


def clean_history_entries(
    histories: list[sq.optimize.HistoryEntryV3],
):
    clean_histories = [
        {
            "step": history.step,
            "loss": history.loss,
            "loop": history.loop,
            **history.aux,
        }
        for history in histories
    ]
    # Move from device to host, i.e. from jax.Array to numpy.ndarray
    clean_histories = jax.tree.map(
        lambda x: x.item() if isinstance(x, jnp.ndarray) else x, clean_histories
    )
    # Flatten the nested dictionaries
    clean_histories = list(map(traverse_util.flatten_dict, clean_histories))
    # Transform the keys of the dictionary
    clean_histories = list(map(transform_key, clean_histories))
    return clean_histories


def default_trainable_v4(
    control_sequence: sq.control.ControlSequence,
    metric: sq.model.LossMetric,
    experiment_identifier: str,
    hamiltonian: typing.Callable | str,
    construct_model_fn: typing.Callable[
        [dict[str, int]], tuple[nn.Module, dict[str, typing.Any]]
    ],
    calculate_metrics_fn: typing.Callable,
    NUM_EPOCH: int = 1000,
    CHECKPOINT_EVERY: int = 100,
):
    """Create trainable function for `ray.tune` for hyperparameter tuning

    Args:
        control_sequence (PulseSequence): Pulse sequence of dataset
        metric (LossMetric): Metric to be minimized for.
        experiment_identifier (str): The experiment identifier
        hamiltonian (typing.Callable | str): Ideal Hamiltonian function or name.
        model_choice (type[nn.Module], optional): Choice of the Blackbox model. Defaults to BasicBlackBoxV2.
        NUM_EPOCH (int, optional): Number of training epoch. Defaults to 1000.
        CHECKPOINT_EVERY (int, optional): Checkpointing every given number. Defaults to 100.

    Returns:
        typing.Callable: Trainable function that recieve hyperparameter configutation, dataset and random key.
    """
    from ray import train

    def trainable(
        config: dict[str, int],
        train_data: sq.optimize.DataBundled,
        val_data: sq.optimize.DataBundled,
        test_data: sq.optimize.DataBundled,
        train_key: jnp.ndarray,
    ):
        optimizer = sq.optimize.get_default_optimizer(8 * NUM_EPOCH)

        model, model_config = construct_model_fn(config)

        partial_loss_fn = partial(
            sq.model.loss_fn,
            model=model,
            loss_metric=metric,
            calculate_metrics_fn=calculate_metrics_fn,
        )

        def prepare_report(history: list[sq.optimize.HistoryEntryV3]):
            metric_types = [
                sq.model.LossMetric.MSEE,
                sq.model.LossMetric.AEF,
                sq.model.LossMetric.WAEE,
            ]
            metrics = {}
            for entry in history:
                for metric_type in metric_types:
                    metrics[f"{entry.loop}/{metric_type}"] = entry.aux[
                        metric_type
                    ].item()

            return metrics

        def callback(
            model_params: VariableDict,
            opt_state: optax.OptState,
            history: list[sq.optimize.HistoryEntryV3],
        ) -> None:
            # Get the lasted 3 entries
            last_entries = history[-3:]

            loops = ["train", "val", "test"]
            # assert that the last 3 entries are from train, val, and test
            assert all(entry.loop in loops for entry in last_entries)

            # Prepare the report
            metrics = prepare_report(history)

            # Check if last_entry.step is divisible by 100
            if (last_entries[-1].step + 1) % CHECKPOINT_EVERY == 0:
                # Checkpoint the model

                # Clean the history entries
                clean_histories = clean_history_entries(history)

                with tempfile.TemporaryDirectory() as tmpdir:
                    _ = sq.model.save_model(
                        path=tmpdir,
                        experiment_identifier=experiment_identifier,
                        control_sequence=control_sequence,
                        hamiltonian=hamiltonian,
                        model_config=model_config,
                        model_params=model_params,
                        history=clean_histories,
                        with_auto_datetime=False,
                    )

                    # Report the loss and val_loss to tune
                    train.report(
                        metrics=metrics,
                        checkpoint=train.Checkpoint.from_directory(tmpdir),
                    )
            else:
                # Report the loss and val_loss to tune
                train.report(
                    metrics=metrics,
                )

            return None

        _, _, history = sq.optimize.train_model(
            key=train_key,
            train_data=train_data,
            val_data=val_data,
            test_data=test_data,
            model=model,
            optimizer=optimizer,
            loss_fn=partial_loss_fn,
            NUM_EPOCH=NUM_EPOCH,
            callbacks=[callback],
        )

        # Prepare the report
        metrics = prepare_report(history[-3:])

        return metrics

    return trainable


def sample_from_search_space(search_space: typing.Mapping[str, Domain]):
    return {key: value.sample() for key, value in search_space.items()}


class SearchAlgo(StrEnum):
    HYPEROPT = "hyperopt"
    OPTUNA = "optuna"


def hypertuner(
    trainable: typing.Callable,
    train_data: sq.optimize.DataBundled,
    test_data: sq.optimize.DataBundled,
    val_data: sq.optimize.DataBundled,
    train_key: jnp.ndarray,
    metric: sq.model.LossMetric,
    search_space: typing.Mapping[str, Domain],
    num_samples: int = 100,
    search_algo: SearchAlgo = SearchAlgo.HYPEROPT,
):
    """Perform hyperparameter tuning

    Args:
        trainable (typing.Callable): Trainable function
        train_pulse_parameters (jnp.ndarray): Training pulse parameters
        train_unitaries (jnp.ndarray): Training ideal unitary matrix
        train_expectation_values (jnp.ndarray): Training experiment expectation value
        test_pulse_parameters (jnp.ndarray): Testing pulse parameters
        test_unitaries (jnp.ndarray): Testing ideal unitary matrix
        test_expectation_values (jnp.ndarray): Testing experiment expectation value
        val_pulse_parameters (jnp.ndarray): Validating pulse parameters
        val_unitaries (jnp.ndarray): Validating ideal unitary matrix
        val_expectation_values (jnp.ndarray): Validating experiment expectation value
        train_key (jnp.ndarray): Random key
        metric (LossMetric): Metric to optimized for.
        num_samples (int, optional): The number of random configuration of hyperparameter. Defaults to 100.
        search_algo (SearchAlgo, optional): The search algorithm to be used for optimization. Defaults to SearchAlgo.HYPEROPT.
        search_spaces (_type_, optional): Search space of hyperparameters. Defaults to { "hidden_layer_1_1": (5, 50), "hidden_layer_1_2": (5, 50), "hidden_layer_2_1": (5, 50), "hidden_layer_2_2": (5, 50), }.
        initial_config (_type_, optional): Initial hyperparameters. Defaults to { "hidden_layer_1_1": 10, "hidden_layer_1_2": 20, "hidden_layer_2_1": 10, "hidden_layer_2_2": 20, }.

    Returns:
        _type_: Optimization result.
    """
    from ray import tune
    from ray.tune.search.hyperopt import HyperOptSearch
    from ray.tune.search.optuna import OptunaSearch
    from ray.tune.search import Searcher

    current_best_params = [{key: value.sample() for key, value in search_space.items()}]

    # Prepend 'val/' to the metric
    prepended_metric = f"val/{metric}"

    if search_algo == SearchAlgo.HYPEROPT:
        search_algo_instance: Searcher = HyperOptSearch(
            metric=prepended_metric,
            mode="min",
            points_to_evaluate=current_best_params,
        )
    elif search_algo == SearchAlgo.OPTUNA:
        search_algo_instance = OptunaSearch(
            metric=prepended_metric,
            mode="min",
        )

    run_config = tune.RunConfig(
        name="tune_experiment",
        checkpoint_config=tune.CheckpointConfig(
            num_to_keep=10,
        ),
    )

    tuner = tune.Tuner(
        tune.with_parameters(
            trainable,
            train_data=train_data,
            val_data=val_data,
            test_data=test_data,
            train_key=train_key,
        ),
        tune_config=tune.TuneConfig(
            search_alg=search_algo_instance,
            metric=prepended_metric,
            mode="min",
            num_samples=num_samples,
        ),
        param_space=search_space,  # type: ignore
        run_config=run_config,
    )

    results = tuner.fit()

    return results


def get_best_hypertuner_results(
    results, metric: sq.model.LossMetric, loop: str = "val"
):
    prepended_metric = f"{loop}/{metric}"

    with results.get_best_result(
        metric=prepended_metric, mode="min"
    ).checkpoint.as_directory() as checkpoint_dir:
        model_state, hist, data_config = sq.model.load_model(
            checkpoint_dir, skip_history=False
        )
    return model_state, hist, data_config


In [None]:
key = jax.random.key(0)
key, data_key, model_key, train_key, gate_optim_key = jax.random.split(key, 5)
sample_size = 1000
shots = 3000
data_model = get_data_model()
qubit_info = sq.predefined.get_mock_qubit_information()
whitebox = sq.predefined.get_single_qubit_whitebox(
    hamiltonian=data_model.ideal_hamiltonian,
    control_sequence=data_model.control_sequence,
    qubit_info=qubit_info,
    dt=data_model.dt,
)

# NOTE pick the pulse sequence you want to use
# def get_control_sequence_fn():
#         return sq.predefined.get_drag_control_sequence(qubit_info)

get_control_sequence_fn = sq.predefined.get_multi_drag_control_sequence_v3

# NOTE: Simulate the experiment with some detuning noise
exp_data, control_sequence, unitaries, noisy_simulator = (
    sq.predefined.generate_experimental_data(
        key=data_key,
        hamiltonian=data_model.total_hamiltonian,
        sample_size=sample_size,
        shots=shots,
        strategy=sq.predefined.SimulationStrategy.SHOT,
        get_qubit_information_fn=lambda: data_model.qubit_information,
        get_control_sequence_fn=lambda: data_model.control_sequence,
    )
)

# Prepare the data for training
loaded_data = sq.utils.prepare_data(
    exp_data=exp_data, control_sequence=control_sequence, whitebox=whitebox
)

model_constructor = sq.model.make_basic_blackbox_model(
    # unitary_activation_fn=lambda x: 2 * jnp.pi * (jnp.cos(x) + 1) / 2,
    # diagonal_activation_fn=lambda x: jnp.cos(x),
    # unitary_activation_fn = lambda x: (2 * jnp.pi * nn.hard_sigmoid(x)) + 1e-3,
    # diagonal_activation_fn = lambda x: ((2 * nn.hard_sigmoid(x)) - 1) + 1e-3,
)

# Choose the loss metric
metric = sq.model.LossMetric.WAEE
# Define trainanle function for hyperparameter tuning
trainable = sq.optimize.default_trainable_v4(
    control_sequence=loaded_data.control_sequence,
    metric=metric,
    experiment_identifier="test",
    hamiltonian=sq.predefined.rotating_transmon_hamiltonian,
    construct_model_fn=lambda x: sq.model.construct_wo_model_from_config(
        x, model_constructor
    ),
    calculate_metrics_fn=sq.model.calculate_metrics,
)

key = jax.random.key(0)
key, random_split_key_1, random_split_key_2, train_key, prediction_key = (
    jax.random.split(key, 5)
)
(
    train_p,
    train_u,
    train_e,
    eval_p,
    eval_u,
    eval_ex,
) = sq.utils.random_split(
    random_split_key_1,
    20,  # Test size
    loaded_data.control_parameters,
    loaded_data.unitaries,
    loaded_data.expectation_values,
)

(val_p, val_u, val_ex, test_p, test_u, test_ex) = sq.utils.random_split(
    random_split_key_2, 10, eval_p, eval_u, eval_ex
)

train_data = sq.optimize.DataBundled(custom_feature_map(train_p), train_u, train_e)
val_data = sq.optimize.DataBundled(custom_feature_map(val_p), val_u, val_ex)
test_data = sq.optimize.DataBundled(custom_feature_map(test_p), test_u, test_ex)

# Hyperparameter tuning
results = sq.optimize.hypertuner(
    trainable=trainable,
    train_data=train_data,
    test_data=test_data,
    val_data=val_data,
    train_key=train_key,
    num_samples=10,
    search_algo=sq.optimize.SearchAlgo.OPTUNA,
    metric=metric,
    search_space={
        "hidden_layer_1_1": tune.randint(0, 1),  # (0, 1) means no hidden layer
        "hidden_layer_1_2": tune.randint(0, 1),  # (0, 1) means no hidden layer
        "hidden_layer_2_1": tune.randint(0, 1),  # (0, 1) means no hidden layer
        "hidden_layer_2_2": tune.randint(4, 5),
    },
)

# Get the best hyperparameters
model_state, train_hist, data_config = sq.optimize.get_best_hypertuner_results(
    results, metric=metric
)

In [None]:
# Save model and load model
save_path = sq.model.save_model(
    path="ckpt",
    experiment_identifier="test",
    control_sequence=loaded_data.control_sequence,
    hamiltonian=data_config.hamiltonian,
    model_config=model_state.model_config,
    model_params=model_state.model_params,
    history=train_hist,
)

loaded_model = sq.model.load_model(save_path)