-
-
Notifications
You must be signed in to change notification settings - Fork 195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Lazy W operator support #443
Changes from 15 commits
71147ee
42ce8c5
a4addeb
197fc40
1a1f5ad
2f17963
df84f70
c57604b
9f40d55
66c62aa
2c5850e
f7e425e
4bd3617
bb185eb
d175f9b
7b7a1cd
08f460d
54b9029
fcc937f
9fe4202
e6f639c
5cedc29
bf600ea
c48a0de
7983293
21b1a5b
be258f8
66a38fb
f01872d
e3c1505
3f8cefe
c4437ea
1eba5e9
b8ae169
e0eea48
72cde28
80d072b
87e174d
90f377b
e2beccb
b520fe0
c1e53da
284fa70
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,10 +54,129 @@ function calc_J!(integrator, cache::OrdinaryDiffEqMutableCache, is_compos) | |
is_compos && (integrator.eigen_est = opnorm(J, Inf)) | ||
end | ||
|
||
""" | ||
WOperator(mass_matrix,gamma,J[;cache=nothing,transform=false]) | ||
|
||
A linear operator that represents the W matrix of an ODEProblem, defined as | ||
|
||
```math | ||
W = MM - \\gamma J | ||
``` | ||
|
||
or, if `transform=true`: | ||
|
||
```math | ||
W = \\frac{1}{\\gamma}MM - J | ||
``` | ||
|
||
where `MM` is the mass matrix (a regular `AbstractMatrix` or a `UniformScaling`), | ||
`γ` is a real number proportional to the time step, and `J` is the Jacobian | ||
operator (must be a `AbstractDiffEqLinearOperator`). | ||
|
||
`WOperator` supports lazy `*` and `mul!` operations, the latter utilizing an | ||
internal cache (can be specified in the constructor; default to regular `Vector`). | ||
It supports all of `AbstractDiffEqLinearOperator`'s interface. | ||
""" | ||
mutable struct WOperator{T, | ||
MType <: Union{UniformScaling,AbstractMatrix}, | ||
GType <: Real, | ||
JType <: DiffEqBase.AbstractDiffEqLinearOperator{T}, | ||
CType <: AbstractVector | ||
} <: DiffEqBase.AbstractDiffEqLinearOperator{T} | ||
mass_matrix::MType | ||
gamma::GType | ||
J::JType | ||
cache::CType | ||
transform::Bool | ||
function WOperator(mass_matrix, gamma, J; cache=nothing, transform=false) | ||
T = eltype(J) | ||
# Convert mass_matrix, if needed | ||
if !isa(mass_matrix, Union{AbstractMatrix,UniformScaling}) | ||
mass_matrix = convert(AbstractMatrix, mass_matrix) | ||
end | ||
# Construct the cache, default to regular vector | ||
if cache == nothing | ||
cache = Vector{T}(undef, size(J, 1)) | ||
end | ||
new{T,typeof(mass_matrix),typeof(gamma),typeof(J),typeof(cache)}(mass_matrix,gamma,J,cache,transform) | ||
end | ||
end | ||
set_gamma!(W::WOperator, gamma) = (W.gamma = gamma; W) | ||
DiffEqBase.update_coefficients!(W::WOperator,u,p,t) = (update_coefficients!(W.J,u,p,t); W) | ||
function Base.convert(::Type{AbstractMatrix}, W::WOperator) | ||
if W.transform | ||
W.mass_matrix / W.gamma - convert(AbstractMatrix,W.J) | ||
else | ||
W.mass_matrix - W.gamma * convert(AbstractMatrix,W.J) | ||
end | ||
end | ||
function Base.convert(::Type{Number}, W::WOperator) | ||
if W.transform | ||
W.mass_matrix / W.gamma - convert(Number,W.J) | ||
else | ||
W.mass_matrix - W.gamma * convert(Number,W.J) | ||
end | ||
end | ||
Base.size(W::WOperator, args...) = size(W.J, args...) | ||
function Base.getindex(W::WOperator, i::Int) | ||
if W.transform | ||
W.mass_matrix[i] / W.gamma - W.J[i] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this shows up everywhere and it's constant given by the algorithm, it might make sense to make this a type parameter so the code can simply dispatch off of it, or at least not have large performance cuts because of this conditioning in indexing. I'm not sure this indexing is really used all that much though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This is true both for factorization (it indexes the concrete matrix instead of W) and There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good |
||
else | ||
W.mass_matrix[i] - W.gamma * W.J[i] | ||
end | ||
end | ||
function Base.getindex(W::WOperator, I::Vararg{Int,N}) where {N} | ||
if W.transform | ||
W.mass_matrix[I...] / W.gamma - W.J[I...] | ||
else | ||
W.mass_matrix[I...] - W.gamma * W.J[I...] | ||
end | ||
end | ||
function Base.:*(W::WOperator, x::Union{AbstractVecOrMat,Number}) | ||
if W.transform | ||
(W.mass_matrix*x) / W.gamma - W.J*x | ||
else | ||
W.mass_matrix*x - W.gamma * (W.J*x) | ||
end | ||
end | ||
function Base.:\(W::WOperator, x::Union{AbstractVecOrMat,Number}) | ||
if size(W) == () # scalar operator | ||
convert(Number,W) \ x | ||
else | ||
convert(AbstractMatrix,W) \ x | ||
end | ||
end | ||
function LinearAlgebra.mul!(Y::AbstractVecOrMat, W::WOperator, B::AbstractVecOrMat) | ||
if W.transform | ||
# Compute mass_matrix * B | ||
if isa(W.mass_matrix, UniformScaling) | ||
a = W.mass_matrix.λ / W.gamma | ||
@. Y = a * B | ||
else | ||
mul!(Y, W.mass_matrix, B) | ||
lmul!(1/W.gamma, Y) | ||
end | ||
# Compute J * B and subtract | ||
mul!(W.cache, W.J, B) | ||
Y .-= W.cache | ||
else | ||
# Compute mass_matrix * B | ||
if isa(W.mass_matrix, UniformScaling) | ||
@. Y = W.mass_matrix.λ * B | ||
else | ||
mul!(Y, W.mass_matrix, B) | ||
end | ||
# Compute J * B | ||
mul!(W.cache, W.J, B) | ||
# Subtract result | ||
axpy!(-W.gamma, W.cache, Y) | ||
end | ||
end | ||
|
||
function calc_W!(integrator, cache::OrdinaryDiffEqMutableCache, dtgamma, repeat_step, W_transform=false) | ||
@inbounds begin | ||
@unpack t,dt,uprev,u,f,p = integrator | ||
@unpack J,W,jac_config = cache | ||
@unpack J,W = cache | ||
mass_matrix = integrator.sol.prob.mass_matrix | ||
is_compos = typeof(integrator.alg) <: CompositeAlgorithm | ||
alg = unwrap_alg(integrator, true) | ||
|
@@ -84,13 +203,18 @@ function calc_W!(integrator, cache::OrdinaryDiffEqMutableCache, dtgamma, repeat_ | |
if !repeat_step && (!alg_can_repeat_jac(alg) || | ||
(integrator.iter < 1 || new_jac || | ||
abs(dt - (t-integrator.tprev)) > 100eps(typeof(integrator.t)))) | ||
if W_transform | ||
for j in 1:length(u), i in 1:length(u) | ||
W[i,j] = mass_matrix[i,j]/dtgamma - J[i,j] | ||
end | ||
else | ||
for j in 1:length(u), i in 1:length(u) | ||
W[i,j] = mass_matrix[i,j] - dtgamma*J[i,j] | ||
if DiffEqBase.has_jac(f) && isa(f.jac_prototype, DiffEqBase.AbstractDiffEqLinearOperator) | ||
set_gamma!(W, dtgamma) | ||
# W.transform = W_transform # necessary? | ||
else # compute W as a dense matrix | ||
if W_transform | ||
for j in 1:length(u), i in 1:length(u) | ||
W[i,j] = mass_matrix[i,j]/dtgamma - J[i,j] | ||
end | ||
else | ||
for j in 1:length(u), i in 1:length(u) | ||
W[i,j] = mass_matrix[i,j] - dtgamma*J[i,j] | ||
end | ||
end | ||
end | ||
else | ||
|
@@ -104,25 +228,32 @@ end | |
function calc_W!(integrator, cache::OrdinaryDiffEqConstantCache, dtgamma, repeat_step, W_transform=false) | ||
@unpack t,uprev,f = integrator | ||
@unpack uf = cache | ||
mass_matrix = integrator.sol.prob.mass_matrix | ||
# calculate W | ||
uf.t = t | ||
isarray = typeof(uprev) <: AbstractArray | ||
iscompo = typeof(integrator.alg) <: CompositeAlgorithm | ||
if !W_transform | ||
if isarray | ||
J = ForwardDiff.jacobian(uf,uprev) | ||
W = I - dtgamma*J | ||
if DiffEqBase.has_jac(f) && isa(f.jac_prototype, DiffEqBase.AbstractDiffEqLinearOperator) | ||
W = WOperator(mass_matrix, dtgamma, deepcopy(f.jac_prototype); transform=false) | ||
else | ||
J = ForwardDiff.derivative(uf,uprev) | ||
W = 1 - dtgamma*J | ||
if isarray | ||
J = ForwardDiff.jacobian(uf,uprev) | ||
else | ||
J = ForwardDiff.derivative(uf,uprev) | ||
end | ||
W = mass_matrix - dtgamma*J | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. out of place isn't using WOperator yet? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's used in the branch where There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not all of the time? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess it's not necessary all of the time, and if we don't allow a linsolve option on it then it's pointless for now. |
||
end | ||
else | ||
if isarray | ||
J = ForwardDiff.jacobian(uf,uprev) | ||
W = I*inv(dtgamma) - J | ||
if DiffEqBase.has_jac(f) && isa(f.jac_prototype, DiffEqBase.AbstractDiffEqLinearOperator) | ||
W = WOperator(mass_matrix, dtgamma, deepcopy(f.jac_prototype); transform=true) | ||
else | ||
J = ForwardDiff.derivative(uf,uprev) | ||
W = inv(dtgamma) - J | ||
if isarray | ||
J = ForwardDiff.jacobian(uf,uprev) | ||
else | ||
J = ForwardDiff.derivative(uf,uprev) | ||
end | ||
W = mass_matrix*inv(dtgamma) - J | ||
end | ||
end | ||
iscompo && (integrator.eigen_est = isarray ? opnorm(J, Inf) : J) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
diffeqoperators are <: AbstractMatrix?