Skip to content

Commit

Permalink
Support and as elementwise- and tensor-product operators (#35150
Browse files Browse the repository at this point in the history
)

While we have broadcasting and `a*b'`, sometimes you need to pass
an operator as an argument to a function. Since we already have `dot` or `⋅`
for the inner product, these elementwise and tensor products fill out
the space of possibilities.
  • Loading branch information
timholy committed May 2, 2020
1 parent 9507225 commit f36036c
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 3 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Expand Up @@ -169,6 +169,8 @@ Standard library changes
* The BLAS submodule now supports the level-2 BLAS subroutine `spmv!` ([#34320]).
* The BLAS submodule now supports the level-1 BLAS subroutine `rot!` ([#35124]).
* New generic `rotate!(x, y, c, s)` and `reflect!(x, y, c, s)` functions ([#35124]).
* `hadamard` or `` (`\odotTAB`) can be used as an elementwise multiplication operator,
and `tensor` or `` (`\otimesTAB`) as the tensor product operator ([#35150]).

#### Markdown

Expand Down
4 changes: 4 additions & 0 deletions stdlib/LinearAlgebra/docs/src/index.md
Expand Up @@ -321,6 +321,8 @@ LinearAlgebra.PosDefException
LinearAlgebra.ZeroPivotException
LinearAlgebra.dot
LinearAlgebra.cross
LinearAlgebra.hadamard
LinearAlgebra.tensor
LinearAlgebra.factorize
LinearAlgebra.Diagonal
LinearAlgebra.Bidiagonal
Expand Down Expand Up @@ -474,6 +476,8 @@ LinearAlgebra.lmul!
LinearAlgebra.rmul!
LinearAlgebra.ldiv!
LinearAlgebra.rdiv!
LinearAlgebra.hadamard!
LinearAlgebra.tensor!
```

## BLAS functions
Expand Down
137 changes: 137 additions & 0 deletions stdlib/LinearAlgebra/src/LinearAlgebra.jl
Expand Up @@ -388,6 +388,143 @@ const ⋅ = dot
const × = cross
export , ×

"""
hadamard(a, b)
a ⊙ b
For arrays `a` and `b`, perform elementwise multiplication.
`a` and `b` must have identical `axes`.
`⊙` can be passed as an operator to higher-order functions.
```jldoctest
julia> a = [2, 3]; b = [5, 7];
julia> a ⊙ b
2-element Array{$Int,1}:
10
21
julia> a ⊙ [5]
ERROR: DimensionMismatch("Axes of `A` and `B` must match, got (Base.OneTo(2),) and (Base.OneTo(1),)")
[...]
```
!!! compat "Julia 1.5"
This function requires at least Julia 1.5. In Julia 1.0-1.4 it is available from
the `Compat` package.
"""
function hadamard(A::AbstractArray, B::AbstractArray)
@noinline throw_dmm(axA, axB) = throw(DimensionMismatch("Axes of `A` and `B` must match, got $axA and $axB"))

axA, axB = axes(A), axes(B)
axA == axB || throw_dmm(axA, axB)
return map(*, A, B)
end
const = hadamard

"""
hadamard!(dest, A, B)
Similar to `hadamard(A, B)` (which can also be written `A ⊙ B`), but stores its results in
the pre-allocated array `dest`.
!!! compat "Julia 1.5"
This function requires at least Julia 1.5. In Julia 1.0-1.4 it is available from
the `Compat` package.
"""
function hadamard!(dest::AbstractArray, A::AbstractArray, B::AbstractArray)
@noinline function throw_dmm(axA, axB, axdest)
throw(DimensionMismatch("`axes(dest) = $axdest` must be equal to `axes(A) = $axA` and `axes(B) = $axB`"))
end

axA, axB, axdest = axes(A), axes(B), axes(dest)
((axdest == axA) & (axdest == axB)) || throw_dmm(axA, axB, axdest)
@simd for I in eachindex(dest, A, B)
@inbounds dest[I] = A[I] * B[I]
end
return dest
end

export , hadamard, hadamard!

"""
tensor(A, B)
A ⊗ B
Compute the tensor product of `A` and `B`.
If `C = A ⊗ B`, then `C[i1, ..., im, j1, ..., jn] = A[i1, ... im] * B[j1, ..., jn]`.
```jldoctest
julia> a = [2, 3]; b = [5, 7, 11];
julia> a ⊗ b
2×3 Array{$Int,2}:
10 14 22
15 21 33
```
See also: [`kron`](@ref).
!!! compat "Julia 1.5"
This function requires at least Julia 1.5. In Julia 1.0-1.4 it is available from
the `Compat` package.
"""
tensor(A::AbstractArray, B::AbstractArray) = [a*b for a in A, b in B]
const = tensor

const CovectorLike{T} = Union{Adjoint{T,<:AbstractVector},Transpose{T,<:AbstractVector}}
function tensor(u::AbstractArray, v::CovectorLike)
# If `v` is thought of as a covector, you might want this to be two-dimensional,
# but thought of as a matrix it should be three-dimensional.
# The safest is to avoid supporting it at all. See discussion in #35150.
error("`tensor` is not defined for co-vectors, perhaps you meant `*`?")
end
function tensor(u::CovectorLike, v::AbstractArray)
error("`tensor` is not defined for co-vectors, perhaps you meant `*`?")
end
function tensor(u::CovectorLike, v::CovectorLike)
error("`tensor` is not defined for co-vectors, perhaps you meant `*`?")
end

"""
tensor!(dest, A, B)
Similar to `tensor(A, B)` (which can also be written `A ⊗ B`), but stores its results in
the pre-allocated array `dest`.
!!! compat "Julia 1.5"
This function requires at least Julia 1.5. In Julia 1.0-1.4 it is available from
the `Compat` package.
"""
function tensor!(dest::AbstractArray, A::AbstractArray, B::AbstractArray)
@noinline function throw_dmm(axA, axB, axdest)
throw(DimensionMismatch("`axes(dest) = $axdest` must concatenate `axes(A) = $axA` and `axes(B) = $axB`"))
end

axA, axB, axdest = axes(A), axes(B), axes(dest)
axes(dest) == (axA..., axB...) || throw_dmm(axA, axB, axdest)
if IndexStyle(dest) === IndexCartesian()
for IB in CartesianIndices(axB)
@inbounds b = B[IB]
@simd for IA in CartesianIndices(axA)
@inbounds dest[IA,IB] = A[IA]*b
end
end
else
i = firstindex(dest)
@inbounds for b in B
@simd for a in A
dest[i] = a*b
i += 1
end
end
end
return dest
end

export , tensor, tensor!

"""
LinearAlgebra.peakflops(n::Integer=2000; parallel::Bool=false)
Expand Down
6 changes: 3 additions & 3 deletions stdlib/LinearAlgebra/src/dense.jl
Expand Up @@ -341,9 +341,9 @@ end
Kronecker tensor product of two vectors or two matrices.
For vectors v and w, the Kronecker product is related to the outer product by
`kron(v,w) == vec(w*transpose(v))` or
`w*transpose(v) == reshape(kron(v,w), (length(w), length(v)))`.
For vectors v and w, the Kronecker product is related to the tensor product [`tensor`](@ref), or `⊗`, by
`kron(v,w) == vec(w ⊗ v)` or
`w ⊗ v == reshape(kron(v,w), (length(w), length(v)))`.
Note how the ordering of `v` and `w` differs on the left and right
of these expressions (due to column-major storage).
Expand Down
60 changes: 60 additions & 0 deletions stdlib/LinearAlgebra/test/addmul.jl
Expand Up @@ -131,6 +131,66 @@ for cmat in mattypes,
push!(testdata, (cmat{celt}, amat{aelt}, bmat{belt}))
end

@testset "Alternative multiplication operators" begin
for T in (Int, Float32, Float64, BigFloat)
a = [T[1, 2], T[-3, 7]]
b = [T[5, 11], T[-13, 17]]
@test map(, a, b) == map(dot, a, b) == [27, 158]
@test map(, a, b) == map(hadamard, a, b) == [a[1].*b[1], a[2].*b[2]]
@test map(, a, b) == map(tensor, a, b) == [a[1]*transpose(b[1]), a[2]*transpose(b[2])]
@test hadamard!(fill(typemax(Int), 2), T[1, 2], T[-3, 7]) == [-3, 14]
@test tensor!(fill(typemax(Int), 2, 2), T[1, 2], T[-3, 7]) == [-3 7; -6 14]
end

@test_throws DimensionMismatch [1,2] [3]
@test_throws DimensionMismatch hadamard!([0, 0, 0], [1,2], [-3,7])
@test_throws DimensionMismatch hadamard!([0, 0], [1,2], [-3])
@test_throws DimensionMismatch hadamard!([0, 0], [1], [-3,7])
@test_throws DimensionMismatch tensor!(Matrix{Int}(undef, 2, 2), [1], [-3,7])
@test_throws DimensionMismatch tensor!(Matrix{Int}(undef, 2, 2), [1,2], [-3])

u, v = [2+2im, 3+5im], [1-3im, 7+3im]
@test u v == conj(u[1])*v[1] + conj(u[2])*v[2]
@test u v == [u[1]*v[1], u[2]*v[2]]
@test u v == [u[1]*v[1] u[1]*v[2]; u[2]*v[1] u[2]*v[2]]
@test hadamard(u, v) == u v
@test tensor(u, v) == u v
dest = similar(u)
@test hadamard!(dest, u, v) == u v
dest = Matrix{Complex{Int}}(undef, 2, 2)
@test tensor!(dest, u, v) == u v

for (A, B, b) in (([1 2; 3 4], [5 6; 7 8], [5,6]),
([1+0.8im 2+0.7im; 3+0.6im 4+0.5im],
[5+0.4im 6+0.3im; 7+0.2im 8+0.1im],
[5+0.6im,6+0.3im]))
@test A b == cat(A*b[1], A*b[2]; dims=3)
@test A B == cat(cat(A*B[1,1], A*B[2,1]; dims=3),
cat(A*B[1,2], A*B[2,2]; dims=3); dims=4)
end

A, B = reshape(1:27, 3, 3, 3), reshape(1:4, 2, 2)
@test A B == [a*b for a in A, b in B]

# Adjoint/transpose is a dual vector, not an AbstractMatrix
v = [1,2]
@test_throws ErrorException v v'
@test_throws ErrorException v transpose(v)
@test_throws ErrorException v' v
@test_throws ErrorException transpose(v) v
@test_throws ErrorException v' v'
@test_throws ErrorException transpose(v) transpose(v)
@test_throws ErrorException v' transpose(v)
@test_throws ErrorException transpose(v) v'
@test_throws ErrorException A v'
@test_throws ErrorException A transpose(v)

# Docs comparison to `kron`
v, w = [1,2,3], [5,7]
@test kron(v,w) == vec(w v)
@test w v == reshape(kron(v,w), (length(w), length(v)))
end

@testset "mul!(::$TC, ::$TA, ::$TB, α, β)" for (TC, TA, TB) in testdata
if needsquare(TA)
na1 = na2 = rand(sizecandidates)
Expand Down

0 comments on commit f36036c

Please sign in to comment.