Skip to content

Commit

Permalink
re-allow vector*(1-row matrix) and transpose thereof (#20423)
Browse files Browse the repository at this point in the history
* re-allow vector*(1-row matrix) and transpose (closes #20389)
  • Loading branch information
stevengj committed Feb 5, 2017
1 parent 7aab89d commit e3bb870
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 39 deletions.
7 changes: 0 additions & 7 deletions base/linalg/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -256,16 +256,9 @@ end

# Methods to resolve ambiguities with `Diagonal`
@inline *(rowvec::RowVector, D::Diagonal) = transpose(D * transpose(rowvec))
*(::Diagonal, ::RowVector) = throw(DimensionMismatch("Cannot right-multiply matrix by transposed vector"))

@inline A_mul_Bt(D::Diagonal, rowvec::RowVector) = D*transpose(rowvec)

At_mul_B(rowvec::RowVector, ::Diagonal) = throw(DimensionMismatch("Cannot left-multiply matrix by vector"))

@inline A_mul_Bc(D::Diagonal, rowvec::RowVector) = D*ctranspose(rowvec)

Ac_mul_B(rowvec::RowVector, ::Diagonal) = throw(DimensionMismatch("Cannot left-multiply matrix by vector"))

conj(D::Diagonal) = Diagonal(conj(D.diag))
transpose(D::Diagonal) = D
ctranspose(D::Diagonal) = conj(D)
Expand Down
7 changes: 7 additions & 0 deletions base/linalg/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,13 @@ function (*){T,S}(A::AbstractMatrix{T}, x::AbstractVector{S})
A_mul_B!(similar(x,TS,size(A,1)),A,x)
end

# these will throw a DimensionMismatch unless B has 1 row (or 1 col for transposed case):
A_mul_Bt(a::AbstractVector, B::AbstractMatrix) = A_mul_Bt(reshape(a,length(a),1),B)
A_mul_Bt(A::AbstractMatrix, b::AbstractVector) = A_mul_Bt(A,reshape(b,length(b),1))
A_mul_Bc(a::AbstractVector, B::AbstractMatrix) = A_mul_Bc(reshape(a,length(a),1),B)
A_mul_Bc(A::AbstractMatrix, b::AbstractVector) = A_mul_Bc(A,reshape(b,length(b),1))
(*)(a::AbstractVector, B::AbstractMatrix) = reshape(a,length(a),1)*B

A_mul_B!{T<:BlasFloat}(y::StridedVector{T}, A::StridedVecOrMat{T}, x::StridedVector{T}) = gemv!(y, 'N', A, x)
for elty in (Float32,Float64)
@eval begin
Expand Down
15 changes: 0 additions & 15 deletions base/linalg/rowvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,26 +154,19 @@ end
sum(@inbounds(return rowvec[i]*vec[i]) for i = 1:length(vec))
end
@inline *(rowvec::RowVector, mat::AbstractMatrix) = transpose(mat.' * transpose(rowvec))
*(vec::AbstractVector, mat::AbstractMatrix) = throw(DimensionMismatch(
"Cannot left-multiply a matrix by a vector")) # Should become a deprecation
*(::RowVector, ::RowVector) = throw(DimensionMismatch("Cannot multiply two transposed vectors"))
@inline *(vec::AbstractVector, rowvec::RowVector) = vec .* rowvec
*(vec::AbstractVector, rowvec::AbstractVector) = throw(DimensionMismatch("Cannot multiply two vectors"))
*(mat::AbstractMatrix, rowvec::RowVector) = throw(DimensionMismatch("Cannot right-multiply matrix by transposed vector"))

# Transposed forms
A_mul_Bt(::RowVector, ::AbstractVector) = throw(DimensionMismatch("Cannot multiply two transposed vectors"))
@inline A_mul_Bt(rowvec::RowVector, mat::AbstractMatrix) = transpose(mat * transpose(rowvec))
A_mul_Bt(vec::AbstractVector, mat::AbstractMatrix) = throw(DimensionMismatch(
"Cannot left-multiply a matrix by a vector"))
@inline A_mul_Bt(rowvec1::RowVector, rowvec2::RowVector) = rowvec1*transpose(rowvec2)
A_mul_Bt(vec::AbstractVector, rowvec::RowVector) = throw(DimensionMismatch("Cannot multiply two vectors"))
@inline A_mul_Bt(vec1::AbstractVector, vec2::AbstractVector) = vec1 * transpose(vec2)
@inline A_mul_Bt(mat::AbstractMatrix, rowvec::RowVector) = mat * transpose(rowvec)

@inline At_mul_Bt(rowvec::RowVector, vec::AbstractVector) = transpose(rowvec) * transpose(vec)
At_mul_Bt(rowvec::RowVector, mat::AbstractMatrix) = throw(DimensionMismatch(
"Cannot left-multiply matrix by vector"))
@inline At_mul_Bt(vec::AbstractVector, mat::AbstractMatrix) = transpose(mat * vec)
At_mul_Bt(rowvec1::RowVector, rowvec2::RowVector) = throw(DimensionMismatch("Cannot multiply two vectors"))
@inline At_mul_Bt(vec::AbstractVector, rowvec::RowVector) = transpose(vec)*transpose(rowvec)
Expand All @@ -182,42 +175,34 @@ At_mul_Bt(vec::AbstractVector, rowvec::AbstractVector) = throw(DimensionMismatch
@inline At_mul_Bt(mat::AbstractMatrix, rowvec::RowVector) = mat.' * transpose(rowvec)

At_mul_B(::RowVector, ::AbstractVector) = throw(DimensionMismatch("Cannot multiply two vectors"))
At_mul_B(rowvec::RowVector, mat::AbstractMatrix) = throw(DimensionMismatch(
"Cannot left-multiply matrix by vector"))
@inline At_mul_B(vec::AbstractVector, mat::AbstractMatrix) = transpose(At_mul_B(mat,vec))
@inline At_mul_B(rowvec1::RowVector, rowvec2::RowVector) = transpose(rowvec1) * rowvec2
At_mul_B(vec::AbstractVector, rowvec::RowVector) = throw(DimensionMismatch(
"Cannot multiply two transposed vectors"))
@inline At_mul_B{T<:Real}(vec1::AbstractVector{T}, vec2::AbstractVector{T}) =
reduce(+, map(At_mul_B, vec1, vec2)) # Seems to be overloaded...
@inline At_mul_B(vec1::AbstractVector, vec2::AbstractVector) = transpose(vec1) * vec2
At_mul_B(mat::AbstractMatrix, rowvec::RowVector) = throw(DimensionMismatch(
"Cannot right-multiply matrix by transposed vector"))

# Conjugated forms
A_mul_Bc(::RowVector, ::AbstractVector) = throw(DimensionMismatch("Cannot multiply two transposed vectors"))
@inline A_mul_Bc(rowvec::RowVector, mat::AbstractMatrix) = ctranspose(mat * ctranspose(rowvec))
A_mul_Bc(vec::AbstractVector, mat::AbstractMatrix) = throw(DimensionMismatch("Cannot left-multiply a matrix by a vector"))
@inline A_mul_Bc(rowvec1::RowVector, rowvec2::RowVector) = rowvec1 * ctranspose(rowvec2)
A_mul_Bc(vec::AbstractVector, rowvec::RowVector) = throw(DimensionMismatch("Cannot multiply two vectors"))
@inline A_mul_Bc(vec1::AbstractVector, vec2::AbstractVector) = vec1 * ctranspose(vec2)
@inline A_mul_Bc(mat::AbstractMatrix, rowvec::RowVector) = mat * ctranspose(rowvec)

@inline Ac_mul_Bc(rowvec::RowVector, vec::AbstractVector) = ctranspose(rowvec) * ctranspose(vec)
Ac_mul_Bc(rowvec::RowVector, mat::AbstractMatrix) = throw(DimensionMismatch("Cannot left-multiply matrix by vector"))
@inline Ac_mul_Bc(vec::AbstractVector, mat::AbstractMatrix) = ctranspose(mat * vec)
Ac_mul_Bc(rowvec1::RowVector, rowvec2::RowVector) = throw(DimensionMismatch("Cannot multiply two vectors"))
@inline Ac_mul_Bc(vec::AbstractVector, rowvec::RowVector) = ctranspose(vec)*ctranspose(rowvec)
Ac_mul_Bc(vec::AbstractVector, rowvec::AbstractVector) = throw(DimensionMismatch("Cannot multiply two transposed vectors"))
@inline Ac_mul_Bc(mat::AbstractMatrix, rowvec::RowVector) = mat' * ctranspose(rowvec)

Ac_mul_B(::RowVector, ::AbstractVector) = throw(DimensionMismatch("Cannot multiply two vectors"))
Ac_mul_B(rowvec::RowVector, mat::AbstractMatrix) = throw(DimensionMismatch("Cannot left-multiply matrix by vector"))
@inline Ac_mul_B(vec::AbstractVector, mat::AbstractMatrix) = ctranspose(Ac_mul_B(mat,vec))
@inline Ac_mul_B(rowvec1::RowVector, rowvec2::RowVector) = ctranspose(rowvec1) * rowvec2
Ac_mul_B(vec::AbstractVector, rowvec::RowVector) = throw(DimensionMismatch("Cannot multiply two transposed vectors"))
@inline Ac_mul_B(vec1::AbstractVector, vec2::AbstractVector) = ctranspose(vec1)*vec2
Ac_mul_B(mat::AbstractMatrix, rowvec::RowVector) = throw(DimensionMismatch("Cannot right-multiply matrix by transposed vector"))

# Left Division #

Expand Down
15 changes: 1 addition & 14 deletions base/linalg/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1585,7 +1585,7 @@ for (f, g) in ((:\, :A_ldiv_B!), (:Ac_ldiv_B, :Ac_ldiv_B!), (:At_ldiv_B, :At_ldi
end
### Multiplication with triangle to the rigth and hence lhs cannot be transposed.
for (f, g) in ((:*, :A_mul_B!), (:A_mul_Bc, :A_mul_Bc!), (:A_mul_Bt, :A_mul_Bt!))
@eval begin
mat != :AbstractVector && @eval begin
function ($f)(A::$mat, B::AbstractTriangular)
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
AA = similar(A, TAB, size(A))
Expand Down Expand Up @@ -1637,25 +1637,12 @@ At_mul_Bt(A::AbstractMatrix, B::AbstractTriangular) = A_mul_Bt(A.', B)

# Specializations for RowVector
@inline *(rowvec::RowVector, A::AbstractTriangular) = transpose(A * transpose(rowvec))
*(::AbstractTriangular, ::RowVector) = throw(DimensionMismatch("Cannot right-multiply matrix by transposed vector"))

@inline A_mul_Bt(rowvec::RowVector, A::AbstractTriangular) = transpose(A * transpose(rowvec))
@inline A_mul_Bt(A::AbstractTriangular, rowvec::RowVector) = A * transpose(rowvec)

@inline At_mul_Bt(A::AbstractTriangular, rowvec::RowVector) = A.' * transpose(rowvec)
At_mul_Bt(::RowVector, ::AbstractTriangular) = throw(DimensionMismatch("Cannot left-multiply matrix by vector"))

At_mul_B(::AbstractTriangular, ::RowVector) = throw(DimensionMismatch("Cannot right-multiply matrix by transposed vector"))
At_mul_B(::RowVector, ::AbstractTriangular) = throw(DimensionMismatch("Cannot left-multiply matrix by vector"))

@inline A_mul_Bc(rowvec::RowVector, A::AbstractTriangular) = ctranspose(A * ctranspose(rowvec))
@inline A_mul_Bc(A::AbstractTriangular, rowvec::RowVector) = A * ctranspose(rowvec)

@inline Ac_mul_Bc(A::AbstractTriangular, rowvec::RowVector) = A' * ctranspose(rowvec)
Ac_mul_Bc(::RowVector, ::AbstractTriangular) = throw(DimensionMismatch("Cannot left-multiply matrix by vector"))

Ac_mul_B(::AbstractTriangular, ::RowVector) = throw(DimensionMismatch("Cannot right-multiply matrix by transposed vector"))
Ac_mul_B(::RowVector, ::AbstractTriangular) = throw(DimensionMismatch("Cannot left-multiply matrix by vector"))

@inline /(rowvec::RowVector, A::Union{UpperTriangular,LowerTriangular}) = transpose(transpose(A) \ transpose(rowvec))
@inline /(rowvec::RowVector, A::Union{UnitUpperTriangular,UnitLowerTriangular}) = transpose(transpose(A) \ transpose(rowvec))
Expand Down
19 changes: 16 additions & 3 deletions test/linalg/rowvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,15 @@ end

@test (rv*v) === 14
@test (rv*mat)::RowVector == [1 4 9]
@test_throws DimensionMismatch [1]*reshape([1],(1,1)) # no longer permitted
@test [1]*reshape([1],(1,1)) == reshape([1],(1,1))
@test_throws DimensionMismatch rv*rv
@test (v*rv)::Matrix == [1 2 3; 2 4 6; 3 6 9]
@test_throws DimensionMismatch v*v # Was previously a missing method error, now an error message
@test_throws DimensionMismatch mat*rv

@test_throws DimensionMismatch rv*v.'
@test (rv*mat.')::RowVector == [1 4 9]
@test_throws DimensionMismatch [1]*reshape([1],(1,1)).' # no longer permitted
@test [1]*reshape([1],(1,1)).' == reshape([1],(1,1))
@test rv*rv.' === 14
@test_throws DimensionMismatch v*rv.'
@test (v*v.')::Matrix == [1 2 3; 2 4 6; 3 6 9]
Expand Down Expand Up @@ -142,7 +142,7 @@ end

@test_throws DimensionMismatch cz*z'
@test (cz*mat')::RowVector == [-2im 4 9]
@test_throws DimensionMismatch [1]*reshape([1],(1,1))' # no longer permitted
@test [1]*reshape([1],(1,1))' == reshape([1],(1,1))
@test cz*cz' === 15 + 0im
@test_throws DimensionMismatch z*cz'
@test (z*z')::Matrix == [2 2+2im 3+3im; 2-2im 4 6; 3-3im 6 9]
Expand Down Expand Up @@ -238,5 +238,18 @@ end
@test_throws DimensionMismatch ut\rv
end

# issue #20389
@testset "1 row/col vec*mat" begin
let x=[1,2,3], A=ones(1,4), y=x', B=A', C=x.*A
@test x*A == y'*A == x*B' == y'*B' == C
@test A'*x' == A'*y == B*x' == B*y == C'
end
end
@testset "complex 1 row/col vec*mat" begin
let x=[1,2,3]*im, A=ones(1,4)*im, y=x', B=A', C=x.*A
@test x*A == y'*A == x*B' == y'*B' == C
@test A'*x' == A'*y == B*x' == B*y == C'
end
end

end # @testset "RowVector"

0 comments on commit e3bb870

Please sign in to comment.