Skip to content

Commit

Permalink
Merge 2376482 into 8fca6dd
Browse files Browse the repository at this point in the history
  • Loading branch information
AlCap23 committed Mar 25, 2020
2 parents 8fca6dd + 2376482 commit 574d17f
Show file tree
Hide file tree
Showing 8 changed files with 19 additions and 12 deletions.
7 changes: 4 additions & 3 deletions examples/SInDy_Examples.jl
Expand Up @@ -64,8 +64,9 @@ println(Ψ)

# Vary the sparsity threshold -> gives better results
λs = exp10.(-5:0.1:-1)
opt = ADMM(1e-2, 1.0)
Ψ = SInDy(sol[:,1:30], DX[:, 1:30], basis, λs, maxiter = 20, opt = opt)
# Use SR3 with high relaxation (allows the solution to diverge from LTSQ) and high iterations
opt = SR3(1e-2, 20.0)
Ψ = SInDy(sol[:,1:10], DX[:, 1:10], basis, λs, maxiter = 10000, opt = opt)
println(Ψ)

# Transform into ODE System
Expand All @@ -80,4 +81,4 @@ sol_ = solve(estimator, Tsit5(), saveat = sol.t)
scatter(sol[:,:]')
plot!(sol_[:,:]')
plot(sol.t, abs.(sol-sol_)')
norm(sol[:,:]-sol_[:,:]) # ≈ 1.89e-13
norm(sol[:,:]-sol_[:,:], 2)
2 changes: 1 addition & 1 deletion src/DataDrivenDiffEq.jl
Expand Up @@ -12,7 +12,7 @@ abstract type abstractKoopmanOperator end;

include("./optimisers/Optimise.jl")
using .Optimise
export set_threshold!
export set_threshold!, set_threshold
export STRRidge, ADMM, SR3
export ADM

Expand Down
2 changes: 1 addition & 1 deletion src/optimisers/Optimise.jl
Expand Up @@ -14,7 +14,7 @@ include("./sr3.jl")
#Nullspace for implicit sindy
include("./adm.jl")

export init, init!, fit!, set_threshold!
export init, init!, fit!, set_threshold!, set_threshold
export STRRidge, ADMM, SR3
export ADM

Expand Down
6 changes: 4 additions & 2 deletions src/optimisers/admm.jl
Expand Up @@ -10,6 +10,8 @@ function set_threshold!(opt::ADMM, threshold)
opt.λ = threshold*opt.ρ
end

get_threshold(opt::ADMM) = opt.λ/opt.ρ

init(o::ADMM, A::AbstractArray, Y::AbstractArray) = A \ Y
init!(X::AbstractArray, o::ADMM, A::AbstractArray, Y::AbstractArray) = ldiv!(X, qr(A, Val(true)), Y)

Expand All @@ -18,7 +20,7 @@ init!(X::AbstractArray, o::ADMM, A::AbstractArray, Y::AbstractArray) = ldiv!(X,
function fit!(X::AbstractArray, A::AbstractArray, Y::AbstractArray, opt::ADMM; maxiter::Int64 = 100)
n, m = size(A)

g = NormL1(opt.λ/opt.ρ)
g = NormL1(get_threshold(opt))

= deepcopy(X)
= zero(X)
Expand All @@ -32,5 +34,5 @@ function fit!(X::AbstractArray, A::AbstractArray, Y::AbstractArray, opt::ADMM; m
ŷ .=+ opt.ρ*(x̂ - X)
end

X[abs.(X) .< opt.λ/opt.ρ] .= zero(eltype(X))
X[abs.(X) .< get_threshold(opt)] .= zero(eltype(X))
end
7 changes: 4 additions & 3 deletions src/optimisers/sr3.jl
Expand Up @@ -14,16 +14,17 @@ function SR3(λ = 1e-1, ν = 1.0)
end

function set_threshold!(opt::SR3, threshold)
opt.λ = threshold
opt.λ = threshold^2*opt.ν /2
return
end

get_threshold(opt::SR3) = sqrt(2*opt.λ/opt.ν)

init(o::SR3, A::AbstractArray, Y::AbstractArray) = A \ Y
init!(X::AbstractArray, o::SR3, A::AbstractArray, Y::AbstractArray) = ldiv!(X, qr(A, Val(true)), Y)

function fit!(X::AbstractArray, A::AbstractArray, Y::AbstractArray, opt::SR3; maxiter::Int64 = 10)
f = opt.R(opt.λ)
f = opt.R(get_threshold(opt))

n, m = size(A)
W = copy(X)
Expand All @@ -41,6 +42,6 @@ function fit!(X::AbstractArray, A::AbstractArray, Y::AbstractArray, opt::SR3; ma
# This is the effective threshold of the SR3 algorithm
# See Unified Framework paper supplementary material S1
#η = sqrt(2*opt.λ*opt.ν)
X[abs.(X) .< opt.λ] .= zero(eltype(X))
X[abs.(X) .< get_threshold(opt)] .= zero(eltype(X))
return
end
4 changes: 3 additions & 1 deletion src/optimisers/strridge.jl
Expand Up @@ -12,6 +12,8 @@ function set_threshold!(opt::STRRidge, threshold)
opt.λ = threshold
end

get_threshold(opt::STRRidge) = opt.λ

init(o::STRRidge, A::AbstractArray, Y::AbstractArray) = A \ Y
init!(X::AbstractArray, o::STRRidge, A::AbstractArray, Y::AbstractArray) = ldiv!(X, qr(A, Val(true)), Y)

Expand All @@ -27,5 +29,5 @@ function fit!(X::AbstractArray, A::AbstractArray, Y::AbstractArray, opt::STRRidg
end
end

X[abs.(X) .< opt.λ] .= zero(eltype(X))
X[abs.(X) .< get_threshold(opt)] .= zero(eltype(X))
end
2 changes: 1 addition & 1 deletion src/sindy.jl
Expand Up @@ -156,7 +156,7 @@ function SInDy(X::AbstractArray{S, 2}, Ẋ::AbstractArray{S, 2}, Ψ::Basis, thre

sparse_regression!(ξ, θ, Ẋ, maxiter, opt, false, false)

[x[j, i, :] = [norm(xi, 0)/length(Ψ); norm(view(Ẋ , i, :) - θ'*xi, 2)] for (i, xi) in enumerate(eachcol(ξ))]
[x[j, i, :] = [norm(xi, 0); norm(view(Ẋ , i, :) - θ'*xi, 2)] for (i, xi) in enumerate(eachcol(ξ))]

normalize ? rescale_xi!(scales, ξ) : nothing
Ξ[j, :, :] = ξ[:, :]'
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Expand Up @@ -226,6 +226,7 @@ end
@test norm(sol[:,:] - sol_3[:,:], 2) < 1e-1

# Now use the threshold adaptation
opt = SR3(1e-2, 20.0)
λs = exp10.(-5:0.1:-1)
Ψ = SInDy(sol[:,:], DX[:, :], basis, λs, maxiter = 20, opt = opt)
estimator = ODEProblem(dynamics(Ψ), u0, tspan, [])
Expand Down

0 comments on commit 574d17f

Please sign in to comment.