# General framework for the optimization

Here, we explain the general framework step by step and provide the necessary code ready to copy and adapt.

### Libraries

For the code, we need to import [jax](https://jax.de/) and [icomo](https://icomo.readthedocs.io/en/stable/).

In [None]:

import jax
import jaxopt
import icomo 

### Model equations

The dynamics of the system (e.g. disease spread) are described by differential equations. These are implemented with icomo as in the following example:

In [None]:
def SIRS(t, y, args):
    """
    Computes the derivatives for the SIRS (Susceptible-Infected-Recovered-Susceptible) model.

    Parameters:
        t (float): Current time.
        y (dict): Dictionary with current values of compartments "S", "I", and "R".
        args (dict): Dictionary containing the model parameters:
            - beta_t (function): Time-dependent transmission rate function β(t).
            - gamma (float): Recovery rate γ.
            - nu (float): Loss of immunity rate ν.

    Returns:
        dict: Derivatives for "S", "I", and "R" at time t.
    """
    β = args["beta_t"]
    γ = args["gamma"]
    nu = args["nu"]
    dS = -β(t) * y["I"] * y["S"] + nu * y["R"]
    dI = β(t) * y["I"] * y["S"]- γ * y["I"]
    dR = γ * y["I"] - nu * y["R"]
    dy = {"S": dS, "I": dI, "R": dR}
    return dy

### Simulation

For the optimization we need a function that carries out the simulation, i.e. integrates the differential equations.

- The optimization algorithm has no boundaries and gives values between -inf and inf, which is why we need to implement boundaries here, for example by using a sigmoid function to transform the values to lie between 0 and 1.
- icomo.interpolate is used to transform arrays into callable functions (here needed for the integration of the differential equations).
- icomo.diffeqsolve is used to integrate the differential equations

In [None]:
def simulation(x):
    """
    Simulates the SIRS model dynamics under time-dependent mitigation measures.

    Parameters:
        x (array-like): Control input that determines the level of mitigation over time. 
                        Transformed via a sigmoid to lie in the interval [0, 1].

    Returns:
        tuple:
            - dict: Time series of the compartments "S", "I", and "R" as computed by the ODE solver.
            - ndarray: Effective fraction of transmission reduction interpolated to the output time grid.
    """
    frac_reduc = jax.nn.sigmoid(-x)*0.99999999999999999 # transform values to [0,1]
    beta_t = frac_reduc*beta_0_t # time-dependent beta
    beta_t_func = icomo.interpolate_func(ts_in=t_beta,values=beta_t) # beta is now callable function
    y0 = {'S': 1-I0, 'I': I0, 'R':0} # initial conditions
    args = {'gamma': gamma, 'nu': nu, 'beta_t': beta_t_func} # arguments for ODE
    
    output = icomo.diffeqsolve(args=args, ODE = SIRS.SIRS, y0 = y0, ts_out = t_out) # solve ODE

    eff_frac_reduc = icomo.interpolate_func(t_beta, frac_reduc, 'cubic')(t_out) # interpolate frac_reduc to t_out
    return output.ys, eff_frac_reduc # return output and effective mitigation



### Cost function

This is the function that we want to minimize. For each value x (which can be a single value or array), the cost is caluclated. For this, the simulation is carried out and the cost of the results is calculated.

jax.value_and_grad is used to calculate the derivative of the cost function using automatic differentiation. The cost function therefore does not need to be differentiated analytically/ by hand.

In [None]:
@jax.jit
def min_func(x):
    """
    Objective function for optimization.

    Runs a simulation based on the mitigation input `x`, calculates the total cost as the sum of infection and mitigation costs.

    Parameters:
        x (float or array-like): Mitigation input(s). 

    Returns:
        float: Total cost computed as the sum of infection and mitigation costs over time.
    """
    output, frac_reduc = simulation(x) # carry out simulation
    m = 1-frac_reduc # mitigation = 1-frac_reduc
    cost = jnp.sum(CI_func(a,output['I']) + CM_func(m))*dt # calculate cost
    return cost


value_and_grad_min_func = jax.jit(jax.value_and_grad(min_func)) # automatic differentiation for cost function


	

### Carry out optimization and get results

The optimization is carried out with jaxopt.ScipyMinimize, which is given the cost function and its derivative defined above. 
For the algorithm, we chose the L-BFGS-B (limited Broyden–Fletcher–Goldfarb–Shanno).
To run the solver a starting value has to be given.

The results for the optimized control are obtained by running the simulation with the optimization output.

In [None]:

# use L-BFGS-B algorithm for optimization
solver = jaxopt.ScipyMinimize(fun=value_and_grad_min_func, value_and_grad=True, method="L-BFGS-B", jit=False, maxiter=500)
res = solver.run(x_0) # carry out optimization
    
output, frac_reduc = simulation(res.params) # carry out simulation with optimal mitigation (which is between -inf and inf)