Skip to content

Warning: dt <= dtmin. Aborting. There is either an error in your model specification or the true solution is unstable. #64

@manyfeatures

Description

@manyfeatures

I also described the problem here
I have troubles with fitting simple data
The code:

using DiffEqFlux
using OrdinaryDiffEq, Flux, Optim, Plots
using Flux, OrdinaryDiffEq
using Zygote 
using DiffEqSensitivity # for ZygteVJP?


abstract type NeuralDELayer <: Function end
basic_tgrad(u,p,t) = zero(u)

struct LTC{M,P,RE,T,TA,AB,A,K} <: NeuralDELayer
    model::M
    p::P # weights
    p_len::Int # for assignment 
    re::RE
    tspan::T
    τ::TA # weights
    τ_len::Int #
    A::AB # weights
    A_len::Int
    args::A
    kwargs::K

    function LTC(model,tspan, τ, A, args...;p = nothing,kwargs...)
        _p,re = Flux.destructure(model) # is it like [p;τ;A] already? 
        if p === nothing
            p = _p
        end
        new{typeof(model),typeof(p),typeof(re),
            typeof(tspan), typeof(τ), typeof(A),
            typeof(args),typeof(kwargs)}(
            model,p, length(p), re,tspan,τ,length(τ),A,length(A),args,kwargs)
    end
end

function (n::LTC)(x)
    function dudt_(u, p, t)
       p_ = @view p[1:n.p_len]
       τ_ = @view p[n.p_len+1:n.p_len+n.τ_len]
       τ_ = Flux.softplus.(τ_) # to ensure τ>=0
       A_ = @view p[n.p_len+n.τ_len+1:end]
       h = -(1 ./τ_+ n.re(p_)(u)) .* u +  n.re(p_)(u) .* A_
    end
    ff = ODEFunction{false}(dudt_,tgrad=basic_tgrad) 
    prob = ODEProblem{false}(ff,x,getfield(n,:tspan), [n.p; n.τ; n.A]) # inital conditions and tspan, etc
    sense = InterpolatingAdjoint(autojacvec=ZygoteVJP()) 
    solve(prob,n.args...;sense=sense,n.kwargs...)
end

Flux.trainable(m::LTC) = (m.p, m.τ, m.A)


# Example 
u0 = Float32[2.; 0.]
datasize = 30
tspan = (0.0f0, 3.0f0)

function trueODEfunc(du,u,p,t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u.^3)'true_A)'
end
t = range(tspan[1],tspan[2],length=datasize)
prob = ODEProblem(trueODEfunc,u0,tspan)
ode_data = Array(solve(prob,Tsit5(),saveat=t))
τ = rand(2) # ~size(out) 
A = zeros(2) # ~size(out)

dudt = Chain(x -> x.^3,
             Dense(2,10,tanh),
             Dense(10,2))
ps = Flux.params(n_ode)

pred = n_ode(u0) # Get the prediction using the correct initial condition
scatter(t,ode_data[1,:],label="data")
scatter!(t,pred[1,:],label="prediction")

function predict_n_ode()
  n_ode(u0)
end
loss_n_ode() = sum(abs2,ode_data .- predict_n_ode())


data = Iterators.repeated((), 1000)
opt = ADAM(0.01)
cb = function () #callback function to observe training
    nothing
end
Flux.train!(loss_n_ode, ps, data, opt, cb = cb)

The warning:

┌ Warning: dt <= dtmin. Aborting. There is either an error in your model specification or the true solution is unstable.
└ @ SciMLBase /home/solar/.julia/packages/SciMLBase/HbD6U/src/integrator_interface.jl:345
┌ Warning: dt <= dtmin. Aborting. There is either an error in your model specification or the true solution is unstable.
└ @ SciMLBase /home/solar/.julia/packages/SciMLBase/HbD6U/src/integrator_interface.jl:345

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions