diff --git a/src/jacvec_operators.jl b/src/jacvec_operators.jl index e9b10f51c..c85722b44 100644 --- a/src/jacvec_operators.jl +++ b/src/jacvec_operators.jl @@ -56,7 +56,7 @@ end ### Operator Implementation -mutable struct JacVecOperator{T,F,T1,T2,uType,P,tType} <: DiffEqBase.AbstractDiffEqLinearOperator{T} +mutable struct JacVecOperator{T,F,T1,T2,uType,P,tType,O} <: DiffEqBase.AbstractDiffEqLinearOperator{T} f::F cache1::T1 cache2::T2 @@ -64,13 +64,15 @@ mutable struct JacVecOperator{T,F,T1,T2,uType,P,tType} <: DiffEqBase.AbstractDif p::P t::tType autodiff::Bool + ishermitian::Bool + opnorm::O - function JacVecOperator{T}(f,p=nothing,t::Union{Nothing,Number}=nothing;autodiff=true) where T + function JacVecOperator{T}(f,p=nothing,t::Union{Nothing,Number}=nothing;autodiff=true,ishermitian=false,opnorm=true) where T p===nothing ? P = Any : P = typeof(p) t===nothing ? tType = Any : tType = typeof(t) - new{T,typeof(f),Nothing,Nothing,Any,P,tType}(f,nothing,nothing,nothing,nothing,nothing,autodiff) + new{T,typeof(f),Nothing,Nothing,Any,P,tType,typeof(opnorm)}(f,nothing,nothing,nothing,nothing,nothing,autodiff,ishermitian) end - function JacVecOperator{T}(f,u::AbstractArray,p=nothing,t::Union{Nothing,Number}=nothing;autodiff=true) where T + function JacVecOperator{T}(f,u::AbstractArray,p=nothing,t::Union{Nothing,Number}=nothing;autodiff=true,ishermitian=false,opnorm=true) where T if autodiff cache1 = ForwardDiff.Dual{JacVecTag}.(u, u) cache2 = ForwardDiff.Dual{JacVecTag}.(u, u) @@ -80,14 +82,16 @@ mutable struct JacVecOperator{T,F,T1,T2,uType,P,tType} <: DiffEqBase.AbstractDif end p===nothing ? P = Any : P = typeof(p) t===nothing ? tType = Any : tType = typeof(t) - new{T,typeof(f),typeof(cache1),typeof(cache2),typeof(u),P,tType}(f,cache1,cache2,u,p,t,autodiff) + new{T,typeof(f),typeof(cache1),typeof(cache2),typeof(u),P,tType,typeof(opnorm)}(f,cache1,cache2,u,p,t,autodiff,ishermitian,opnorm) end function JacVecOperator(f,u,args...;kwargs...) JacVecOperator{eltype(u)}(f,u,args...;kwargs...) end - end +LinearAlgebra.opnorm(L::JacVecOperator, p::Real=2) = L.opnorm +LinearAlgebra.ishermitian(L::JacVecOperator) = L.ishermitian + Base.size(L::JacVecOperator) = (length(L.cache1),length(L.cache1)) Base.size(L::JacVecOperator,i::Int) = length(L.cache1) function update_coefficients!(L::JacVecOperator,u,p,t) diff --git a/src/matrixfree_operators.jl b/src/matrixfree_operators.jl index d7dc0f514..72ac5cbec 100644 --- a/src/matrixfree_operators.jl +++ b/src/matrixfree_operators.jl @@ -6,7 +6,7 @@ mutable struct MatrixFreeOperator{F,N,S,O} <: AbstractMatrixFreeOperator{F} opnorm::O ishermitian::Bool function MatrixFreeOperator(f::F, args::N; - size=nothing, opnorm=nothing, ishermitian=false) where {F,N} + size=nothing, opnorm=true, ishermitian=false) where {F,N} @assert (N <: Tuple && length(args) in (1,2)) "Arguments of a "* "MatrixFreeOperator must be a tuple with one or two elements" return new{F,N,typeof(size),typeof(opnorm)}(f, args, size, opnorm, ishermitian) diff --git a/test/jacvec_operators.jl b/test/jacvec_operators.jl index 087bdea6d..801e527f4 100644 --- a/test/jacvec_operators.jl +++ b/test/jacvec_operators.jl @@ -52,12 +52,13 @@ function lorenz(du,u,p,t) end u0 = [1.0;0.0;0.0] tspan = (0.0,100.0) -ff = ODEFunction(lorenz,jac_prototype=JacVecOperator{Float64}(lorenz,u0)) -prob = ODEProblem(ff,u0,tspan) -@test_broken sol = solve(prob,Rosenbrock23()) -@test_broken sol = solve(prob,Rosenbrock23(linsolve=LinSolveGMRES(tol=1e-10))) - -ff = ODEFunction(lorenz,jac_prototype=JacVecOperator{Float64}(lorenz,u0,autodiff=false)) -prob = ODEProblem(ff,u0,tspan) -@test_broken sol = solve(prob,Rosenbrock23()) -@test_broken sol = solve(prob,Rosenbrock23(linsolve=LinSolveGMRES(tol=1e-10))) +ff1 = ODEFunction(lorenz,jac_prototype=JacVecOperator{Float64}(lorenz,u0)) +ff2 = ODEFunction(lorenz,jac_prototype=JacVecOperator{Float64}(lorenz,u0,autodiff=false)) +for ff in [ff1, ff2] + prob = ODEProblem(ff,u0,tspan) + @test solve(prob,TRBDF2()).retcode == :Success + @test solve(prob,TRBDF2(linsolve=LinSolveGMRES(tol=1e-10))).retcode == :Success + @test solve(prob,Exprb32()).retcode == :Success + @test_broken sol = solve(prob,Rosenbrock23()) + @test_broken sol = solve(prob,Rosenbrock23(linsolve=LinSolveGMRES(tol=1e-10))) +end