In [1]:
import jax 
import jax.numpy as jnp

import qontrol as ql 
import equinox as eqx
import dynamiqs as dq
import optimistix as optx

from jax import Array
from qontrol.cost import Cost
from qontrol.model import Model
from dynamiqs.method import Method as dq_Method
from dynamiqs import Options as dq_Options

Simple optimization loop inspired by: https://docs.kidger.site/optimistix/examples/interactive/

At some point, make it better using: https://github.com/patrick-kidger/optimistix/blob/main/optimistix/_iterate.py

In [7]:
def loss(
    parameters: Array | dict,
    costs:Cost,
    model: Model,
    method: dq_Method,
    dq_options: dq_Options,
) -> [float, Array]:

    result, H = model(parameters, method, dq_options)
    cost_values, terminate = zip(*costs(result, H, parameters), strict=True)
    total_cost = jnp.log(jt_reduce(jnp.add, cost_values))
    expects = result.expects if hasattr(result, 'expects') else None

    return total_cost, (total_cost, cost_values, terminate, expects)

_loss = lambda _params, _args : loss(params, _args[0], _args[1], method, dq_options)

In [8]:
def optimize(
    parameters,
    costs,
    model,
    optimizer, 
    method,
    dq_options,
    epochs=100,
):
    
    args = (costs, model) 
    f_struct, aux_struct = jax.eval_shape(_loss, parameters, args)
    state = optimizer.init(_loss, parameters, args, {}, f_struct, aux_struct, frozenset())
    step = eqx.filter_jit(eqx.Partial(optimizer.step, fn=_loss, options={}, tags=frozenset()))

    for epoch in range(epochs):
        parameters, state, aux = step(parameters, state, args)
        (total_cost, cost_values, terminate, expects) = aux

        print(f"Epoch {epoch+1}, Total Cost: {total_cost}, Individual Costs: {cost_values}")
        if any(terminate):
            print("Termination condition met. Stopping optimization.")
            break