Skip to content

Commit

Permalink
add mulstyle feature
Browse files Browse the repository at this point in the history
refactor linear combination multiplication

try to get rid of allocations in lincomb mul

show coefficients

more details in uniform scaling mul

opt-out of allocation testing

fix mulstyle error in kronecker, improve coverage
  • Loading branch information
dkarrasch committed Nov 21, 2019
1 parent 8201996 commit 88171b8
Show file tree
Hide file tree
Showing 12 changed files with 134 additions and 89 deletions.
17 changes: 17 additions & 0 deletions src/LinearMaps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,23 @@ const MapOrMatrix{T} = Union{LinearMap{T},AbstractMatrix{T}}

Base.eltype(::LinearMap{T}) where {T} = T

abstract type MulStyle end

struct FiveArg <: MulStyle end
struct ThreeArg <: MulStyle end

mulstyle(::Type{FiveArg}, ::Type{FiveArg}) = FiveArg
mulstyle(::Type{ThreeArg}, ::Type{FiveArg}) = ThreeArg
mulstyle(::Type{FiveArg}, ::Type{ThreeArg}) = ThreeArg
mulstyle(::Type{ThreeArg}, ::Type{ThreeArg}) = ThreeArg
mulstyle(::LinearMap) = ThreeArg # default
if VERSION v"1.3.0-alpha.115"
mulstyle(::AbstractMatrix) = FiveArg
else
mulstyle(::AbstractMatrix) = ThreeArg
end
mulstyle(A::LinearMap, As::LinearMap...) = mulstyle(mulstyle(A), mulstyle(As...))

Base.isreal(A::LinearMap) = eltype(A) <: Real
LinearAlgebra.issymmetric(::LinearMap) = false # default assumptions
LinearAlgebra.ishermitian(A::LinearMap{<:Real}) = issymmetric(A)
Expand Down
2 changes: 2 additions & 0 deletions src/blockmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ end

BlockMap{T}(maps::As, rows::S) where {T,As<:Tuple{Vararg{LinearMap}},S} = BlockMap{T,As,S}(maps, rows)

mulstyle(A::BlockMap) = mulstyle(A.maps...)

function check_dim(A::LinearMap, dim, n)
n == size(A, dim) || throw(DimensionMismatch("Expected $n, got $(size(A, dim))"))
return nothing
Expand Down
129 changes: 54 additions & 75 deletions src/linearcombination.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
struct LinearCombination{T, As<:Tuple{Vararg{LinearMap}}} <: LinearMap{T}
struct LinearCombination{T, MS<:MulStyle, As<:Tuple{Vararg{LinearMap}}} <: LinearMap{T}
maps::As
function LinearCombination{T, As}(maps::As) where {T, As}
function LinearCombination{T, MS, As}(maps::As) where {T, MS<:MulStyle, As}
N = length(maps)
sz = size(maps[1])
for n in 1:N
size(maps[n]) == sz || throw(DimensionMismatch("LinearCombination"))
promote_type(T, eltype(maps[n])) == T || throw(InexactError())
for Ai in maps
size(Ai) == sz || throw(DimensionMismatch("LinearCombination"))
promote_type(T, eltype(Ai)) == T || throw(InexactError())
end
new{T, As}(maps)
MS === FiveArg && mulstyle(maps...) === ThreeArg && throw("wrong mulstyle in constructor")
new{T, MS, As}(maps)
end
end

LinearCombination{T}(maps::As) where {T, As} = LinearCombination{T, As}(maps)
LinearCombination{T,MS}(maps::As) where {T, MS<:MulStyle, As} = LinearCombination{T, mulstyle(maps...), As}(maps)

mulstyle(::LinearCombination{T,MS}) where {T, MS<:MulStyle} = MS

# basic methods
Base.size(A::LinearCombination) = size(A.maps[1])
Expand Down Expand Up @@ -39,106 +42,82 @@ julia> LinearMap(ones(Int, 3, 3)) + CS + I + rand(3, 3);
function Base.:(+)(A₁::LinearMap, A₂::LinearMap)
size(A₁) == size(A₂) || throw(DimensionMismatch("+"))
T = promote_type(eltype(A₁), eltype(A₂))
return LinearCombination{T}(tuple(A₁, A₂))
return LinearCombination{T, mulstyle(A₁, A₂)}(tuple(A₁, A₂))
end
function Base.:(+)(A₁::LinearMap, A₂::LinearCombination)
size(A₁) == size(A₂) || throw(DimensionMismatch("+"))
T = promote_type(eltype(A₁), eltype(A₂))
return LinearCombination{T}(tuple(A₁, A₂.maps...))
return LinearCombination{T, mulstyle(A₁, A₂)}(tuple(A₁, A₂.maps...))
end
Base.:(+)(A₁::LinearCombination, A₂::LinearMap) = +(A₂, A₁)
function Base.:(+)(A₁::LinearCombination, A₂::LinearCombination)
size(A₁) == size(A₂) || throw(DimensionMismatch("+"))
T = promote_type(eltype(A₁), eltype(A₂))
return LinearCombination{T}(tuple(A₁.maps..., A₂.maps...))
return LinearCombination{T, mulstyle(A₁, A₂)}(tuple(A₁.maps..., A₂.maps...))
end
Base.:(-)(A₁::LinearMap, A₂::LinearMap) = +(A₁, -A₂)

# comparison of LinearCombination objects, sufficient but not necessary
Base.:(==)(A::LinearCombination, B::LinearCombination) = (eltype(A) == eltype(B) && A.maps == B.maps)

# special transposition behavior
LinearAlgebra.transpose(A::LinearCombination) = LinearCombination{eltype(A)}(map(transpose, A.maps))
LinearAlgebra.adjoint(A::LinearCombination) = LinearCombination{eltype(A)}(map(adjoint, A.maps))
LinearAlgebra.transpose(A::LinearCombination) = LinearCombination{eltype(A), mulstyle(A)}(map(transpose, A.maps))
LinearAlgebra.adjoint(A::LinearCombination) = LinearCombination{eltype(A), mulstyle(A)}(map(adjoint, A.maps))

# multiplication with vectors
if VERSION < v"1.3.0-alpha.115"

function A_mul_B!(y::AbstractVector, A::LinearCombination, x::AbstractVector)
# no size checking, will be done by individual maps
A_mul_B!(y, A.maps[1], x)
l = length(A.maps)
if l>1
z = similar(y)
for n in 2:l
A_mul_B!(z, A.maps[n], x)
y .+= z
end
for Atype in (AbstractVector, AbstractMatrix)
@eval Base.@propagate_inbounds function LinearAlgebra.mul!(y::$Atype, A::LinearCombination, x::$Atype,
α::Number=true, β::Number=false)
@boundscheck check_dim_mul(y, A, x)
return _lincombmul!(y, A, x, α, β)
end
return y
end

else # 5-arg mul! is available for matrices

# map types that have an allocation-free 5-arg mul! implementation
const FreeMap = Union{MatrixMap,UniformScalingMap}

function A_mul_B!(y::AbstractVector, A::LinearCombination{T,As}, x::AbstractVector) where {T, As<:Tuple{Vararg{FreeMap}}}
# no size checking, will be done by individual maps
A_mul_B!(y, A.maps[1], x)
for n in 2:length(A.maps)
mul!(y, A.maps[n], x, true, true)
end
return y
end
function A_mul_B!(y::AbstractVector, A::LinearCombination, x::AbstractVector)
# no size checking, will be done by individual maps
A_mul_B!(y, A.maps[1], x)
l = length(A.maps)
if l>1
z = similar(y)
for n in 2:l
An = A.maps[n]
if An isa FreeMap
mul!(y, An, x, true, true)
else
A_mul_B!(z, A.maps[n], x)
y .+= z
end
@inline function _lincombmul!(y, A::LinearCombination{<:Any,FiveArg}, x, α::Number, β::Number)
if iszero(α) # trivial cases
iszero(β) && (fill!(y, zero(eltype(y))); return y)
isone(β) && return y
# β != 0, 1
rmul!(y, β)
return y
else
mul!(y, first(A.maps), x, α, β)
@inbounds for An in Base.tail(A.maps)
mul!(y, An, x, α, true)
end
return y
end
return y
end

function LinearAlgebra.mul!(y::AbstractVector, A::LinearCombination{T,As}, x::AbstractVector, α::Number=true, β::Number=false) where {T, As<:Tuple{Vararg{FreeMap}}}
length(y) == size(A, 1) || throw(DimensionMismatch("mul!"))
if isone(α)
iszero(β) && (A_mul_B!(y, A, x); return y)
!isone(β) && rmul!(y, β)
elseif iszero(α)
@inline function _lincombmul!(y, A::LinearCombination{<:Any,ThreeArg}, x, α::Number, β::Number)
if iszero(α)
iszero(β) && (fill!(y, zero(eltype(y))); return y)
isone(β) && return y
# β != 0, 1
rmul!(y, β)
return y
else # α != 0, 1
if iszero(β)
A_mul_B!(y, A, x)
rmul!(y, α)
return y
elseif !isone(β)
rmul!(y, β)
end # β-cases
end # α-cases

for An in A.maps
mul!(y, An, x, α, true)
else
mul!(y, first(A.maps), x, α, β)
l = length(A.maps)
if l>1
z = similar(y)
@inbounds for n in 2:l
An = A.maps[n]
muladd!(mulstyle(An), y, An, x, α, z)
end
end
return y
end
return y
end

end # VERSION
@inline muladd!(::Type{FiveArg}, y, A, x, α, _) = mul!(y, A, x, α, true)
@inline function muladd!(::Type{ThreeArg}, y, A, x, α, z)
A_mul_B!(z, A, x)
y .+= isone(α) ? z : z .* α
end

A_mul_B!(y::AbstractVector, A::LinearCombination, x::AbstractVector) = mul!(y, A, x)

At_mul_B!(y::AbstractVector, A::LinearCombination, x::AbstractVector) = A_mul_B!(y, transpose(A), x)
At_mul_B!(y::AbstractVector, A::LinearCombination, x::AbstractVector) = mul!(y, transpose(A), x)

Ac_mul_B!(y::AbstractVector, A::LinearCombination, x::AbstractVector) = A_mul_B!(y, adjoint(A), x)
Ac_mul_B!(y::AbstractVector, A::LinearCombination, x::AbstractVector) = mul!(y, adjoint(A), x)
2 changes: 2 additions & 0 deletions src/transpose.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ struct AdjointMap{T, A<:LinearMap{T}} <: LinearMap{T}
lmap::A
end

mulstyle(A::Union{TransposeMap,AdjointMap}) = mulstyle(A.lmap)

# transposition behavior of LinearMap objects
LinearAlgebra.transpose(A::TransposeMap) = A.lmap
LinearAlgebra.adjoint(A::AdjointMap) = A.lmap
Expand Down
23 changes: 18 additions & 5 deletions src/uniformscalingmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ UniformScalingMap(λ::Number, M::Int, N::Int) =
UniformScalingMap::T, sz::Dims{2}) where {T} =
(sz[1] == sz[2] ? UniformScalingMap(λ, sz[1]) : error("UniformScalingMap needs to be square"))

mulstyle(::UniformScalingMap) = FiveArg

# properties
Base.size(A::UniformScalingMap) = (A.M, A.M)
Base.isreal(A::UniformScalingMap) = isreal(A.λ)
Expand Down Expand Up @@ -91,11 +93,22 @@ function _scaling!(y, J::UniformScalingMap, x, α::Number=true, β::Number=false
rmul!(y, β)
return y
else # α != 0, 1
iszero(β) && (y .= λ .* x .* α; return y)
isone(β) && (y .+= λ .* x .* α; return y)
# β != 0, 1
y .= y .* β .+ λ .* x .* α
return y
if iszero(β)
iszero(λ) && return fill!(y, zero(eltype(y)))
isone(λ) && return y .= x .* α
y .= λ .* x .* α
return y
elseif isone(β)
iszero(λ) && return y
isone(λ) && return y .+= x .* α
y .+= λ .* x .* α
return y
else # β != 0, 1
iszero(λ) && (rmul!(y, β); return y)
isone(λ) && (y .= y .* β .+ x .* α; return y)
y .= y .* β .+ λ .* x .* α
return y
end
end
end

Expand Down
2 changes: 2 additions & 0 deletions src/wrappedmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ end

const MatrixMap{T} = WrappedMap{T,<:AbstractMatrix}

mulstyle(A::WrappedMap) = mulstyle(A.lmap)

LinearAlgebra.transpose(A::MatrixMap{T}) where {T} =
WrappedMap{T}(transpose(A.lmap); issymmetric=A._issymmetric, ishermitian=A._ishermitian, isposdef=A._isposdef)
LinearAlgebra.adjoint(A::MatrixMap{T}) where {T} =
Expand Down
6 changes: 5 additions & 1 deletion test/blockmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using Test, LinearMaps, LinearAlgebra
A11 = rand(elty, 10, 10)
A12 = rand(elty, 10, n2)
L = @inferred hcat(LinearMap(A11), LinearMap(A12))
@test @inferred(LinearMaps.mulstyle(L)) == matrixstyle
@test L isa LinearMaps.BlockMap{elty}
A = [A11 A12]
x = rand(10+n2)
Expand Down Expand Up @@ -36,6 +37,7 @@ using Test, LinearMaps, LinearAlgebra
A21 = rand(elty, 20, 10)
L = @inferred vcat(LinearMap(A11), LinearMap(A21))
@test L isa LinearMaps.BlockMap{elty}
@test @inferred(LinearMaps.mulstyle(L)) == matrixstyle
A = [A11; A21]
x = rand(10)
@test size(L) == size(A)
Expand All @@ -62,6 +64,7 @@ using Test, LinearMaps, LinearAlgebra
A = [A11 A12; A21 A22]
@inferred hvcat((2,2), LinearMap(A11), LinearMap(A12), LinearMap(A21), LinearMap(A22))
L = [LinearMap(A11) LinearMap(A12); LinearMap(A21) LinearMap(A22)]
@test @inferred(LinearMaps.mulstyle(L)) == matrixstyle
@test @inferred !issymmetric(L)
@test @inferred !ishermitian(L)
x = rand(30)
Expand Down Expand Up @@ -102,12 +105,13 @@ using Test, LinearMaps, LinearAlgebra
@test Matrix(adjoint(B)) == C'
end
end

@testset "adjoint/transpose" begin
for elty in (Float32, Float64, ComplexF64), transform in (transpose, adjoint)
A12 = rand(elty, 10, 10)
A = [I A12; transform(A12) I]
L = [I LinearMap(A12); transform(LinearMap(A12)) I]
@test @inferred(LinearMaps.mulstyle(L)) == matrixstyle
if elty <: Complex
if transform == transpose
@test @inferred issymmetric(L)
Expand Down
8 changes: 7 additions & 1 deletion test/kronecker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using Test, LinearMaps, LinearAlgebra
LB = LinearMap(B)
LK = @inferred kron(LA, LB)
@test @inferred size(LK) == size(K)
@test LinearMaps.mulstyle(LK) == LinearMaps.ThreeArg
for i in (1, 2)
@test @inferred size(LK, i) == size(K, i)
end
Expand All @@ -31,6 +32,11 @@ using Test, LinearMaps, LinearAlgebra
@test @inferred kron(LA, LB)' == @inferred kron(LA', LB')
@test (@inferred kron(LA, B)) == (@inferred kron(LA, LB)) == (@inferred kron(A, LB))
@test @inferred ishermitian(kron(LA'LA, LB'LB))
A = rand(2, 5); B = rand(4, 2)
K = @inferred kron(A, LinearMap(B))
@test Matrix(K) kron(A, B)
K = @inferred kron(LinearMap(B), A)
@test Matrix(K) kron(B, A)
A = rand(3, 3); B = rand(2, 2); LA = LinearMap(A); LB = LinearMap(B)
@test @inferred issymmetric(kron(LA'LA, LB'LB))
@test @inferred ishermitian(kron(LA'LA, LB'LB))
Expand Down Expand Up @@ -59,7 +65,7 @@ using Test, LinearMaps, LinearAlgebra
@test Matrix(kronsum(transform(LA), transform(LB))) transform(KSmat)
@test Matrix(transform(LinearMap(kronsum(LA, LB)))) Matrix(transform(KS)) transform(KSmat)
end
@inferred kronsum(A, A, LB)
@test @inferred(kronsum(A, A, LB)) == @inferred((A, A, B))
@test Matrix(@inferred LA^⊕(3)) == Matrix(@inferred A^⊕(3)) Matrix(kronsum(LA, A, A))
@test @inferred(kronsum(LA, LA, LB)) == @inferred(kronsum(LA, kronsum(LA, LB))) == @inferred(kronsum(A, A, B))
@test Matrix(@inferred kronsum(A, B, A, B, A, B)) Matrix(@inferred kronsum(LA, LB, LA, LB, LA, LB))
Expand Down
22 changes: 17 additions & 5 deletions test/linearcombination.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using Test, LinearMaps, LinearAlgebra, BenchmarkTools
CS! = LinearMap{ComplexF64}(cumsum!,
(y, x) -> (copyto!(y, x); reverse!(y); cumsum!(y, y)), 10;
ismutating=true)
v = rand(10)
v = rand(ComplexF64, 10)
u = similar(v)
b = @benchmarkable mul!($u, $CS!, $v)
@test run(b, samples=3).allocs == 0
Expand All @@ -13,12 +13,20 @@ using Test, LinearMaps, LinearAlgebra, BenchmarkTools
@test mul!(u, L, v) n * cumsum(v)
b = @benchmarkable mul!($u, $L, $v)
@test run(b, samples=5).allocs <= 1
for α in (false, true, rand(ComplexF64)), β in (false, true, rand(ComplexF64))
@test mul!(copy(u), L, v, α, β) Matrix(L)*v*α + u*β
end

A = 2 * rand(ComplexF64, (10, 10)) .- 1
B = rand(ComplexF64, size(A)...)
M = @inferred LinearMap(A)
N = @inferred LinearMap(B)
@test @inferred(LinearMaps.mulstyle(M)) == matrixstyle
@test @inferred(LinearMaps.mulstyle(N)) == matrixstyle
LC = @inferred M + N
@test @inferred(LinearMaps.mulstyle(LC)) == matrixstyle
@test @inferred(LinearMaps.mulstyle(LC + I)) == matrixstyle
@test @inferred(LinearMaps.mulstyle(LC + 2.0*I)) == matrixstyle
v = rand(ComplexF64, 10)
w = similar(v)
b = @benchmarkable mul!($w, $M, $v)
Expand All @@ -27,10 +35,14 @@ using Test, LinearMaps, LinearAlgebra, BenchmarkTools
b = @benchmarkable mul!($w, $LC, $v)
@test run(b, samples=3).allocs == 0
for α in (false, true, rand(ComplexF64)), β in (false, true, rand(ComplexF64))
b = @benchmarkable mul!($w, $LC, $v, $α, $β)
@test run(b, samples=3).allocs == 0
b = @benchmarkable mul!($w, $(LC + I), $v, $α, $β)
@test run(b, samples=3).allocs == 0
if testallocs
b = @benchmarkable mul!($w, $LC, $v, $α, $β)
@test run(b, samples=3).allocs == 0
b = @benchmarkable mul!($w, $(I + LC), $v, $α, $β)
@test run(b, samples=3).allocs == 0
b = @benchmarkable mul!($w, $(LC + I), $v, $α, $β)
@test run(b, samples=3).allocs == 0
end
y = rand(ComplexF64, size(v))
@test mul!(copy(y), LC, v, α, β) Matrix(LC)*v*α + y*β
@test mul!(copy(y), LC+I, v, α, β) Matrix(LC + I)*v*α + y*β
Expand Down
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
using Test, LinearMaps
import LinearMaps: FiveArg, ThreeArg

const matrixstyle = VERSION v"1.3.0-alpha.115" ? FiveArg : ThreeArg

const testallocs = false

include("linearmaps.jl")

Expand Down

0 comments on commit 88171b8

Please sign in to comment.