Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions src/LinearOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ export AbstractLinearOperator, LinearOperator,
opInverse, opCholesky, opLDL, opHouseholder, opHermitian,
check_ctranspose, check_hermitian, check_positive_definite,
shape, hermitian, ishermitian, symmetric, issymmetric,
nprod, ntprod, nctprod,
opRestriction, opExtension


Expand Down Expand Up @@ -53,6 +54,32 @@ mutable struct LinearOperator{T} <: AbstractLinearOperator{T}
prod # apply the operator to a vector
tprod # apply the transpose operator to a vector
ctprod # apply the transpose conjugate operator to a vector
nprod :: Int
ntprod :: Int
nctprod :: Int
end

LinearOperator{T}(nrow::Int, ncol::Int, symmetric::Bool, hermitian::Bool, prod, tprod, ctprod) where T =
LinearOperator{T}(nrow, ncol, symmetric, hermitian, prod, tprod, ctprod, 0, 0, 0)

nprod(op::AbstractLinearOperator) = op.nprod
ntprod(op::AbstractLinearOperator) = op.ntprod
nctprod(op::AbstractLinearOperator) = op.nctprod

increase_nprod(op::AbstractLinearOperator) = (op.nprod += 1)
increase_ntprod(op::AbstractLinearOperator) = (op.ntprod += 1)
increase_nctprod(op::AbstractLinearOperator) = (op.nctprod += 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to define it for Adjoint, Transpose and Conjugate as well

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. It's a non-issue at this point because we're not wrapping operators into other operators, but it will be done in the next PR.


"""
reset!(op)

Reset the product counters of a linear operator.
"""
function reset!(op::AbstractLinearOperator)
op.nprod = 0
op.ntprod = 0
op.nctprod = 0
return op
end

"""
Expand Down Expand Up @@ -115,9 +142,9 @@ function show(io :: IO, op :: AbstractLinearOperator)
s *= @sprintf(" eltype: %s\n", eltype(op))
s *= @sprintf(" symmetric: %s\n", op.symmetric)
s *= @sprintf(" hermitian: %s\n", op.hermitian)
#s *= @sprintf(" prod: %s\n", string(op.prod))
#s *= @sprintf(" tprod: %s\n", string(op.tprod))
#s *= @sprintf(" ctprod: %s", string(op.ctprod))
s *= @sprintf(" nprod: %d\n", nprod(op))
s *= @sprintf(" ntprod: %d\n", ntprod(op))
s *= @sprintf(" nctprod: %d\n", nctprod(op))
s *= "\n"
print(io, s)
end
Expand Down Expand Up @@ -216,6 +243,7 @@ end
# Apply an operator to a vector.
function *(op :: AbstractLinearOperator, v :: AbstractVector)
size(v, 1) == size(op, 2) || throw(LinearOperatorException("shape mismatch"))
increase_nprod(op)
op.prod(v)
end

Expand Down
9 changes: 9 additions & 0 deletions src/PreallocatedLinearOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,14 @@ mutable struct PreallocatedLinearOperator{T} <: AbstractPreallocatedLinearOperat
prod # apply the operator to a vector
tprod # apply the transpose operator to a vector
ctprod # apply the transpose conjugate operator to a vector
nprod :: Int
ntprod :: Int
nctprod :: Int
end

PreallocatedLinearOperator{T}(nrow::Int, ncol::Int, symmetric::Bool, hermitian::Bool, prod, tprod, ctprod) where T =
PreallocatedLinearOperator{T}(nrow, ncol, symmetric, hermitian, prod, tprod, ctprod, 0, 0, 0)

"""
show(io, op)

Expand All @@ -39,6 +45,9 @@ function show(io :: IO, op :: AbstractPreallocatedLinearOperator)
s *= @sprintf(" eltype: %s\n", eltype(op))
s *= @sprintf(" symmetric: %s\n", op.symmetric)
s *= @sprintf(" hermitian: %s\n", op.hermitian)
s *= @sprintf(" nprod: %d\n", nprod(op))
s *= @sprintf(" ntprod: %d\n", ntprod(op))
s *= @sprintf(" nctprod: %d\n", nctprod(op))
s *= "\n"
print(io, s)
end
Expand Down
38 changes: 36 additions & 2 deletions src/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,20 @@ conj(A :: TransposeLinearOperator) = adjoint(A.parent)
transpose(A :: AdjointLinearOperator) = conj(A.parent)
transpose(A :: ConjugateLinearOperator) = adjoint(A.parent)

nprod(A::AdjointLinearOperator) = nctprod(A.parent)
ntprod(A::AdjointLinearOperator) = nprod(A.parent) # transpose(A') = conj(A)
nctprod(A::AdjointLinearOperator) = nprod(A.parent) # (A')' == A

nprod(A::TransposeLinearOperator) = ntprod(A.parent)
ntprod(A::TransposeLinearOperator) = nprod(A.parent)
nctprod(A::TransposeLinearOperator) = nprod(A.parent) # (transpose(A))' = conj(A)

for f in [:nprod, :ntprod, :nctprod]
@eval begin
$f(A::ConjugateLinearOperator) = $f(A.parent)
end
end

const AdjTrans = Union{AdjointLinearOperator,TransposeLinearOperator}

size(A :: AdjTrans) = size(A.parent)[[2;1]]
Expand Down Expand Up @@ -61,31 +75,51 @@ function *(op :: AdjointLinearOperator, v :: AbstractVector)
length(v) == size(op.parent, 1) || throw(LinearOperatorException("shape mismatch"))
p = op.parent
ishermitian(p) && return p * v
p.ctprod !== nothing && return p.ctprod(v)
if p.ctprod !== nothing
increase_nctprod(p)
return p.ctprod(v)
end
tprod = p.tprod
increment_tprod = true
if p.tprod === nothing
if issymmetric(p)
increment_tprod = false
tprod = p.prod
else
throw(LinearOperatorException("unable to infer conjugate transpose operator"))
end
end
if increment_tprod
increase_ntprod(p)
else
increase_nprod(p)
end
return conj.(tprod(conj.(v)))
end

function *(op :: TransposeLinearOperator, v :: AbstractVector)
length(v) == size(op.parent, 1) || throw(LinearOperatorException("shape mismatch"))
p = op.parent
issymmetric(p) && return p * v
p.tprod !== nothing && return p.tprod(v)
if p.tprod !== nothing
increase_ntprod(p)
return p.tprod(v)
end
increment_ctprod = true
ctprod = p.ctprod
if p.ctprod === nothing
if ishermitian(p)
increment_ctprod = false
ctprod = p.prod
else
throw(LinearOperatorException("unable to infer transpose operator"))
end
end
if increment_ctprod
increase_nctprod(p)
else
increase_nprod(p)
end
return conj.(ctprod(conj.(v)))
end

Expand Down
8 changes: 8 additions & 0 deletions src/lbfgs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,13 @@ mutable struct LBFGSOperator{T} <: AbstractLinearOperator{T}
ctprod # apply the transpose conjugate operator to a vector
inverse :: Bool
data :: LBFGSData{T}
nprod :: Int
ntprod :: Int
nctprod :: Int
end

LBFGSOperator{T}(nrow::Int, ncol::Int, symmetric::Bool, hermitian::Bool, prod, tprod, ctprod, inverse::Bool, data::LBFGSData{T}) where T =
LBFGSOperator{T}(nrow, ncol, symmetric, hermitian, prod, tprod, ctprod, inverse, data, 0, 0, 0)

"""
InverseLBFGSOperator(T, n, [mem=5; scaling=true])
Expand Down Expand Up @@ -259,5 +264,8 @@ Resets the LBFGS data of the given operator.
"""
function reset!(op :: LBFGSOperator)
reset!(op.data)
op.nprod = 0
op.ntprod = 0
op.nctprod = 0
return op
end
9 changes: 9 additions & 0 deletions src/lsr1.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,14 @@ mutable struct LSR1Operator{T} <: AbstractLinearOperator{T}
ctprod # apply the transpose conjugate operator to a vector
inverse :: Bool
data :: LSR1Data{T}
nprod :: Int
ntprod :: Int
nctprod :: Int
end

LSR1Operator{T}(nrow::Int, ncol::Int, symmetric::Bool, hermitian::Bool, prod, tprod, ctprod, inverse::Bool, data::LSR1Data{T}) where T =
LSR1Operator{T}(nrow, ncol, symmetric, hermitian, prod, tprod, ctprod, inverse, data, 0, 0, 0)

"""
LSR1Operator(T, n, [mem=5; scaling=false)
LSR1Operator(n, [mem=5; scaling=false)
Expand Down Expand Up @@ -183,5 +189,8 @@ Resets the LSR1 data of the given operator.
"""
function reset!(op :: LSR1Operator)
reset!(op.data)
op.nprod = 0
op.ntprod = 0
op.nctprod = 0
return op
end
41 changes: 41 additions & 0 deletions test/test_linop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,47 @@ function test_linop()
@test Matrix([adjtrans(opA) opA; opA adjtrans(opA)]) == Matrix([adjtrans(A) A; A adjtrans(A)])
end
end

@testset ExtendedTestSet "Counters" begin
op = LinearOperator(rand(3,4) + im * rand(3,4))
@test nprod(op) == 0
@test ntprod(op) == 0
@test nctprod(op) == 0
nprods = 5
ntprods = 4
nctprods = 7
for _ = 1 : nprods
op * rand(4)
end
for _ = 1 : ntprods
transpose(op) * rand(3)
end
for _ = 1 : nctprods
op' * rand(3)
end
@test nprod(op) == nprods
@test ntprod(op) == ntprods
@test nctprod(op) == nctprods
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could add a test here for transpose, adjoint and conjugate operators? So we can cover these lines. 💯

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. Done. Thanks.

for _ = 1 : nprods
conj(op) * rand(4)
end
@test nprod(op) == 2 * nprods

opᵀ = transpose(op)
@test nprod(opᵀ) == ntprod(op)
@test ntprod(opᵀ) == nprod(op)
@test nctprod(opᵀ) == nprod(op)

opᴴ = op'
@test nprod(opᴴ) == nctprod(op)
@test ntprod(opᴴ) == nprod(op)
@test nctprod(opᴴ) == nprod(op)

reset!(op)
@test nprod(op) == 0
@test ntprod(op) == 0
@test nctprod(op) == 0
end
end

test_linop()