Skip to content

Commit

Permalink
Merge d895fc7 into c2e9901
Browse files Browse the repository at this point in the history
  • Loading branch information
AlCap23 committed Feb 9, 2020
2 parents c2e9901 + d895fc7 commit 407a533
Show file tree
Hide file tree
Showing 12 changed files with 469 additions and 74 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Expand Up @@ -8,12 +8,15 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Compat = "2.2, 3.0"
ModelingToolkit = "1.1.3"
ProximalOperators = "0.10"
QuadGK = "2.3.1"
julia = "1"

[extras]
Expand Down
66 changes: 66 additions & 0 deletions examples/Basis_Creation.jl
@@ -0,0 +1,66 @@
using LinearAlgebra
using DataDrivenDiffEq
using Plots
using ModelingToolkit

# Frist we define a set of variables and parameters
@variables u[1:3]
@parameters w[1:2]

# Now the equations which form our basis
h = [u[1]; u[2]; cos(w[1]*u[2]+w[2]*u[3]); u[3]+u[2]]

# Then we simply create a basis
b = Basis(h, u, parameters = w)

# And look at the corresponding eqs
println(b)

# Suppose we want to add another equation, say sin(u[1])
# The basis behaves like an array
push!(b, sin(u[1]))
size(b) # (5)

# Adding an equation which is already present, does not change the basis
push!(b, sin(u[1]))
size(b) # Still 5

# We can iterate over the basis
for bi in b
println(bi)
end

# Index specific eqs
b[3]

# And of course evaluate
# With fixed parameters
b([1;2;3], p = [2; 4])
# And without
b([1;2;3])
# Or for trajectories
X = randn(3, 40)
Y_p = b(X)
Y = b(X, p = [2;4])


# This allows you to transform a basis simply via
@variables x[1:2]
y = [sin(x[1]); cos(x[1]); x[2]]
b2 = Basis(b(y), x, parameters = w)
println(b2)


# We can merge basis
b3 = merge(b, b2)

# Also in place
merge!(b3, b2)
println(b3)

# Get the variables or parameters
variables(b)
parameters(b)

# We can also check if two bases are equal
b == b
52 changes: 32 additions & 20 deletions examples/SInDy_Examples.jl
Expand Up @@ -7,14 +7,14 @@ gr()



# Create a test problem
# Create a
function pendulum(u, p, t)
x = u[2]
y = -9.81sin(u[1]) - 0.1u[2]
y = -9.81sin(u[1]) - 0.1u[2]^3 -0.2*cos(u[1])
return [x;y]
end

u0 = [0.2π; -1.0]
u0 = [0.99π; -1.0]
tspan = (0.0, 20.0)
prob = ODEProblem(pendulum, u0, tspan)
sol = solve(prob, Tsit5(), saveat = 0.3)
Expand All @@ -31,31 +31,42 @@ end
@variables u[1:2]

# Lots of polynomials
polys = [u[1]^0]
for i 1:3
for j 1:3
polys = Operation[1]
for i 1:5
push!(polys, u.^i...)
for j 1:i-1
push!(polys, u[1]^i*u[2]^j)
end
end

# And some other stuff
h = [1u[1];1u[2]; cos(u[1]); sin(u[1]); u[1]*u[2]; u[1]*sin(u[2]); u[2]*cos(u[2]); polys...]
h = [cos(u[1]); sin(u[1]); u[1]*u[2]; u[1]*sin(u[2]); u[2]*cos(u[2]); polys...]

basis = Basis(h, u)
println(basis)

# Get the reduced basis via the sparse regression
opt = STRRidge(1e-10/0.05)
Ψ = SInDy(sol[:,:], DX, basis, maxiter = 100, opt = opt)
println.basis)


opt = ADMM(1e-10, 0.05)
Ψ = SInDy(sol[:,:], DX, basis, maxiter = 2000, opt = opt)
println.basis)

opt = SR3(1e-2, 1.8)
Ψ = SInDy(sol[:,:], DX, basis, maxiter = 2000, opt = opt)
println.basis)
# Thresholded Sequential Least Squares, works fine for more data
# than assumptions, converges fast but fails sometimes with too much noise
opt = STRRidge(1e-2)
Ψ = SInDy(sol[:,1:25], DX[:, 1:25], basis, maxiter = 100, opt = opt)
println(Ψ)

# Lasso as ADMM, typically needs more information, more tuning
opt = ADMM(1e-2, 1.0)
Ψ = SInDy(sol[:,1:50], DX[:, 1:50], basis, maxiter = 5000, opt = opt)
println(Ψ)

# SR3, works good with lesser data and tuning
opt = SR3(1e-2, 1.0)
Ψ = SInDy(sol[:,1:30], DX[:, 1:30], basis, maxiter = 5000, opt = opt)
println(Ψ)

# Vary the sparsity threshold -> gives better results
λs = exp10.(-5:0.1:-1)
opt = ADMM(1e-2, 1.0)
Ψ = SInDy(sol[:,1:30], DX[:, 1:30], basis, λs, maxiter = 20, opt = opt)
println(Ψ)

# Transform into ODE System
sys = ODESystem(Ψ)
Expand All @@ -64,8 +75,9 @@ sys = ODESystem(Ψ)
estimator = ODEProblem(dynamics(Ψ), u0, tspan)
sol_ = solve(estimator, Tsit5(), saveat = sol.t)


# Yeah! We got it right
scatter(sol[:,:]')
plot!(sol_[:,:]')

plot(sol.t, abs.(sol-sol_)')
norm(sol[:,:]-sol_[:,:]) # ≈ 1.89e-13
7 changes: 7 additions & 0 deletions src/DataDrivenDiffEq.jl
Expand Up @@ -2,19 +2,23 @@ module DataDrivenDiffEq

using LinearAlgebra
using ModelingToolkit
using QuadGK, Statistics
using Compat

abstract type abstractBasis end;
abstract type abstractKoopmanOperator end;

include("./optimisers/Optimise.jl")
using .Optimise
export set_threshold!
export STRRidge, ADMM, SR3
export ADM

include("./basis.jl")
export Basis
export variables, jacobian, dynamics
export free_parameters, parameters, variables


include("./exact_dmd.jl")
export ExactDMD
Expand All @@ -39,5 +43,8 @@ export SInDy
include("./isindy.jl")
export ISInDy

include("./utils.jl")
export AIC, AICC, BIC
export hankel, optimal_shrinkage, optimal_shrinkage!

end # module
101 changes: 95 additions & 6 deletions src/basis.jl
@@ -1,15 +1,31 @@
import Base.==


mutable struct Basis{O, V, P} <: abstractBasis
basis::O
variables::V
parameter::P
f_
end

is_independent(o::Operation) = isempty(o.args)

Base.print(io::IO, x::Basis) = show(io, x)
Base.show(io::IO, x::Basis) = print(io, "$(length(x.basis)) dimensional basis in ", "$(String.([v.op.name for v in x.variables]))")

@inline function Base.print(io::IO, x::Basis)
show(io, x)
println()
if length(x.variables) == length(x.basis)
for (i, bi) in enumerate(x.basis)
println("d$(x.variables[i]) = $bi")
end
else
for (i, bi) in enumerate(x.basis)
println("f_$i = $bi")
end
end
end

is_independent(o::Operation) = isempty(o.args)

function Basis(basis::AbstractVector{Operation}, variables::AbstractVector{Operation}; parameters = [])
@assert all(is_independent.(variables)) "Please provide independent variables for base."

Expand All @@ -30,6 +46,15 @@ function update!(b::Basis)
b.f_ = ModelingToolkit.build_function(b.basis, vs, ps, (), simplified_expr, Val{false})[1]
return
end

function Base.push!(b::Basis, ops::AbstractArray{Operation})
@inbounds for o in ops
push!(b.basis, o)
end
unique!(b.basis)
update!(b)
return
end

function Base.push!(b::Basis, op₀::Operation)
op = simplify_constants(op₀)
Expand All @@ -47,17 +72,70 @@ function Base.deleteat!(b::Basis, inds)
return
end

(b::Basis)(x::AbstractArray{T, 1}; p::AbstractArray = []) where T <: Number = b.f_(x, p)
function Base.merge(basis_a::Basis, basis_b::Basis)
b = unique(vcat(basis_a.basis, basis_b.basis))
vs = unique(vcat(basis_a.variables, basis_b.variables))
ps = unique(vcat(basis_a.parameter, basis_b.parameter))
return Basis(b, vs, parameters = ps)
end

function Base.merge!(basis_a::Basis, basis_b::Basis)
push!(basis_a, basis_b.basis)
basis_a.variables = unique(vcat(basis_a.variables, basis_b.variables))
basis_a.parameter = unique(vcat(basis_a.parameter, basis_b.parameter))
update!(basis_a)
return
end

Base.getindex(b::Basis, idx::Int64) = b.basis[idx]
Base.getindex(b::Basis, ids::UnitRange{Int64}) = b.basis[ids]
Base.getindex(b::Basis, ::Colon) = b.basis
Base.firstindex(b::Basis) = firstindex(b.basis)
Base.lastindex(b::Basis) = lastindex(b.basis)
Base.iterate(b::Basis) = iterate(b.basis)
Base.iterate(b::Basis, id::Int64) = iterate(b.basis, id)

function (==)(x::Basis, y::Basis)
n = zeros(Bool, length(x.basis))
@inbounds for (i, xi) in enumerate(x)
n[i] = any(isequal.(xi, y.basis))
end
return all(n)
end

function count_operation(o::Expression, ops::AbstractArray)
if isa(o, ModelingToolkit.Constant)
return 0
end
k = o.op ops ? 1 : 0
if !isempty(o.args)
k += sum([count_operation(ai, ops) for ai in o.args])
end
return k
end

free_parameters(b::Basis; operations = [+]) = sum([count_operation(bi, operations) for bi in b.basis]) + length(b.basis)

(b::Basis)(x::AbstractArray{T, 1}; p::AbstractArray = []) where T <: Number = b.f_(x, isempty(p) ? parameters(b) : p)



function (b::Basis)(x::AbstractArray{T, 2}; p::AbstractArray = []) where T <: Number
res = zeros(eltype(x), length(b.basis), size(x)[2])
if (isempty(p) || eltype(p) <: Expression) && !isempty(parameters(b))
pi = isempty(p) ? parameters(b) : p
res = zeros(eltype(pi), length(b), size(x)[2])
else
pi = p
res = zeros(eltype(x), length(b), size(x)[2])
end
@inbounds for i in 1:size(x)[2]
res[:, i] .= b.f_(x[:, i], p)
res[:, i] .= b.f_(x[:, i], pi)
end
return res
end

Base.size(b::Basis) = size(b.basis)
Base.length(b::Basis) = length(b.basis)
ModelingToolkit.parameters(b::Basis) = b.parameter
variables(b::Basis) = b.variables

Expand All @@ -68,13 +146,15 @@ function jacobian(b::Basis)
return ModelingToolkit.build_function(expand_derivatives.(j), vs, ps, (), simplified_expr, Val{false})[1]
end


function Base.unique!(b::Basis)
N = length(b.basis)
removes = Vector{Bool}()
for i 1:N
push!(removes, any([isequal(b.basis[i], b.basis[j]) for j in i+1:N]))
end
deleteat!(b, removes)
update!(b)
end

function Base.unique(b::Basis)
Expand All @@ -98,6 +178,15 @@ function Base.unique(b₀::AbstractVector{Operation})
return b[returns]
end

function Base.unique!(b::AbstractArray{Operation})
N = length(b)
removes = Vector{Bool}()
for i 1:N
push!(removes, any([isequal(b[i], b[j]) for j in i+1:N]))
end
deleteat!(b, removes)
end

function fix_single_vars_in_basis!(basis,variables)
for (ind, el) in enumerate(basis)
for (ind_var, var) in enumerate(variables)
Expand Down
2 changes: 1 addition & 1 deletion src/optimisers/Optimise.jl
Expand Up @@ -14,7 +14,7 @@ include("./sr3.jl")
#Nullspace for implicit sindy
include("./adm.jl")

export init, fit!
export init, init!, fit!, set_threshold!
export STRRidge, ADMM, SR3
export ADM

Expand Down

0 comments on commit 407a533

Please sign in to comment.