-
-
Notifications
You must be signed in to change notification settings - Fork 69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Continuous-adjoint methods for diagonal-noise SDEs scale in the square of number of dimensions #854
Comments
MWE:
On CPU:
On GPU:
|
I don't immediately see how we can rewrite the block https://github.com/SciML/SciMLSensitivity.jl/blob/1997fb1a2c288f3da37f61c7b0894eb4e42c5cd6/src/derivative_wrappers.jl#L897C13-L909 without scalar indexing... Probably worth looking into |
Wait why is the indexing required there? Why not just compute all derivatives together, i.e.: _dy, back = Zygote.pullback(y, p) do u, p
f(u, p, t)
end
tmp1, tmp2 = back(λ)
if dgrad !== nothing
if tmp2 !== nothing
!isempty(dgrad) && (vec(dgrad) .= vec(tmp2))
end
end
dλ !== nothing && (vec(dλ) .= vec(tmp1))
dy !== nothing && (dy = _dy) ? |
because if the primal noise process has diagonal noise, the adjoint has commutative noise [see (14) in App. 9.5 of https://arxiv.org/pdf/2001.01328.pdf] |
I get that, but I don't see why the piece of code right there needs to be indexed. That's exactly the same result as what I posted? |
Maybe there is a trivial solution.. using Lux, Zygote, DifferentialEquations, ComponentArrays, Random, SciMLSensitivity
p = [1.5, 1.0, 3.0, 1.0]
m = 2
function f(u, p, t)
dx = p[1] * u[1] - p[2] * u[1] * u[2] * t
dy = -p[3] * u[2] + t * p[4] * u[1] * u[2]
[dx, dy]
end
Random.seed!(434988934)
y = rand(m)
λ = rand(m)
t = rand()
dW = rand(m)
dgrad = zeros(length(p),m)
dλ = zeros(m,m)
dy = zeros(m)
for i in 1:m
_dy, back = Zygote.pullback(y, p) do u, p
f(u, p, t)[i]
end
tmp1, tmp2 = back(λ[i])
dgrad[:, i] .= vec(tmp2)
dλ[:, i] .= vec(tmp1)
dy[i] = _dy
end
dy2, back = Zygote.pullback(y, p) do u, p
f(u, p, t)
end
tmp1, tmp2 = back(λ) julia> dgrad
4×2 Matrix{Float64}:
0.0813625 0.0
-0.0261179 0.0
0.0 -0.409558
0.0 0.168718 vs. tmp2
4-element Vector{Float64}:
0.08136250711023468
-0.02611788331081627
-0.40955766401332616
0.16871806153418192
# how to multiply tmp2 with dW such that dgrad * dW == tmp2 (*) dW? and julia> dλ
2×2 Matrix{Float64}:
0.107301 0.188725
-0.0374919 -1.52155
julia> tmp1
2-element Vector{Float64}:
0.2960254183310152
-1.5590457490405374
# how to multiply tmp1 with dW such that dλ * dW == tmp1 (*) dW? |
Zygote has a bug here that's easy to workaround: using Lux, Zygote, DifferentialEquations, ComponentArrays, Random, SciMLSensitivity, LinearAlgebra
p = [1.5, 1.0, 3.0, 1.0]
m = 2
function f(u, p, t)
dx = p[1] * u[1] - p[2] * u[1] * u[2] * t
dy = -p[3] * u[2] + t * p[4] * u[1] * u[2]
[dx, dy]
end
Random.seed!(434988934)
y = rand(m)
λ = rand(m)
t = rand()
dW = rand(m)
dgrad = zeros(length(p),m)
dλ = zeros(m,m)
dy = zeros(m)
for i in 1:m
_dy, back = Zygote.pullback(y, p) do u, p
f(u, p, t)[i]
end
tmp1, tmp2 = back(λ[i])
dgrad[:, i] .= vec(tmp2)
dλ[:, i] .= vec(tmp1)
dy[i] = _dy
end
dy2, back = Zygote.pullback(y, p) do u, p
f(u, p, t)
end
out = [back(x) for x in eachcol(Diagonal(λ))]
dgrad == stack(last.(out)) # true
dλ == stack(first.(out)) # true |
whoaaaa nice!!! :0 |
Executing the MWE on SciMLSensitvitiy#master now yields this:
There is some sort of If I
|
Same error on CPU:
maybe something wrong with the MWE? |
I'm pretty sure I can solve this, it seems like |
seems like this workaround doesn't work on gpus as |
What is the bug exactly? Can we fix it? This workaround is O(n^2) |
< deleted because I had an incorrect theory here - see below > |
hmmm BUT we need to perturb this correctly with the noise. I can't get behind how torchsde / diffrax are doing this right now... |
In
Python:
So Python has this more generic |
update: there's a trivial solution!!
don't compute my comments about paramnoisemixing are not important, noisemixing has nothing to do with this, it just works. but the solver implementation hurdle is still relevant.
this is still nontrivial to implement in Julia because of the solver design issue mentioned above |
this of course gives us quite a performance boost:
|
Hi,
this line is scalar indexing in a pullback:
SciMLSensitivity.jl/src/derivative_wrappers.jl
Line 899 in 1997fb1
This means you can't diff on a GPU in this case, as scalar indexing is not allowed.
Excerpt from the error:
The text was updated successfully, but these errors were encountered: