/
TensorCore.jl
317 lines (246 loc) · 9.95 KB
/
TensorCore.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
module TensorCore
using LinearAlgebra
export ⊙, hadamard, hadamard!
export ⊗, tensor, tensor!
export ⊡, boxdot, boxdot!
"""
hadamard(a, b)
a ⊙ b
For arrays `a` and `b`, perform elementwise multiplication.
`a` and `b` must have identical `axes`.
`⊙` can be passed as an operator to higher-order functions.
# Examples
```jldoctest; setup=:(using TensorCore)
julia> a = [2, 3]; b = [5, 7];
julia> a ⊙ b
2-element Array{$Int,1}:
10
21
julia> a ⊙ [5]
ERROR: DimensionMismatch("Axes of `A` and `B` must match, got (Base.OneTo(2),) and (Base.OneTo(1),)")
[...]
```
See also `hadamard!(y, a, b)`.
"""
function hadamard(A::AbstractArray, B::AbstractArray)
@noinline throw_dmm(axA, axB) = throw(DimensionMismatch("Axes of `A` and `B` must match, got $axA and $axB"))
axA, axB = axes(A), axes(B)
axA == axB || throw_dmm(axA, axB)
return map(*, A, B)
end
const ⊙ = hadamard
"""
hadamard!(dest, A, B)
Similar to `hadamard(A, B)` (which can also be written `A ⊙ B`), but stores its results in
the pre-allocated array `dest`.
"""
function hadamard!(dest::AbstractArray, A::AbstractArray, B::AbstractArray)
@noinline function throw_dmm(axA, axB, axdest)
throw(DimensionMismatch("`axes(dest) = $axdest` must be equal to `axes(A) = $axA` and `axes(B) = $axB`"))
end
axA, axB, axdest = axes(A), axes(B), axes(dest)
((axdest == axA) & (axdest == axB)) || throw_dmm(axA, axB, axdest)
@simd for I in eachindex(dest, A, B)
@inbounds dest[I] = A[I] * B[I]
end
return dest
end
"""
tensor(A, B)
A ⊗ B
Compute the tensor product of `A` and `B`.
If `C = A ⊗ B`, then `C[i1, ..., im, j1, ..., jn] = A[i1, ... im] * B[j1, ..., jn]`.
For vectors `v` and `w`, the Kronecker product is related to the tensor product by
`kron(v,w) == vec(w ⊗ v)` or `w ⊗ v == reshape(kron(v,w), (length(w), length(v)))`.
# Examples
```jldoctest; setup=:(using TensorCore)
julia> a = [2, 3]; b = [5, 7, 11];
julia> a ⊗ b
2×3 Array{$Int,2}:
10 14 22
15 21 33
```
See also `tensor!(Y,A,B)`.
"""
tensor(A::AbstractArray, B::AbstractArray) = [a*b for a in A, b in B]
const ⊗ = tensor
const CovectorLike{T} = Union{Adjoint{T,<:AbstractVector},Transpose{T,<:AbstractVector}}
function tensor(u::AbstractArray, v::CovectorLike)
# If `v` is thought of as a covector, you might want this to be two-dimensional,
# but thought of as a matrix it should be three-dimensional.
# The safest is to avoid supporting it at all. See discussion in #35150.
error("`tensor` is not defined for co-vectors, perhaps you meant `*`?")
end
function tensor(u::CovectorLike, v::AbstractArray)
error("`tensor` is not defined for co-vectors, perhaps you meant `*`?")
end
function tensor(u::CovectorLike, v::CovectorLike)
error("`tensor` is not defined for co-vectors, perhaps you meant `*`?")
end
"""
tensor!(dest, A, B)
Similar to `tensor(A, B)` (which can also be written `A ⊗ B`), but stores its results in
the pre-allocated array `dest`.
"""
function tensor!(dest::AbstractArray, A::AbstractArray, B::AbstractArray)
@noinline function throw_dmm(axA, axB, axdest)
throw(DimensionMismatch("`axes(dest) = $axdest` must concatenate `axes(A) = $axA` and `axes(B) = $axB`"))
end
axA, axB, axdest = axes(A), axes(B), axes(dest)
axes(dest) == (axA..., axB...) || throw_dmm(axA, axB, axdest)
if IndexStyle(dest) === IndexCartesian()
for IB in CartesianIndices(axB)
@inbounds b = B[IB]
@simd for IA in CartesianIndices(axA)
@inbounds dest[IA,IB] = A[IA]*b
end
end
else
i = firstindex(dest)
@inbounds for b in B
@simd for a in A
dest[i] = a*b
i += 1
end
end
end
return dest
end
export boxdot, ⊡, boxdot!
"""
boxdot(A,B) = A ⊡ B # \\boxdot
Generalised matrix multiplication: Contracts the last dimension of `A` with
the first dimension of `B`, for any `ndims(A)` & `ndims(B)`.
If both are vectors, then it returns a scalar `== sum(A .* B)`.
# Examples
```jldoctest; setup=:(using TensorCore)
julia> A = rand(3,4,5); B = rand(5,6,7);
julia> size(A ⊡ B)
(3, 4, 6, 7)
julia> typeof(rand(5) ⊡ rand(5))
Float64
julia> try B ⊡ A catch err println(err) end
DimensionMismatch("neighbouring axes of `A` and `B` must match, got Base.OneTo(7) and Base.OneTo(3)")
```
This is the same behaviour as Mathematica's function `Dot[A, B]`.
It is not identicaly to Python's `numpy.dot(A, B)`, which contracts with the second-last
dimension of `B` instead of the first, but both keep all the other dimensions.
Unlike Julia's `LinearAlgebra.dot`, it does not conjugate `A`, so these two agree only
for real-valued vectors.
When interacting with `Adjoint` vectors, this always obeys `(x ⊡ y)' == y' ⊡ x'`,
and hence may sometimes return another `Adjoint` vector. (And similarly for `Transpose`.)
```jldoctest; setup=:(using TensorCore)
julia> M = rand(5,5); v = rand(5);
julia> typeof(v ⊡ M')
Array{Float64,1}
julia> typeof(M ⊡ v') # adjoint of the previous line
Adjoint{Float64,Array{Float64,1}}
julia> typeof(v' ⊡ M') # same as *, and equal to adjoint(M ⊡ v)
Adjoint{Float64,Array{Float64,1}}
julia> typeof(v' ⊡ v)
Float64
```
See also `boxdot!(Y,A,B)`, which is to `⊡` as `mul!` is to `*`.
"""
function boxdot(A::AbstractArray, B::AbstractArray)
Amat = _squash_left(A)
Bmat = _squash_right(B)
axA, axB = axes(Amat,2), axes(Bmat,1)
axA == axB || _throw_dmm(axA, axB)
return _boxdot_reshape(Amat * Bmat, A, B)
end
const ⊡ = boxdot
@noinline _throw_dmm(axA, axB) = throw(DimensionMismatch("neighbouring axes of `A` and `B` must match, got $axA and $axB"))
_squash_left(A::AbstractArray) = reshape(A, :,size(A,ndims(A)))
_squash_left(A::AbstractMatrix) = A
_squash_right(B::AbstractArray) = reshape(B, size(B,1),:)
_squash_right(B::AbstractVecOrMat) = B
function _boxdot_reshape(AB::AbstractArray, A::AbstractArray{T,N}, B::AbstractArray{S,M}) where {T,N,S,M}
ax = ntuple(i -> i<N ? axes(A, i) : axes(B, i-N+2), Val(N+M-2))
reshape(AB, ax) # some cases don't come here, so this doesn't really support OffsetArrays
end
# These can skip final reshape:
_boxdot_reshape(AB::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat) = AB
# These produce scalar output:
function boxdot(A::AbstractVector, B::AbstractVector)
axA, axB = axes(A,1), axes(B,1)
axA == axB || _throw_dmm(axA, axB)
if eltype(A) <: Number
return transpose(A)*B
else
return sum(a*b for (a,b) in zip(A,B))
end
end
# Multiplication by a scalar:
boxdot(A::AbstractArray, b::Number) = A*b
boxdot(a::Number, B::AbstractArray) = a*B
boxdot(a::Number, b::Number) = a*b
using LinearAlgebra: AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec
# Adjont and Transpose, vectors or almost (returning a scalar)
boxdot(A::AdjointAbsVec, B::AbstractVector) = A * B
boxdot(A::TransposeAbsVec, B::AbstractVector) = A * B
boxdot(A::AbstractVector, B::AdjointAbsVec) = A ⊡ vec(B)
boxdot(A::AbstractVector, B::TransposeAbsVec) = A ⊡ vec(B)
boxdot(A::AdjointAbsVec, B::AdjointAbsVec) = adjoint(adjoint(B) ⊡ adjoint(A))
boxdot(A::AdjointAbsVec, B::TransposeAbsVec) = vec(A) ⊡ vec(B)
boxdot(A::TransposeAbsVec, B::AdjointAbsVec) = vec(A) ⊡ vec(B)
boxdot(A::TransposeAbsVec, B::TransposeAbsVec) = transpose(transpose(B) ⊡ transpose(A))
# ... with a matrix (returning another such)
boxdot(A::AdjointAbsVec, B::AbstractMatrix) = A * B
boxdot(A::TransposeAbsVec, B::AbstractMatrix) = A * B
boxdot(A::AbstractMatrix, B::AdjointAbsVec) = (B' ⊡ A')'
boxdot(A::AbstractMatrix, B::TransposeAbsVec) = transpose(transpose(B) ⊡ transpose(A))
# ... and with higher-dim (returning a plain array)
boxdot(A::AdjointAbsVec, B::AbstractArray) = vec(A) ⊡ B
boxdot(A::TransposeAbsVec, B::AbstractArray) = vec(A) ⊡ B
boxdot(A::AbstractArray, B::AdjointAbsVec) = A ⊡ vec(B)
boxdot(A::AbstractArray, B::TransposeAbsVec) = A ⊡ vec(B)
"""
boxdot!(Y, A, B, α=1, β=0)
In-place version of `boxdot`, i.e. `Y .= (A ⊡ B) .* β .+ Y .* α`.
Like 5-argument `mul!`, the use of `α, β` here requires Julia 1.3 or later.
"""
function boxdot! end
if VERSION < v"1.3" # Then 5-arg mul! isn't defined
function boxdot!(Y::AbstractArray, A::AbstractArray, B::AbstractArray)
szY = prod(size(A)[1:end-1]), prod(size(B)[2:end])
mul!(reshape(Y, szY), _squash_left(A), _squash_right(B))
Y
end
boxdot!(Y::AbstractArray, A::AbstractArray, B::AdjOrTransAbsVec) = boxdot!(Y, A, vec(B))
else
function boxdot!(Y::AbstractArray, A::AbstractArray, B::AbstractArray, α::Number=true, β::Number=false)
szY = prod(size(A)[1:end-1]), prod(size(B)[2:end])
mul!(reshape(Y, szY), _squash_left(A), _squash_right(B), α, β)
Y
end
# For boxdot!, only where mul! behaves differently:
boxdot!(Y::AbstractArray, A::AbstractArray, B::AdjOrTransAbsVec,
α::Number=true, β::Number=false) = boxdot!(Y, A, vec(B))
end
"""
TensorCore._adjoint(A)
This extends `adjoint` to understand higher-dimensional arrays, always reversing the
order of dimensions. On Julia 1.5 and later, the symbol `'` can be overloaded locally
as `var"'"`, as shown below.
Then `(x ⊡ y)' == y' ⊡ x'` holds for `x` and `y` arrays of any dimension.
# Examples
```jldoctest; setup=:(using TensorCore)
julia> T3 = rand(3,4,5); v = rand(5);
julia> size(T3 ⊡ v')
(3, 4)
julia> let var"'" = TensorCore._adjoint
v ⊡ T3' ≈ (T3 ⊡ v')'
end
true
```
"""
_adjoint(x) = adjoint(x)
_adjoint(x::AbstractVecOrMat) = adjoint(x)
_adjoint(x::AbstractArray{T,N}) where {T<:Number,N} = conj(PermutedDimsArray(x, ntuple(i -> N-i+1, N)))
_adjoint(x::AbstractArray{T,N}) where {T,N} = adjoint.(PermutedDimsArray(x, ntuple(i -> N-i+1, N)))
_transpose(x) = transpose(x)
_transpose(x::AbstractVecOrMat) = transpose(x)
_transpose(x::AbstractArray{T,N}) where {T<:Number,N} = PermutedDimsArray(x, ntuple(i -> N-i+1, N))
_transpose(x::AbstractArray{T,N}) where {T,N} = transpose.(PermutedDimsArray(x, ntuple(i -> N-i+1, N)))
end