Skip to content

Commit

Permalink
Fix scalar indexing for diagonal-noise SDEs
Browse files Browse the repository at this point in the history
  • Loading branch information
frankschae committed Aug 3, 2023
1 parent 0897da4 commit 1aee06b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 27 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,3 @@ For information on using the package,
[see the stable documentation](https://docs.sciml.ai/SciMLSensitivity/stable/). Use the
[in-development documentation](https://docs.sciml.ai/SciMLSensitivity/dev/) for the version of
the documentation, which contains the unreleased features.

40 changes: 14 additions & 26 deletions src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -878,36 +878,24 @@ function _jacNoise!(λ, y, p, t, S::TS, isnoise::ZygoteVJP, dgrad, dλ,

if StochasticDiffEq.is_diagonal_noise(prob)
if inplace_sensitivity(S)
for i in 1:m
_dy, back = Zygote.pullback(y, p) do u, p
out_ = Zygote.Buffer(similar(u))
f(out_, u, p, t)
copy(out_[i])
end
tmp1, tmp2 = back(λ[i]) #issue: tmp2 = zeros(p)
if dgrad !== nothing
if tmp2 !== nothing
!isempty(dgrad) && (dgrad[:, i] .= vec(tmp2))
end
end
!== nothing && (dλ[:, i] .= vec(tmp1))
dy !== nothing && (dy[i] = _dy)
_dy, back = Zygote.pullback(y, p) do u, p
out_ = Zygote.Buffer(similar(u))
f(out_, u, p, t)
copy(out_)
end
else
for i in 1:m
_dy, back = Zygote.pullback(y, p) do u, p
f(u, p, t)[i]
end
tmp1, tmp2 = back(λ[i])
if dgrad !== nothing
if tmp2 !== nothing
!isempty(dgrad) && (dgrad[:, i] .= vec(tmp2))
end
end
!== nothing && (dλ[:, i] .= vec(tmp1))
dy !== nothing && (dy[i] = _dy)
_dy, back = Zygote.pullback(y, p) do u, p
f(u, p, t)
end
end
out = [back(x) for x in eachcol(Diagonal(λ))]
if dgrad !== nothing
if tmp2 !== nothing
!isempty(dgrad) && (dgrad .= vec(stack(last.(out))))
end
end
!== nothing && (dλ .= vec(stack(first.(out))))
dy !== nothing && (dy .= _dy)
else
if inplace_sensitivity(S)
for i in 1:m
Expand Down

0 comments on commit 1aee06b

Please sign in to comment.