Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Continuous-adjoint methods for diagonal-noise SDEs scale in the square of number of dimensions #854

Open
linusheck opened this issue Jul 28, 2023 · 18 comments

Comments

@linusheck
Copy link
Contributor

linusheck commented Jul 28, 2023

Hi,
this line is scalar indexing in a pullback:


This means you can't diff on a GPU in this case, as scalar indexing is not allowed.
Excerpt from the error:

Scalar indexing is disallowed.

Invocation of getindex resulted in scalar indexing of a GPU array.

This is typically caused by calling an iterating implementation of a method.

Such implementations *do not* execute on the GPU, but very slowly on the CPU,

and therefore are only permitted from the REPL for prototyping purposes.

If you did intend to index this array, annotate the caller with @allowscalar.

error(::String)@error.jl:35
assertscalar(::String)@GPUArraysCore.jl:103
getindex@indexing.jl:9[inlined]
adjoint@array.jl:44[inlined]
_pullback(::Zygote.Context{false}, ::typeof(getindex), ::CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::Int64)@adjoint.jl:66
_pullback@derivative_wrappers.jl:899[inlined]

@linusheck linusheck changed the title Scalar indexing in pullback for diagonal noise Continuous adjoint methods for diagonal-noise SDEs not compatible with GPUs Aug 1, 2023
@linusheck
Copy link
Contributor Author

linusheck commented Aug 2, 2023

MWE:

using Lux, Zygote, DifferentialEquations, ComponentArrays, Random, CUDA, SciMLSensitivity

rng = Xoshiro()

drift_net = Dense(2 => 2)
diffusion_net = Dense(2 => 2)

ps_drift_, st_drift = Lux.setup(rng, drift_net)

ps_diffusion_, st_diffusion = Lux.setup(rng, diffusion_net)

ps_ = ComponentArray((ps_drift=ps_drift_,ps_diffusion=ps_diffusion_)) |> Lux.gpu

function drift(u, ps, t)
    drift_net(u, ps.ps_drift, st_drift)[1]
end
function diffusion(u, ps, t)
    diffusion_net(u, ps.ps_diffusion, st_diffusion)[1]
end

u0 = [1f0, 1f0] |> Lux.gpu

tspan = (0f0, 1f0)
datasize = 10

solver = EulerHeun()
sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP())

function loss(ps)
    problem = SDEProblem(drift, diffusion, u0, tspan, ps)
    solution = solve(problem, solver; sensealg=sensealg, saveat=collect(range(tspan[1], tspan[end], datasize)), dt=(tspan[end] / datasize))
    return sum(vec(solution |> Lux.gpu))
end

println(loss(ps_))
println(Zygote.gradient(ps -> loss(ps), ps_))

On CPU:

┌ Info: The GPU function is being called but the GPU is not accessible.
│ Defaulting back to the CPU. (No action is required if you want
└ to run on the CPU).
-0.445899
((ps_drift = (weight = Float32[11.908233 11.131403; 1.8911543 1.8107269], bias = Float32[10.83848; 1.7313616;;]), ps_diffusion = (weight = Float32[-22.582191 -22.426481; -0.6261419 -1.0105969], bias = Float32[-21.103664; -0.61384314;;])),)

On GPU:

41.25828
ERROR: LoadError: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] assertscalar(op::String)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/HaQcr/src/GPUArraysCore.jl:103
  [3] getindex
    @ ~/.julia/packages/GPUArrays/TnEpb/src/host/indexing.jl:9 [inlined]
  [4] adjoint
    @ ~/.julia/packages/Zygote/SuKWp/src/lib/array.jl:44 [inlined]
  [5] _pullback(__context__::Zygote.Context{false}, 568::typeof(getindex), x::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, inds::Int64)
    @ Zygote ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66
  [6] _pullback
    @ ~/.julia/packages/SciMLSensitivity/E8w3Z/src/derivative_wrappers.jl:899 [inlined]
...

@ChrisRackauckas
Copy link
Member

Wait why is the indexing required there? Why not just compute all derivatives together, i.e.:

                _dy, back = Zygote.pullback(y, p) do u, p
                    f(u, p, t)
                end
                tmp1, tmp2 = back(λ)
                if dgrad !== nothing
                    if tmp2 !== nothing
                        !isempty(dgrad) && (vec(dgrad) .= vec(tmp2))
                    end
                end!== nothing && (vec(dλ) .= vec(tmp1))
                dy !== nothing && (dy = _dy)

?

@frankschae
Copy link
Member

because if the primal noise process has diagonal noise, the adjoint has commutative noise [see (14) in App. 9.5 of https://arxiv.org/pdf/2001.01328.pdf]

@ChrisRackauckas
Copy link
Member

I get that, but I don't see why the piece of code right there needs to be indexed. That's exactly the same result as what I posted?

@frankschae
Copy link
Member

Maybe there is a trivial solution..

using Lux, Zygote, DifferentialEquations, ComponentArrays, Random, SciMLSensitivity

p = [1.5, 1.0, 3.0, 1.0]
m = 2

function f(u, p, t)
    dx = p[1] * u[1] - p[2] * u[1] * u[2] * t
    dy = -p[3] * u[2] + t * p[4] * u[1] * u[2]
    [dx, dy]
end
Random.seed!(434988934)
y = rand(m)
λ = rand(m)
t = rand()

dW = rand(m)


dgrad = zeros(length(p),m)
dλ = zeros(m,m)
dy = zeros(m)

for i in 1:m
    _dy, back = Zygote.pullback(y, p) do u, p
        f(u, p, t)[i]
    end
    tmp1, tmp2 = back(λ[i])
    dgrad[:, i] .= vec(tmp2)
    dλ[:, i] .= vec(tmp1)
    dy[i] = _dy
end

dy2, back = Zygote.pullback(y, p) do u, p
    f(u, p, t)
end
tmp1, tmp2 = back(λ)
julia> dgrad
4×2 Matrix{Float64}:
  0.0813625   0.0
 -0.0261179   0.0
  0.0        -0.409558
  0.0         0.168718

vs.

tmp2
4-element Vector{Float64}:
  0.08136250711023468
 -0.02611788331081627
 -0.40955766401332616
  0.16871806153418192
  # how to multiply tmp2 with dW such that dgrad * dW  ==  tmp2 (*) dW?

and

julia>2×2 Matrix{Float64}:
  0.107301    0.188725
 -0.0374919  -1.52155

julia> tmp1
2-element Vector{Float64}:
  0.2960254183310152
 -1.5590457490405374
 
  # how to multiply tmp1 with dW such that dλ * dW  ==  tmp1 (*) dW?

@ChrisRackauckas
Copy link
Member

Zygote has a bug here that's easy to workaround:

using Lux, Zygote, DifferentialEquations, ComponentArrays, Random, SciMLSensitivity, LinearAlgebra

p = [1.5, 1.0, 3.0, 1.0]
m = 2

function f(u, p, t)
    dx = p[1] * u[1] - p[2] * u[1] * u[2] * t
    dy = -p[3] * u[2] + t * p[4] * u[1] * u[2]
    [dx, dy]
end
Random.seed!(434988934)
y = rand(m)
λ = rand(m)
t = rand()

dW = rand(m)


dgrad = zeros(length(p),m)
dλ = zeros(m,m)
dy = zeros(m)

for i in 1:m
    _dy, back = Zygote.pullback(y, p) do u, p
        f(u, p, t)[i]
    end
    tmp1, tmp2 = back(λ[i])
    dgrad[:, i] .= vec(tmp2)
    dλ[:, i] .= vec(tmp1)
    dy[i] = _dy
end

dy2, back = Zygote.pullback(y, p) do u, p
    f(u, p, t)
end
out = [back(x) for x in eachcol(Diagonal(λ))]

dgrad == stack(last.(out)) # true== stack(first.(out)) # true

@linusheck
Copy link
Contributor Author

whoaaaa nice!!! :0

@linusheck
Copy link
Contributor Author

Executing the MWE on SciMLSensitvitiy#master now yields this:

julia> println(Zygote.gradient(ps -> loss(ps), ps_))
ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] assertscalar(op::String)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/uOYfN/src/GPUArraysCore.jl:103
  [3] getindex
    @ ~/.julia/packages/GPUArrays/5XhED/src/host/indexing.jl:9 [inlined]
  [4] generic_matvecmul!(C::Vector{Float32}, tA::Char, A::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, B::SubArray{Float32, 1, LinearAlgebra.Diagonal{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, false}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
    @ LinearAlgebra ~/julia-1.9.0/share/julia/stdlib/v1.9/LinearAlgebra/src/matmul.jl:791
  [5] mul!
    @ ~/julia-1.9.0/share/julia/stdlib/v1.9/LinearAlgebra/src/matmul.jl:115 [inlined]
  [6] mul!
    @ ~/julia-1.9.0/share/julia/stdlib/v1.9/LinearAlgebra/src/matmul.jl:276 [inlined]
  [7] *
    @ ~/julia-1.9.0/share/julia/stdlib/v1.9/LinearAlgebra/src/matmul.jl:105 [inlined]
  [8] #1480
    @ ~/.julia/packages/ChainRules/9sNmB/src/rulesets/Base/arraymath.jl:60 [inlined]
  [9] unthunk
    @ ~/.julia/packages/ChainRulesCore/0t04l/src/tangent_types/thunks.jl:204 [inlined]
 [10] unthunk
    @ ~/.julia/packages/ChainRulesCore/0t04l/src/tangent_types/thunks.jl:237 [inlined]
 [11] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/JeHtr/src/compiler/chainrules.jl:110 [inlined]
...
[17] (::Zygote.Pullback{Tuple{typeof(diffusion), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Float32}, Tuple{Zygote.ZBack{ComponentArrays.var"#getproperty_adjoint#85"{ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Symbol}}, Zygote.var"#2033#back#209"{Zygote.var"#back#207"{2, 1, Zygote.Context{false}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, Zygote.var"#1990#back#190"{Zygote.var"#186#189"{Zygote.Context{false}, GlobalRef, NamedTuple{(), Tuple{}}}}, Zygote.Pullback{Tuple{Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))}}}, NamedTuple{(), Tuple{}}}, Tuple{Zygote.ZBack{ChainRules.var"#times_pullback#1481"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, Zygote.ZBack{Lux.var"#vec_pullback#193"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.var"#3754#back#1177"{Zygote.var"#1171#1175"{Tuple{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, Zygote.ZBack{ComponentArrays.var"#getproperty_adjoint#85"{ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))}}}, Symbol}}, Zygote.var"#2184#back#299"{Zygote.var"#back#298"{:activation, Zygote.Context{false}, Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, typeof(identity)}}, Zygote.ZBack{ComponentArrays.var"#getproperty_adjoint#85"{ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))}}}, Symbol}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#2017#back#200"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Lux.__apply_activation), typeof(identity), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{}}}}, Zygote.var"#1990#back#190"{Zygote.var"#186#189"{Zygote.Context{false}, GlobalRef, Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}}})(Δ::SubArray{Float32, 1, LinearAlgebra.Diagonal{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, false})
    @ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
 [18] #287
    @ ~/.julia/packages/Zygote/JeHtr/src/lib/lib.jl:206 [inlined]
 [19] #2173#back
    @ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71 [inlined]
 [20] Pullback
    @ ~/.julia/packages/SciMLBase/kTUaf/src/scimlfunctions.jl:2127 [inlined]
 [21] Pullback
    @ ~/.julia/packages/SciMLSensitivity/bCIak/src/derivative_wrappers.jl:911 [inlined]

There is some sort of C::Vector{Float32} involved which shouldn't be there I think.

If I CUDA.allowscalar(true) it yields

julia> println(Zygote.gradient(ps -> loss(ps), ps_))
ERROR: BoundsError: attempt to access 12×1 ComponentMatrix{Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}, FlatAxis}} with indices 1:1:12×1:1:1 at index [13:24]
Stacktrace:
  [1] copyto!(dest::ComponentMatrix{Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}, FlatAxis}}, dstart::Int64, src::ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, sstart::Int64, n::Int64)
    @ Base ./abstractarray.jl:1137
  [2] copyto!
    @ ./abstractarray.jl:1121 [inlined]
  [3] _typed_stack(::Colon, ::Type{Float32}, ::Type{ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}}, A::Vector{ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}}, Aax::Tuple{Base.OneTo{Int64}})
    @ Base ./abstractarray.jl:2803
  [4] _typed_stack
    @ ./abstractarray.jl:2793 [inlined]
  [5] _stack
    @ ./abstractarray.jl:2783 [inlined]
  [6] _stack
    @ ./abstractarray.jl:2775 [inlined]
  [7] #stack#178
    @ ./abstractarray.jl:2743 [inlined]
  [8] stack(iter::Vector{ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}})
    @ Base ./abstractarray.jl:2743
  [9] _jacNoise!(λ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, y::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, p::ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, t::Float32, S::SciMLSensitivity.ODEInterpolatingAdjointSensitivityFunction{SciMLSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}, InterpolatingAdjoint{0, true, Val{:central}, ZygoteVJP}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, RODESolution{Float32, 2, Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Nothing, Nothing, Vector{Float32}, NoiseProcess{Float32, 2, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing, Nothing, typeof(DiffEqNoiseProcess.WHITE_NOISE_DIST), typeof(DiffEqNoiseProcess.WHITE_NOISE_BRIDGE), false, ResettableStacks.ResettableStack{Tuple{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing}, false}, ResettableStacks.ResettableStack{Tuple{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing}, false}, RSWM{Float64}, Nothing, RandomNumbers.Xorshifts.Xoroshiro128Plus}, SDEProblem{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}, false, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, SciMLBase.FullSpecialize, typeof(drift), typeof(diffusion), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, typeof(diffusion), Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}, EulerHeun, StochasticDiffEq.LinearInterpolationData{Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Vector{Float32}}, DiffEqBase.Stats, Nothing}, SciMLSensitivity.CheckpointSolution{RODESolution{Float32, 2, Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Nothing, Nothing, Vector{Float32}, NoiseWrapper{Float32, 2, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing, NoiseProcess{Float32, 2, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing, Nothing, typeof(DiffEqNoiseProcess.WHITE_NOISE_DIST), typeof(DiffEqNoiseProcess.WHITE_NOISE_BRIDGE), false, ResettableStacks.ResettableStack{Tuple{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing}, false}, ResettableStacks.ResettableStack{Tuple{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing}, false}, RSWM{Float64}, Nothing, RandomNumbers.Xorshifts.Xoroshiro128Plus}, Nothing, false}, SDEProblem{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}, false, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, NoiseWrapper{Float32, 2, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing, NoiseProcess{Float32, 2, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing, Nothing, typeof(DiffEqNoiseProcess.WHITE_NOISE_DIST), typeof(DiffEqNoiseProcess.WHITE_NOISE_BRIDGE), false, ResettableStacks.ResettableStack{Tuple{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing}, false}, ResettableStacks.ResettableStack{Tuple{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing}, false}, RSWM{Float64}, Nothing, RandomNumbers.Xorshifts.Xoroshiro128Plus}, Nothing, false}, SDEFunction{false, SciMLBase.FullSpecialize, typeof(drift), typeof(diffusion), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, typeof(diffusion), Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}, EulerHeun, StochasticDiffEq.LinearInterpolationData{Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Vector{Float32}}, DiffEqBase.Stats, Nothing}, Vector{Tuple{Float32, Float32}}, NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}, Nothing}, SDEProblem{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}, false, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, SciMLBase.FullSpecialize, typeof(drift), typeof(diffusion), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, typeof(diffusion), Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}, ODEFunction{false, true, typeof(diffusion), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}}, isnoise::ZygoteVJP, dgrad::SubArray{Float32, 2, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{UnitRange{Int64}, UnitRange{Int64}}, false}, dλ::Nothing, dy::Nothing)
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/bCIak/src/derivative_wrappers.jl:918

@linusheck
Copy link
Contributor Author

linusheck commented Aug 10, 2023

Same error on CPU:

julia> println(Zygote.gradient(ps -> loss(ps), ps_))
ERROR: BoundsError: attempt to access 12×1 ComponentMatrix{Float32, Matrix{Float32}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}, FlatAxis}} with indices 1:1:12×1:1:1 at index [13:24]
Stacktrace:
  [1] copyto!(dest::ComponentMatrix{Float32, Matrix{Float32}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}, FlatAxis}}, dstart::Int64, src::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, sstart::Int64, n::Int64)
    @ Base ./abstractarray.jl:1137
  [2] copyto!
    @ ./abstractarray.jl:1121 [inlined]
  [3] _typed_stack(::Colon, ::Type{Float32}, ::Type{ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}}, A::Vector{ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(ps_drift = ViewAxis(1:6, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))), ps_diffusion = ViewAxis(7:12, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}}, Aax::Tuple{Base.OneTo{Int64}})
    @ Base ./abstractarray.jl:2803
  [4] _typed_stack
    @ ./abstractarray.jl:2793 [inlined]
  [5] _stack
    @ ./abstractarray.jl:2783 [inlined]
  [6] _stack
...

maybe something wrong with the MWE?
my component vector seems okay

@linusheck
Copy link
Contributor Author

I'm pretty sure I can solve this, it seems like stack on a ComponentVector isn't behaving as expected

@linusheck
Copy link
Contributor Author

Zygote has a bug here that's easy to workaround:

using Lux, Zygote, DifferentialEquations, ComponentArrays, Random, SciMLSensitivity, LinearAlgebra

p = [1.5, 1.0, 3.0, 1.0]
m = 2

function f(u, p, t)
    dx = p[1] * u[1] - p[2] * u[1] * u[2] * t
    dy = -p[3] * u[2] + t * p[4] * u[1] * u[2]
    [dx, dy]
end
Random.seed!(434988934)
y = rand(m)
λ = rand(m)
t = rand()

dW = rand(m)


dgrad = zeros(length(p),m)
dλ = zeros(m,m)
dy = zeros(m)

for i in 1:m
    _dy, back = Zygote.pullback(y, p) do u, p
        f(u, p, t)[i]
    end
    tmp1, tmp2 = back(λ[i])
    dgrad[:, i] .= vec(tmp2)
    dλ[:, i] .= vec(tmp1)
    dy[i] = _dy
end

dy2, back = Zygote.pullback(y, p) do u, p
    f(u, p, t)
end
out = [back(x) for x in eachcol(Diagonal(λ))]

dgrad == stack(last.(out)) # true== stack(first.(out)) # true

seems like this workaround doesn't work on gpus as eachcol(Diagonal(...)) and every alternative I've tried doesn't work with CUDA

@linusheck
Copy link
Contributor Author

linusheck commented Aug 12, 2023

Zygote has a bug here

What is the bug exactly? Can we fix it? This workaround is O(n^2)

@linusheck
Copy link
Contributor Author

linusheck commented Aug 12, 2023

< deleted because I had an incorrect theory here - see below >

@linusheck
Copy link
Contributor Author

hmmm BUT we need to perturb this correctly with the noise. I can't get behind how torchsde / diffrax are doing this right now...

@linusheck
Copy link
Contributor Author

linusheck commented Aug 14, 2023

In torchsde, they never actually define g in the adjoint SDE, only define g_prod which is the product between g and the noise. So compare the implementation for EulerHeun:

integrator.f(ftmp1,uprev,p,t)
integrator.g(gtmp1,uprev,p,t)

if is_diagonal_noise(integrator.sol.prob)
  @.. nrtmp=gtmp1*W.dW
else
  mul!(nrtmp,gtmp1,W.dW)
end

Python:

f, g_prod = self.sde.f_and_g_prod(t0, y0, I_k)

y_prime = y0 + g_prod

g_prod_prime = self.sde.g_prod(t1, y_prime, I_k)

y1 = y0 + dt * f + (g_prod + g_prod_prime) * 0.5

So Python has this more generic gprod and this is used to avoid computing the large matrix. Somehow Julia needs this as well, otherwise we can't train any diagonal SDEs efficiently. But the assumption that we compute g seems pretty deeply nestled into the library.

@linusheck
Copy link
Contributor Author

linusheck commented Sep 9, 2023

update: there's a trivial solution!!

  # how to multiply tmp2 with dW such that dgrad * dW  ==  tmp2 (*) dW?
  # how to multiply tmp1 with dW such that dλ * dW  ==  tmp1 (*) dW?

don't compute dgrad- move the multiplication with dW into the vjp. that's it

my comments about paramnoisemixing are not important, noisemixing has nothing to do with this, it just works. but the solver implementation hurdle is still relevant.

using Lux, Zygote, DifferentialEquations, ComponentArrays, Random, SciMLSensitivity, LinearAlgebra

p = [1.5, 1.0, 3.0, 1.0]
m = 2

function f(u, p, t)
    dx = p[1] * u[1] - p[2] * u[1] * u[2] * t
    dy = -p[3] * u[2] + t * p[4] * u[1] * u[2]
    [dx, dy]
end

Random.seed!(434988934)
y = rand(m)
λ = rand(m)
t = rand()

dgrad = zeros(length(p),m)
dλ = zeros(m,m)
dy = zeros(m)

dy, back = Zygote.pullback(y, p) do u, p
    f(u, p, t)
end
out = [back(x) for x in eachcol(Diagonal(λ))]

dλ1 = stack(first.(out))
dgrad1 = stack(last.(out))

dW = rand(m)

println("Computed with a vjp for each dimension: $(dλ1 * dW)")
println("Computed with a vjp for each dimension: $(dgrad1 * dW)")

dy2, back = Zygote.pullback(y, p) do u, p
    f(u, p, t) .* dW
end

out2 = back(λ)

resλ = first(out2)
resgrad = last(out2)

println("Computed with a single vjp: $resλ")
println("Computed with a single vjp: $resgrad")

println(dλ1 * dW ≈ resλ) # true
println(dgrad1 * dW ≈ resgrad) # true

this is still nontrivial to implement in Julia because of the solver design issue mentioned above

@linusheck
Copy link
Contributor Author

this of course gives us quite a performance boost:

m = 10000
(...)
function f(u, p, t)
    [p[1] * x - p[2] * t + p[3] * p[4] * t * x * x for x in u]
end

(...)

# this scales in O(m^2)
# precompile, then execute
[back(x) for x in eachcol(Diagonal(λ))]
@time out = [back(x) for x in eachcol(Diagonal(λ))]

(...)

# this scales in O(m)
back(λ)
@time out2 = back(λ)

println(dλ1 * dW ≈ resλ) # true
println(dgrad1 * dW ≈ resgrad) # true

=>>

 29.061289 seconds (400.42 M allocations: 78.243 GiB, 21.32% gc time, 0.13% compilation time)
  0.002238 seconds (40.05 k allocations: 8.241 MiB)
true
true

@linusheck linusheck changed the title Continuous adjoint methods for diagonal-noise SDEs not compatible with GPUs Continuous-adjoint methods for diagonal-noise SDEs scale in the square of number of dimensions Sep 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants