Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implicit differential equation solver with internal autodiff of the Jacobian #31

Open
000Justin000 opened this issue Feb 28, 2019 · 1 comment

Comments

@000Justin000
Copy link

@000Justin000 000Justin000 commented Feb 28, 2019

Hello,

I keep getting the WARNING: Instability detected. Aborting error. when I am doing training. I think this might because the differential equation I set up is stiff, and I need an implicit solver for that. However, when I replace the default Tsit5() solver with Rosenbrock23(), I get a stack-overflow error. Would you help me with this?

Thanks!
Junteng

using DifferentialEquations
using Flux
using DiffEqFlux
using Plots

u0 = Float32[2.; 0.]
datasize = 30
tspan = (0.0f0,1.5f0)

function trueODEfunc(du,u,p,t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u.^3)'true_A)'
end
t = range(tspan[1],tspan[2],length=datasize)
prob = ODEProblem(trueODEfunc,u0,tspan)
ode_data = Array(solve(prob,Tsit5(),saveat=t))

dudt = Chain(x -> x.^3,
             Dense(2,50,tanh),
             Dense(50,2))
n_ode(x) = neural_ode(dudt,x,tspan,Rosenbrock23(),saveat=t,reltol=1e-7,abstol=1e-9)

function predict_n_ode()
  n_ode(u0)
end
loss_n_ode() = sum(abs2,ode_data .- predict_n_ode())

data = Iterators.repeated((), 1000)
opt = ADAM(0.1)
cb = function () #callback function to observe training
  display(loss_n_ode())
  # plot current prediction against data
  cur_pred = Flux.data(predict_n_ode())
  pl = scatter(t,ode_data[1,:],label="data")
  scatter!(pl,t,cur_pred[1,:],label="prediction")
  display(plot(pl))
end

# Display the ODE with the initial parameter values.
cb()

ps = Flux.params(dudt)
Flux.train!(loss_n_ode, ps, data, opt, cb = cb)
@ChrisRackauckas

This comment has been minimized.

Copy link
Member

@ChrisRackauckas ChrisRackauckas commented Mar 3, 2019

Rosenbrock23(autodiff=false) works. For example:

using DifferentialEquations
using Flux
using DiffEqFlux
using Plots

u0 = Float32[2.; 0.]
datasize = 30
tspan = (0.0f0,1.5f0)

function trueODEfunc(du,u,p,t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u.^3)'true_A)'
end
t = range(tspan[1],tspan[2],length=datasize)
prob = ODEProblem(trueODEfunc,u0,tspan)
ode_data = Array(solve(prob,Tsit5(),saveat=t))

dudt = Chain(x -> x.^3,
             Dense(2,50,tanh),
             Dense(50,2))
n_ode(x) = neural_ode(dudt,x,tspan,Rosenbrock23(autodiff=false),saveat=t,reltol=1e-7,abstol=1e-9)

function predict_n_ode()
  n_ode(u0)
end
loss_n_ode() = sum(abs2,ode_data .- predict_n_ode())

data = Iterators.repeated((), 1000)
opt = ADAM(0.1)
cb = function () #callback function to observe training
  display(loss_n_ode())
  # plot current prediction against data
  cur_pred = Flux.data(predict_n_ode())
  pl = scatter(t,ode_data[1,:],label="data")
  scatter!(pl,t,cur_pred[1,:],label="prediction")
  display(plot(pl))
end

# Display the ODE with the initial parameter values.
cb()

ps = Flux.params(dudt)
Flux.train!(loss_n_ode, ps, data, opt, cb = cb)

So at least that's a workaround for now. Nesting the ADs seems to be a point that needs some work, maybe @YingboMa or @MikeInnes can pitch in.

@ChrisRackauckas ChrisRackauckas changed the title Implicit differential equation solver? Implicit differential equation solver with internal autodiff of the Jacobian Apr 9, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
2 participants
You can’t perform that action at this time.