/
custom.jl
285 lines (201 loc) · 10.3 KB
/
custom.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
# # Defining custom `LinearMap` types
# In this section, we want to demonstrate on a simple, actually built-in, linear map type
# how to define custom `LinearMap` subtypes. First of all, `LinearMap{T}` is an extendable
# abstract type, where `T` denotes the `eltype`.
# ## Basics
# As an example, we want to define a map type whose objects correspond to lazy analogues
# of `fill`ed matrices. Naturally, we need to store the filled value `λ` and the `size`
# of the linear map.
using LinearMaps, LinearAlgebra
struct MyFillMap{T} <: LinearMaps.LinearMap{T}
λ::T
size::Dims{2}
function MyFillMap(λ::T, dims::Dims{2}) where {T}
all(≥(0), dims) || throw(ArgumentError("dims of MyFillMap must be non-negative"))
promote_type(T, typeof(λ)) == T || throw(InexactError())
return new{T}(λ, dims)
end
end
# By default, for any `A::MyFillMap{T}`, `eltype(A)` returns `T`. Upon application to a
# vector `x` and/or interaction with other `LinearMap` objects, we need to check consistent
# sizes.
Base.size(A::MyFillMap) = A.size
# By a couple of defaults provided for all subtypes of `LinearMap`, we only need to define
# a `LinearMaps._unsafe_mul!` method to have a minimal, operational type. The (internal)
# function `_unsafe_mul!` is called by `LinearAlgebra.mul!`, constructors, and conversions
# and only needs to be concerned with the bare computing kernel. Dimension checking is done
# on the level of `mul!` etc. Factoring out dimension checking is done to minimise overhead
# caused by repetitive checking.
# !!! note
# Multiple dispatch at the `_unsafe_mul!` level happens via the second (the map type)
# and the third arguments (`AbstractVector` or `AbstractMatrix`, see the
# [Application to matrices](@ref) section below). For that reason, the output argument
# can remain type-unbound.
function LinearMaps._unsafe_mul!(y, A::MyFillMap, x::AbstractVector)
return fill!(y, iszero(A.λ) ? zero(eltype(y)) : A.λ*sum(x))
end
# Again, due to generic fallbacks the following now "just work":
# * out-of-place multiplication `A*x`,
# * in-place multiplication with vectors `mul!(y, A, x)`,
# * in-place multiply-and-add with vectors `mul!(y, A, x, α, β)`,
# * in-place multiplication and multiply-and-add with matrices `mul!(Y, A, X, α, β)`,
# * conversion to a (sparse) matrix `Matrix(A)` and `sparse(A)`,
# * complete slicing of columns (and rows if the adjoint action is defined).
A = MyFillMap(5.0, (3, 3)); x = ones(3); sum(x)
#-
A * x
#-
mul!(zeros(3), A, x)
#-
mul!(ones(3), A, x, 2, 2)
#-
mul!(ones(3,3), A, reshape(collect(1:9), 3, 3), 2, 2)
# ## Multiply-and-add and the `MulStyle` trait
# While the above function calls work out of the box due to generic fallbacks, the latter
# may be suboptimally implemented for your custom map type. Let's see some benchmarks.
using BenchmarkTools
@benchmark mul!($(zeros(3)), $A, $x)
#-
@benchmark mul!($(zeros(3)), $A, $x, $(rand()), $(rand()))
# The second benchmark indicates the allocation of an intermediate vector `z`
# which stores the result of `A*x` before it gets scaled and added to (the scaled)
# `y = zeros(3)`. For that reason, it is beneficial to provide a custom "5-arg
# `_unsafe_mul!`" if you can avoid the allocation of an intermediate vector. To indicate
# that there exists an allocation-free implementation of multiply-and-add, you should set
# the `MulStyle` trait, whose default is `ThreeArg()`, to `FiveArg()`.
LinearMaps.MulStyle(A::MyFillMap) = FiveArg()
function LinearMaps._unsafe_mul!(y, A::MyFillMap, x::AbstractVector, α, β)
if iszero(α)
!isone(β) && rmul!(y, β)
return y
else
temp = A.λ * sum(x) * α
if iszero(β)
y .= temp
elseif isone(β)
y .+= temp
else
y .= y .* β .+ temp
end
end
return y
end
# With this function at hand, let's redo the benchmark.
@benchmark mul!($(zeros(3)), $A, $x, $(rand()), $(rand()))
# There you go, the allocation is gone and the computation time is significantly reduced.
# ## Adjoints and transposes
# Generically, taking the transpose (or the adjoint) of a (real, resp.) map wraps the
# linear map by a `TransposeMap`, taking the adjoint of a complex map wraps it by an
# `AdjointMap`.
typeof(A')
# Not surprisingly, without further definitions, multiplying `A'` by `x` yields an error.
try A'x catch e println(e) end
# If the operator is symmetric or Hermitian, the transpose and the adjoint, respectively,
# of the linear map `A` is given by `A` itself. So let us define corresponding checks.
LinearAlgebra.issymmetric(A::MyFillMap) = A.size[1] == A.size[2]
LinearAlgebra.ishermitian(A::MyFillMap) = isreal(A.λ) && A.size[1] == A.size[2]
LinearAlgebra.isposdef(A::MyFillMap) = (size(A, 1) == size(A, 2) == 1 && isposdef(A.λ))
Base.:(==)(A::MyFillMap, B::MyFillMap) = A.λ == B.λ && A.size == B.size
# These are used, for instance, in checking symmetry or positive definiteness of
# higher-order `LinearMap`s, like products or linear combinations of linear maps, or signal
# to iterative eigenproblem solvers that real eigenvalues are to be computed.
# Without these definitions, the first three functions would return `false` (by default),
# and the last one would fall back to `===`.
# With this at hand, we note that `A` above is symmetric, and we can compute
transpose(A)*x
# This, however, does not work for nonsquare maps
try MyFillMap(5.0, (3, 4))' * ones(3) catch e println(e) end
# which require explicit adjoint/transpose handling, for which there exist two *distinct*
# paths.
# ### Path 1: Generic, non-invariant `LinearMap` subtypes
# The first option is to write `LinearMaps._unsafe_mul!` methods for the corresponding
# wrapped map types; for instance,
function LinearMaps._unsafe_mul!(
y,
transA::LinearMaps.TransposeMap{<:Any,<:MyFillMap},
x::AbstractVector
)
λ = transA.lmap.λ
return fill!(y, iszero(λ) ? zero(eltype(y)) : transpose(λ)*sum(x))
end
# Now, the adjoint multiplication works.
MyFillMap(5.0, (3, 4))' * ones(3)
# If you have set the `MulStyle` trait to `FiveArg()`, you should provide a corresponding
# 5-arg `mul!` method for `LinearMaps.TransposeMap{<:Any,<:MyFillMap}` and
# `LinearMaps.AdjointMap{<:Any,<:MyFillMap}`.
# ### Path 2: Invariant `LinearMap` subtypes
# Before we start, let us delete the previously defined method to make sure we use the
# following definitions.
Base.delete_method(
first(methods(
LinearMaps._unsafe_mul!,
(Any, LinearMaps.TransposeMap{<:Any,<:MyFillMap}, AbstractVector))
)
)
# The second option is when your class of linear maps that are modelled by your custom
# `LinearMap` subtype are invariant under taking adjoints and transposes.
LinearAlgebra.adjoint(A::MyFillMap) = MyFillMap(adjoint(A.λ), reverse(A.size))
LinearAlgebra.transpose(A::MyFillMap) = MyFillMap(transpose(A.λ), reverse(A.size))
# With such invariant definitions, i.e., the adjoint/transpose of a `MyFillMap` is again
# a `MyFillMap`, no further method definitions are required, and the entire functionality
# listed above just works for adjoints/transposes of your custom map type.
mul!(ones(3), A', x, 2, 2)
#-
MyFillMap(5.0, (3, 4))' * ones(3)
# Now that we have defined the action of adjoints/transposes, the
# following right action on vectors is automatically defined:
ones(3)' * MyFillMap(5.0, (3, 4))
# and `transpose(x) * A` correspondingly, as well as in-place multiplication
mul!(similar(x)', x', A)
# and `mul!(transpose(y), transpose(x), A)`.
# ## Application to matrices
# By default, applying a `LinearMap` `A` to a matrix `X` via `A*X` does
# *not* apply `A` to each column of `X` viewed as a vector, but interprets
# `X` as a linear map, wraps it as such and returns `(A*X)::CompositeMap`.
# Calling the in-place multiplication function `mul!(Y, A, X)` for matrices,
# however, does compute the columnwise action of `A` on `X` and stores the
# result in `Y`. In case there is a more efficient implementation for the
# matrix application, you can provide `_unsafe_mul!` methods with signature
# `_unsafe_mul!(Y, A::MyFillMap, X::AbstractMatrix)`, and, depending
# on the chosen path to handle adjoints/transposes, corresponding methods
# for wrapped maps of type `AdjointMap` or `TransposeMap`, plus potentially
# corresponding 5-arg `mul!` methods. This may seem like a lot of methods to
# be implemented, but note that adding such methods is only necessary/recommended
# for increased performance.
# ## Computing a matrix representation
# In some cases, it might be necessary to compute a matrix representation of a `LinearMap`.
# This is essentially done via the
# `[LinearMaps._unsafe_mul!(::Matrix,::LinearMap,::Number)]`(@ref) method, for which a
# generic fallback exists: it applies the `LinearMap` successively to the standard unit
# vectors.
F = MyFillMap(5, (100,100))
M = Matrix{eltype(F)}(undef, size(F))
@benchmark Matrix($F)
#-
@benchmark LinearMaps._unsafe_mul!($(Matrix{Int}(undef, (100,100))), $(MyFillMap(5, (100,100))), true)
# If a more performant implementation exists, it is recommended to overwrite this method,
# for instance (as before, size checks need not be included here since they are handled by
# the corresponding `LinearAlgebra.mul!` method):
LinearMaps._unsafe_mul!(M, A::MyFillMap, s::Number) = fill!(M, A.λ*s)
@benchmark Matrix($F)
#-
@benchmark LinearMaps._unsafe_mul!($(Matrix{Int}(undef, (100,100))), $(MyFillMap(5, (100,100))), true)
# As one can see, the above runtimes are dominated by the allocation of the output matrix,
# but still overwriting the multiplication kernel yields a speed-up of about factor 3 for
# the matrix filling part.
# ## Slicing
# As usual, generic fallbacks for `LinearMap` slicing exist and are handled by the following
# method hierarchy, where at least one of `I` and `J` has to be a `Colon`:
#
# Base.getindex(::LinearMap, I, J)
# -> LinearMaps._getindex(::LinearMap, I, J)
#
# The method `Base.getindex` checks the validity of the the requested indices and calls
# `LinearMaps._getindex`, which should be overloaded for custom `LinearMap`s subtypes.
# For instance:
@benchmark F[1,:]
#-
LinearMaps._getindex(A::MyFillMap, ::Integer, J::Base.Slice) = fill(A.λ, axes(J))
@benchmark F[1,:]
# Note that in `Base.getindex` `Colon`s are converted to `Base.Slice` via
# `Base.to_indices`, thus the dispatch must be on `Base.Slice` rather than on `Colon`.