In [3]:
import jax

jax.config.update("jax_enable_x64", True)
from torch.utils.data import DataLoader, Subset
import optax
import jax.numpy as jnp

from specq_jax.core import (
    SpecQDataset,
    # create_train_step,
    # create_train_step_v2,
    # loss as loss_fn,
    # loss_v2 as loss_fn_v2,
    gate_loss,
    X,
    Y,
    Z,
    calculate_expvals,
    plot_expvals,
    gate_fidelity,
    Wo_2_level,
    rotating_transmon_hamiltonian,
    batched_calculate_expectation_value,
    batch_mse
)
from specq_jax.data import load_data

from specq_jax.model import BasicBlackBox
from exp_data_0020 import get_multi_drag_pulse_sequence
import specq_dev.specq.shared as specq

from jaxopt import ProjectedGradient
from jaxopt.projection import projection_box
import pandas as pd
import matplotlib.pyplot as plt
from alive_progress import alive_bar, alive_it

from torch import Generator, manual_seed

from flax.training.train_state import TrainState
from jaxtyping import Array, Complex, Float
from flax import linen as nn
from typing import Callable

In [28]:
def loss_v2(
    state: TrainState,
    pulse_parameters: Float[Array, "batch num_pulses num_features"],  # noqa: F722
    unitaries: Complex[Array, "batch dim dim"],  # noqa: F722
    expectation_values: Complex[Array, "batch num_expectations"],  # noqa: F722
    evaluate_expectation_values: list[
        specq.ExpectationValue
    ] = specq.default_expectation_values,
):
    # Predict Vo for each pauli operator from paluse parameters
    Wos_params = state.apply_fn(state.params, pulse_parameters)

    predict_expectation_values = []

    # Calculate expectation values for all cases
    for idx, exp_case in enumerate(evaluate_expectation_values):
        Wo = jax.vmap(Wo_2_level, in_axes=(0, 0))(
            Wos_params[exp_case.observable]["U"], Wos_params[exp_case.observable]["D"]
        )
        # Calculate expectation value for each pauli operator
        batch_expectaion_values = batched_calculate_expectation_value(
            unitaries,
            Wo,
            jnp.array(exp_case.initial_statevector),
        )

        predict_expectation_values.append(batch_expectaion_values)

    return jnp.mean(
        batch_mse(expectation_values, jnp.array(predict_expectation_values).T)
    )


def create_train_step_v2(
    key: jnp.ndarray,
    model: nn.Module,
    optimiser: optax.GradientTransformation,
    loss_fn: Callable[
        [
            TrainState,
            Float[Array, "batch num_pulses num_features"],  # noqa: F722
            Complex[Array, "batch dim dim"],  # noqa: F722
            Complex[Array, "batch num_expectations"],  # noqa: F722
        ],
        Float[Array, "1"],
    ],
    input_shape: Array,
):
    params = model.init(
        key,
        jnp.ones(input_shape, jnp.float32),
    )  # dummy key just as example input

    # opt_state = optimiser.init(params)

    state = TrainState.create(apply_fn=model.apply, params=params, tx=optimiser)

    @jax.jit
    def train_step(
        state: TrainState,
        pulse_parameters: Float[Array, "batch num_pulses num_features"],  # noqa: F722
        unitaries: Complex[Array, "batch dim dim"],  # noqa: F722
        expectations: Complex[Array, "batch num_expectations"],  # noqa: F722
    ):
        loss, grads = jax.value_and_grad(loss_fn)(
            state,
            pulse_parameters,
            unitaries,
            expectations,
        )

        state = state.apply_gradients(grads=grads)

        return state, loss

    @jax.jit
    def test_step(
        state: TrainState,
        pulse_parameters: Float[Array, "batch num_pulses num_features"],  # noqa: F722
        unitaries: Complex[Array, "batch dim dim"],  # noqa: F722
        expectations: Complex[Array, "batch num_expectations"],  # noqa: F722
    ):
        loss = loss_fn(
            state,
            pulse_parameters,
            unitaries,
            expectations,
        )
        return loss

    return train_step, test_step, state


def with_validation_train(
    train_dataloader: DataLoader,
    val_dataloader: DataLoader,
    train_step,
    test_step,
    state,
    num_epochs=1250,
):

    history = []
    total_len = len(train_dataloader)

    NUM_EPOCHS = num_epochs

    with alive_bar(int(NUM_EPOCHS * total_len), force_tty=True) as bar:
        for epoch in range(NUM_EPOCHS):
            total_loss = 0.0
            for i, batch in enumerate(train_dataloader):

                _pulse_parameters = batch["x0"].numpy()
                _unitaries = batch["x1"].numpy()
                _expectations = batch["y"].numpy()

                state, loss = train_step(
                    state, _pulse_parameters, _unitaries, _expectations
                )

                history.append(
                    {
                        "epoch": epoch,
                        "step": i,
                        "loss": float(loss),
                        "global_step": epoch * total_len + i,
                        "val_loss": None,
                    }
                )

                total_loss += loss

                bar()

            # Validation
            val_loss = 0.0
            for i, batch in enumerate(val_dataloader):

                _pulse_parameters = batch["x0"].numpy()
                _unitaries = batch["x1"].numpy()
                _expectations = batch["y"].numpy()

                val_loss += test_step(
                    state, _pulse_parameters, _unitaries, _expectations
                )

            history[-1]["val_loss"] = float(val_loss / len(val_dataloader))

    return state, history


def plot_history(history, lr_scheduler):

    hist_df = pd.DataFrame(history)
    train = hist_df[["global_step", "loss"]].values

    train_x = train[:, 0]
    train_y = train[:, 1]

    validate = hist_df[["global_step", "val_loss"]].replace(0, jnp.nan).dropna().values

    validate_x = validate[:, 0]
    validate_y = validate[:, 1]
    # The second plot has height ratio 2
    fig, ax = plt.subplots(2, 1, figsize=(10, 6), sharex=True, height_ratios=[3, 1])

    # The first plot is the training loss and the validation loss
    ax[0].plot(train_x, train_y, label="train_loss")
    ax[0].plot(validate_x, validate_y, ".-", label="val_loss")
    ax[0].set_yscale("log")

    # plot the horizontal line [1e-3, 1e-2]
    ax[0].axhline(1e-3, color="red", linestyle="--")
    ax[0].axhline(1e-2, color="red", linestyle="--")

    # The second plot is the learning rate
    lr = lr_scheduler(train_x)
    ax[1].plot(train_x, lr, label="learning_rate")
    ax[1].set_yscale("log")

    # for thred in [1e-3, 1e-2, 1e-5, 1e-4]:
    #     ax[1].axhline(thred, color="red", linestyle="--")

    ax[0].legend()
    ax[1].legend()

    fig.tight_layout()

    return fig, ax


def optimize(x0, lower, upper, fun):

    pg = ProjectedGradient(fun=fun, projection=projection_box)
    opt_params, state = pg.run(jnp.array(x0), hyperparams_proj=(lower, upper))

    return opt_params, state

In [3]:
exp_data, pulse_parameters, unitaries, expectations, pulse_sequence, simulator = (
    load_data(
        "../../specq-experiment/datasets/0020",
        get_multi_drag_pulse_sequence,
        rotating_transmon_hamiltonian,
    )
)

In [29]:
start_idx, end_idx = 0, 1500

# Final goal of setting up is to create a dataset and a dataloader
dataset = SpecQDataset(
    pulse_parameters=pulse_parameters[start_idx: end_idx],
    unitaries=unitaries[start_idx: end_idx],
    expectation_values=expectations[start_idx: end_idx],
)

batch_size = 150
# Randomly split dataset into training and validation
key = jax.random.PRNGKey(0)
val_indices = jax.random.choice(
    key, len(dataset), (int(0.2 * len(dataset)),), replace=False
).tolist()

training_indices = list(
    set([i for i in range(len(dataset))]) - set(val_indices)
)

train_dataset = Subset(dataset, training_indices)
val_dataset = Subset(dataset, val_indices)

g = Generator()
g.manual_seed(0)

# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, generator=g)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, generator=g)

# len(train_dataloader), len(val_dataloader)

key, model_key = jax.random.split(jax.random.PRNGKey(0))
model = BasicBlackBox(feature_size=5)
# model = ParallelBlackBox(hidden_sizes=(20, 10))
# optimiser = optax.adam(learning_rate=1e-3)

warmup_start_lr, warmup_steps = 1e-6, 1000
start_lr, end_lr, steps = 1e-2, 1e-5, 10_000
lr_scheduler = optax.join_schedules(
    [
        optax.linear_schedule(
            warmup_start_lr,
            start_lr,
            warmup_steps,
        ),
        optax.linear_schedule(
            start_lr,
            end_lr,
            steps - warmup_steps,
        ),
    ],
    [warmup_steps],
)

optimiser = optax.adam(lr_scheduler)

train_step, test_step, state = create_train_step_v2(
    key=model_key,
    model=model,
    optimiser=optimiser,
    loss_fn=loss_v2,
    input_shape=(batch_size, pulse_parameters.shape[1]),
)

In [30]:
state, history = with_validation_train(
    train_dataloader, val_dataloader, train_step, test_step, state
)

|⚠︎                                       | (!) 0/10000 [0%] in 0.5s (0.00/s)    


TypeError: argument of type 'TrainState' is not iterable

In [22]:
# A simple model with one linear layer.
key1, key2 = jax.random.split(jax.random.key(0))
x1 = jax.random.normal(key1, (5,))  # A simple JAX array.
model = nn.Dense(features=3)
variables = model.init(key2, x1)

# Flax's TrainState is a pytree dataclass and is supported in checkpointing.
# Define your class with `@flax.struct.dataclass` decorator to make it compatible.
tx = optax.sgd(learning_rate=0.001)  # An Optax SGD optimizer.
state = TrainState.create(apply_fn=model.apply, params=variables["params"], tx=tx)

loss_fn = lambda params, x: jnp.sum(state.apply_fn({"params": params}, x))
loss, grad = jax.value_and_grad(loss_fn)(state.params, x1)

# Perform a simple gradient update similar to the one during a normal training workflow.
state = state.apply_gradients(grads=grad)
# state = state.apply_gradients(grads=jax.tree_util.tree_map(jnp.ones_like, state.params))

# Some arbitrary nested pytree with a dictionary and a NumPy array.
config = {"dimensions": jnp.array([5, 3])}

# Bundle everything together.
ckpt = {"model": state, "config": config, "data": [x1]}
ckpt

{'model': TrainState(step=1, apply_fn=<bound method Module.apply of Dense(
     # attributes
     features = 3
     use_bias = True
     dtype = None
     param_dtype = float32
     precision = None
     kernel_init = init
     bias_init = zeros
     dot_general = None
     dot_general_cls = None
 )>, params={'bias': Array([-0.001, -0.001, -0.001], dtype=float32), 'kernel': Array([[ 0.2641001 , -0.6103823 , -0.23097458],
        [ 0.11147016, -0.87561315,  0.9810296 ],
        [ 0.36252323,  0.18267715, -0.6856925 ],
        [-0.849457  , -0.63919145, -0.4793319 ],
        [-0.68809915, -0.33936214, -0.05847799]], dtype=float32)}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x361395760>, update=<function chain.<locals>.update_fn at 0x361395940>), opt_state=(EmptyState(), EmptyState())),
 'config': {'dimensions': Array([5, 3], dtype=int64)},
 'data': [Array([-2.61055618,  0.03385296,  1.08633353, -1.48029861,  0.4889569 ],      dtype=float64)]}

In [17]:
state = TrainState.create(apply_fn=model.apply, params=variables["params"], tx=tx)


state.apply_fn(state.params, x1)

ApplyScopeInvalidVariablesTypeError: The first argument passed to an apply function should be a dictionary of collections. Each collection should be a dictionary with string keys. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ApplyScopeInvalidVariablesTypeError)

In [19]:
model.apply(variables, x1)

Array([0.63957241, 2.55293094, 0.58267994], dtype=float64)