Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"

[compat]
FastClosures = "0.2, 0.3"
Expand Down
28 changes: 15 additions & 13 deletions src/LinearOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,9 @@ import Base.hcat, Base.vcat, Base.hvcat
abstract type AbstractLinearOperator{T} end
OperatorOrMatrix = Union{AbstractLinearOperator, AbstractMatrix}

include("adjtrans.jl")

eltype(A :: AbstractLinearOperator{T}) where {T} = T
isreal(A :: AbstractLinearOperator{T}) where {T} = T <: Real

include("PreallocatedLinearOperators.jl")

"""
Base type to represent a linear operator.
The usual arithmetic operations may be applied to operators
Expand Down Expand Up @@ -95,11 +91,12 @@ size(op :: AbstractLinearOperator) = (op.nrow, op.ncol)
Return the size of a linear operator along dimension `d`.
"""
function size(op :: AbstractLinearOperator, d :: Int)
nrow, ncol = size(op)
if d == 1
return op.nrow
return nrow
end
if d == 2
return op.ncol
return ncol
end
throw(LinearOperatorException("Linear operators only have 2 dimensions for now"))
end
Expand Down Expand Up @@ -137,11 +134,12 @@ Display basic information about a linear operator.
"""
function show(io :: IO, op :: AbstractLinearOperator)
s = "Linear operator\n"
s *= @sprintf(" nrow: %s\n", op.nrow)
s *= @sprintf(" ncol: %d\n", op.ncol)
nrow, ncol = size(op)
s *= @sprintf(" nrow: %s\n", nrow)
s *= @sprintf(" ncol: %d\n", ncol)
s *= @sprintf(" eltype: %s\n", eltype(op))
s *= @sprintf(" symmetric: %s\n", op.symmetric)
s *= @sprintf(" hermitian: %s\n", op.hermitian)
s *= @sprintf(" symmetric: %s\n", symmetric(op))
s *= @sprintf(" hermitian: %s\n", hermitian(op))
s *= @sprintf(" nprod: %d\n", nprod(op))
s *= @sprintf(" ntprod: %d\n", ntprod(op))
s *= @sprintf(" nctprod: %d\n", nctprod(op))
Expand Down Expand Up @@ -352,6 +350,13 @@ end
-(x :: Number, op :: AbstractLinearOperator) = x + (-op)


include("adjtrans.jl")
include("PreallocatedLinearOperators.jl")
include("qn.jl") # quasi-Newton operators
include("kron.jl")
include("TimedOperators.jl")


# Utility functions.

"""
Expand Down Expand Up @@ -730,9 +735,6 @@ function opHermitian(T :: AbstractMatrix)
opHermitian(d, T)
end

include("qn.jl") # quasi-Newton operators
include("kron.jl")

"""
Z = opRestriction(I, ncol)
Z = opRestriction(:, ncol)
Expand Down
36 changes: 36 additions & 0 deletions src/TimedOperators.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
using TimerOutputs

export TimedLinearOperator

mutable struct TimedLinearOperator{T} <: AbstractLinearOperator{T}
timer :: TimerOutput
op :: AbstractLinearOperator{T}
prod
tprod
ctprod
end

function TimedLinearOperator(op::AbstractLinearOperator{T}) where T
timer = TimerOutput()
prod(x) = @timeit timer "prod" op.prod(x)
tprod(x) = @timeit timer "tprod" op.tprod(x)
ctprod(x) = @timeit timer "ctprod" op.ctprod(x)
TimedLinearOperator{T}(timer, op, prod, tprod, ctprod)
end

TimedLinearOperator(op::AdjointLinearOperator) = adjoint(TimedLinearOperator(op.parent))
TimedLinearOperator(op::TransposeLinearOperator) = transpose(TimedLinearOperator(op.parent))
TimedLinearOperator(op::ConjugateLinearOperator) = conj(TimedLinearOperator(op.parent))

for fn ∈ (:size, :shape, :symmetric, :issymmetric, :hermitian, :ishermitian,
:nprod, :ntprod, :nctprod, :increase_nprod, :increase_ntprod, :increase_nctprod, :reset!)
@eval begin
$fn(A::TimedLinearOperator) = $fn(A.op)
end
end

function show(io :: IO, op :: TimedLinearOperator)
show(io, op.op)
show(io, op.timer)
end

2 changes: 1 addition & 1 deletion src/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ nprod(A::TransposeLinearOperator) = ntprod(A.parent)
ntprod(A::TransposeLinearOperator) = nprod(A.parent)
nctprod(A::TransposeLinearOperator) = nprod(A.parent) # (transpose(A))' = conj(A)

for f in [:nprod, :ntprod, :nctprod]
for f in [:nprod, :ntprod, :nctprod, :increase_nprod, :increase_ntprod, :increase_nctprod]
@eval begin
$f(A::ConjugateLinearOperator) = $f(A.parent)
end
Expand Down
39 changes: 39 additions & 0 deletions test/test_linop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,45 @@ function test_linop()
@test ntprod(op) == 0
@test nctprod(op) == 0
end

@testset ExtendedTestSet "Timers" begin
op = LinearOperator(rand(3,4) + im * rand(3,4))
top = TimedLinearOperator(op)
nprods = 5
ntprods = 4
nctprods = 7
for _ = 1 : nprods
op * rand(4)
end
for _ = 1 : ntprods
transpose(op) * rand(3)
end
for _ = 1 : nctprods
op' * rand(3)
end
for fn ∈ (:size, :shape, :symmetric, :issymmetric, :hermitian, :ishermitian, :nprod, :ntprod, :nctprod)
@eval begin
@test $fn($top) == $fn($top.op)
end
end

reset!(op)
reset!(top)

top2 = TimedLinearOperator(op') # the same as top'
nrow, ncol = size(op)
u = rand(nrow) + im * rand(nrow)
@test all(top' * u .== top2 * u)
v = rand(ncol) + im * rand(ncol)
@test all(top * v .== top2' * v)

top3 = TimedLinearOperator(transpose(op)) # the same as transpose(top)
nrow, ncol = size(op)
u = rand(nrow) + im * rand(nrow)
@test all(transpose(top) * u .== top3 * u)
v = rand(ncol) + im * rand(ncol)
@test all(top * v .== transpose(top3) * v)
end
end

test_linop()