-
-
Notifications
You must be signed in to change notification settings - Fork 161
Open
Description
Hi,
Is there a way to combine the ensembleproblem command to solve SDEs in combination with Flux Neural Nets. Right now, when I try to combine them, I get the error when I try to train the Neural Net. "Info: Epoch 1
└ @ Main /home/ec2-user/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121
Mutating arrays is not supported
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] (::Zygote.var"#459#460")(::Nothing) at /home/ec2-user/.julia/packages/Zygote/1GXzF/src/lib/array.jl:67
[3] (::Zygote.var"#1009#back#461"{Zygote.var"#459#460"})(::Nothing) at /home/ec2-user/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[4] materialize! at ./broadcast.jl:826 [inlined]"
Here is the code:
### Packages
import Pkg; Pkg.build("DifferentialEquations")
import Pkg; Pkg.add("DifferentialEquations")
using DifferentialEquations
using Flux
using DifferentialEquations.EnsembleAnalysis
using IterTools: ncycle
using Flux: @epochs
using DiffEqSensitivity
# Input parameters
T = 1.0f0
tspan = (0.0f0,T)
m=Float32[1.0]
v=Float32[1.0]
true_y0 = exp(T)*m[1] + (exp(T)-1)*( sqrt( 2*v[1]/(1-exp(-2*T)) )+1 )
true_Z = sqrt(2*v[1]/(1-exp(-2*T)))
u0 = Float32[true_y0 + 0.25] # initial guess
Z = Float32[true_Z + 0.25]
ps = Flux.params(Z,u0) # or Flux.params(p1,u0) if you want to also optimize over all u0 parameters
function drift_nn(u,p,t)
[-(u[1] + p[1] + 1.0f0)]
end
function stoch_nn(u,p,t)
[p[1]]
end
probSDE_nn = SDEProblem{false}(drift_nn,stoch_nn,u0,tspan)
ensemble_prob = EnsembleProblem(probSDE_nn)
function yT_nn(p)
sim = solve(ensemble_prob,EM(),trajectories=1000,p = p[1][1],u0=p[2],dt = 0.005)
end
function loss_nn(m,v)
m_val,v_val = timepoint_meanvar(yT_nn(ps),1.0)
loss = (m_val[1]-m[1])^2 + (v_val[1]-v[1])^2
end
yT_nn(ps)
cb2 = function() #callback function to observe training
display(loss_nn(m,v))
end
cb2()
Z_pre = deepcopy(Z)
u0_pre = deepcopy(u0)
# train model on the same data num_cycle times
num_cycles = 1
data = ncycle([(m, v)], num_cycles)
opt = ADAM(0.01)
@epochs 5 Flux.train!(loss_nn, ps , data, opt, cb=cb2);
Thank you!
Metadata
Metadata
Assignees
Labels
No labels