Skip to content
This repository was archived by the owner on Jul 19, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions src/jacvec_operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,23 @@ end

### Operator Implementation

mutable struct JacVecOperator{T,F,T1,T2,uType,P,tType} <: DiffEqBase.AbstractDiffEqLinearOperator{T}
mutable struct JacVecOperator{T,F,T1,T2,uType,P,tType,O} <: DiffEqBase.AbstractDiffEqLinearOperator{T}
f::F
cache1::T1
cache2::T2
u::uType
p::P
t::tType
autodiff::Bool
ishermitian::Bool
opnorm::O

function JacVecOperator{T}(f,p=nothing,t::Union{Nothing,Number}=nothing;autodiff=true) where T
function JacVecOperator{T}(f,p=nothing,t::Union{Nothing,Number}=nothing;autodiff=true,ishermitian=false,opnorm=true) where T
p===nothing ? P = Any : P = typeof(p)
t===nothing ? tType = Any : tType = typeof(t)
new{T,typeof(f),Nothing,Nothing,Any,P,tType}(f,nothing,nothing,nothing,nothing,nothing,autodiff)
new{T,typeof(f),Nothing,Nothing,Any,P,tType,typeof(opnorm)}(f,nothing,nothing,nothing,nothing,nothing,autodiff,ishermitian)
end
function JacVecOperator{T}(f,u::AbstractArray,p=nothing,t::Union{Nothing,Number}=nothing;autodiff=true) where T
function JacVecOperator{T}(f,u::AbstractArray,p=nothing,t::Union{Nothing,Number}=nothing;autodiff=true,ishermitian=false,opnorm=true) where T
if autodiff
cache1 = ForwardDiff.Dual{JacVecTag}.(u, u)
cache2 = ForwardDiff.Dual{JacVecTag}.(u, u)
Expand All @@ -80,14 +82,16 @@ mutable struct JacVecOperator{T,F,T1,T2,uType,P,tType} <: DiffEqBase.AbstractDif
end
p===nothing ? P = Any : P = typeof(p)
t===nothing ? tType = Any : tType = typeof(t)
new{T,typeof(f),typeof(cache1),typeof(cache2),typeof(u),P,tType}(f,cache1,cache2,u,p,t,autodiff)
new{T,typeof(f),typeof(cache1),typeof(cache2),typeof(u),P,tType,typeof(opnorm)}(f,cache1,cache2,u,p,t,autodiff,ishermitian,opnorm)
end
function JacVecOperator(f,u,args...;kwargs...)
JacVecOperator{eltype(u)}(f,u,args...;kwargs...)
end

end

LinearAlgebra.opnorm(L::JacVecOperator, p::Real=2) = L.opnorm
LinearAlgebra.ishermitian(L::JacVecOperator) = L.ishermitian

Base.size(L::JacVecOperator) = (length(L.cache1),length(L.cache1))
Base.size(L::JacVecOperator,i::Int) = length(L.cache1)
function update_coefficients!(L::JacVecOperator,u,p,t)
Expand Down
2 changes: 1 addition & 1 deletion src/matrixfree_operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mutable struct MatrixFreeOperator{F,N,S,O} <: AbstractMatrixFreeOperator{F}
opnorm::O
ishermitian::Bool
function MatrixFreeOperator(f::F, args::N;
size=nothing, opnorm=nothing, ishermitian=false) where {F,N}
size=nothing, opnorm=true, ishermitian=false) where {F,N}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this true/false?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

opnorm is defaulted to be one, as opnorm is only used to test relative tolerance in expv.

@assert (N <: Tuple && length(args) in (1,2)) "Arguments of a "*
"MatrixFreeOperator must be a tuple with one or two elements"
return new{F,N,typeof(size),typeof(opnorm)}(f, args, size, opnorm, ishermitian)
Expand Down
19 changes: 10 additions & 9 deletions test/jacvec_operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,13 @@ function lorenz(du,u,p,t)
end
u0 = [1.0;0.0;0.0]
tspan = (0.0,100.0)
ff = ODEFunction(lorenz,jac_prototype=JacVecOperator{Float64}(lorenz,u0))
prob = ODEProblem(ff,u0,tspan)
@test_broken sol = solve(prob,Rosenbrock23())
@test_broken sol = solve(prob,Rosenbrock23(linsolve=LinSolveGMRES(tol=1e-10)))

ff = ODEFunction(lorenz,jac_prototype=JacVecOperator{Float64}(lorenz,u0,autodiff=false))
prob = ODEProblem(ff,u0,tspan)
@test_broken sol = solve(prob,Rosenbrock23())
@test_broken sol = solve(prob,Rosenbrock23(linsolve=LinSolveGMRES(tol=1e-10)))
ff1 = ODEFunction(lorenz,jac_prototype=JacVecOperator{Float64}(lorenz,u0))
ff2 = ODEFunction(lorenz,jac_prototype=JacVecOperator{Float64}(lorenz,u0,autodiff=false))
for ff in [ff1, ff2]
prob = ODEProblem(ff,u0,tspan)
@test solve(prob,TRBDF2()).retcode == :Success
@test solve(prob,TRBDF2(linsolve=LinSolveGMRES(tol=1e-10))).retcode == :Success
@test solve(prob,Exprb32()).retcode == :Success
@test_broken sol = solve(prob,Rosenbrock23())
@test_broken sol = solve(prob,Rosenbrock23(linsolve=LinSolveGMRES(tol=1e-10)))
end