-
-
Notifications
You must be signed in to change notification settings - Fork 161
Closed
Description
I have been trying to implement an autoencoder with an ODE solve in between, which uses several initial values from the input time series.
Say we have a time series y of dimension 2x100 and I'd like to solve the ODE over small time intervals [0,10] using initial conditions y[:,1:10:end]. It works fine on the cpu using hcat([Array(solve(...))]...), however using the gpu gives me the error:
ERROR: LoadError: CuArray only supports bits types
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] CuArrays.CuArray{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing}}},1,P} where P(::UndefInitializer, ::Tuple{Int64}) at /home/manu/.julia/packages/CuArrays/l0gXB/src/array.jl:106
[3] similar(::CuArrays.CuArray{Float32,1,Nothing}, ::Type{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing}}}}, ::Tuple{Int64}) at /home/manu/.julia/packages/CuArrays/l0gXB/src/array.jl:139
[4] similar(::CuArrays.CuArray{Float32,1,Nothing}, ::Type{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing}}}}, ::Int64) at ./abstractarray.jl:628
[5] similar(::ReverseDiff.TrackedArray{Float32,Float32,1,CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing}}, ::Int64) at /home/manu/.julia/packages/ReverseDiff/uy0uk/src/tracked.jl:325
[6] Zygote.Buffer(::ReverseDiff.TrackedArray{Float32,Float32,1,CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing}}, ::Int64) at /home/manu/.julia/packages/Zygote/YeCEW/src/tools/buffer.jl:42
[7] lotka_volterra(::Array{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}},1}, ::ReverseDiff.TrackedArray{Float32,Float32,1,CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing}}, ::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}, ::ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}}) at /home/manu/Documents/Work/NormalFormAE/test/NLRAN_github.jl:15
[8] (::ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing})(::Array{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}},1}, ::Vararg{Any,N} where N) at /home/manu/.julia/packages/DiffEqBase/KnYSY/src/diffeqfunction.jl:248
[9] (::DiffEqSensitivity.var"#67#74"{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}})(::ReverseDiff.TrackedArray{Float32,Float32,1,CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing}}, ::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}, ::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}) at /home/manu/.julia/packages/DiffEqSensitivity/9c6qf/src/local_sensitivity/adjoint_common.jl:113
[10] ReverseDiff.GradientTape(::Function, ::Tuple{CuArrays.CuArray{Float32,1,Nothing},Array{Float32,1},Array{Float32,1}}, ::ReverseDiff.GradientConfig{Tuple{ReverseDiff.TrackedArray{Float32,Float32,1,CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing}},ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}},ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}}}) at /home/manu/.julia/packages/ReverseDiff/uy0uk/src/api/tape.jl:207
[11] ReverseDiff.GradientTape(::Function, ::Tuple{CuArrays.CuArray{Float32,1,Nothing},Array{Float32,1},Array{Float32,1}}) at /home/manu/.julia/packages/ReverseDiff/uy0uk/src/api/tape.jl:204
[12] adjointdiffcache(::Function, ::InterpolatingAdjoint{0,true,Val{:central},Bool}, ::Bool, ::ODESolution{Float32,2,Array{CuArrays.CuArray{Float32,1,Nothing},1},Nothing,Nothing,Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},ODEProblem{CuArrays.CuArray{Float32,1,Nothing},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{CuArrays.CuArray{Float32,1,Nothing},1},Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},OrdinaryDiffEq.Tsit5Cache{CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}}},DiffEqBase.DEStats}, ::Nothing, ::ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}; quad::Bool, noiseterm::Bool) at /home/manu/.julia/packages/DiffEqSensitivity/9c6qf/src/local_sensitivity/adjoint_common.jl:111
[13] adjointdiffcache at /home/manu/.julia/packages/DiffEqSensitivity/9c6qf/src/local_sensitivity/adjoint_common.jl:26 [inlined]
[14] DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction(::Function, ::InterpolatingAdjoint{0,true,Val{:central},Bool}, ::Bool, ::ODESolution{Float32,2,Array{CuArrays.CuArray{Float32,1,Nothing},1},Nothing,Nothing,Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},ODEProblem{CuArrays.CuArray{Float32,1,Nothing},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{CuArrays.CuArray{Float32,1,Nothing},1},Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},OrdinaryDiffEq.Tsit5Cache{CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}}},DiffEqBase.DEStats}, ::Nothing, ::Array{Float32,1}, ::Nothing, ::NamedTuple{(:reltol, :abstol),Tuple{Float64,Float64}}) at /home/manu/.julia/packages/DiffEqSensitivity/9c6qf/src/local_sensitivity/interpolating_adjoint.jl:37
[15] ODEAdjointProblem(::ODESolution{Float32,2,Array{CuArrays.CuArray{Float32,1,Nothing},1},Nothing,Nothing,Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},ODEProblem{CuArrays.CuArray{Float32,1,Nothing},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{CuArrays.CuArray{Float32,1,Nothing},1},Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},OrdinaryDiffEq.Tsit5Cache{CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}}},DiffEqBase.DEStats}, ::InterpolatingAdjoint{0,true,Val{:central},Bool}, ::DiffEqSensitivity.var"#df#115"{CuArrays.CuArray{Float32,2,Nothing},CuArrays.CuArray{Float32,1,Nothing},Colon}, ::StepRangeLen{Float32,Float64,Float64}, ::Nothing; checkpoints::Array{Float32,1}, callback::CallbackSet{Tuple{},Tuple{}}, reltol::Float64, abstol::Float64, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/manu/.julia/packages/DiffEqSensitivity/9c6qf/src/local_sensitivity/interpolating_adjoint.jl:115
[16] _adjoint_sensitivities(::ODESolution{Float32,2,Array{CuArrays.CuArray{Float32,1,Nothing},1},Nothing,Nothing,Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},ODEProblem{CuArrays.CuArray{Float32,1,Nothing},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{CuArrays.CuArray{Float32,1,Nothing},1},Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},OrdinaryDiffEq.Tsit5Cache{CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}}},DiffEqBase.DEStats}, ::InterpolatingAdjoint{0,true,Val{:central},Bool}, ::Tsit5, ::DiffEqSensitivity.var"#df#115"{CuArrays.CuArray{Float32,2,Nothing},CuArrays.CuArray{Float32,1,Nothing},Colon}, ::StepRangeLen{Float32,Float64,Float64}, ::Nothing; abstol::Float64, reltol::Float64, checkpoints::Array{Float32,1}, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/manu/.julia/packages/DiffEqSensitivity/9c6qf/src/local_sensitivity/sensitivity_interface.jl:17
[17] adjoint_sensitivities(::ODESolution{Float32,2,Array{CuArrays.CuArray{Float32,1,Nothing},1},Nothing,Nothing,Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},ODEProblem{CuArrays.CuArray{Float32,1,Nothing},Tuple{Float32,Float32},true,Array{Float32,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{CuArrays.CuArray{Float32,1,Nothing},1},Array{Float32,1},Array{Array{CuArrays.CuArray{Float32,1,Nothing},1},1},OrdinaryDiffEq.Tsit5Cache{CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},CuArrays.CuArray{Float32,1,Nothing},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}}},DiffEqBase.DEStats}, ::Tsit5, ::Vararg{Any,N} where N; sensealg::InterpolatingAdjoint{0,true,Val{:central},Bool}, kwargs::Base.Iterators.Pairs{Symbol,Float64,Tuple{Symbol},NamedTuple{(:reltol,),Tuple{Float64}}}) at /home/manu/.julia/packages/DiffEqSensitivity/9c6qf/src/local_sensitivity/sensitivity_interface.jl:6
[18] (::DiffEqSensitivity.var"#adjoint_sensitivity_backpass#114"{Tsit5,InterpolatingAdjoint{0,true,Val{:central},Bool},CuArrays.CuArray{Float32,1,Nothing},Array{Float32,1},Tuple{},Colon})(::CuArrays.CuArray{Float32,2,Nothing}) at /home/manu/.julia/packages/DiffEqSensitivity/9c6qf/src/local_sensitivity/concrete_solve.jl:107
[19] #512#back at /home/manu/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:55 [inlined]
[20] #174 at /home/manu/.julia/packages/Zygote/YeCEW/src/lib/lib.jl:182 [inlined]
[21] (::Zygote.var"#347#back#176"{Zygote.var"#174#175"{DiffEqBase.var"#512#back#457"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#114"{Tsit5,InterpolatingAdjoint{0,true,Val{:central},Bool},CuArrays.CuArray{Float32,1,Nothing},Array{Float32,1},Tuple{},Colon}},Tuple{NTuple{6,Nothing},Tuple{Nothing}}}})(::CuArrays.CuArray{Float32,2,Nothing}) at /home/manu/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[22] #solve#443 at /home/manu/.julia/packages/DiffEqBase/KnYSY/src/solve.jl:69 [inlined]
[23] (::typeof(∂(#solve#443)))(::CuArrays.CuArray{Float32,2,Nothing}) at /home/manu/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
[24] (::Zygote.var"#174#175"{typeof(∂(#solve#443)),Tuple{NTuple{6,Nothing},Tuple{Nothing}}})(::CuArrays.CuArray{Float32,2,Nothing}) at /home/manu/.julia/packages/Zygote/YeCEW/src/lib/lib.jl:182
[25] (::Zygote.var"#347#back#176"{Zygote.var"#174#175"{typeof(∂(#solve#443)),Tuple{NTuple{6,Nothing},Tuple{Nothing}}}})(::CuArrays.CuArray{Float32,2,Nothing}) at /home/manu/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[26] (::typeof(∂(solve##kw)))(::CuArrays.CuArray{Float32,2,Nothing}) at /home/manu/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
[27] predict_ODE_solve at /home/manu/Documents/Work/NormalFormAE/test/NLRAN_github.jl:47 [inlined]
[28] (::typeof(∂(predict_ODE_solve)))(::CuArrays.CuArray{Float32,2,Nothing}) at /home/manu/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
[29] #41 at ./none:0 [inlined]
[30] (::typeof(∂(λ)))(::CuArrays.CuArray{Float32,2,Nothing}) at /home/manu/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
[31] #1187 at /home/manu/.julia/packages/Zygote/YeCEW/src/lib/array.jl:172 [inlined]
[32] #3 at ./generator.jl:36 [inlined]
[33] iterate at ./generator.jl:47 [inlined]
[34] collect(::Base.Generator{Base.Iterators.Zip{Tuple{Array{typeof(∂(λ)),1},NTuple{10,CuArrays.CuArray{Float32,2,Nothing}}}},Base.var"#3#4"{Zygote.var"#1187#1191"}}) at ./array.jl:665
[35] map at ./abstractarray.jl:2154 [inlined]
[36] (::Zygote.var"#1186#1190"{Array{typeof(∂(λ)),1}})(::NTuple{10,CuArrays.CuArray{Float32,2,Nothing}}) at /home/manu/.julia/packages/Zygote/YeCEW/src/lib/array.jl:172
[37] (::Zygote.var"#1194#1195"{Zygote.var"#1186#1190"{Array{typeof(∂(λ)),1}}})(::NTuple{10,CuArrays.CuArray{Float32,2,Nothing}}) at /home/manu/.julia/packages/Zygote/YeCEW/src/lib/array.jl:187
[38] loss_func at /home/manu/Documents/Work/NormalFormAE/test/NLRAN_github.jl:54 [inlined]
[39] (::typeof(∂(loss_func)))(::Float64) at /home/manu/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
[40] #16 at /home/manu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:85 [inlined]
[41] (::typeof(∂(λ)))(::Float64) at /home/manu/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
[42] (::Zygote.var"#49#50"{Params,Zygote.Context,typeof(∂(λ))})(::Float64) at /home/manu/.julia/packages/Zygote/YeCEW/src/compiler/interface.jl:179
[43] gradient(::Function, ::Params) at /home/manu/.julia/packages/Zygote/YeCEW/src/compiler/interface.jl:55
[44] macro expansion at /home/manu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:84 [inlined]
[45] macro expansion at /home/manu/.julia/packages/Juno/tLMZd/src/progress.jl:134 [inlined]
[46] train!(::typeof(loss_func), ::Params, ::Array{CuArrays.CuArray{Float32,2,Nothing},1}, ::ADAM; cb::Flux.Optimise.var"#18#26") at /home/manu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:81
[47] train!(::Function, ::Params, ::Array{CuArrays.CuArray{Float32,2,Nothing},1}, ::ADAM) at /home/manu/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:79
[48] top-level scope at /home/manu/Documents/Work/NormalFormAE/test/NLRAN_github.jl:69
[49] include(::String) at ./client.jl:439
in expression starting at /home/manu/Documents/Work/NormalFormAE/test/NLRAN_github.jl:66
here is the code:
using Pkg
Pkg.activate(".")
using DiffEqFlux, DiffEqSensitivity, Flux, OrdinaryDiffEq, Zygote, Test, DifferentialEquations
# Training hyperparameters
nBatchsize = 10
nEpochs = 10
tspan = 5.0
tsize = 100
nbatches = div(tsize,nBatchsize)
# ODE solve
function lotka_volterra(du,u,p,t)
dx = Zygote.Buffer(u,size(u)[1])
dx[1] = u[1]*(p[1]-p[2]*u[2])
dx[2] = u[2]*(p[3]*u[1]-p[4])
du .= copy(dx)
nothing
end
# Define parameters and initial conditions for data
p = Float32[2.2, 1.0, 2.0, 0.4]
u0 = Float32[0.01, 0.01]
t = range(0.0,tspan,length=tsize)
# Define ODE problem and generate data
prob = ODEProblem(lotka_volterra,u0,(0.0,tspan),p)
yy = Array(solve(prob,saveat=t))
y_original = Array(solve(prob,saveat=t))
yy = yy .+ yy*(0.01.*rand(size(yy)[2],size(yy)[2])) # Creates noisy, translated data
data = Float32.(yy) |> gpu
# Define autoencoder networks
NN_encode = Chain(Dense(2,10),Dense(10,10),Dense(10,2)) |> gpu
NN_decode = Chain(Dense(2,10),Dense(10,10),Dense(10,2)) |> gpu
# Define new ODE problem for "batch" evolution
t_batch = range(0.0f0,Float32(tspan/nbatches),length = nBatchsize)
prob2 = ODEProblem(lotka_volterra,u0,(0.0f0,Float32(tspan/nbatches)),p)
# ODE solve to be used for training
function predict_ODE_solve(x)
return Array(solve(prob2,Tsit5(),u0=x,saveat=t_batch,reltol=1e-4))
end
function loss_func(data_)
enc_ = NN_encode(data_)
# Solve ODE using initial values from multiple points in enc_.
# Note: reduce(hcat,[..]) gives a mutating arrays error
enc_ODE_solve = hcat([predict_ODE_solve(enc_[:,(i-1)*nBatchsize+1]) for i in 1:nbatches]...) |> gpu
dec_1 = NN_decode(enc_ODE_solve)
dec_2 = NN_decode(enc_)
loss = Flux.mse(data_,dec_1) + Flux.mse(data_,dec_2) + 0.001*Flux.mse(enc_,enc_ODE_solve)
args_["loss"] = loss
return loss
end
opt = ADAM(0.001)
loss_func(data) # This works
for ep in 1:nEpochs
global args_
@info "Epoch $ep"
Flux.train!(loss_func, Flux.params(NN_encode,NN_decode), [(data)], opt)
loss_ = args_["loss"]
println("loss: $(loss_)")
endand here is the current status of packages
[c7e460c6] ArgParse v1.1.0
[fbb218c0] BSON v0.2.6
[6e4b80f9] BenchmarkTools v0.5.0
[3895d2a7] CUDAapi v4.0.0
[c5f51814] CUDAdrv v6.3.0
[be33ccc6] CUDAnative v3.1.0
[3a865a2d] CuArrays v2.2.0
[31a5f54b] Debugger v0.6.4
[aae7a2af] DiffEqFlux v1.12.0
[41bf760c] DiffEqSensitivity v6.19.0
[0c46a032] DifferentialEquations v6.14.0
[31c24e10] Distributions v0.23.3
[5789e2e9] FileIO v1.3.0
[587475ba] Flux v0.10.4
[0c68f7d7] GPUArrays v3.4.1
[033835bb] JLD2 v0.1.13
[429524aa] Optim v0.21.0
[1dea7af3] OrdinaryDiffEq v5.39.1
[91a5bcdd] Plots v1.3.6
[8d666b04] PolyChaos v0.2.1
[ee283ea6] Rebugger v0.3.3
[295af30f] Revise v2.7.1
[9f7883ad] Tracker v0.2.6
[e88e6eb3] Zygote v0.4.20
[9a3f8284] Random Is there a way to push this to the GPU efficiently? Any help would be appreciated. Thanks for the fantastic work on this package! :)
Metadata
Metadata
Assignees
Labels
No labels