Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
ab91320
Adapt build function slightly
AlCap23 Jan 26, 2022
4e28a47
Move problem into subfolder, dispatch basis prob
AlCap23 Jan 26, 2022
57e0826
Add train_test split etc
AlCap23 Jan 26, 2022
50bf6a5
Remove old code
AlCap23 Jan 26, 2022
23f919b
Adapt DDD function signature
AlCap23 Jan 27, 2022
e0018ee
Add more dispatch and assert validity
AlCap23 Jan 27, 2022
3638c69
Name getter
AlCap23 Jan 27, 2022
100fcfb
Dispatch Base.in for Num and Symbolic
AlCap23 Jan 27, 2022
130ff0e
Add tests
AlCap23 Jan 27, 2022
4c22301
Adapt jacobian and add tests
AlCap23 Jan 27, 2022
9f98a6d
Move stuff and add samplers
AlCap23 Jan 28, 2022
d61278b
More sampling and selection
AlCap23 Jan 29, 2022
3914c04
Adapt for use with measurements
AlCap23 Jan 29, 2022
d536007
Adapt init and project.toml
AlCap23 Jan 29, 2022
994d0bb
Reintroduce progress
AlCap23 Jan 29, 2022
d43d214
Use symbol instead of string
AlCap23 Jan 29, 2022
8ae51a1
Remove the LinearProblem
AlCap23 Jan 29, 2022
b14fa6e
Adapt gitignore
AlCap23 Feb 2, 2022
5fc1e5e
Adapt gitignore
AlCap23 Feb 2, 2022
28f429b
Add CommonSolve
AlCap23 Feb 2, 2022
849ebb2
Cleanup
AlCap23 Feb 2, 2022
dc876b2
New solve structure SINDy
AlCap23 Feb 2, 2022
dcd0ad3
Correct ISINDy and adapt API
AlCap23 Feb 3, 2022
dcfa1ef
Koopman and export
AlCap23 Feb 3, 2022
06ab659
Draft the KO
AlCap23 Feb 3, 2022
7416a42
Finish up operation
AlCap23 Feb 3, 2022
7a358fa
Adapt tests
AlCap23 Feb 3, 2022
642a0f4
Fix interfaces and fallbacks
AlCap23 Feb 4, 2022
56d8453
Fix output behaviour and indexing
AlCap23 Feb 4, 2022
b643d7e
Fix wrong X in tests
AlCap23 Feb 4, 2022
8a3027b
Fix typo
AlCap23 Feb 4, 2022
f78fd18
Adapt test for new basis API and relax error
AlCap23 Feb 4, 2022
0c7c9d3
Remove local testfile
AlCap23 Feb 4, 2022
05e87e1
Adapt naming of problems to gensym
AlCap23 Feb 4, 2022
7664772
Adapt discrete eval
AlCap23 Feb 4, 2022
ff48065
Fix solution for direct prob
AlCap23 Feb 4, 2022
3774d14
Adapt kwargs in common options
AlCap23 Feb 4, 2022
ac3fc91
Adjusted tests for now
AlCap23 Feb 5, 2022
6aa0b41
Rmv progress and add info instead
AlCap23 Feb 5, 2022
4ed0bc3
Further adapt
AlCap23 Feb 5, 2022
21b02ed
Adapt cartpole
AlCap23 Feb 14, 2022
503d353
Add more tests and fix partioning
AlCap23 Feb 15, 2022
cb8f1d2
Rmv info
AlCap23 Feb 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
Manifest.toml
.DS_Store
/docs/build/*
hall_of_fame.*
7 changes: 7 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@ authors = ["Julius Martensen <julius.martensen@gmail.com>"]
version = "0.7.0"

[deps]
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Expand All @@ -22,13 +26,16 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
CommonSolve = "0.2"
Compat = "3.0"
DataInterpolations = "3"
DiffEqBase = "6"
Distributions = "^0.25.9, 0.25"
DocStringExtensions = "0.7, 0.8"
Flux = "^0.12.4"
Measurements = "2.7"
ModelingToolkit = "7, 8"
Parameters = "0.12"
ProgressMeter = "1.6"
QuadGK = "2.4"
RecipesBase = "1"
Expand Down
44 changes: 32 additions & 12 deletions src/DataDrivenDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@ module DataDrivenDiffEq
using DocStringExtensions
using LinearAlgebra
using DiffEqBase
using CommonSolve
using ModelingToolkit

using Distributions
using QuadGK
using Statistics
using DataInterpolations

using Parameters
using Random
using Measurements

using Requires
using ProgressMeter
Expand All @@ -21,6 +24,7 @@ using Compat
using DocStringExtensions
using RecipesBase

@reexport using DiffEqBase: solve
@reexport using ModelingToolkit: states, parameters, independent_variable, observed, controls, get_iv
@reexport using DataInterpolations: ConstantInterpolation, LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation, QuadraticSpline, CubicSpline, BSplineInterpolation, BSplineApprox, Curvefit
using Symbolics: scalarize, variable
Expand All @@ -45,7 +49,11 @@ abstract type AbstractSymbolicRegression end
abstract type AbstractDataDrivenProblem{dType, cType, probType} end
abstract type AbstractDataDrivenSolution end


# Optimizer
abstract type AbstractProximalOperator end;
abstract type AbstractOptimizer{T} end;
abstract type AbstractSubspaceOptimizer{T} <: AbstractOptimizer{T} end;



## Basis
Expand All @@ -55,7 +63,7 @@ include("./basis/utils.jl")
include("./basis/type.jl")
export Basis
export jacobian, dynamics
export free_parameters
export free_parameters, implicit_variables

include("./utils/basis_generators.jl")
export chebyshev_basis, monomial_basis, polynomial_basis
Expand All @@ -76,11 +84,11 @@ export burst_sampling, subsample
## Sparse Regression

include("./optimizers/Optimize.jl")
@reexport using DataDrivenDiffEq.Optimize: sparse_regression!
@reexport using DataDrivenDiffEq.Optimize: set_threshold!, get_threshold
@reexport using DataDrivenDiffEq.Optimize: STLSQ, ADMM, SR3
@reexport using DataDrivenDiffEq.Optimize: ImplicitOptimizer
@reexport using DataDrivenDiffEq.Optimize: SoftThreshold, HardThreshold, ClippedAbsoluteDeviation
export SoftThreshold, HardThreshold,ClippedAbsoluteDeviation
export sparse_regression!
export init, init!, set_threshold!, get_threshold
export STLSQ, ADMM, SR3
export ImplicitOptimizer

## Koopman

Expand Down Expand Up @@ -110,21 +118,33 @@ const AbstractDiscreteProb{N,C} = AbstractDataDrivenProblem{N,C,DDProbType(2)}
const AbstracContProb{N,C} = AbstractDataDrivenProblem{N,C,DDProbType(3)}


include("./problem.jl")
include("./problem/type.jl")

export DataDrivenProblem
export DiscreteDataDrivenProblem, ContinuousDataDrivenProblem, DirectDataDrivenProblem
export is_autonomous, is_discrete, is_direct, is_continuous, is_parametrized, has_timepoints
export is_valid
export is_valid, @is_applicable, get_name

include("./problem/sample.jl")
export DataSampler, Split, Batcher

# Result selection
select_by(x, y::AbstractMatrix) = y
select_by(x, sol) = select_by(Val(x), sol)


include("./solution.jl")
export DataDrivenSolution
export result, parameters, parameter_map, algorithm
export output, metrics, error, aic, determination, get_problem

include("./solve/sindy.jl")

include("./solve/common.jl")
export DataDrivenCommonOptions
include("./solve/sparse_identification.jl")
#include("./solve/sindy.jl")
include("./solve/koopman.jl")
export solve
#export solve

include("./recipes/problem_result.jl")

Expand Down
83 changes: 41 additions & 42 deletions src/basis/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,44 +9,43 @@ function _build_ddd_function(rhs, states, parameters, iv, eval_expression::Bool
end

function f(
u::AbstractVector{T} where T,
p::AbstractVector{T} where T,
u::AbstractVector,
p::AbstractVector,
t::T where T
)
return f_oop(u, p, t)
end

function f(
du::AbstractVector{T} where T,
u::AbstractVector{T} where T,
p::AbstractVector{T} where T,
du::AbstractVector,
u::AbstractVector,
p::AbstractVector,
t::T where T
)
return f_iip(du, u, p, t)
end

function f(
x::AbstractMatrix{T} where T,
p::AbstractVector{T} where T,
t::AbstractVector{T} where T
x::AbstractMatrix,
p::AbstractVector,
t::AbstractVector
)
@assert size(x, 2) == length(t) "Measurements and time points must be of equal length!"

return hcat([f(x[:,i], p, t[i]) for i in 1:size(x, 2)]...)

return reduce(hcat, map(i->f(x[:,i], p, t[i]), 1:length(t)))
end


function f(
y::AbstractMatrix{T} where T,
x::AbstractMatrix{T} where T,
p::AbstractVector{T} where T,
t::AbstractVector{T} where T
y::AbstractMatrix,
x::AbstractMatrix,
p::AbstractVector,
t::AbstractVector
)
@assert size(x, 2) == length(t) "Measurements and time points must be of equal length!"
@assert size(x, 2) == size(y, 2) "Measurements and preallocated output must be of equal length!"

for i = 1:size(x, 2)
@simd for i = 1:size(x, 2)
@views f(y[:, i], x[:, i], p, t[i])
end

Expand Down Expand Up @@ -104,46 +103,46 @@ function _build_ddd_function(
end

function f(
u::AbstractVector{T} where T,
p::AbstractVector{T} where T,
u::AbstractVector,
p::AbstractVector,
t::T where T,
c::AbstractVector{T} where T = zeros(eltype(u), size(controls)...),
c::AbstractVector = zeros(eltype(u), size(controls)...),
)
return c_oop(u, p, t, c)
end

function f(
du::AbstractVector{T} where T,
u::AbstractVector{T} where T,
p::AbstractVector{T} where T,
du::AbstractVector,
u::AbstractVector,
p::AbstractVector,
t::T where T,
c::AbstractVector{T} where T= zeros(eltype(u), size(controls)...),
c::AbstractVector= zeros(eltype(u), size(controls)...),
)
return c_iip(du, u, p, t, c)
end


function f(
x::AbstractMatrix{T} where T,
p::AbstractVector{T} where T,
t::AbstractVector{T} where T
x::AbstractMatrix,
p::AbstractVector,
t::AbstractVector
)
@assert size(x, 2) == length(t) "Measurements and time points must be of equal length!"

return hcat([f(x[:,i], p, t[i]) for i in 1:size(x, 2)]...)
return reduce(hcat, map(i->f(x[:,i], p, t[i]), 1:length(t)))

end

function f(
y::AbstractMatrix{T} where T,
x::AbstractMatrix{T} where T,
p::AbstractVector{T} where T,
t::AbstractVector{T} where T
y::AbstractMatrix,
x::AbstractMatrix,
p::AbstractVector,
t::AbstractVector
)
@assert size(x, 2) == length(t) "Measurements and time points must be of equal length!"
@assert size(x, 2) == size(y, 2) "Measurements and preallocated output must be of equal length!"

for i = 1:size(x, 2)
@simd for i = 1:length(t)
@views f(y[:, i], x[:, i], p, t[i])
end

Expand All @@ -152,31 +151,31 @@ function _build_ddd_function(


function f(
x::AbstractMatrix{T} where T,
p::AbstractVector{T} where T,
t::AbstractVector{T} where T,
u::AbstractMatrix{T} where T
x::AbstractMatrix,
p::AbstractVector,
t::AbstractVector,
u::AbstractMatrix
)
@assert size(x, 2) == length(t) "Measurements and time points must be of equal length!"
@assert size(x, 2) == size(u, 2) "Measurements and inputs must be of equal length!"


return hcat([f(x[:,i], p, t[i], u[:, i]) for i in 1:size(x, 2)]...)
return reduce(hcat, map(i->f(x[:,i], p, t[i], u[:, i]), 1:length(t)))

end

function f(
y::AbstractMatrix{T} where T,
x::AbstractMatrix{T} where T,
p::AbstractVector{T} where T,
t::AbstractVector{T} where T,
u::AbstractMatrix{T} where T
y::AbstractMatrix,
x::AbstractMatrix,
p::AbstractVector,
t::AbstractVector,
u::AbstractMatrix
)
@assert size(x, 2) == length(t) "Measurements and time points must be of equal length!"
@assert size(x, 2) == size(y, 2) "Measurements and preallocated output must be of equal length!"
@assert size(x, 2) == size(u, 2) "Measurements and inputs must be of equal length!"

for i = 1:size(x, 2)
@simd for i = 1:length(t)
@views f(y[:, i], x[:, i], p, t[i], u[:, i])
end

Expand Down
Loading