Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 36 additions & 7 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

# Matrix-matrix multiplication

AdjOrTransStridedMat{T} = Union{Adjoint{T, <:StridedMatrix}, Transpose{T, <:StridedMatrix}}
StridedMaybeAdjOrTransMat{T} = Union{StridedMatrix{T}, Adjoint{T, <:StridedMatrix}, Transpose{T, <:StridedMatrix}}

# matmul.jl: Everything to do with dense matrix multiplication

matprod(x, y) = x*y + x*y
Expand Down Expand Up @@ -59,18 +64,26 @@ end
@inline mul!(y::StridedVector{T}, A::StridedVecOrMat{T}, x::StridedVector{T},
alpha::Number, beta::Number) where {T<:BlasFloat} =
gemv!(y, 'N', A, x, alpha, beta)

# Complex matrix times real vector. Reinterpret the matrix as a real matrix and do real matvec compuation.
for elty in (Float32,Float64)
for elty in (Float32, Float64)
@eval begin
@inline function mul!(y::StridedVector{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, x::StridedVector{$elty},
alpha::Real, beta::Real)
alpha::Real, beta::Real)
Afl = reinterpret($elty, A)
yfl = reinterpret($elty, y)
mul!(yfl, Afl, x, alpha, beta)
return y
end
end
end

# Real matrix times complex vector.
# Multiply the matrix with the real and imaginary parts separately
@inline mul!(y::StridedVector{Complex{T}}, A::StridedMaybeAdjOrTransMat{T}, x::StridedVector{Complex{T}},
alpha::Number, beta::Number) where {T<:BlasFloat} =
gemv!(y, A isa StridedArray ? 'N' : 'T', A isa StridedArray ? A : parent(A), x, alpha, beta)

@inline mul!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector,
alpha::Number, beta::Number) =
generic_matvecmul!(y, 'N', A, x, MulAddMul(alpha, beta))
Expand Down Expand Up @@ -113,11 +126,6 @@ end
(*)(x::AdjointAbsVec, A::AbstractMatrix) = (A'*x')'
(*)(x::TransposeAbsVec, A::AbstractMatrix) = transpose(transpose(A)*transpose(x))

# Matrix-matrix multiplication

AdjOrTransStridedMat{T} = Union{Adjoint{T, <:StridedMatrix}, Transpose{T, <:StridedMatrix}}
StridedMaybeAdjOrTransMat{T} = Union{StridedMatrix{T}, Adjoint{T, <:StridedMatrix}, Transpose{T, <:StridedMatrix}}

_parent(A) = A
_parent(A::Adjoint) = parent(A)
_parent(A::Transpose) = parent(A)
Expand Down Expand Up @@ -525,6 +533,27 @@ function gemv!(y::StridedVector{T}, tA::AbstractChar, A::StridedVecOrMat{T}, x::
end
end

function gemv!(y::StridedVector{Complex{T}}, tA::AbstractChar, A::StridedVecOrMat{T}, x::StridedVector{Complex{T}},
α::Number = true, β::Number = false) where {T<:BlasFloat}
mA, nA = lapack_size(tA, A)
nA != length(x) &&
throw(DimensionMismatch(lazy"second dimension of A, $nA, does not match length of x, $(length(x))"))
mA != length(y) &&
throw(DimensionMismatch(lazy"first dimension of A, $mA, does not match length of y, $(length(y))"))
mA == 0 && return y
nA == 0 && return _rmul_or_fill!(y, β)
alpha, beta = promote(α, β, zero(T))
@views if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == 1 && stride(A, 2) >= size(A, 1)
xfl = reinterpret(reshape, T, x) # Use reshape here.
yfl = reinterpret(reshape, T, y)
BLAS.gemv!(tA, alpha, A, xfl[1, :], beta, yfl[1, :])
BLAS.gemv!(tA, alpha, A, xfl[2, :], beta, yfl[2, :])
return y
else
return generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
end
end

function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T},
_add = MulAddMul()) where {T<:BlasFloat}
nC = checksquare(C)
Expand Down
Loading