Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clean up affine operator #21

Merged
merged 11 commits into from
Jun 4, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 34 additions & 62 deletions src/sciml.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,79 +177,51 @@ LinearAlgebra.ldiv!(Y::AbstractVector, L::FactorizedOperator, B::AbstractVector)
LinearAlgebra.ldiv!(L::FactorizedOperator, B::AbstractVector) = ldiv!(L.F, B)

"""
AffineOperator{T} <: AbstractSciMLOperator{T}

`Ex: (A₁(t) + ... + Aₙ(t))*u + B₁(t) + ... + Bₘ(t)`
L = AffineOperator(A, b)
L(u) = A*u + b
"""
struct AffineOperator{T,AType,bType} <: AbstractSciMLOperator{T}
A::AType
b::bType

AffineOperator{T}(As,Bs,du_cache=nothing)
function AffineOperator(A::AbstractSciMLOperator, b::AbstractVector)
T = promote_type(eltype.((A,b))...)
new{T,typeof(A),typeof(b)}(A, b)
end
end

Takes in two tuples for split Affine DiffEqs
getops(L::AffineOperator) = (L.A, L.b)
Base.size(L::AffineOperator) = size(L.A)

1. update_coefficients! works by updating the coefficients of the component
operators.
2. Function calls L(u, p, t) and L(du, u, p, t) are fallbacks interpretted in this form.
This will allow them to work directly in the nonlinear ODE solvers without
modification.
3. f(du, u, p, t) is only allowed if a du_cache is given
4. B(t) can be Union{Number,AbstractArray}, in which case they are constants.
Otherwise they are interpreted they are functions v=B(t) and B(v,t)
islinear(::AffineOperator) = false
Base.iszero(L::AffineOperator) = all(iszero, getops(L))
has_adjoint(L::AffineOperator) = all(has_adjoint, L.ops)
has_mul!(L::AffineOperator) = has_mul!(L.A)
has_ldiv(L::AffineOperator) = has_ldiv(L.A)
has_ldiv!(L::AffineOperator) = has_ldiv!(L.A)

Solvers will see this operator from integrator.f and can interpret it by
checking the internals of As and Bs. For example, it can check isconstant(As[1])
etc.
"""
struct AffineOperator{T,T1,T2,U} <: AbstractSciMLOperator{T}
As::T1
Bs::T2
du_cache::U
function AffineOperator{T}(As,Bs,du_cache=nothing) where T
all([size(a) == size(As[1])
for a in As]) || error("Operator sizes do not agree")
new{T,typeof(As),typeof(Bs),typeof(du_cache)}(As,Bs,du_cache)
end
end

Base.size(L::AffineOperator) = size(L.As[1])
Base.:*(L::AffineOperator, u::AbstractVector) = L.A * u + L.b
Base.:\(L::AffineOperator, u::AbstractVector) = L.A \ (u - L.b)

getops(L::AffineOperator) = (L.As..., L.Bs...)
function LinearAlgebra.mul!(v::AbstractVector, L::AffineOperator, u::AbstractVector)
mul!(v, L.A, u)
axpy!(true, L.b, v)
end

function (L::AffineOperator)(u,p,t::Number)
update_coefficients!(L,u,p,t)
du = sum(A*u for A in L.As)
for B in L.Bs
if typeof(B) <: Union{Number,AbstractArray}
du .+= B
else
du .+= B(t)
end
end
du
function LinearAlgebra.mul!(v::AbstractVector, L::AffineOperator, u::AbstractVector, α::Number, β::Number)
mul!(v, L.A, u, α, β)
axpy!(α, L.b, v)
end

function (L::AffineOperator)(du,u,p,t::Number)
update_coefficients!(L,u,p,t)
L.du_cache === nothing && error("Can only use inplace AffineOperator if du_cache is given.")
du_cache = L.du_cache
fill!(du,zero(first(du)))
# TODO: Make type-stable via recursion
for A in L.As
mul!(du_cache,A,u)
du .+= du_cache
end
for B in L.Bs
if typeof(B) <: Union{Number,AbstractArray}
du .+= B
else
B(du_cache,t)
du .+= du_cache
end
end
function LinearAlgebra.ldiv!(v::AbstractVector, L::AffineOperator, u::AbstractVector)
copy!(v, u)
ldiv!(L, v)
end

function update_coefficients!(L::AffineOperator,u,p,t)
# TODO: Make type-stable via recursion
for A in L.As; update_coefficients!(A,u,p,t); end
for B in L.Bs; update_coefficients!(B,u,p,t); end
function LinearAlgebra.ldiv!(L::AffineOperator, u::AbstractVector)
axpy!(-true, L.b, u)
ldiv!(L.A, u)
end

"""
Expand Down
25 changes: 25 additions & 0 deletions test/sciml.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,31 @@ N = 8
@test At \ u ≈ AAt \ u ≈ FFt \ u
end

@testset "AffineOperator" begin
u = rand(N)
A = rand(N,N)
D = Diagonal(A)
b = rand(N)
α = rand()
β = rand()

L = AffineOperator(MatrixOperator(A), b)

@test L * u ≈ A * u + b
v=rand(N); @test mul!(v, L, u) ≈ A * u + b
v=rand(N); w=copy(v); @test mul!(v, L, u, α, β) ≈ α*(A*u + b) + β*w

L = AffineOperator(MatrixOperator(D), b)
@test L \ u ≈ D \ (u - b)
#
# TODO uncomment later
# ldiv! for MatrixOperator defined in
# https://github.com/SciML/SciMLOperators.jl/pull/22
#
# v=rand(N); @test ldiv!(v, L, u) ≈ D \ (u-b)
# v=rand(N); @test ldiv!(L, u) ≈ D \ (u-b)
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved
end

@testset "SciMLFunctionOperator" begin
end

Expand Down