Skip to content

Zygote Type error when using Rodas5() solver but only when manually setting sensealg #694

@infiniteFun

Description

@infiniteFun

From: https://discourse.julialang.org/t/zygote-type-error-when-using-rodas5-solver-but-only-when-manually-setting-sensealg/84434

I’m dealing with an error here that I'd already posted about on discourse and I am reposting here after @ChrisRackauckas asked me to. My apologies for the delay.

It happens when I try to do pullback on a loss function with:

  • Rodas5() solver
  • Lux neural network specifying the ODE. When I specify the ODE without Lux the gradients are computed no problem.
  • Setting sensealg manually to InterpolatingAdjoint(; autojacvec=ZygoteVJP()). If I let the solver decide it automatically, even when using a neural network I get no error.

I import the basic packages and set up the ODE Function and a function to solve it. The solve function takes a flag specifying whether sensealg is to be set manually or automatically and which ODE Function is to be solved. I've run the different permutations of that and only one the neural net + manual sensealg one bugs out.

using DifferentialEquations, SciMLSensitivity, Zygote, PreallocationTools, Lux, Random, ComponentArrays

rng = Random.default_rng()
Random.seed!(rng, 0)

basic_tgrad(u, p, t) = zero(u)

############################ Lux NN #################################################

m = Lux.Dense(2, 1, tanh)
ps, st = Lux.setup(rng, m)
ps = ComponentArray(ps)

m(Float32.([1.f0, 0.f0]), ps, st)

############################### Creating ODE Function ###################################

function f(u, p, t)
    du_1 = m(u, p, st)[1]
    du_2 = u[2]
    return [du_1 ; du_2]
end

mass_matrix = Float32.([1.0 0.0; 0.0 0.0])

f_ = ODEFunction{false}(f, mass_matrix=mass_matrix, tgrad=basic_tgrad)

function g(u, p, t)
    du_1 = p * u[1]
    du_2 = u[2]
    return [du_1; du_2]
end

g_ = ODEFunction{false}(g, mass_matrix=mass_matrix, tgrad=basic_tgrad)

################################# Solve/loss Function ###########################################

function solve_de(p; sense=true, func=f_)
    prob = ODEProblem{false}(func, [1.f0, 1.f0], (0.f0, 1.f0), p)
    if sense
        sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP())
        return solve(prob, Rodas5(); saveat=0.1, sensealg=sensealg)
    else
        return solve(prob, Rodas5(); saveat=0.1)
    end
end

function loss(p; sense=true, func=f_)
    pred = Array(solve_de(p; sense=sense, func=func))
    return sum(abs2, pred)
end

Below we use an ODE Function defined by a Lux Neural Net, set sensealg manually and it bugs out

#Lux NN. Sensealg is manual
(l_1,), back_1 = pullback(p -> loss(p), ps)
back_1(one(l_1)) #Error: Expected Float32, Got ForwardDiff.Dual

The below code runs just fine. Both manual sensealg and the neural network seems to be necessary to recreate the bug.

#Lux NN. Sensealg is auto.
(l_2,), back_2 = pullback(p -> loss(p; sense=false), ps)
back_2(one(l_2)) #This works

# No neural net. Sensealg is manual
(l_3,), back_3 = pullback(p -> loss(p; sense=true, func=g_), [2.f0])
back_3(one(l_3)) #This works

# No neural net. Sensealg is auto
(l_4,), back_4 = pullback(p -> loss(p; sense=false, func=g_), [2.f0])
back_4(one(l_4)) #This works

The stack trace is very long. For readability I've put it up on pastebin and shared the link here. I'm not sure of the etiquette surrounding this so if that's not how it's done then I can certainly post the stack trace here.

Stacktrace of the bug: https://pastebin.com/FaMWfFAJ

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions