diff --git a/src/SciMLSensitivity.jl b/src/SciMLSensitivity.jl index 7529176bd..5e6156aae 100644 --- a/src/SciMLSensitivity.jl +++ b/src/SciMLSensitivity.jl @@ -37,7 +37,8 @@ import SciMLBase: unwrapped_f, _unwrap_val import SciMLBase: AbstractOverloadingSensitivityAlgorithm, AbstractSensitivityAlgorithm, AbstractForwardSensitivityAlgorithm, AbstractAdjointSensitivityAlgorithm, AbstractSecondOrderSensitivityAlgorithm, - AbstractShadowingSensitivityAlgorithm + AbstractShadowingSensitivityAlgorithm, + AbstractTimeseriesSolution include("parameters_handling.jl") include("sensitivity_algorithms.jl") diff --git a/src/adjoint_common.jl b/src/adjoint_common.jl index 7c356e58d..6659174a4 100644 --- a/src/adjoint_common.jl +++ b/src/adjoint_common.jl @@ -651,3 +651,13 @@ function out_and_ts(_ts, duplicate_iterator_times, sol) end return out, ts end + +Zygote.@adjoint function Zygote.literal_getproperty(sol::AbstractTimeseriesSolution, + ::Val{:u}) + function solu_adjoint(Δ) + zerou = zero(sol.prob.u0) + _Δ = @. ifelse(Δ === nothing, (zerou,), Δ) + (SciMLBase.build_solution(sol.prob, sol.alg, sol.t, _Δ),) + end + sol.u, solu_adjoint +end