Skip to content

Multiple ODE solves from a single time series on GPU #279

@mkalia94

Description

@mkalia94

I have been trying to implement an autoencoder with an ODE solve in between, which uses several initial values from the input time series.

Say we have a time series y of dimension 2x100 and I'd like to solve the ODE over small time intervals [0,10] using initial conditions y[:,1:10:end]. It works fine on the cpu using hcat([Array(solve(...))]...), however using the gpu gives me the error:

ERROR: LoadError: CuArray only supports bits types
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] CuArrays.CuArray{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing}}},1,P} where P(::UndefInitializer, ::Tuple{Int64}) at /home/manu/.julia/packages/CuArrays/l0gXB/src/array.jl:106
 [3] similar(::CuArrays.CuArray{Float32,1,Nothing}, ::Type{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing}}}}, ::Tuple{Int64}) at /home/manu/.julia/packages/CuArrays/l0gXB/src/array.jl:139
 [4] similar(::CuArrays.CuArray{Float32,1,Nothing}, ::Type{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing}}}}, ::Int64) at ./abstractarray.jl:628
 [5] similar(::ReverseDiff.TrackedArray{Float32,Float32,1,CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing}}, ::Int64) at /home/manu/.julia/packages/ReverseDiff/uy0uk/src/tracked.jl:325
 [6] Zygote.Buffer(::ReverseDiff.TrackedArray{Float32,Float32,1,CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing}}, ::Int64) at /home/manu/.julia/packages/Zygote/YeCEW/src/tools/buffer.jl:42
 [7] lotka_volterra(::Array{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}},1}, ::ReverseDiff.TrackedArray{Float32,Float32,1,CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing}}, ::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}, ::ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}}) at /home/manu/Documents/Work/NormalFormAE/test/NLRAN_github.jl:15
[8] (::ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing})(::Array{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}},1}, ::Vararg{Any,N} where N) at /home/manu/.julia/packages/DiffEqBase/KnYSY/src/diffeqfunction.jl:248
 [9] (::DiffEqSensitivity.var"#67#74"{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}})(::ReverseDiff.TrackedArray{Float32,Float32,1,CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing}}, ::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}, ::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}) at /home/manu/.julia/packages/DiffEqSensitivity/9c6qf/src/local_sensitivity/adjoint_common.jl:113
 [10] ReverseDiff.GradientTape(::Function, ::Tuple{CuArrays.CuArray{Float32,1,Nothing},Array{Float32,1},Array{Float32,1}}, ::ReverseDiff.GradientConfig{Tuple{ReverseDiff.TrackedArray{Float32,Float32,1,CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing}},ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}},ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}}}) at /home/manu/.julia/packages/ReverseDiff/uy0uk/src/api/tape.jl:207
 [11] ReverseDiff.GradientTape(::Function, ::Tuple{CuArrays.CuArray{Float32,1,Nothing},Array{Float32,1},Array{Float32,1}}) at /home/manu/.julia/packages/ReverseDiff/uy0uk/src/api/tape.jl:204
 [12] adjointdiffcache(::Function, ::InterpolatingAdjoint{0,true,Val{:central},Bool}, ::Bool, ::ODESolution{Float32,2,Array{CuArrays.CuArray{Float32,1,Nothing},1},Nothing,Nothing,Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},ODEProblem{CuArrays.CuArray{Float32,1,Nothing},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{CuArrays.CuArray{Float32,1,Nothing},1},Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},OrdinaryDiffEq.Tsit5Cache{CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}}},DiffEqBase.DEStats}, ::Nothing, ::ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}; quad::Bool, noiseterm::Bool) at /home/manu/.julia/packages/DiffEqSensitivity/9c6qf/src/local_sensitivity/adjoint_common.jl:111
 [13] adjointdiffcache at /home/manu/.julia/packages/DiffEqSensitivity/9c6qf/src/local_sensitivity/adjoint_common.jl:26 [inlined]
 [14] DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction(::Function, ::InterpolatingAdjoint{0,true,Val{:central},Bool}, ::Bool, ::ODESolution{Float32,2,Array{CuArrays.CuArray{Float32,1,Nothing},1},Nothing,Nothing,Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},ODEProblem{CuArrays.CuArray{Float32,1,Nothing},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{CuArrays.CuArray{Float32,1,Nothing},1},Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},OrdinaryDiffEq.Tsit5Cache{CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}}},DiffEqBase.DEStats}, ::Nothing, ::Array{Float32,1}, ::Nothing, ::NamedTuple{(:reltol, :abstol),Tuple{Float64,Float64}}) at /home/manu/.julia/packages/DiffEqSensitivity/9c6qf/src/local_sensitivity/interpolating_adjoint.jl:37
 [15] ODEAdjointProblem(::ODESolution{Float32,2,Array{CuArrays.CuArray{Float32,1,Nothing},1},Nothing,Nothing,Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},ODEProblem{CuArrays.CuArray{Float32,1,Nothing},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{CuArrays.CuArray{Float32,1,Nothing},1},Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},OrdinaryDiffEq.Tsit5Cache{CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}}},DiffEqBase.DEStats}, ::InterpolatingAdjoint{0,true,Val{:central},Bool}, ::DiffEqSensitivity.var"#df#115"{CuArrays.CuArray{Float32,2,Nothing},CuArrays.CuArray{Float32,1,Nothing},Colon}, ::StepRangeLen{Float32,Float64,Float64}, ::Nothing; checkpoints::Array{Float32,1}, callback::CallbackSet{Tuple{},Tuple{}}, reltol::Float64, abstol::Float64, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/manu/.julia/packages/DiffEqSensitivity/9c6qf/src/local_sensitivity/interpolating_adjoint.jl:115
 [16] _adjoint_sensitivities(::ODESolution{Float32,2,Array{CuArrays.CuArray{Float32,1,Nothing},1},Nothing,Nothing,Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},ODEProblem{CuArrays.CuArray{Float32,1,Nothing},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{CuArrays.CuArray{Float32,1,Nothing},1},Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},OrdinaryDiffEq.Tsit5Cache{CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}}},DiffEqBase.DEStats}, ::InterpolatingAdjoint{0,true,Val{:central},Bool}, ::Tsit5, ::DiffEqSensitivity.var"#df#115"{CuArrays.CuArray{Float32,2,Nothing},CuArrays.CuArray{Float32,1,Nothing},Colon}, ::StepRangeLen{Float32,Float64,Float64}, ::Nothing; abstol::Float64, reltol::Float64, checkpoints::Array{Float32,1}, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/manu/.julia/packages/DiffEqSensitivity/9c6qf/src/local_sensitivity/sensitivity_interface.jl:17
 [17] adjoint_sensitivities(::ODESolution{Float32,2,Array{CuArrays.CuArray{Float32,1,Nothing},1},Nothing,Nothing,Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},ODEProblem{CuArrays.CuArray{Float32,1,Nothing},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{CuArrays.CuArray{Float32,1,Nothing},1},Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},OrdinaryDiffEq.Tsit5Cache{CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}}},DiffEqBase.DEStats}, ::Tsit5, ::Vararg{Any,N} where N; sensealg::InterpolatingAdjoint{0,true,Val{:central},Bool}, kwargs::Base.Iterators.Pairs{Symbol,Float64,Tuple{Symbol},NamedTuple{(:reltol,),Tuple{Float64}}}) at /home/manu/.julia/packages/DiffEqSensitivity/9c6qf/src/local_sensitivity/sensitivity_interface.jl:6
 [18] (::DiffEqSensitivity.var"#adjoint_sensitivity_backpass#114"{Tsit5,InterpolatingAdjoint{0,true,Val{:central},Bool},CuArrays.CuArray{Float32,1,Nothing},Array{Float32,1},Tuple{},Colon})(::CuArrays.CuArray{Float32,2,Nothing}) at /home/manu/.julia/packages/DiffEqSensitivity/9c6qf/src/local_sensitivity/concrete_solve.jl:107
 [19] #512#back at /home/manu/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:55 [inlined]
 [20] #174 at /home/manu/.julia/packages/Zygote/YeCEW/src/lib/lib.jl:182 [inlined]
 [21] (::Zygote.var"#347#back#176"{Zygote.var"#174#175"{DiffEqBase.var"#512#back#457"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#114"{Tsit5,InterpolatingAdjoint{0,true,Val{:central},Bool},CuArrays.CuArray{Float32,1,Nothing},Array{Float32,1},Tuple{},Colon}},Tuple{NTuple{6,Nothing},Tuple{Nothing}}}})(::CuArrays.CuArray{Float32,2,Nothing}) at /home/manu/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [22] #solve#443 at /home/manu/.julia/packages/DiffEqBase/KnYSY/src/solve.jl:69 [inlined]
 [23] (::typeof((#solve#443)))(::CuArrays.CuArray{Float32,2,Nothing}) at /home/manu/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [24] (::Zygote.var"#174#175"{typeof((#solve#443)),Tuple{NTuple{6,Nothing},Tuple{Nothing}}})(::CuArrays.CuArray{Float32,2,Nothing}) at /home/manu/.julia/packages/Zygote/YeCEW/src/lib/lib.jl:182
 [25] (::Zygote.var"#347#back#176"{Zygote.var"#174#175"{typeof((#solve#443)),Tuple{NTuple{6,Nothing},Tuple{Nothing}}}})(::CuArrays.CuArray{Float32,2,Nothing}) at /home/manu/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [26] (::typeof((solve##kw)))(::CuArrays.CuArray{Float32,2,Nothing}) at /home/manu/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [27] predict_ODE_solve at /home/manu/Documents/Work/NormalFormAE/test/NLRAN_github.jl:47 [inlined]
 [28] (::typeof((predict_ODE_solve)))(::CuArrays.CuArray{Float32,2,Nothing}) at /home/manu/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [29] #41 at ./none:0 [inlined]
 [30] (::typeof((λ)))(::CuArrays.CuArray{Float32,2,Nothing}) at /home/manu/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [31] #1187 at /home/manu/.julia/packages/Zygote/YeCEW/src/lib/array.jl:172 [inlined]
 [32] #3 at ./generator.jl:36 [inlined]
 [33] iterate at ./generator.jl:47 [inlined]
 [34] collect(::Base.Generator{Base.Iterators.Zip{Tuple{Array{typeof((λ)),1},NTuple{10,CuArrays.CuArray{Float32,2,Nothing}}}},Base.var"#3#4"{Zygote.var"#1187#1191"}}) at ./array.jl:665
 [35] map at ./abstractarray.jl:2154 [inlined]
 [36] (::Zygote.var"#1186#1190"{Array{typeof((λ)),1}})(::NTuple{10,CuArrays.CuArray{Float32,2,Nothing}}) at /home/manu/.julia/packages/Zygote/YeCEW/src/lib/array.jl:172
 [37] (::Zygote.var"#1194#1195"{Zygote.var"#1186#1190"{Array{typeof((λ)),1}}})(::NTuple{10,CuArrays.CuArray{Float32,2,Nothing}}) at /home/manu/.julia/packages/Zygote/YeCEW/src/lib/array.jl:187
 [38] loss_func at /home/manu/Documents/Work/NormalFormAE/test/NLRAN_github.jl:54 [inlined]
 [39] (::typeof((loss_func)))(::Float64) at /home/manu/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [40] #16 at /home/manu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:85 [inlined]
 [41] (::typeof((λ)))(::Float64) at /home/manu/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [42] (::Zygote.var"#49#50"{Params,Zygote.Context,typeof((λ))})(::Float64) at /home/manu/.julia/packages/Zygote/YeCEW/src/compiler/interface.jl:179
 [43] gradient(::Function, ::Params) at /home/manu/.julia/packages/Zygote/YeCEW/src/compiler/interface.jl:55
 [44] macro expansion at /home/manu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:84 [inlined]
 [45] macro expansion at /home/manu/.julia/packages/Juno/tLMZd/src/progress.jl:134 [inlined]
 [46] train!(::typeof(loss_func), ::Params, ::Array{CuArrays.CuArray{Float32,2,Nothing},1}, ::ADAM; cb::Flux.Optimise.var"#18#26") at /home/manu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:81
 [47] train!(::Function, ::Params, ::Array{CuArrays.CuArray{Float32,2,Nothing},1}, ::ADAM) at /home/manu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:79
 [48] top-level scope at /home/manu/Documents/Work/NormalFormAE/test/NLRAN_github.jl:69
 [49] include(::String) at ./client.jl:439
in expression starting at /home/manu/Documents/Work/NormalFormAE/test/NLRAN_github.jl:66

here is the code:

using Pkg
Pkg.activate(".")
using DiffEqFlux, DiffEqSensitivity, Flux, OrdinaryDiffEq, Zygote, Test, DifferentialEquations

# Training hyperparameters
nBatchsize = 10
nEpochs = 10
tspan = 5.0
tsize = 100
nbatches = div(tsize,nBatchsize)

# ODE solve
function lotka_volterra(du,u,p,t)
    dx = Zygote.Buffer(u,size(u)[1])
    dx[1] = u[1]*(p[1]-p[2]*u[2])
    dx[2] = u[2]*(p[3]*u[1]-p[4])
    du .= copy(dx)
    nothing
end

# Define parameters and initial conditions for data
p = Float32[2.2, 1.0, 2.0, 0.4] 
u0 = Float32[0.01, 0.01]

t = range(0.0,tspan,length=tsize)

# Define ODE problem and generate data
prob = ODEProblem(lotka_volterra,u0,(0.0,tspan),p)
yy = Array(solve(prob,saveat=t))
y_original = Array(solve(prob,saveat=t))
yy = yy .+ yy*(0.01.*rand(size(yy)[2],size(yy)[2])) # Creates noisy, translated data

data = Float32.(yy) |> gpu

# Define autoencoder networks
NN_encode = Chain(Dense(2,10),Dense(10,10),Dense(10,2)) |> gpu
NN_decode = Chain(Dense(2,10),Dense(10,10),Dense(10,2)) |> gpu

# Define new ODE problem for "batch" evolution
t_batch = range(0.0f0,Float32(tspan/nbatches),length = nBatchsize)
prob2 = ODEProblem(lotka_volterra,u0,(0.0f0,Float32(tspan/nbatches)),p)

# ODE solve to be used for training
function predict_ODE_solve(x)
    return Array(solve(prob2,Tsit5(),u0=x,saveat=t_batch,reltol=1e-4)) 
end

function loss_func(data_)
    enc_ = NN_encode(data_)
    # Solve ODE using initial values from multiple points in enc_.
    # Note: reduce(hcat,[..]) gives a mutating arrays error
    enc_ODE_solve = hcat([predict_ODE_solve(enc_[:,(i-1)*nBatchsize+1]) for i in 1:nbatches]...) |> gpu
    dec_1 = NN_decode(enc_ODE_solve)
    dec_2 = NN_decode(enc_)
    loss = Flux.mse(data_,dec_1) + Flux.mse(data_,dec_2) + 0.001*Flux.mse(enc_,enc_ODE_solve)
    args_["loss"] = loss
    return loss
end

opt = ADAM(0.001)

loss_func(data) # This works

for ep in 1:nEpochs
    global args_
    @info "Epoch $ep"
    Flux.train!(loss_func, Flux.params(NN_encode,NN_decode), [(data)], opt)
    loss_ = args_["loss"]
    println("loss: $(loss_)")
end

and here is the current status of packages

[c7e460c6] ArgParse v1.1.0
  [fbb218c0] BSON v0.2.6
  [6e4b80f9] BenchmarkTools v0.5.0
  [3895d2a7] CUDAapi v4.0.0
  [c5f51814] CUDAdrv v6.3.0
  [be33ccc6] CUDAnative v3.1.0
  [3a865a2d] CuArrays v2.2.0
  [31a5f54b] Debugger v0.6.4
  [aae7a2af] DiffEqFlux v1.12.0
  [41bf760c] DiffEqSensitivity v6.19.0
  [0c46a032] DifferentialEquations v6.14.0
  [31c24e10] Distributions v0.23.3
  [5789e2e9] FileIO v1.3.0
  [587475ba] Flux v0.10.4
  [0c68f7d7] GPUArrays v3.4.1
  [033835bb] JLD2 v0.1.13
  [429524aa] Optim v0.21.0
  [1dea7af3] OrdinaryDiffEq v5.39.1
  [91a5bcdd] Plots v1.3.6
  [8d666b04] PolyChaos v0.2.1
  [ee283ea6] Rebugger v0.3.3
  [295af30f] Revise v2.7.1
  [9f7883ad] Tracker v0.2.6
  [e88e6eb3] Zygote v0.4.20
  [9a3f8284] Random 

Is there a way to push this to the GPU efficiently? Any help would be appreciated. Thanks for the fantastic work on this package! :)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions