-
-
Notifications
You must be signed in to change notification settings - Fork 161
Closed
FluxML/Zygote.jl
#880Description
MWE:
using DiffEqFlux, Flux, Optim, OrdinaryDiffEq, CUDA, DiffEqSensitivity, Plots
u0 = [1.1; 1.1] |> gpu
tspan = (0.0f0,25.0f0)
ann = FastChain(FastDense(2,16,tanh), FastDense(16,16,tanh), FastDense(16,1)) |>gpu
p1 = initial_params(ann) |>gpu
p2 = Float32[0.5,-0.5]
p3 = [p1;p2]
θ = Float32[u0;p3]
function dudt_(u,p,t)
x, y = u
[cpu(ann(gpu(u),p[1:length(p1)]))[1],p[end-1]*y + p[end]*x]
end
prob = ODEProblem{false}(dudt_,u0,tspan,p3)
function predict_adjoint(θ)
gpu(Array(solve(prob,Tsit5(),u0=cpu(θ[1:2]),p=θ[3:end],saveat=0.0:1:25.0,sensealg=InterpolatingAdjoint())))
end
loss_adjoint(θ) = sum(abs2,predict_adjoint(θ)[2,:].-1)
l = loss_adjoint(θ)
cb = function (θ,l)
println(l)
#display(plot(solve(remake(prob,p=Flux.data(p3),u0=Flux.data(u0)),Tsit5(),saveat=0.1),ylim=(0,6)))
return false
end
loss1 = loss_adjoint(θ)
res = DiffEqFlux.sciml_train(loss_adjoint, θ, ADAM(), cb = cb, maxiters=10)This is for the case with heavy scalar nonlinear code + a neural network. We'll need to figure out how to handle the backpass effectively.
Metadata
Metadata
Assignees
Labels
No labels