Skip to content
This repository has been archived by the owner on Jul 19, 2023. It is now read-only.

Basic operator composition #67

Merged
merged 18 commits into from
Jul 16, 2018
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/DiffEqOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ __precompile__()

module DiffEqOperators

import Base: *, getindex
import Base: *, /, \, size, getindex, setindex!, Matrix
using DiffEqBase, StaticArrays, LinearAlgebra
import DiffEqBase: update_coefficients, update_coefficients!
import LinearAlgebra: mul!, lmul!, rmul!, axpy!, opnorm, factorize
import DiffEqBase: AbstractDiffEqLinearOperator, update_coefficients!, is_constant

abstract type AbstractDerivativeOperator{T} <: DiffEqBase.AbstractDiffEqLinearOperator{T} end
abstract type AbstractDerivativeOperator{T} <: AbstractDiffEqLinearOperator{T} end

struct DEFAULT_UPDATE_FUNC end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why make it an overloaded singleton instead of a function?

(::DEFAULT_UPDATE_FUNC)(A,u,p,t) = A

### Basic Operators
include("diffeqscalar.jl")
Expand Down
200 changes: 56 additions & 144 deletions src/array_operator.jl
Original file line number Diff line number Diff line change
@@ -1,148 +1,60 @@
### AbstractDiffEqLinearOperator defined by an array and update functions
mutable struct DiffEqArrayOperator{T,Arr<:Union{T,AbstractMatrix{T}},Sca,F} <: DiffEqBase.AbstractDiffEqLinearOperator{T}
A::Arr
α::Sca
_isreal::Bool
_issymmetric::Bool
_ishermitian::Bool
_isposdef::Bool
update_func::F
end

DEFAULT_UPDATE_FUNC = (L,u,p,t)->nothing

function DiffEqArrayOperator(A::Number,α=1.0,
update_func = DEFAULT_UPDATE_FUNC)
if (typeof(α) <: Number)
_α = DiffEqScalar(nothing,α)
elseif (typeof(α) <: DiffEqScalar) # Must be a DiffEqScalar already
_α = α
else # Assume it's some kind of function
# Wrapping the function call in one() should solve any cases
# where the function is not well-behaved at 0.0, as long as
# the return type is correct.
_α = DiffEqScalar(α,one(α(0.0)))
end
DiffEqArrayOperator{typeof(A),typeof(A),typeof(_α),
typeof(update_func)}(
A,_α,isreal(A),issymmetric(A),ishermitian(A),
isposdef(A),update_func)
end

function DiffEqArrayOperator(A::AbstractMatrix{T},α=1.0,
update_func = DEFAULT_UPDATE_FUNC) where T
if (typeof(α) <: Number)
_α = DiffEqScalar(nothing,α)
elseif (typeof(α) <: DiffEqScalar) # Must be a DiffEqScalar already
_α = α
else # Assume it's some kind of function
# Wrapping the function call in one() should solve any cases
# where the function is not well-behaved at 0.0, as long as
# the return type is correct.
_α = DiffEqScalar(α,one(α(0.0)))
end
DiffEqArrayOperator{T,typeof(A),typeof(_α),
typeof(update_func)}(
A,_α,isreal(A),issymmetric(A),ishermitian(A),
isposdef(A),update_func)
end

Base.isreal(L::DiffEqArrayOperator) = L._isreal
Base.issymmetric(L::DiffEqArrayOperator) = L._issymmetric
Base.ishermitian(L::DiffEqArrayOperator) = L._ishermitian
Base.isposdef(L::DiffEqArrayOperator) = L._isposdef
DiffEqBase.is_constant(L::DiffEqArrayOperator) = L.update_func == DEFAULT_UPDATE_FUNC
Base.full(L::DiffEqArrayOperator) = full(L.A) .* L.α.coeff
Base.exp(L::DiffEqArrayOperator) = exp(full(L))
DiffEqBase.has_exp(L::DiffEqArrayOperator) = true
Base.size(L::DiffEqArrayOperator) = size(L.A)
Base.size(L::DiffEqArrayOperator, m::Integer) = size(L.A, m)
LinearAlgebra.opnorm(L::DiffEqArrayOperator, p::Real=2) = opnorm(L.A, p) * abs(L.α.coeff)
DiffEqBase.update_coefficients!(L::DiffEqArrayOperator,u,p,t) = (L.update_func(L.A,u,p,t); L.α = L.α(t); nothing)
DiffEqBase.update_coefficients(L::DiffEqArrayOperator,u,p,t) = (L.update_func(L.A,u,p,t); L.α = L.α(t); L)

function (L::DiffEqArrayOperator)(u,p,t)
update_coefficients!(L,u,p,t)
L*u
end

function (L::DiffEqArrayOperator)(du,u,p,t)
update_coefficients!(L,u,p,t)
mul!(du,L,u)
end

### Forward some extra operations
function Base.:*(α::Number,L::DiffEqArrayOperator)
DiffEqArrayOperator(L.A,DiffEqScalar(L.α.func,L.α.coeff*α),L.update_func)
end

function Base.:*(α::Number,L::DiffEqArrayOperator{T,Arr,Sca,F}) where {T,Arr<:Number,Sca,F}
L.α.coeff*α*L.A
end

Base.:*(L::DiffEqArrayOperator,α::Number) = α*L
Base.:*(L::DiffEqArrayOperator,b::AbstractVector) = L.α.coeff*L.A*b
Base.:*(L::DiffEqArrayOperator,b::AbstractArray) = L.α.coeff*L.A*b

function LinearAlgebra.mul!(v::AbstractVector,L::DiffEqArrayOperator,b::AbstractVector)
mul!(v,L.A,b)
rmul!(v,L.α.coeff)
end

function LinearAlgebra.mul!(v::AbstractArray,L::DiffEqArrayOperator,b::AbstractArray)
mul!(v,L.A,b)
rmul!(v,L.α.coeff)
end

function Base.A_ldiv_B!(x,L::DiffEqArrayOperator, b::AbstractArray)
A_ldiv_B!(x,L.A,b)
rmul!(x,inv(L.α.coeff))
end

function Base.:/(x,L::DiffEqArrayOperator)
x/(L.α.coeff*L.A)
end

function Base.:/(L::DiffEqArrayOperator,x)
L.α.coeff*L.A/x
end

"""
FactorizedDiffEqArrayOperator{T,I}

A helper function for holding factorized version of the DiffEqArrayOperator
"""
struct FactorizedDiffEqArrayOperator{T,I}
A::T
inv_coeff::I
end
DiffEqArrayOperator(A[; update_func])

Base.factorize(L::DiffEqArrayOperator) = FactorizedDiffEqArrayOperator(factorize(L.A),inv(L.α.coeff))
Base.lufact(L::DiffEqArrayOperator,args...) = FactorizedDiffEqArrayOperator(lufact(L.A,args...),inv(L.α.coeff))
Base.lufact!(L::DiffEqArrayOperator,args...) = FactorizedDiffEqArrayOperator(lufact!(L.A,args...),inv(L.α.coeff))
Base.qrfact(L::DiffEqArrayOperator,args...) = FactorizedDiffEqArrayOperator(qrfact(L.A,args...),inv(L.α.coeff))
Base.qrfact!(L::DiffEqArrayOperator,args...) = FactorizedDiffEqArrayOperator(qrfact!(L.A,args...),inv(L.α.coeff))
Base.cholfact(L::DiffEqArrayOperator,args...) = FactorizedDiffEqArrayOperator(cholfact(L.A,args...),inv(L.α.coeff))
Base.cholfact!(L::DiffEqArrayOperator,args...) = FactorizedDiffEqArrayOperator(cholfact!(L.A,args...),inv(L.α.coeff))
Base.ldltfact(L::DiffEqArrayOperator,args...) = FactorizedDiffEqArrayOperator(ldltfact(L.A,args...),inv(L.α.coeff))
Base.ldltfact!(L::DiffEqArrayOperator,args...) = FactorizedDiffEqArrayOperator(ldltfact!(L.A,args...),inv(L.α.coeff))
Base.bkfact(L::DiffEqArrayOperator,args...) = FactorizedDiffEqArrayOperator(bkfact(L.A,args...),inv(L.α.coeff))
Base.bkfact!(L::DiffEqArrayOperator,args...) = FactorizedDiffEqArrayOperator(bkfact!(L.A,args...),inv(L.α.coeff))
Base.lqfact(L::DiffEqArrayOperator,args...) = FactorizedDiffEqArrayOperator(lqfact(L.A,args...),inv(L.α.coeff))
Base.lqfact!(L::DiffEqArrayOperator,args...) = FactorizedDiffEqArrayOperator(lqfact!(L.A,args...),inv(L.α.coeff))
Base.svdfact(L::DiffEqArrayOperator,args...) = FactorizedDiffEqArrayOperator(svdfact(L.A,args...),inv(L.α.coeff))
Base.svdfact!(L::DiffEqArrayOperator,args...) = FactorizedDiffEqArrayOperator(svdfact!(L.A,args...),inv(L.α.coeff))
Represents a time-dependent linear operator given by an AbstractMatrix. The
update function is called by `update_coefficients!` and is assumed to have
the following signature:

function Base.A_ldiv_B!(x,L::FactorizedDiffEqArrayOperator, b::AbstractArray)
A_ldiv_B!(x,L.A,b)
rmul!(x,inv(L.inv_coeff))
end

function Base.:\(L::FactorizedDiffEqArrayOperator, b::AbstractArray)
(L.A \ b) * L.inv_coeff
end
update_func(A::AbstractMatrix,u,p,t) -> [modifies A]

@inline Base.getindex(L::DiffEqArrayOperator,i::Int) = L.A[i]
@inline Base.getindex(L::DiffEqArrayOperator,I::Vararg{Int, N}) where {N} = L.A[I...]
@inline Base.setindex!(L::DiffEqArrayOperator, v, i::Int) = (L.A[i]=v)
@inline Base.setindex!(L::DiffEqArrayOperator, v, I::Vararg{Int, N}) where {N} = (L.A[I...]=v)
You can also use `setval!(α,A)` to bypass the `update_coefficients!` interface
and directly mutate the array's value.
"""
mutable struct DiffEqArrayOperator{T,AType<:AbstractMatrix{T},F} <: AbstractDiffEqLinearOperator{T}
A::AType
update_func::F
DiffEqArrayOperator(A::AType; update_func=DEFAULT_UPDATE_FUNC()) where {AType} =
new{eltype(A),AType,typeof(update_func)}(A, update_func)
end

update_coefficients!(L::DiffEqArrayOperator,u,p,t) = (L.update_func(L.A,u,p,t); L)
setval!(L::DiffEqArrayOperator, A) = (L.A = A; L)
is_constant(L::DiffEqArrayOperator) = L.update_func == DEFAULT_UPDATE_FUNC()
(L::DiffEqArrayOperator)(u,p,t) = (update_coefficients!(L,u,p,t); L.A * u)
(L::DiffEqArrayOperator)(du,u,p,t) = (update_coefficients!(L,u,p,t); mul!(du, L.A, u))

# Forward operations that use the underlying array
for pred in (:isreal, :issymmetric, :ishermitian, :isposdef)
@eval LinearAlgebra.$pred(L::DiffEqArrayOperator) = $pred(L.A)
end
size(L::DiffEqArrayOperator) = size(L.A)
size(L::DiffEqArrayOperator, m) = size(L.A, m)
opnorm(L::DiffEqArrayOperator, p::Real=2) = opnorm(L.A, p)
getindex(L::DiffEqArrayOperator, i::Int) = L.A[i]
getindex(L::DiffEqArrayOperator, I::Vararg{Int, N}) where {N} = L.A[I...]
setindex!(L::DiffEqArrayOperator, v, i::Int) = (L.A[i] = v)
setindex!(L::DiffEqArrayOperator, v, I::Vararg{Int, N}) where {N} = (L.A[I...] = v)
*(L::DiffEqArrayOperator, x) = L.A * x
*(x, L::DiffEqArrayOperator) = x * L.A
/(L::DiffEqArrayOperator, x) = L.A / x
/(x, L::DiffEqArrayOperator) = x / L.A
mul!(Y, L::DiffEqArrayOperator, B) = mul!(Y, L.A, B)
ldiv!(Y, L::DiffEqArrayOperator, B) = ldiv!(Y, L.A, B)

# Forward operations that use the full matrix
Matrix(L::DiffEqArrayOperator) = Matrix(L.A)
Base.exp(L::DiffEqArrayOperator) = exp(Matrix(L))

# Factorization
struct FactorizedDiffEqArrayOperator{T<:Number,FType<:Factorization{T}} <: AbstractDiffEqLinearOperator{T}
F::FType
end

factorize(L::DiffEqArrayOperator) = FactorizedDiffEqArrayOperator(factorize(L.A))
for fact in (:lu, :lu!, :qr, :qr!, :chol, :chol!, :ldlt, :ldlt!,
:bkfact, :bkfact!, :lq, :lq!, :svd, :svd!)
@eval LinearAlgebra.$fact(L::DiffEqArrayOperator, args...) = FactorizedDiffEqArrayOperator($fact(L.A, args...))
end

ldiv!(Y, L::FactorizedDiffEqArrayOperator, B) = ldiv!(Y, L.F, B)
\(L::FactorizedDiffEqArrayOperator, x) = L.F \ x
42 changes: 23 additions & 19 deletions src/diffeqscalar.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,32 @@
"""
DiffEqScalar Interface
DiffEqScalar(val[; update_func])

DiffEqScalar(func,coeff=1.0)
Represents a time-dependent scalar/scaling operator. The update function
is called by `update_coefficients!` and is assumed to have the following
signature:

This is a function with a coefficient.
update_func(oldval,u,p,t) -> newval

α(t) returns a new DiffEqScalar with an updated coefficient.
You can also use `setval!(α,val)` to bypass the `update_coefficients!`
interface and directly mutate the scalar's value.
"""
struct DiffEqScalar{F,T}
func::F
coeff::T
DiffEqScalar{T}(func) where T = new{typeof(func),T}(func,one(T))
DiffEqScalar{F,T}(func,coeff) where {F,T} = new{F,T}(func,coeff)
mutable struct DiffEqScalar{T<:Number,F} <: AbstractDiffEqLinearOperator{T}
val::T
update_func::F
DiffEqScalar(val::T; update_func=DEFAULT_UPDATE_FUNC()) where {T} =
new{T,typeof(update_func)}(val, update_func)
end

DiffEqScalar(func,coeff=1.0) = DiffEqScalar{typeof(func),typeof(coeff)}(func,coeff)
update_coefficients!(α::DiffEqScalar,u,p,t) = (α.val = α.update_func(α.val,u,p,t); α)
setval!(α::DiffEqScalar, val) = (α.val = val; α)
is_constant(α::DiffEqScalar) = α.update_func == DEFAULT_UPDATE_FUNC()

function (α::DiffEqScalar)(t)
if α.func == nothing
return DiffEqScalar(α.func,α.coeff)
else
return DiffEqScalar(α.func,α.func(t))
end
end
*(α::DiffEqScalar, x) = α.val * x
*(x, α::DiffEqScalar) = x * α.val
lmul!(α::DiffEqScalar, B) = lmul!(α.val, B)
rmul!(B, α::DiffEqScalar) = rmul!(B, α.val)
mul!(Y, α::DiffEqScalar, B) = mul!(Y, α.val, B)
axpy!(α::DiffEqScalar, X, Y) = axpy!(α.val, X, Y)

Base.:*(α::Number,B::DiffEqScalar) = DiffEqScalar(B.func,B.coeff*α)
Base.:*(B::DiffEqScalar,α::Number) = DiffEqScalar(B.func,B.coeff*α)
(α::DiffEqScalar)(u,p,t) = (update_coefficients!(α,u,p,t); α.val * u)
::DiffEqScalar)(du,u,p,t) = (update_coefficients!(α,u,p,t); @. du = α.val * u)
18 changes: 8 additions & 10 deletions test/array_operators_interface.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
using DiffEqOperators
using DiffEqOperators, Random, LinearAlgebra
using Test

N = 5
srand(0); A = rand(N,N); u = rand(N)
L = DiffEqArrayOperator(A)
a = 3.5
La = L * a

@test La * u ≈ (a*A) * u
@test lufact(La) \ u ≈ (a*A) \ u
@test opnorm(La) ≈ opnorm(a*A)
@test exp(La) ≈ exp(a*A)
@test La[2,3] ≈ A[2,3] # should this be La[2,3] == a*A[2,3]?
@test L * u ≈ A * u
@test lu(L) \ u ≈ A \ u
@test opnorm(L) ≈ opnorm(A)
@test exp(L) ≈ exp(A)
@test L[2,3] == A[2,3]

update_func = (_A,u,p,t) -> _A .= t * A
t = 3.0
Atmp = zeros(N,N)
Lt = DiffEqArrayOperator(Atmp, a, update_func)
@test Lt(u,nothing,t) ≈ (a*t*A) * u
Lt = DiffEqArrayOperator(Atmp; update_func=update_func)
@test Lt(u,nothing,t) ≈ (t*A) * u