Skip to content

Commit

Permalink
Just Float64
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed May 21, 2024
1 parent b38fd72 commit 72631ea
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions docs/src/examples/ode/exogenous_input.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ using OrdinaryDiffEq, Lux, ComponentArrays, Optimization,
OptimizationPolyalgorithms, OptimizationOptimisers, Plots, Random
rng = Random.default_rng()
tspan = (0.1f0, Float32(10.0))
tspan = (0.1, 10.0)
tsteps = range(tspan[1], tspan[2], length = 100)
t_vec = collect(tsteps)
ex = vec(ones(Float32, length(tsteps), 1))
ex = vec(ones(Float64, length(tsteps), 1))
f(x) = (atan(8.0 * x - 4.0) + atan(4.0)) / (2.0 * atan(4.0))
function hammerstein_system(u)
Expand All @@ -59,13 +59,13 @@ function hammerstein_system(u)
return y
end
y = Float32.(hammerstein_system(ex))
y = hammerstein_system(ex)
plot(collect(tsteps), y, ticks = :native)
nn_model = Lux.Chain(Lux.Dense(2, 8, tanh), Lux.Dense(8, 1))
p_model, st = Lux.setup(rng, nn_model)
u0 = Float32.([0.0])
u0 = Float64.([0.0])
function dudt(u, p, t)
global st
Expand All @@ -89,7 +89,7 @@ end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentArray(p_model))
optprob = Optimization.OptimizationProblem(optf, ComponentArray{Float64}(p_model))
res0 = Optimization.solve(optprob, PolyOpt(), maxiters = 100)
Expand Down

0 comments on commit 72631ea

Please sign in to comment.