Skip to content

Commit

Permalink
more
Browse files Browse the repository at this point in the history
  • Loading branch information
timholy committed Mar 19, 2020
1 parent e8b884c commit 80385cf
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 9 deletions.
6 changes: 4 additions & 2 deletions NEWS.md
Expand Up @@ -103,8 +103,10 @@ Standard library changes
* `normalize` now supports multidimensional arrays ([#34239])
* `lq` factorizations can now be used to compute the minimum-norm solution to under-determined systems ([#34350]).
* The BLAS submodule now supports the level-2 BLAS subroutine `spmv!` ([#34320]).
* `` (`\odotTAB`) can be used as an elementwise multiplication operator,
and `` (`\otimesTAB`) as the outer product operator ([#35150]).
* `.*`, meaning elementwise multiplication, can be passed as operator
to higher-order functions ([#35150]).
* `outermul`, `outermul!`, and `` (`\otimesTAB`) can be used to
compute the outer/tensor product ([#35150]).

#### Markdown

Expand Down
2 changes: 2 additions & 0 deletions stdlib/LinearAlgebra/docs/src/index.md
Expand Up @@ -321,6 +321,7 @@ LinearAlgebra.PosDefException
LinearAlgebra.ZeroPivotException
LinearAlgebra.dot
LinearAlgebra.cross
LinearAlgebra.outermul
LinearAlgebra.factorize
LinearAlgebra.Diagonal
LinearAlgebra.Bidiagonal
Expand Down Expand Up @@ -474,6 +475,7 @@ LinearAlgebra.lmul!
LinearAlgebra.rmul!
LinearAlgebra.ldiv!
LinearAlgebra.rdiv!
LinearAlgebra.outermul!
```

## BLAS functions
Expand Down
45 changes: 41 additions & 4 deletions stdlib/LinearAlgebra/src/LinearAlgebra.jl
Expand Up @@ -386,11 +386,48 @@ const ⋅ = dot
const × = cross
export , ×

# Allow passing ⋅, .*, and ⊗ as an operator to `map` and similar
# Allow passing .* as an operator to `map` and similar
const var".*" = (x...,) -> .*(x...,)
(a::AbstractVector, b::AbstractVector) = a * transpose(b)
(A::AbstractArray, B::AbstractArray) = A .* reshape(B, ntuple(_->Base.OneTo(1), ndims(A))..., axes(B)...)
export .*,

"""
outermul(A, B)
A ⊗ B
Compute the tensor product of `A` and `B` (also sometimes called "outer product").
If `C = A ⊗ B`, then `C[i1, ..., im, j1, ..., jn] = A[i1, ... im] * B[j1, ..., jn]`.
```jldoctest
julia> a = [2, 3]; b = [5, 7, 11];
julia> a ⊗ b
2×3 Array{$Int,2}:
10 14 22
15 21 33
```
See also: [`kron`](@ref).
"""
outermul(A::AbstractArray, B::AbstractArray) = A .* reshape(B, ntuple(_->Base.OneTo(1), ndims(A))..., axes(B)...)
const = outermul

"""
outermul!(dest, A, B)
Similar to `outermul(A, B)` (which can also be written `A ⊗ B`), but stores its results in the pre-allocated array `dest`.
"""
function outermul!(dest::AbstractArray, A::AbstractArray, B::AbstractArray)
axes(dest) == (axes(A)..., axes(B)...) ||
throw(DimensionMismatch("`axes(dest) = $(axes(dest))` must concatenate `axes(A) = $(axes(A))` and `axes(B) = $(axes(B))`"))
for IB in CartesianIndices(B)
b = B[IB]
@simd for IA in CartesianIndices(A)
@inbounds dest[IA,IB] = A[IA]*b
end
end
return dest
end

export .*, , outermul, outermul!

"""
LinearAlgebra.peakflops(n::Integer=2000; parallel::Bool=false)
Expand Down
6 changes: 3 additions & 3 deletions stdlib/LinearAlgebra/src/dense.jl
Expand Up @@ -341,9 +341,9 @@ end
Kronecker tensor product of two vectors or two matrices.
For vectors v and w, the Kronecker product is related to the outer product by
`kron(v,w) == vec(w*transpose(v))` or
`w*transpose(v) == reshape(kron(v,w), (length(w), length(v)))`.
For vectors v and w, the Kronecker product is related to the outer product [`outermul`](@ref), or `⊗`, by
`kron(v,w) == vec(w ⊗ v)` or
`w ⊗ v == reshape(kron(v,w), (length(w), length(v)))`.
Note how the ordering of `v` and `w` differs on the left and right
of these expressions (due to column-major storage).
Expand Down
3 changes: 3 additions & 0 deletions stdlib/LinearAlgebra/test/addmul.jl
Expand Up @@ -142,6 +142,9 @@ end
u, v = [2+2im, 3+5im], [1-3im, 7+3im]
@test u v == conj(u[1])*v[1] + conj(u[2])*v[2]
@test u v == [u[1]*v[1] u[1]*v[2]; u[2]*v[1] u[2]*v[2]]
@test outermul(u, v) == u v
dest = Matrix{Complex{Int}}(undef, 2, 2)
@test outermul!(dest, u, v) == u v
for (A, B, b) in (([1 2; 3 4], [5 6; 7 8], [5,6]),
([1+0.8im 2+0.7im; 3+0.6im 4+0.5im],
[5+0.4im 6+0.3im; 7+0.2im 8+0.1im],
Expand Down

0 comments on commit 80385cf

Please sign in to comment.