-
-
Notifications
You must be signed in to change notification settings - Fork 80
Description
I’m dealing with an error here that I'd already posted about on discourse and I am reposting here after @ChrisRackauckas asked me to. My apologies for the delay.
It happens when I try to do pullback on a loss function with:
- Rodas5() solver
- Lux neural network specifying the ODE. When I specify the ODE without Lux the gradients are computed no problem.
- Setting sensealg manually to InterpolatingAdjoint(; autojacvec=ZygoteVJP()). If I let the solver decide it automatically, even when using a neural network I get no error.
I import the basic packages and set up the ODE Function and a function to solve it. The solve function takes a flag specifying whether sensealg is to be set manually or automatically and which ODE Function is to be solved. I've run the different permutations of that and only one the neural net + manual sensealg one bugs out.
using DifferentialEquations, SciMLSensitivity, Zygote, PreallocationTools, Lux, Random, ComponentArrays
rng = Random.default_rng()
Random.seed!(rng, 0)
basic_tgrad(u, p, t) = zero(u)
############################ Lux NN #################################################
m = Lux.Dense(2, 1, tanh)
ps, st = Lux.setup(rng, m)
ps = ComponentArray(ps)
m(Float32.([1.f0, 0.f0]), ps, st)
############################### Creating ODE Function ###################################
function f(u, p, t)
du_1 = m(u, p, st)[1]
du_2 = u[2]
return [du_1 ; du_2]
end
mass_matrix = Float32.([1.0 0.0; 0.0 0.0])
f_ = ODEFunction{false}(f, mass_matrix=mass_matrix, tgrad=basic_tgrad)
function g(u, p, t)
du_1 = p * u[1]
du_2 = u[2]
return [du_1; du_2]
end
g_ = ODEFunction{false}(g, mass_matrix=mass_matrix, tgrad=basic_tgrad)
################################# Solve/loss Function ###########################################
function solve_de(p; sense=true, func=f_)
prob = ODEProblem{false}(func, [1.f0, 1.f0], (0.f0, 1.f0), p)
if sense
sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP())
return solve(prob, Rodas5(); saveat=0.1, sensealg=sensealg)
else
return solve(prob, Rodas5(); saveat=0.1)
end
end
function loss(p; sense=true, func=f_)
pred = Array(solve_de(p; sense=sense, func=func))
return sum(abs2, pred)
end
Below we use an ODE Function defined by a Lux Neural Net, set sensealg manually and it bugs out
#Lux NN. Sensealg is manual
(l_1,), back_1 = pullback(p -> loss(p), ps)
back_1(one(l_1)) #Error: Expected Float32, Got ForwardDiff.DualThe below code runs just fine. Both manual sensealg and the neural network seems to be necessary to recreate the bug.
#Lux NN. Sensealg is auto.
(l_2,), back_2 = pullback(p -> loss(p; sense=false), ps)
back_2(one(l_2)) #This works
# No neural net. Sensealg is manual
(l_3,), back_3 = pullback(p -> loss(p; sense=true, func=g_), [2.f0])
back_3(one(l_3)) #This works
# No neural net. Sensealg is auto
(l_4,), back_4 = pullback(p -> loss(p; sense=false, func=g_), [2.f0])
back_4(one(l_4)) #This worksThe stack trace is very long. For readability I've put it up on pastebin and shared the link here. I'm not sure of the etiquette surrounding this so if that's not how it's done then I can certainly post the stack trace here.
Stacktrace of the bug: https://pastebin.com/FaMWfFAJ