Skip to content

EnsembleProblem with timepoint_meanvar errors due to internal mutation #446

@charles-r-smith

Description

@charles-r-smith

Hi,

Is there a way to combine the ensembleproblem command to solve SDEs in combination with Flux Neural Nets. Right now, when I try to combine them, I get the error when I try to train the Neural Net. "Info: Epoch 1
└ @ Main /home/ec2-user/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121
Mutating arrays is not supported

Stacktrace:
[1] error(::String) at ./error.jl:33
[2] (::Zygote.var"#459#460")(::Nothing) at /home/ec2-user/.julia/packages/Zygote/1GXzF/src/lib/array.jl:67
[3] (::Zygote.var"#1009#back#461"{Zygote.var"#459#460"})(::Nothing) at /home/ec2-user/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[4] materialize! at ./broadcast.jl:826 [inlined]"

Here is the code:


### Packages
import Pkg; Pkg.build("DifferentialEquations")
import Pkg; Pkg.add("DifferentialEquations")

using DifferentialEquations
using Flux
using DifferentialEquations.EnsembleAnalysis
using IterTools: ncycle
using Flux: @epochs
using DiffEqSensitivity

# Input parameters
T = 1.0f0
tspan = (0.0f0,T)
m=Float32[1.0]
v=Float32[1.0]
true_y0 = exp(T)*m[1] + (exp(T)-1)*( sqrt(  2*v[1]/(1-exp(-2*T)) )+1 )
true_Z = sqrt(2*v[1]/(1-exp(-2*T)))
u0 = Float32[true_y0 + 0.25] # initial guess
Z = Float32[true_Z + 0.25]
ps = Flux.params(Z,u0) # or Flux.params(p1,u0) if you want to also optimize over all u0 parameters

function drift_nn(u,p,t)
    [-(u[1] + p[1] + 1.0f0)]
end

function stoch_nn(u,p,t)
    [p[1]]
end

probSDE_nn = SDEProblem{false}(drift_nn,stoch_nn,u0,tspan)
ensemble_prob = EnsembleProblem(probSDE_nn)
function yT_nn(p)
    sim = solve(ensemble_prob,EM(),trajectories=1000,p = p[1][1],u0=p[2],dt = 0.005)
end

function  loss_nn(m,v)
    m_val,v_val = timepoint_meanvar(yT_nn(ps),1.0)
    loss = (m_val[1]-m[1])^2 + (v_val[1]-v[1])^2
end

yT_nn(ps)

cb2 = function() #callback function to observe training
  display(loss_nn(m,v))
end

cb2()

Z_pre = deepcopy(Z)
u0_pre = deepcopy(u0)

# train model on the same data num_cycle times
num_cycles = 1
data = ncycle([(m, v)], num_cycles)

opt = ADAM(0.01)
@epochs 5 Flux.train!(loss_nn, ps , data,  opt, cb=cb2);




Thank you!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions