Skip to content

Commit

Permalink
Merge 016e618 into 0a57a69
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrevaya committed Jan 29, 2020
2 parents 0a57a69 + 016e618 commit 02bb4a5
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
26 changes: 25 additions & 1 deletion src/basis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ function Basis(basis::AbstractVector{Operation}, variables::AbstractVector{Opera
@assert all(is_independent.(variables)) "Please provide independent variables for base."

bs = unique(basis)
fix_single_vars_in_basis!(bs, variables)

vs = sort!([b for b in [ModelingToolkit.vars(bs)...] if !b.known], by = x -> x.name)
ps = sort!([b for b in [ModelingToolkit.vars(bs)...] if b.known], by = x -> x.name )
Expand All @@ -30,7 +31,9 @@ function update!(b::Basis)
return
end

function Base.push!(b::Basis, op::Operation)
function Base.push!(b::Basis, op₀::Operation)
op = simplify_constants(op₀)
fix_single_vars_in_basis!(op, b.variables)
push!(b.basis, op)
# Check for uniqueness
unique!(b)
Expand Down Expand Up @@ -84,6 +87,27 @@ function Base.unique(b::Basis)
return Basis(b.basis[returns], variables(b), parameters = parameters(b))
end

function Base.unique(b₀::AbstractVector{Operation})
b = simplify_constants.(b₀)
N = length(b)
returns = Vector{Bool}()
for i 1:N
push!(returns, any([isequal(b[i], b[j]) for j in i+1:N]))
end
returns = [!r for r in returns]
return b[returns]
end

function fix_single_vars_in_basis!(basis,variables)
for (ind, el) in enumerate(basis)
for (ind_var, var) in enumerate(variables)
if isequal(el,var)
basis[ind] = 1var
end
end
end
end

function dynamics(b::Basis)
return b.f_
end
Expand Down
5 changes: 3 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ using Test
@testset "Basis" begin
@variables u[1:3]
@parameters w[1:2]
h = [1u[1]; 1u[2]; cos(w[1]*u[2]+w[2]*u[3]); 1u[3]+1u[2]]
basis = Basis(h, u, parameters = w)
h = [u[1]; u[2]; cos(w[1]*u[2]+w[2]*u[3]); u[3]+u[2]]
h_not_unique = [1u[1]; u[1]; 1u[1]^1; h]
basis = Basis(h_not_unique, u, parameters = w)
basis_2 = unique(basis)
@test size(basis) == size(h)
@test basis([1.0; 2.0; π], p = [0. 1.]) [1.0; 2.0; -1.0; π+2.0]
Expand Down

0 comments on commit 02bb4a5

Please sign in to comment.