Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward sensitivity and adjoints for DDEs #281

Open
devmotion opened this issue Jun 11, 2020 · 25 comments
Open

Forward sensitivity and adjoints for DDEs #281

devmotion opened this issue Jun 11, 2020 · 25 comments

Comments

@devmotion
Copy link
Member

For simple cases we can just differentiate through the DDE solver but if the DDE system contains parameter-dependent C1-discontinuities the forward sensitivities have jump discontinuities which, e.g., ForwardDiff can't deal with (see SciML/DelayDiffEq.jl#183 and https://epubs.siam.org/doi/abs/10.1137/100814949).

Hence it would be great if we could add forward sensitivity equations and adjoint methods for DDEs. I would like to help with getting them in here, I've just never worked on DiffEqSensitivity and hence might need some guidance and/or time to get started. I would assume that (hopefully) one can exploit the existing implementations for ODEs.

@devmotion
Copy link
Member Author

@ChrisRackauckas
Copy link
Member

@andrschl
Copy link

andrschl commented Nov 18, 2020

I would have a first implementation of an interpolating adjoint method for DDEs with constant delays. The code can be found here:
https://github.com/andrschl/dde_adjoint_method

However, I am sure the structure and efficiency could be improved a lot. Also, I am new to julia and DiffEqSensitivity and I am not sure how this would fit best into the existing sensitivity framework.

As described on the github repo, there is also quite some difference between my sensitivity method and the AD method as soon as there are any delays. So I am not sure, whether this is still a bug in my code or because of the issues mentioned above.

@andrschl
Copy link

andrschl commented Nov 19, 2020

@ChrisRackauckas, @devmotion Apparently, performance and memory usage of my approach become terrible for increasing number of sample times. I assume this is because of the adjoint state history which I constructed as a dictionary of dense solutions of previous steps. And as I saw the solution itself contains the initial history. So I am essentially storing all solutions. Any suggestions how to do this on a lower level and more efficiently?

The second point which I am not sure about are the discontinuities. Currently I account for discontinuities in the adjoint state through passing them to the solver by constant_lags. I assume it would be better to do this directly, but I am struggling a bit to find the related code in the library..

Would it be an option to use DiscreteCallbacks for C⁰ and C¹ discontinuities and to use a single solve(...) command? This would require, that the method of steps can handle jumps in the history function.

@devmotion
Copy link
Member Author

Yes, as mentioned in the initial discussion in DelayDiffEq, IMO we should just use callbacks. The solver handles discontinuities fine, a common example in the tests is actually a discontinuity at the initial time point which is then propagated to discontinuities in higher order derivatives (or derivatives of the same order in the case of neutral DDEs) which have to be considered by the solver as well.

@andrschl
Copy link

andrschl commented Nov 19, 2020

Ok, great. This should simplify things a lot and then we can also add discontinuities coming from non-smoothness at the initial point.

@andrschl
Copy link

andrschl commented Nov 19, 2020

Adding the callbacks improved things a lot :D The code looks much simpler and runtime and memory usage are better than for ReverseDiffAdjoint() in my example.

I pushed the code onto my repo. Currently I am ignoring the discontinuities due to the initial condition, but it shouldn't be hard to add them.

Also I feel like my adjoint function has already a similar form as adjoint_sensitivities(sol,alg,dg,ts;sensealg=InterpolatingAdjoint()), but I would need some guidance to properly include it into the DiffEqSensitivity framework :)

@ChrisRackauckas
Copy link
Member

So what needs to be done to get this into the DiffEqSensitivity framework? The dispatches currently live here:

https://github.com/SciML/DiffEqSensitivity.jl/blob/master/src/local_sensitivity/interpolating_adjoint.jl

@andrschl
Copy link

andrschl commented Dec 3, 2020

Thanks for the reference to the code and sorry for my late reply. I am currently adding some things such as a continuous loss with delays. I will give it a try to include it as soon as I am done and find some time. Would you prefer a similar struct and function as for ODEInterpolatingAdjointSensitivityFunction just for DDEs (I did not look at SDDEs)?

@ChrisRackauckas
Copy link
Member

Would you prefer a similar struct and function as for ODEInterpolatingAdjointSensitivityFunction just for DDEs (I did not look at SDDEs)?

I think that might be what's needed. Don't worry about SDDEs: I think from the SDE adjoint work it's clear that "SDDEs should just work", but it'll probably take like 5 years before someone has a clear proof that it actually does work, so you might as well just make sure the current SDE handling code works and let us address it in the future.

@andrschl
Copy link

andrschl commented Dec 7, 2020

Ok, yeah makes sense. However, I am still not sure whether the DDE adjoint is actually correct. In my tests, there is an average mismatch compared to the ReverseDiffAdjoint of 1-2%. In some gradient components it is sometimes much larger. And my test case does not include any parameter dependent discontinuities. Without delays the mismatch is much smaller (approx. 0.1%).
@devmotion @ChrisRackauckas do you know of any existing implementations to compare it to? I only found this C++ library for forward sensitivities.

Also for the discontinuities I am currently using the following two callbacks:

    # callback for adjoint jumps
    a_jumps = ξ[2:end-1]
    a_jump_idx = 1
    affect1!(integrator) = begin
        a_jump_idx += 1
        integrator.u[1:data_dim] += dl(sol(T - integrator.t), data[:, N-a_jump_idx])
    end
    cb1 = PresetTimeCallback(a_jumps,affect1!, save_positions=(true,true))

    # callback for adjoint C¹,C² discontinuities
    higher_order_disc = get_higher_order_disc(ξ, lags, order=order)
    affect2!(integrator) = nothing
    cb2 = PresetTimeCallback(higher_order_disc,affect2!, save_positions=(false,true))

    cb = CallbackSet(cb1,cb2)

Where cb1 is for the adjoint jumps in the presence of a discrete loss and cb2 is supposed to handle the other discontinuities. Is it enough to do affect2!(integrator) = nothing for the higher order discontinuities?

@devmotion
Copy link
Member Author

I'm sorry, I can't focus on this right now and I would have to look up the adjoint equations first.

However, in my opinion, the major question when implementing the forward sensitivities equations is what would be a good interface and API for all the derivatives that are required here. As soon as one has decided on an API for all the delayed Jacobians, derivatives, Jacobians, and parameter derivatives of the delayed arguments that appear in the extended system as shown in eq 5 in the SIAM paper linked above (they are not part of the standard DiffEqFunctions API), the implementation of the extended DDE system should hopefully be quite straightforward.

@andrschl
Copy link

andrschl commented Dec 21, 2020

Thanks for your answer and sorry for my late one. Well, for the gradients I just defined the functions,

# we need the functions fi: xi->f(x0,...,xi,...,xk)
    ndelays=length(lags)
    fs = []
    for i in 0:ndelays
        g = function (x, xt, p)
            y = vcat(xt[1:data_dim*i],x,xt[(i+1)*data_dim+1:end])
            return f(y, p)
        end
        push!(fs, g)
    end

and then calculated the VJP in the adjoint DDE function. The same could also be used for forward sensitivities.

@devmotion
Copy link
Member Author

My point was that it is not possible to specify the Jacobians manually 🤷 Similar to the ODE case, there should be an interface that allows users to specify all these required derivatives and Jacobians, and would use AD in case they are not provided. We don't want to hardcode AD, and in particular not of a particular AD system (your code always uses Zygote it seems).

@andrschl
Copy link

Ok, I see. Thanks. Yes, as I am working with neural networks I did not really need that, but it shouldn't be a problem to add it. I am currently busy, but I'll try to make it ready for a PR once I submitted my master thesis.

Another issue I have been thinking about is concerning efficiency: A general form of the DDE adjoint method looks like this:
image
Where:
image
So we only have a DDE in a(t) and just a standard integration for the possibly high dimensional b(t). So we don't need an interpolation for b(t) nor ODE-like error control. The error control issue may be handled by changing the internal norm as in suggested in https://arxiv.org/pdf/2009.09457.pdf. But what about the interpolation? Can we tell the solver to only store intermediate values for a(t)?

@ChrisRackauckas
Copy link
Member

The error control issue may be handled by changing the internal norm as in suggested in https://arxiv.org/pdf/2009.09457.pdf.

It's not a good idea to do that, even on ODEs. It's numerically unstable to not have the error control on the integral term. That doesn't mean there aren't random equations where it will work, but there are counter examples where it completely fails (purely time-dependent equations) and there are examples where you get poor optimization behavior due to the gradient inaccuracies (pretty much any non-neural network example, this is actually one of the well-described issues in continuous sensitivity analysis so there's about a hundred papers to this effect). Given the known issues with this, it should not ever be made a default, and instead a tutorial should mention how it could be done and caution the user as to the possible effects one might see (just like BacksolveAdjoint).

But what about the interpolation? Can we tell the solver to only store intermediate values for a(t)?

save_idxs. We can make a stronger save_idxs for DDEs that makes it assume no history on the non-saved parts if needed.

@devmotion
Copy link
Member Author

it shouldn't be a problem to add it

It is trivial if there exists a reasonable design - it is something that has to be addressed in DiffEqBase in the DDEFunction and could probably also exploited in DelayDiffEq, and hence it is not primarily a DiffEqSensitivity issue.

Can we tell the solver to only store intermediate values for a(t)?

Are you looking for save_idxs and saveat keyword arguments?

@devmotion
Copy link
Member Author

We can make a stronger save_idxs for DDEs that makes it assume no history on the non-saved parts if needed.

I think I already made some optimizations here. At least it was discussed in some PR and might be part of some issue.

@andrschl
Copy link

It's not a good idea to do that, even on ODEs. It's numerically unstable to not have the error control on the integral term. That doesn't mean there aren't random equations where it will work, but there are counter examples where it completely fails (purely time-dependent equations) and there are examples where you get poor optimization behavior due to the gradient inaccuracies (pretty much any non-neural network example, this is actually one of the well-described issues in continuous sensitivity analysis so there's about a hundred papers to this effect). Given the known issues with this, it should not ever be made a default, and instead a tutorial should mention how it could be done and caution the user as to the possible effects one might see (just like BacksolveAdjoint).

Thanks a lot for pointing this out! Then I'll better forget about this^^

save_idxs. We can make a stronger save_idxs for DDEs that makes it assume no history on the non-saved parts if needed.

I used h(p,t,idxs=1:data_dim) such that the interpolation only solves for the values of a(t). When we use save_idxs=1:data_dim are we not completely losing dJ/dp = b(T) at the end? However, I think the main problem are the intermediate values of b(t) which are stored due to the 2 callbacks mentioned above. Ideally I would like to only ever save the values of a(t) expect for dJ/dp = b(T) at the end. In a small test example of mine the solution object has a size of 1GB due to this. So I guess it would be great if we can get rid of that b(t) values which are never used :)

It is trivial if there exists a reasonable design - it is something that has to be addressed in DiffEqBase in the DDEFunction and could probably also exploited in DelayDiffEq, and hence it is not primarily a DiffEqSensitivity issue.

OK I see. So this means we should be able to add the jacobians directly to the DDEFunction, right?

@devmotion
Copy link
Member Author

devmotion commented Dec 27, 2020

So this means we should be able to add the jacobians directly to the DDEFunction, right?

Yes, exactly, the Jacobians and all other terms required in the sensitivity equations.

Edit: See also SciML/DelayDiffEq.jl#138.

@devmotion
Copy link
Member Author

Indeed there is already an issue: SciML/OrdinaryDiffEq.jl#335

@andrschl
Copy link

OK, great. Thanks for the link to the issue. But in this case it might be better to go with the QuadratureAdjoint until this is solved. i.e. calculating a dense adjoint and then integrating it using QuadGK.jl

@devmotion
Copy link
Member Author

I just had a closer look at your code (for the first time, I have to admit), and unfortunately I think the implementation only works in your special case (at most). It seems it mainly exploits the fact that you define your DDE explicitly as dde(u, p, h, t) = f(vcat(u, h(t - lags[1]), h(t - lags[2]), ..., h(t - lags[end])), p) and then only work with f in the adjoint implementation.

In general, f might depend on t and dde can depend on delayed derivatives of the solution y of arbitrary order. However, the more general problem is that f is not accessible - users just provide dde, and at most a list of constant and state-dependent lags. Therefore as far as I can see currently we can't compute the required delayed Jacobians if they are not provided by the user. In the case of only constant delays we might be able to inject a modified history function where h is fixed apart from v = h(t - lags[i]) but this won't work in general: as soon as some lags are non-constant, some of them might become (at least numerically) equal at some time point, and hence we would not be able to fix all h(t - lags[j]) for all j that are not equal to i. Also this whole approach fails as soon as there are some higher order derivatives. And we don't even know for which derivatives and delays we should compute the Jacobians!

I am not sure yet what the best solution is to this problem. Somehow to me it feels like it would be helpful for the sensitivity analysis to enforce more structure in, e.g., the arguments of dde (maybe explicitly provide the delayed states?), similar to your f, but on the other hand it is a huge feature to be able to evaluate h at any time point (and also in-place!) and even have access to its derivatives, so I don't think this should be changed.

If users specify the delayed Jacobians manually, it could be sufficient to encode more information in how delays are specified, e.g., by explicitly specifying the order of derivatives. Then the delayed Jacobians could be specified as a tuple (or array) of functions in some fixed order (e.g., first element the regular Jacobian wrt y(t), and then the delayed states of order zero with constant lags, then the delayed states of order zero with state-dependent lags, then the delayed states of order 1 with constant lags, etc.), similar to the jac field of ODEFunction. And similarly, the derivatives of the delayed arguments with respect to y(t) and the parameters could be specified as tuple (or array) of functions in some specific order.

@devmotion
Copy link
Member Author

Maybe the cleanest approach would be to add some special types of DDEFunctions that encode additional information (e.g., with only constant delays of order zero) such as certain mathematical structures of dde.

In general, it might also be useful to perform symbolic calculations with ModelingToolkit (or to obtain more information about the mathematical structure), I guess.

@andrschl
Copy link

andrschl commented Dec 29, 2020

Thanks a lot for looking into it. The current version of the code on github is a bit outdated and I added some minor changes like quadrature adjoint and a continuous loss in the meantime. But yeah I see the problems. I explicitly use that definition of f in order to calculate the jacobians and the adjoint code that I implemented so far will only work for constant delays and no dependency on derivatives of u. And actually for this case we could also use ReverseDiffAdjoint which was faster in my experiments^^. Also I have been assuming that the history doesn't depend on p, but that should be trivial to add. I think I came across the formulas for state dependent delays and neutral type equations somewhere but would have to look it up.

If users specify the delayed Jacobians manually, it could be sufficient to encode more information in how delays are specified, e.g., by explicitly specifying the order of derivatives. Then the delayed Jacobians could be specified as a tuple (or array) of functions in some fixed order (e.g., first element the regular Jacobian wrt y(t), and then the delayed states of order zero with constant lags, then the delayed states of order zero with state-dependent lags, then the delayed states of order 1 with constant lags, etc.), similar to the jac field of ODEFunction. And similarly, the derivatives of the delayed arguments with respect to y(t) and the parameters could be specified as tuple (or array) of functions in some specific order.

Yes that makes sense. But then to begin it might be the best to enforce a manual definition of the jacobians? Although this is a bit annoying in the case I used it for.

Maybe the cleanest approach would be to add some special types of DDEFunctions that encode additional information (e.g., with only constant delays of order zero) such as certain mathematical structures of dde.

That would be great. But I am not sure how we would have to do this such that it is still compatible with the DDE solver. Maybe one could say that the solver if provided with a DDEFunction of this special type, then internally defines dde(u, p, h, t) = dde_special_type(u, h(t - lags[1]), h(t - lags[2]), ..., h(t - lags[end]), p)?

In general, it might also be useful to perform symbolic calculations with ModelingToolkit (or to obtain more information about the mathematical structure), I guess.

What do you exactly mean by this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants