Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SDEs gradients issues with EnsembleProblems #765

Closed
gabrevaya opened this issue Sep 14, 2022 · 5 comments
Closed

SDEs gradients issues with EnsembleProblems #765

gabrevaya opened this issue Sep 14, 2022 · 5 comments

Comments

@gabrevaya
Copy link

gabrevaya commented Sep 14, 2022

Strangely, the following codes work well if DiffEqFlux is not loaded but they break if it is.

MWE
using Zygote
using StochasticDiffEq
using SciMLSensitivity
using DiffEqFlux # if used, the error rises

function f!(du, u, p, t)
  x, y = u
  α, β, δ, γ = p
  du[1] = α*x - β*x*y
  du[2] = -δ*y + γ*x*y
end

function g!(du,u,p,t)
  du .= 0.02f0
end

u₀ = Float32[1.0, 1.0]
p = Float32[1.25, 1.5, 1.75, 2]
tspan = (0.f0, 1.f0)
prob = SDEProblem(f!, g!, u₀, tspan, p)

function diffeq(prob, p)
  prob_func(prob,i,repeat) = remake(prob, p = p[i])
  ens_prob = EnsembleProblem(prob, prob_func = prob_func)
  sol = solve(ens_prob, SOSRI(), EnsembleSerial(); trajectories = length(p), saveat = 0:0.01:1)
  return sum(sol)
end

p0 = fill(Float32[1.25, 1.5, 1.75, 2], 3) # Having more than 2 sets of parameters for the ensemble causes an error 
gradient(p -> diffeq(prob, p), p0)
Error
ERROR: MethodError: no method matching +(::Tuple{}, ::NamedTuple{(), Tuple{}})
Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...) at operators.jl:591
  +(::Union{InitialValues.NonspecificInitialValue, InitialValues.SpecificInitialValue{typeof(+)}}, ::Any) at ~/.julia/packages/InitialValues/OWP8V/src/InitialValues.jl:154
  +(::MultivariatePolynomials.RationalPoly, ::Any) at ~/.julia/packages/MultivariatePolynomials/1bIGc/src/rational.jl:50
  ...
Stacktrace:
  [1] accum(x::Tuple{}, y::NamedTuple{(), Tuple{}})
    @ Zygote ~/.julia/packages/Zygote/PD12J/src/lib/lib.jl:17
  [2] macro expansion
    @ ~/.julia/packages/Zygote/PD12J/src/lib/lib.jl:27 [inlined]
  [3] accum(x::NamedTuple{(:data, :itr), Tuple{Tuple{}, Nothing}}, y::NamedTuple{(:data, :itr), Tuple{NamedTuple{(), Tuple{}}, Nothing}})
    @ Zygote ~/.julia/packages/Zygote/PD12J/src/lib/lib.jl:27
  [4] macro expansion
    @ ~/.julia/packages/Zygote/PD12J/src/lib/lib.jl:27 [inlined]
  [5] accum(x::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float32}, Nothing, Nothing, Nothing, NamedTuple{(:data, :itr), Tuple{Tuple{}, Nothing}}, Nothing, Nothing}}, y::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float32}, Nothing, Nothing, Nothing, NamedTuple{(:data, :itr), Tuple{NamedTuple{(), Tuple{}}, Nothing}}, Nothing, Nothing}})
    @ Zygote ~/.julia/packages/Zygote/PD12J/src/lib/lib.jl:27
  [6] macro expansion
    @ ~/.julia/packages/Zygote/PD12J/src/lib/lib.jl:27 [inlined]
  [7] accum(x::NamedTuple{(:prob, :prob_func, :output_func, :reduction, :u_init, :safetycopy), Tuple{NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float32}, Nothing, Nothing, Nothing, NamedTuple{(:data, :itr), Tuple{Tuple{}, Nothing}}, Nothing, Nothing}}, NamedTuple{(:p,), Tuple{Vector{Union{Nothing, Vector{Float32}}}}}, Nothing, Nothing, Nothing, Nothing}}, y::NamedTuple{(:prob, :prob_func, :output_func, :reduction, :u_init, :safetycopy), Tuple{NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float32}, Nothing, Nothing, Nothing, NamedTuple{(:data, :itr), Tuple{NamedTuple{(), Tuple{}}, Nothing}}, Nothing, Nothing}}, NamedTuple{(:p,), Tuple{Vector{Union{Nothing, Vector{Float32}}}}}, Nothing, Nothing, Nothing, Nothing}})
    @ Zygote ~/.julia/packages/Zygote/PD12J/src/lib/lib.jl:27
  [8] macro expansion
    @ ~/.julia/packages/Zygote/PD12J/src/lib/lib.jl:27 [inlined]
  [9] accum(x::NamedTuple{(:kwargs, :prob, :alg), Tuple{Nothing, NamedTuple{(:prob, :prob_func, :output_func, :reduction, :u_init, :safetycopy), Tuple{NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float32}, Nothing, Nothing, Nothing, NamedTuple{(:data, :itr), Tuple{Tuple{}, Nothing}}, Nothing, Nothing}}, NamedTuple{(:p,), Tuple{Vector{Union{Nothing, Vector{Float32}}}}}, Nothing, Nothing, Nothing, Nothing}}, Nothing}}, y::NamedTuple{(:kwargs, :prob, :alg), Tuple{Nothing, NamedTuple{(:prob, :prob_func, :output_func, :reduction, :u_init, :safetycopy), Tuple{NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float32}, Nothing, Nothing, Nothing, NamedTuple{(:data, :itr), Tuple{NamedTuple{(), Tuple{}}, Nothing}}, Nothing, Nothing}}, NamedTuple{(:p,), Tuple{Vector{Union{Nothing, Vector{Float32}}}}}, Nothing, Nothing, Nothing, Nothing}}, Nothing}})
    @ Zygote ~/.julia/packages/Zygote/PD12J/src/lib/lib.jl:27
 [10] _mapreduce(f::typeof(identity), op::typeof(Zygote.accum), #unused#::IndexLinear, A::Vector{NamedTuple{(:kwargs, :prob, :alg), Tuple{Nothing, NamedTuple{(:prob, :prob_func, :output_func, :reduction, :u_init, :safetycopy), Tuple{NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float32}, Nothing, Nothing, Nothing, NamedTuple{(:data, :itr), Tuple{NamedTuple{(), Tuple{}}, Nothing}}, Nothing, Nothing}}, NamedTuple{(:p,), Tuple{Vector{Union{Nothing, Vector{Float32}}}}}, Nothing, Nothing, Nothing, Nothing}}, Nothing}}})
    @ Base ./reduce.jl:438
 [11] _mapreduce_dim(f::Function, op::Function, #unused#::Base._InitialValue, A::Vector{NamedTuple{(:kwargs, :prob, :alg), Tuple{Nothing, NamedTuple{(:prob, :prob_func, :output_func, :reduction, :u_init, :safetycopy), Tuple{NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float32}, Nothing, Nothing, Nothing, NamedTuple{(:data, :itr), Tuple{NamedTuple{(), Tuple{}}, Nothing}}, Nothing, Nothing}}, NamedTuple{(:p,), Tuple{Vector{Union{Nothing, Vector{Float32}}}}}, Nothing, Nothing, Nothing, Nothing}}, Nothing}}}, #unused#::Colon)
    @ Base ./reducedim.jl:365
 [12] #mapreduce#765
    @ ./reducedim.jl:357 [inlined]
 [13] mapreduce
    @ ./reducedim.jl:357 [inlined]
 [14] #reduce#767
    @ ./reducedim.jl:406 [inlined]
 [15] reduce(op::Function, A::Vector{NamedTuple{(:kwargs, :prob, :alg), Tuple{Nothing, NamedTuple{(:prob, :prob_func, :output_func, :reduction, :u_init, :safetycopy), Tuple{NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float32}, Nothing, Nothing, Nothing, NamedTuple{(:data, :itr), Tuple{NamedTuple{(), Tuple{}}, Nothing}}, Nothing, Nothing}}, NamedTuple{(:p,), Tuple{Vector{Union{Nothing, Vector{Float32}}}}}, Nothing, Nothing, Nothing, Nothing}}, Nothing}}})
    @ Base ./reducedim.jl:406
 [16] (::SciMLSensitivity.var"#∇responsible_map_internal#416"{Vector{typeof((λ))}})(Δ::EnsembleSolution{FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, 2, Vector{Vector{FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}}}})
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/LEq5V/src/zygote.jl:34
 [17] #741#back
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
 [18] Pullback
    @ ~/.julia/packages/SciMLBase/jEZRD/src/ensemble/basic_ensemble_solve.jl:145 [inlined]
 [19] (::typeof((#solve_batch#475)))(Δ::EnsembleSolution{FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, 2, Vector{Vector{FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}}}})
    @ Zygote ~/.julia/packages/Zygote/PD12J/src/compiler/interface2.jl:0
 [20] Pullback
    @ ~/.julia/packages/SciMLBase/jEZRD/src/ensemble/basic_ensemble_solve.jl:144 [inlined]
 [21] (::typeof((solve_batch##kw)))(Δ::EnsembleSolution{FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, 2, Vector{Vector{FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}}}})
    @ Zygote ~/.julia/packages/Zygote/PD12J/src/compiler/interface2.jl:0
 [22] macro expansion
    @ ./timing.jl:382 [inlined]
 [23] Pullback
    @ ~/.julia/packages/SciMLBase/jEZRD/src/ensemble/basic_ensemble_solve.jl:56 [inlined]
 [24] (::typeof((#__solve#470)))(Δ::FillArrays.Fill{Float32, 3, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/PD12J/src/compiler/interface2.jl:0
 [25] Pullback
    @ ~/.julia/packages/SciMLBase/jEZRD/src/ensemble/basic_ensemble_solve.jl:45 [inlined]
 [26] (::typeof((__solve##kw)))(Δ::FillArrays.Fill{Float32, 3, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/PD12J/src/compiler/interface2.jl:0
 [27] #208
    @ ~/.julia/packages/Zygote/PD12J/src/lib/lib.jl:206 [inlined]
 [28] (::Zygote.var"#1914#back#210"{Zygote.var"#208#209"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{Nothing, Nothing}}, typeof((__solve##kw))}})(Δ::FillArrays.Fill{Float32, 3, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
 [29] Pullback
    @ ~/.julia/packages/DiffEqBase/BHoDE/src/solve.jl:841 [inlined]
 [30] (::typeof((#solve#33)))(Δ::FillArrays.Fill{Float32, 3, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/PD12J/src/compiler/interface2.jl:0
 [31] (::Zygote.var"#208#209"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{Nothing, Nothing}}, typeof((#solve#33))})(Δ::FillArrays.Fill{Float32, 3, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/PD12J/src/lib/lib.jl:206
 [32] #1914#back
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
 [33] Pullback
    @ ~/.julia/packages/DiffEqBase/BHoDE/src/solve.jl:837 [inlined]
 [34] (::typeof((solve##kw)))(Δ::FillArrays.Fill{Float32, 3, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/PD12J/src/compiler/interface2.jl:0
 [35] Pullback
    @ ~/Documents/doctorado/issues/goku_net_on_SDEs/MWE.jl:25 [inlined]
 [36] (::typeof((diffeq)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/PD12J/src/compiler/interface2.jl:0
 [37] Pullback
    @ ~/Documents/doctorado/issues/goku_net_on_SDEs/MWE.jl:30 [inlined]
 [38] (::typeof((#12)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/PD12J/src/compiler/interface2.jl:0
 [39] (::Zygote.var"#60#61"{typeof((#12))})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/PD12J/src/compiler/interface.jl:45
 [40] gradient(f::Function, args::Vector{Vector{Float32}})
    @ Zygote ~/.julia/packages/Zygote/PD12J/src/compiler/interface.jl:97
 [41] top-level scope
    @ ~/Documents/doctorado/issues/goku_net_on_SDEs/MWE.jl:30

The error doesn't appear if only 2 sets of parameters are used for the ensemble. I also found another issue when I remake the problem for changing the time span, even if I don't take the gradient respect to it. The following MWE is the same as the previous one except that the parameters vector p0 is of length 2 to avoid the previous error, a time range is added as argument and the problem is remade to change the time span. As in the previous case, it works without any errors if DiffEqFlux is not loaded.

MWE 2
using Zygote
using StochasticDiffEq
using SciMLSensitivity
using DiffEqFlux # if used, the error rises

function f!(du, u, p, t)
    x, y = u
    α, β, δ, γ = p
    du[1] = α*x - β*x*y
    du[2] = -δ*y + γ*x*y
end

function g!(du,u,p,t)
    du .= 0.02f0
end

u₀ = Float32[1.0, 1.0]
p = Float32[1.25, 1.5, 1.75, 2]
tspan = (0.f0, 1.f0)
prob = SDEProblem(f!, g!, u₀, tspan, p)

function diffeq(prob, p, t)
    prob = remake(prob; tspan = (t[1], t[end])) # This causes an error
    prob_func(prob,i,repeat) = remake(prob, p = p[i])
    ens_prob = EnsembleProblem(prob, prob_func = prob_func)
    sol = solve(ens_prob, SOSRI(), EnsembleSerial(); trajectories = length(p), saveat = t)
    return sum(sol)
end

p0 = fill(Float32[1.25, 1.5, 1.75, 2], 2) # Having more than 2 sets of parameters for the ensemble causes another error 
t = 0:0.01:1
gradient(p -> diffeq(prob, p, t), p0)
Error 2
ERROR: Need an adjoint for constructor NamedTuple{(), Tuple{}}. Gradient is of type Tuple{}
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] (::Zygote.Jnew{NamedTuple{(), Tuple{}}, Nothing, true})(Δ::Tuple{})
    @ Zygote ~/.julia/packages/Zygote/PD12J/src/lib/lib.jl:327
  [3] (::Zygote.var"#1958#back#226"{Zygote.Jnew{NamedTuple{(), Tuple{}}, Nothing, true}})(Δ::Tuple{})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [4] Pullback
    @ ./boot.jl:607 [inlined]
  [5] (::typeof((NamedTuple{(), Tuple{}})))(Δ::Tuple{})
    @ Zygote ~/.julia/packages/Zygote/PD12J/src/compiler/interface2.jl:0
  [6] macro expansion
    @ ./namedtuple.jl:342 [inlined]
  [7] Pullback
    @ ./namedtuple.jl:341 [inlined]
  [8] (::typeof((structdiff)))(Δ::Tuple{})
    @ Zygote ~/.julia/packages/Zygote/PD12J/src/compiler/interface2.jl:0
  [9] Pullback
    @ ~/.julia/packages/SciMLBase/jEZRD/src/problems/sde_problems.jl:97 [inlined]
 [10] (::typeof((Type##kw)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float32}, Nothing, Nothing, Nothing, NamedTuple{(:data, :itr), Tuple{Tuple{}, Nothing}}, Nothing, Nothing}})
    @ Zygote ~/.julia/packages/Zygote/PD12J/src/compiler/interface2.jl:0
 [11] Pullback
    @ ~/.julia/packages/SciMLBase/jEZRD/src/remake.jl:32 [inlined]
 [12] (::typeof((#remake#507)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float32}, Nothing, Nothing, Nothing, NamedTuple{(:data, :itr), Tuple{Tuple{}, Nothing}}, Nothing, Nothing}})
    @ Zygote ~/.julia/packages/Zygote/PD12J/src/compiler/interface2.jl:0
 [13] Pullback
    @ ~/.julia/packages/SciMLBase/jEZRD/src/remake.jl:28 [inlined]
 [14] (::typeof((remake##kw)))(Δ::NamedTuple{(:f, :g, :u0, :tspan, :p, :noise, :kwargs, :noise_rate_prototype, :seed), Tuple{Nothing, Nothing, Vector{Float32}, Nothing, Nothing, Nothing, NamedTuple{(:data, :itr), Tuple{Tuple{}, Nothing}}, Nothing, Nothing}})
    @ Zygote ~/.julia/packages/Zygote/PD12J/src/compiler/interface2.jl:0
 [15] Pullback
    @ ~/Documents/doctorado/issues/goku_net_on_SDEs/MWE2.jl:23 [inlined]
 [16] (::typeof((diffeq)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/PD12J/src/compiler/interface2.jl:0
 [17] Pullback
    @ ~/Documents/doctorado/issues/goku_net_on_SDEs/MWE2.jl:32 [inlined]
 [18] (::typeof((#9)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/PD12J/src/compiler/interface2.jl:0
 [19] (::Zygote.var"#60#61"{typeof((#9))})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/PD12J/src/compiler/interface.jl:45
 [20] gradient(f::Function, args::Vector{Vector{Float32}})
    @ Zygote ~/.julia/packages/Zygote/PD12J/src/compiler/interface.jl:97
 [21] top-level scope
    @ ~/Documents/doctorado/issues/goku_net_on_SDEs/MWE2.jl:32
Project and version info
(goku_net_on_SDEs) pkg> st
Status `~/Documents/doctorado/issues/goku_net_on_SDEs/Project.toml`
  [aae7a2af] DiffEqFlux v1.52.0
  [1ed8b502] SciMLSensitivity v7.9.0
  [789caeaf] StochasticDiffEq v6.53.0
  [e88e6eb3] Zygote v0.6.47
julia> versioninfo()
Julia Version 1.8.1
Commit afb6c60d69a (2022-09-06 15:09 UTC)
Platform Info:
  OS: macOS (arm64-apple-darwin21.5.0)
  CPU: 8 × Apple M1 Pro
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-13.0.1 (ORCJIT, apple-m1)
  Threads: 1 on 6 virtual cores
Environment:
  JULIA_EDITOR = code
  JULIA_NUM_THREADS = 
@gabrevaya
Copy link
Author

I the second case (remaking the problem with new tspan), if we don't use an EnsembleProblem, we get another error, which might be the cause of Error 2.

MWE 3
function diffeq(prob, p, t)
  prob = remake(prob; tspan = (t[1], t[end]))
  sol = solve(prob, SOSRI(); saveat = t)
  return sum(sol)
end

p0 = Float32[1.25, 1.5, 1.75, 2]
gradient(p -> diffeq(prob, p, 0:0.01:1), p0)
Error 3
ERROR: Gradient ChainRulesCore.ZeroTangent() should be a tuple
Stacktrace:
[1] error(s::String)
  @ Base ./error.jl:35
[2] gradtuple1(x::ChainRulesCore.ZeroTangent)
  @ ZygoteRules ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:24
[3] (::Zygote.var"#1794#back#155"{typeof(identity)})(Δ::ChainRulesCore.ZeroTangent)
  @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[4] Pullback
  @ ~/Documents/doctorado/issues/goku_net_on_SDEs/MWE2.jl:41 [inlined]
[5] (::typeof((diffeq)))(Δ::Float32)
  @ Zygote ~/.julia/packages/Zygote/PD12J/src/compiler/interface2.jl:0
[6] Pullback
  @ ~/Documents/doctorado/issues/goku_net_on_SDEs/MWE2.jl:47 [inlined]
[7] (::typeof((#5)))(Δ::Float32)
  @ Zygote ~/.julia/packages/Zygote/PD12J/src/compiler/interface2.jl:0
[8] (::Zygote.var"#60#61"{typeof((#5))})(Δ::Float32)
  @ Zygote ~/.julia/packages/Zygote/PD12J/src/compiler/interface.jl:45
[9] gradient(f::Function, args::Vector{Float32})
  @ Zygote ~/.julia/packages/Zygote/PD12J/src/compiler/interface.jl:97
[10] top-level scope
  @ ~/Documents/doctorado/issues/goku_net_on_SDEs/MWE2.jl:47

Although they have similarities, I realize that my first comment is tackling two different issues. Let me know if you prefer me to open a separate Github Issue or keep both here.

@avik-pal
Copy link
Member

I have seen this error before but couldn't get it to fix properly. I worked around it by https://github.com/SciML/DeepEquilibriumNetworks.jl/blob/ba6d66fcbdbd8bb2d39a5a27a3e4fced127aa584/experiments/src/DEQExperiments.jl#L16. But this is in no way the correct solution. I looks like Zygote generating an incorrect backward pass.

@ChrisRackauckas
Copy link
Member

Pullback Zygote to v0.6.43. There was a change to the accum derivative that broke a lot of code.

FluxML/Zygote.jl#1304

I think this might be the same issue

@gabrevaya
Copy link
Author

gabrevaya commented Sep 19, 2022

Hi! Thanks for your responses! Unfortunately, neither the workaround suggested by @avik-pal nor pining Zygote to v0.6.43 worked for any of the MWEs. @ChrisRackauckas, note that in contrast to the issue that you mentioned, if I don't load DiffEqFlux the codes work without any error, which is very strange (at least to me).

@ChrisRackauckas
Copy link
Member

The original MWE here is solved. The other two are because the derivative w.r.t. t is not defined for SDEs, but it's also not analytically defined so that's not solvable. No as to why Zygote crashes, it doesn't have proper activity analysis to know to know differentiate that term until after it has pulled back more, that is an issue with Zygote upstream though. Other AD systems have solved that. So I'll close this but feel free to ask any more questions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants