diff --git a/src/LinearOperators.jl b/src/LinearOperators.jl index e72d20f6..f5025060 100644 --- a/src/LinearOperators.jl +++ b/src/LinearOperators.jl @@ -9,6 +9,7 @@ export AbstractLinearOperator, LinearOperator, opInverse, opCholesky, opLDL, opHouseholder, opHermitian, check_ctranspose, check_hermitian, check_positive_definite, shape, hermitian, ishermitian, symmetric, issymmetric, + nprod, ntprod, nctprod, opRestriction, opExtension @@ -53,6 +54,32 @@ mutable struct LinearOperator{T} <: AbstractLinearOperator{T} prod # apply the operator to a vector tprod # apply the transpose operator to a vector ctprod # apply the transpose conjugate operator to a vector + nprod :: Int + ntprod :: Int + nctprod :: Int +end + +LinearOperator{T}(nrow::Int, ncol::Int, symmetric::Bool, hermitian::Bool, prod, tprod, ctprod) where T = + LinearOperator{T}(nrow, ncol, symmetric, hermitian, prod, tprod, ctprod, 0, 0, 0) + +nprod(op::AbstractLinearOperator) = op.nprod +ntprod(op::AbstractLinearOperator) = op.ntprod +nctprod(op::AbstractLinearOperator) = op.nctprod + +increase_nprod(op::AbstractLinearOperator) = (op.nprod += 1) +increase_ntprod(op::AbstractLinearOperator) = (op.ntprod += 1) +increase_nctprod(op::AbstractLinearOperator) = (op.nctprod += 1) + +""" + reset!(op) + +Reset the product counters of a linear operator. +""" +function reset!(op::AbstractLinearOperator) + op.nprod = 0 + op.ntprod = 0 + op.nctprod = 0 + return op end """ @@ -115,9 +142,9 @@ function show(io :: IO, op :: AbstractLinearOperator) s *= @sprintf(" eltype: %s\n", eltype(op)) s *= @sprintf(" symmetric: %s\n", op.symmetric) s *= @sprintf(" hermitian: %s\n", op.hermitian) - #s *= @sprintf(" prod: %s\n", string(op.prod)) - #s *= @sprintf(" tprod: %s\n", string(op.tprod)) - #s *= @sprintf(" ctprod: %s", string(op.ctprod)) + s *= @sprintf(" nprod: %d\n", nprod(op)) + s *= @sprintf(" ntprod: %d\n", ntprod(op)) + s *= @sprintf(" nctprod: %d\n", nctprod(op)) s *= "\n" print(io, s) end @@ -216,6 +243,7 @@ end # Apply an operator to a vector. function *(op :: AbstractLinearOperator, v :: AbstractVector) size(v, 1) == size(op, 2) || throw(LinearOperatorException("shape mismatch")) + increase_nprod(op) op.prod(v) end diff --git a/src/PreallocatedLinearOperators.jl b/src/PreallocatedLinearOperators.jl index 01e8edae..84778c8a 100644 --- a/src/PreallocatedLinearOperators.jl +++ b/src/PreallocatedLinearOperators.jl @@ -25,8 +25,14 @@ mutable struct PreallocatedLinearOperator{T} <: AbstractPreallocatedLinearOperat prod # apply the operator to a vector tprod # apply the transpose operator to a vector ctprod # apply the transpose conjugate operator to a vector + nprod :: Int + ntprod :: Int + nctprod :: Int end +PreallocatedLinearOperator{T}(nrow::Int, ncol::Int, symmetric::Bool, hermitian::Bool, prod, tprod, ctprod) where T = + PreallocatedLinearOperator{T}(nrow, ncol, symmetric, hermitian, prod, tprod, ctprod, 0, 0, 0) + """ show(io, op) @@ -39,6 +45,9 @@ function show(io :: IO, op :: AbstractPreallocatedLinearOperator) s *= @sprintf(" eltype: %s\n", eltype(op)) s *= @sprintf(" symmetric: %s\n", op.symmetric) s *= @sprintf(" hermitian: %s\n", op.hermitian) + s *= @sprintf(" nprod: %d\n", nprod(op)) + s *= @sprintf(" ntprod: %d\n", ntprod(op)) + s *= @sprintf(" nctprod: %d\n", nctprod(op)) s *= "\n" print(io, s) end diff --git a/src/adjtrans.jl b/src/adjtrans.jl index b8b0efd7..758a166b 100644 --- a/src/adjtrans.jl +++ b/src/adjtrans.jl @@ -28,6 +28,20 @@ conj(A :: TransposeLinearOperator) = adjoint(A.parent) transpose(A :: AdjointLinearOperator) = conj(A.parent) transpose(A :: ConjugateLinearOperator) = adjoint(A.parent) +nprod(A::AdjointLinearOperator) = nctprod(A.parent) +ntprod(A::AdjointLinearOperator) = nprod(A.parent) # transpose(A') = conj(A) +nctprod(A::AdjointLinearOperator) = nprod(A.parent) # (A')' == A + +nprod(A::TransposeLinearOperator) = ntprod(A.parent) +ntprod(A::TransposeLinearOperator) = nprod(A.parent) +nctprod(A::TransposeLinearOperator) = nprod(A.parent) # (transpose(A))' = conj(A) + +for f in [:nprod, :ntprod, :nctprod] + @eval begin + $f(A::ConjugateLinearOperator) = $f(A.parent) + end +end + const AdjTrans = Union{AdjointLinearOperator,TransposeLinearOperator} size(A :: AdjTrans) = size(A.parent)[[2;1]] @@ -61,15 +75,25 @@ function *(op :: AdjointLinearOperator, v :: AbstractVector) length(v) == size(op.parent, 1) || throw(LinearOperatorException("shape mismatch")) p = op.parent ishermitian(p) && return p * v - p.ctprod !== nothing && return p.ctprod(v) + if p.ctprod !== nothing + increase_nctprod(p) + return p.ctprod(v) + end tprod = p.tprod + increment_tprod = true if p.tprod === nothing if issymmetric(p) + increment_tprod = false tprod = p.prod else throw(LinearOperatorException("unable to infer conjugate transpose operator")) end end + if increment_tprod + increase_ntprod(p) + else + increase_nprod(p) + end return conj.(tprod(conj.(v))) end @@ -77,15 +101,25 @@ function *(op :: TransposeLinearOperator, v :: AbstractVector) length(v) == size(op.parent, 1) || throw(LinearOperatorException("shape mismatch")) p = op.parent issymmetric(p) && return p * v - p.tprod !== nothing && return p.tprod(v) + if p.tprod !== nothing + increase_ntprod(p) + return p.tprod(v) + end + increment_ctprod = true ctprod = p.ctprod if p.ctprod === nothing if ishermitian(p) + increment_ctprod = false ctprod = p.prod else throw(LinearOperatorException("unable to infer transpose operator")) end end + if increment_ctprod + increase_nctprod(p) + else + increase_nprod(p) + end return conj.(ctprod(conj.(v))) end diff --git a/src/lbfgs.jl b/src/lbfgs.jl index 76abcd87..9d9d85d7 100644 --- a/src/lbfgs.jl +++ b/src/lbfgs.jl @@ -48,8 +48,13 @@ mutable struct LBFGSOperator{T} <: AbstractLinearOperator{T} ctprod # apply the transpose conjugate operator to a vector inverse :: Bool data :: LBFGSData{T} + nprod :: Int + ntprod :: Int + nctprod :: Int end +LBFGSOperator{T}(nrow::Int, ncol::Int, symmetric::Bool, hermitian::Bool, prod, tprod, ctprod, inverse::Bool, data::LBFGSData{T}) where T = + LBFGSOperator{T}(nrow, ncol, symmetric, hermitian, prod, tprod, ctprod, inverse, data, 0, 0, 0) """ InverseLBFGSOperator(T, n, [mem=5; scaling=true]) @@ -259,5 +264,8 @@ Resets the LBFGS data of the given operator. """ function reset!(op :: LBFGSOperator) reset!(op.data) + op.nprod = 0 + op.ntprod = 0 + op.nctprod = 0 return op end diff --git a/src/lsr1.jl b/src/lsr1.jl index f6185b01..df84dd56 100644 --- a/src/lsr1.jl +++ b/src/lsr1.jl @@ -38,8 +38,14 @@ mutable struct LSR1Operator{T} <: AbstractLinearOperator{T} ctprod # apply the transpose conjugate operator to a vector inverse :: Bool data :: LSR1Data{T} + nprod :: Int + ntprod :: Int + nctprod :: Int end +LSR1Operator{T}(nrow::Int, ncol::Int, symmetric::Bool, hermitian::Bool, prod, tprod, ctprod, inverse::Bool, data::LSR1Data{T}) where T = + LSR1Operator{T}(nrow, ncol, symmetric, hermitian, prod, tprod, ctprod, inverse, data, 0, 0, 0) + """ LSR1Operator(T, n, [mem=5; scaling=false) LSR1Operator(n, [mem=5; scaling=false) @@ -183,5 +189,8 @@ Resets the LSR1 data of the given operator. """ function reset!(op :: LSR1Operator) reset!(op.data) + op.nprod = 0 + op.ntprod = 0 + op.nctprod = 0 return op end diff --git a/test/test_linop.jl b/test/test_linop.jl index 8fdd61b5..8e881f7d 100644 --- a/test/test_linop.jl +++ b/test/test_linop.jl @@ -551,6 +551,47 @@ function test_linop() @test Matrix([adjtrans(opA) opA; opA adjtrans(opA)]) == Matrix([adjtrans(A) A; A adjtrans(A)]) end end + + @testset ExtendedTestSet "Counters" begin + op = LinearOperator(rand(3,4) + im * rand(3,4)) + @test nprod(op) == 0 + @test ntprod(op) == 0 + @test nctprod(op) == 0 + nprods = 5 + ntprods = 4 + nctprods = 7 + for _ = 1 : nprods + op * rand(4) + end + for _ = 1 : ntprods + transpose(op) * rand(3) + end + for _ = 1 : nctprods + op' * rand(3) + end + @test nprod(op) == nprods + @test ntprod(op) == ntprods + @test nctprod(op) == nctprods + for _ = 1 : nprods + conj(op) * rand(4) + end + @test nprod(op) == 2 * nprods + + opᵀ = transpose(op) + @test nprod(opᵀ) == ntprod(op) + @test ntprod(opᵀ) == nprod(op) + @test nctprod(opᵀ) == nprod(op) + + opᴴ = op' + @test nprod(opᴴ) == nctprod(op) + @test ntprod(opᴴ) == nprod(op) + @test nctprod(opᴴ) == nprod(op) + + reset!(op) + @test nprod(op) == 0 + @test ntprod(op) == 0 + @test nctprod(op) == 0 + end end test_linop()