Skip to content

Commit

Permalink
Merge 765ae96 into a3ec0cf
Browse files Browse the repository at this point in the history
  • Loading branch information
AlCap23 committed Feb 20, 2020
2 parents a3ec0cf + 765ae96 commit ab4f845
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 15 deletions.
51 changes: 46 additions & 5 deletions src/sindy.jl
Expand Up @@ -33,32 +33,65 @@ function simplified_matvec(Ξ::AbstractArray{T,1}, basis) where T <: Real
eq
end


function normalize_theta!(scales::AbstractArray, θ::AbstractArray)
@assert length(scales) == size(θ, 1)
@inbounds for (i, ti) in enumerate(eachrow(θ))
scales[i] = norm(ti, 2)
normalize!(ti, 2)
end
return
end

function rescale_xi!(scales::AbstractArray, Ξ::AbstractArray)
@inbounds for (si, ti) in zip(scales, eachrow(Ξ))
ti .= ti / si
end
return
end


# One Variable on multiple derivatives
function SInDy(X::AbstractArray{S, 1}, Ẋ::AbstractArray, Ψ::Basis; kwargs...) where S <: Number
return SInDy(X', Ẋ, Ψ; kwargs...)
end

# Multiple on one
function SInDy(X::AbstractArray{S, 2}, Ẋ::AbstractArray{S, 1}, Ψ::Basis; kwargs...) where S <: Number
return SInDy(X, Ẋ', Ψ; kwargs...)
end

# Returns a basis for the differential state
function SInDy(X::AbstractArray{S, 2}, Ẋ::AbstractArray{S, 2}, Ψ::Basis; p::AbstractArray = [], maxiter::Int64 = 10, opt::T = Optimise.STRRidge()) where {T <: Optimise.AbstractOptimiser, S <: Number}
# General
function SInDy(X::AbstractArray{S, 2}, Ẋ::AbstractArray{S, 2}, Ψ::Basis; p::AbstractArray = [], maxiter::Int64 = 10, opt::T = Optimise.STRRidge(), denoise::Bool = false, normalize::Bool = true) where {T <: Optimise.AbstractOptimiser, S <: Number}
@assert size(X)[end] == size(Ẋ)[end]
nx, nm = size(X)
ny, nm = size(Ẋ)

Ξ = zeros(eltype(X), length(Ψ), ny)
scales = ones(eltype(X), length(Ψ))
θ = Ψ(X, p = p)

denoise ? optimal_shrinkage!') : nothing
normalize ? normalize_theta!(scales, θ) : nothing
# Initial estimate
Optimise.init!(Ξ, opt, θ', Ẋ')
Optimise.fit!(Ξ, θ', Ẋ', opt, maxiter = maxiter)

normalize ? rescale_xi!(scales, Ξ) : nothing

return Basis(simplified_matvec(Ξ, Ψ.basis), variables(Ψ), parameters = p)
end


function SInDy(X::AbstractArray{S, 1}, Ẋ::AbstractArray, Ψ::Basis, thresholds::AbstractArray; kwargs...) where S <: Number
return SInDy(X', Ẋ, Ψ, thresholds; kwargs...)
end

function SInDy(X::AbstractArray{S, 2}, Ẋ::AbstractArray{S, 1}, Ψ::Basis, thresholds::AbstractArray; kwargs...) where S <: Number
return SInDy(X, Ẋ', Ψ, thresholds; kwargs...)
end

# Returns an array of basis for all values of lambda
function SInDy(X::AbstractArray{S, 2}, Ẋ::AbstractArray{S, 2}, Ψ::Basis, thresholds::AbstractArray ; p::AbstractArray = [], maxiter::Int64 = 10, opt::T = Optimise.STRRidge()) where {T <: Optimise.AbstractOptimiser, S <: Number}
function SInDy(X::AbstractArray{S, 2}, Ẋ::AbstractArray{S, 2}, Ψ::Basis, thresholds::AbstractArray ; p::AbstractArray = [], maxiter::Int64 = 10, opt::T = Optimise.STRRidge(),denoise::Bool = false, normalize::Bool = true) where {T <: Optimise.AbstractOptimiser, S <: Number}
@assert size(X)[end] == size(Ẋ)[end]
nx, nm = size(X)
ny, nm = size(Ẋ)
Expand All @@ -70,13 +103,21 @@ function SInDy(X::AbstractArray{S, 2}, Ẋ::AbstractArray{S, 2}, Ψ::Basis, thre
Ξ = zeros(eltype(X), length(thresholds), ny, length(Ψ))
x = zeros(eltype(X), length(thresholds), ny, 2)
pareto = zeros(eltype(X), ny, length(thresholds))
scales = ones(eltype(X), length(Ψ))

denoise ? optimal_shrinkage!') : nothing
normalize ? normalize_theta!(scales, θ) : nothing

@inbounds for (j, threshold) in enumerate(thresholds)
set_threshold!(opt, threshold)

Optimise.init!(ξ, opt, θ', Ẋ')
Optimise.fit!(ξ, θ', Ẋ', opt, maxiter = maxiter)
Ξ[j, :, :] = ξ[:, :]'

[x[j, i, :] = [norm(xi, 0)/length(Ψ); norm(view(Ẋ , i, :) - θ'*xi, 2)] for (i, xi) in enumerate(eachcol(ξ))]

normalize ? rescale_xi!(scales, ξ) : nothing
Ξ[j, :, :] = ξ[:, :]'
end

# Create the evaluation
Expand Down
30 changes: 20 additions & 10 deletions test/runtests.jl
Expand Up @@ -157,9 +157,10 @@ end
end

u0 = [0.99π; -1.0]
tspan = (0.0, 10.0)
tspan = (0.0, 20.0)
dt = 0.3
prob = ODEProblem(pendulum, u0, tspan)
sol = solve(prob, Tsit5(), saveat = 0.1)
sol = solve(prob, Tsit5(), saveat = dt)


# Create the differential data
Expand Down Expand Up @@ -188,20 +189,20 @@ end
opt = STRRidge(1e-2)
basis = Basis(h, u, parameters = [])
Ψ = SInDy(sol[:,:], DX, basis, opt = opt, maxiter = 2000)
@test_nowarn set_threshold!(opt, 0.1)
@test_nowarn set_threshold!(opt, 1e-2)
@test size(Ψ)[1] == 2

# Simulate
estimator = ODEProblem(dynamics(Ψ), u0, tspan, [])
sol_ = solve(estimator,Tsit5(), saveat = 0.1)
sol_ = solve(estimator,Tsit5(), saveat = dt)
@test sol[:,:] sol_[:,:]

opt = ADMM(1e-2, 0.7)
Ψ = SInDy(sol[:,:], DX, basis, maxiter = 5000, opt = opt)
@test_nowarn set_threshold!(opt, 0.1)
@test_nowarn set_threshold!(opt, 1e-2)
# Simulate
estimator = ODEProblem(dynamics(Ψ), u0, tspan, [])
sol_2 = solve(estimator,Tsit5(), saveat = 0.1)
sol_2 = solve(estimator,Tsit5(), saveat = dt)
@test norm(sol[:,:] - sol_2[:,:], 2) < 2e-1
#@test sol[:,:] ≈ sol_2[:,:]

Expand All @@ -211,20 +212,29 @@ end

# Simulate
estimator = ODEProblem(dynamics(Ψ), u0, tspan, [])
sol_3 = solve(estimator,Tsit5(), saveat = 0.1)
sol_3 = solve(estimator,Tsit5(), saveat = dt)
@test norm(sol[:,:] - sol_3[:,:], 2) < 1e-1

# Now use the threshold adaptation
λs = exp10.(-5:0.1:-1)
Ψ = SInDy(sol[:,:], DX[:, :], basis, λs, maxiter = 20, opt = opt)
estimator = ODEProblem(dynamics(Ψ), u0, tspan, [])
sol_4 = solve(estimator,Tsit5(), saveat = 0.1)
sol_4 = solve(estimator,Tsit5(), saveat = dt)
@test norm(sol[:,:] - sol_4[:,:], 2) < 1e-1

# Check for errors
# TODO infer the type of array and automatically push this
@test_nowarn SInDy(sol[:,:], DX[1,:], basis, λs, maxiter = 1, opt = opt)
@test_nowarn SInDy(sol[:, :], DX[1, :], basis, maxiter = 1, opt = opt)
@test_nowarn SInDy(sol[:, :], DX[1, :], basis, λs, maxiter = 1, opt = opt, denoise = true, normalize = true)

# Check with noise
X = sol[:, :] + 1e-3*randn(size(sol[:,:])...)
set_threshold!(opt, 3.5e-1)
Ψ = SInDy(X, DX, basis, maxiter = 10000, opt = opt, denoise = true, normalize = true)

estimator = ODEProblem(dynamics(Ψ), u0, tspan, [])
sol_4 = solve(estimator,Tsit5(), saveat = dt)
@test norm(sol[:,:] - sol_4[:,:], 2) < 5e-1

end

@testset "ISInDy" begin
Expand Down

0 comments on commit ab4f845

Please sign in to comment.