In [None]:
import Zygote
import Random: shuffle
import LinearAlgebra: norm

"""
    mean_squared_error(f, st, x, y, θ, λ)

Random a priori loss function. Use this function to train the closure term to reproduce the right hand side.

Arguments:

In [None]:
- `f`: Function that represents the model.
- `st`: State of the model.
- `x`: Input data.
- `y`: Target data.
- `θ`: Parameters of the model.
- `λ`: Regularization parameter.

Returns:

In [None]:
- `total_loss`: Mean squared error loss.
"""
function mean_squared_error(f, st, x, y, θ, λ)
    prediction = Array(f(x, θ, st)[1])
    total_loss = sum(abs2, prediction - y) / sum(abs2, y)
    return total_loss + λ * norm(θ, 1), nothing
end

"""
    create_randloss_derivative(GS_data, FG_target, f, st; nuse = size(GS_data, 2), λ=0)

Create a randomized loss function that compares the derivatives.
This function creates a randomized loss function derivative by selecting a subset of the data. This is done because using the entire dataset at each iteration would be too expensive.

Arguments

In [None]:
- `GS_data`: The input data.
- `FG_target`: The target data.
- `f`: The model function.
- `st`: The model state.
- `nuse`: The number of samples to use for the loss function. Defaults to the size of `GS_data`.
- `λ`: The regularization parameter. Defaults to 0.

Returns

In [None]:
A function `randloss` that computes the mean squared error loss using the selected subset of data.
"""
function create_randloss_derivative(GS_data,
        FG_target,
        f,
        st;
        nuse = size(GS_data, 2),
        λ = 0)
    d = ndims(GS_data)
    nsample = size(GS_data, d)
    function randloss(θ)
        i = Zygote.@ignore sort(shuffle(1:nsample)[1:nuse])
        x_use = Zygote.@ignore ArrayType(selectdim(GS_data, d, i))
        y_use = Zygote.@ignore ArrayType(selectdim(FG_target, d, i))
        mean_squared_error(f, st, x_use, y_use, θ, λ)
    end
end

auxiliary function to solve the NeuralODE, given parameters p

In [None]:
function predict_u_CNODE(uv0, θ, tg)
    sol = Array(training_CNODE(uv0, θ, st)[1])
    #tg_size = size(tg)
    #println("sol size ", size(sol))
    #println("tg size ", size(tg))
    #println(sol[:,:,1] == tg[:,:,1])

    # handle unstable solver
    #if any(isnan, sol)

   # if some steps succesfully run, then use them for the loss
   nok = 1
   while any(isnan, sol[:, :, 1:nok])
       nok += 1
   end
   if nok > 1
       println("Unstability after ", nok, " steps")
       tg = tg[:,:,1:nok]
       sol = sol[:,:,1:nok]
   else
       # otherwise run the auxiliary solver
       println("Using auxiliary solver ")
       sol = Array(training_CNODE_2(uv0, θ, st)[1])
       if any(isnan, sol)
           println("ERROR: NaN detected in the prediction")
           return fill(1e6 * sum(θ), tg_size)
       end
   end

In [None]:
    #end
    return sol, tg[:, :, 1:size(sol, 3)]
end

"""
    create_randloss_MulDtO(target; nunroll, nintervals=1, nsamples, λ_c, λ_l1)

This function creates a random loss function for the multishooting method with multiple shooting intervals.

Arguments

In [None]:
- `target`: The target data for the loss function.
- `nunroll`: The number of time steps to unroll.
- `noverlaps`: The number of time steps that overlaps between each consecutive intervals.
- `nintervals`: The number of shooting intervals.
- `nsamples`: The number of samples to select.
- `λ_c`: The weight for the continuity term. It sets how strongly we make the pieces match (continuity term).
- `λ_l1`: The coefficient for the L1 regularization term in the loss function.

Returns

In [None]:
- `randloss_MulDtO`: A random loss function for the multishooting method.
"""
function create_randloss_MulDtO(
        target; nunroll, nintervals = 1, noverlaps = 1, nsamples, λ_c, λ_l1)

TODO: there should be some check about the consistency of the input arguments
Get the number of time steps

In [None]:
    d = ndims(target)
    nt = size(target, d)
    function randloss_MulDtO(θ)

We compute the requested length of consecutive timesteps
Notice that each interval is long nunroll+1 because we are including the initial conditions as step_0

In [None]:
        length_required = nintervals * (nunroll + 1) - noverlaps * (nintervals - 1)

Zygote will select a random initial condition that can accomodate all the multishooting intervals

In [None]:
        istart = Zygote.@ignore rand(1:(nt - length_required))
        trajectory = Zygote.@ignore ArrayType(selectdim(target,
            d,
            istart:(istart + length_required)))

and select a certain number of samples

In [None]:
        trajectory = Zygote.@ignore trajectory[:, rand(1:size(trajectory, 2), nsamples), :]

then return the loss for each multishooting set

In [None]:
        loss_MulDtO_oneset(trajectory,
            θ,
            nunroll = nunroll,
            nintervals = nintervals,
            noverlaps = noverlaps,
            nsamples = nsamples,
            λ_c = λ_c,
            λ_l1 = λ_l1)
    end
end

"""
    loss_MulDtO_oneset(trajectory, θ; λ_c=1e1, λ_l1=1e1, nunroll, nintervals, nsamples=nsamples)

Compute the loss function for the multiple shooting method with a continuous neural ODE (CNODE) model.
Check https://docs.sciml.ai/DiffEqFlux/dev/examples/multiple_shooting/ for more details.

Arguments

In [None]:
- `trajectory`: The trajectory of the system.
- `θ`: The parameters of the CNODE model.
- `λ_c`: The weight for the continuity term. It sets how strongly we make the pieces match (continuity term). Default is `1e1`.
- `λ_l1`: The weight for the L1 regularization term. Default is `1e1`.
- `nunroll`: The number of time steps to unroll the trajectory.
- `noverlaps`: The number of time steps that overlaps between each consecutive intervals.
- `nintervals`: The number of intervals to divide the trajectory into.
- `nsamples`: The number of samples. Default is `nsamples`.

Returns

In [None]:
- `loss`: The computed loss value.
- `nothing`: Placeholder return value.
"""
function loss_MulDtO_oneset(trajectory,
        θ;
        λ_c = 1e1,
        λ_l1 = 1e1,
        nunroll,
        nintervals,
        noverlaps,
        nsamples = nsamples)

Get the timesteps where the intervals start

In [None]:
    #starting_points = [i*(nunroll+1-noverlaps) for i in 1:(nintervals-1)]
    #pushfirst!(starting_points,1)
    starting_points = [i == 0 ? 1 : i * (nunroll + 1 - noverlaps)
                       for i in 0:(nintervals - 1)]

Take all the time intervals and concatenate them in the batch dimension

In [None]:
    list_tr = cat([trajectory[:, :, i:(i + nunroll)]
                   for i in starting_points]...,
        dims = 2)

Get all the initial conditions

In [None]:
    list_starts = cat([trajectory[:, :, i] for i in starting_points]...,
        dims = 2)

Use the differentiable solver to get the predictions

In [None]:
    pred, list_tr = predict_u_CNODE(list_starts, θ, list_tr)

the loss is the sum of the differences between the real trajectory and the predicted one

In [None]:
    loss = sum(abs2, list_tr .- pred) ./ sum(abs2, list_tr)

    if λ_c > 0 && size(list_tr, 3) == nunroll + 1

//TODO check if the continuity term is correct
Compute the continuity term by comparing end of one interval with the start of the next one
(!) Remind that the trajectory is stored as:
  pred[grid, (nintervals*nsamples), nunroll+1]
and we need to compare the last noverlaps points of an interval

In [None]:
        pred_end = pred[:, :, (end - noverlaps + 1):end]

with the first noverlaps points of the next interval EXCLUDING the initial condition
(which is already part of the loss function)

In [None]:
        pred_start = pred[:, :, 2:(1 + noverlaps)]
        continuity = 0

loop over all the samples, which have been concatenated in dim 2

In [None]:
        for s in 1:nsamples

each sample contains nintervals, we need to shift the index by

In [None]:
            s_shift = (s - 1) * nintervals

loop over all the intervals for the sample (excluding the last one)

In [None]:
            for i in 1:(nintervals - 1)
                continuity += sum(abs,
                    pred_end[:, s_shift + i] .- pred_start[:, s_shift + i + 1])
            end
        end
    else
        continuity = 0
    end

    return loss + (continuity * λ_c) + λ_l1 * norm(θ), nothing
end

---

*This notebook was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*