Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

check sizes of arguments in dot; fixes #28617 #28666

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion stdlib/LinearAlgebra/src/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

function dot(x::BitVector, y::BitVector)
# simplest way to mimic Array dot behavior
length(x) == length(y) || throw(DimensionMismatch())
if axes(x) != axes(y)
throw(DimensionMismatch("The first array has axes $(axes(x)) that do not match the axes of the second, $(axes(y)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`."))
end
s = 0
xc = x.chunks
yc = y.chunks
Expand Down
56 changes: 6 additions & 50 deletions stdlib/LinearAlgebra/src/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -353,71 +353,27 @@ for (fname, elty) in ((:cblas_zdotu_sub,:ComplexF64),
end
end

@inline function _dot_length_check(x,y)
n = length(x)
if n != length(y)
throw(DimensionMismatch("dot product arguments have lengths $(length(x)) and $(length(y))"))
end
n
end

for (elty, f) in ((Float32, :dot), (Float64, :dot),
(ComplexF32, :dotc), (ComplexF64, :dotc),
(ComplexF32, :dotu), (ComplexF64, :dotu))
@eval begin
function $f(x::DenseArray{$elty}, y::DenseArray{$elty})
n = _dot_length_check(x,y)
$f(n, x, 1, y, 1)
end

function $f(x::StridedVector{$elty}, y::DenseArray{$elty})
n = _dot_length_check(x,y)
xstride = stride(x,1)
ystride = stride(y,1)
x_delta = xstride < 0 ? n : 1
GC.@preserve x $f(n,pointer(x,x_delta),xstride,y,ystride)
end

function $f(x::DenseArray{$elty}, y::StridedVector{$elty})
n = _dot_length_check(x,y)
xstride = stride(x,1)
ystride = stride(y,1)
y_delta = ystride < 0 ? n : 1
GC.@preserve y $f(n,x,xstride,pointer(y,y_delta),ystride)
end

function $f(x::StridedVector{$elty}, y::StridedVector{$elty})
n = _dot_length_check(x,y)
xstride = stride(x,1)
ystride = stride(y,1)
x_delta = xstride < 0 ? n : 1
y_delta = ystride < 0 ? n : 1
GC.@preserve x y $f(n,pointer(x,x_delta),xstride,pointer(y,y_delta),ystride)
end
end
end

function dot(DX::Union{DenseArray{T},AbstractVector{T}}, DY::Union{DenseArray{T},AbstractVector{T}}) where T<:BlasReal
require_one_based_indexing(DX, DY)
n = length(DX)
if n != length(DY)
throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
if axes(DX) != axes(DY)
throw(DimensionMismatch("The first array has axes $(axes(DX)) that do not match the axes of the second, $(axes(DY)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`."))
end
return dot(n, DX, stride(DX, 1), DY, stride(DY, 1))
end
function dotc(DX::Union{DenseArray{T},AbstractVector{T}}, DY::Union{DenseArray{T},AbstractVector{T}}) where T<:BlasComplex
require_one_based_indexing(DX, DY)
n = length(DX)
if n != length(DY)
throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
if axes(DX) != axes(DY)
throw(DimensionMismatch("The first array has axes $(axes(DX)) that do not match the axes of the second, $(axes(DY)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`."))
end
return dotc(n, DX, stride(DX, 1), DY, stride(DY, 1))
end
function dotu(DX::Union{DenseArray{T},AbstractVector{T}}, DY::Union{DenseArray{T},AbstractVector{T}}) where T<:BlasComplex
require_one_based_indexing(DX, DY)
n = length(DX)
if n != length(DY)
throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
if axes(DX) != axes(DY)
throw(DimensionMismatch("The first array has axes $(axes(DX)) that do not match the axes of the second, $(axes(DY)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`."))
end
return dotu(n, DX, stride(DX, 1), DY, stride(DY, 1))
end
Expand Down
26 changes: 22 additions & 4 deletions stdlib/LinearAlgebra/src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -904,12 +904,30 @@ end

dot(x::Number, y::Number) = conj(x) * y

"""
dot(x, y)
x ⋅ y

Compute the dot product between two arrays with the same [`axes`](@ref) as if they
were vectors. For complex arrays, the elements of the first array are conjugated.
This is the classical dot product for vectors and the Hilbert-Schmidt dot
product `tr(x' * y)` for matrices. When the arrays have equal axes, calling
`dot` is semantically equivalent to `sum(dot(vx,vy) for (vx,vy) in zip(x, y))`.

# Examples
```jldoctest
julia> dot([1; 1], [2; 3])
5

julia> dot([im; im], [1; 1])
0 - 2im
```
"""
function dot(x::AbstractArray, y::AbstractArray)
lx = length(x)
if lx != length(y)
throw(DimensionMismatch("first array has length $(lx) which does not match the length of the second, $(length(y))."))
if axes(x) != axes(y)
throw(DimensionMismatch("The first array has axes $(axes(x)) that do not match the axes of the second, $(axes(y)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`."))
end
if lx == 0
if length(x) == 0
return dot(zero(eltype(x)), zero(eltype(y)))
end
s = zero(dot(first(x), first(y)))
Expand Down
3 changes: 0 additions & 3 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@ matprod(x, y) = x*y + x*y

# dot products

dot(x::Union{DenseArray{T},StridedVector{T}}, y::Union{DenseArray{T},StridedVector{T}}) where {T<:BlasReal} = BLAS.dot(x, y)
dot(x::Union{DenseArray{T},StridedVector{T}}, y::Union{DenseArray{T},StridedVector{T}}) where {T<:BlasComplex} = BLAS.dotc(x, y)

function dot(x::Vector{T}, rx::AbstractRange{TI}, y::Vector{T}, ry::AbstractRange{TI}) where {T<:BlasReal,TI<:Integer}
if length(rx) != length(ry)
throw(DimensionMismatch("length of rx, $(length(rx)), does not equal length of ry, $(length(ry))"))
Expand Down
10 changes: 10 additions & 0 deletions stdlib/LinearAlgebra/test/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,23 @@ Random.seed!(100)
x2 = convert(Vector{elty}, randn(n))
@test BLAS.dot(x1,x2) ≈ sum(x1.*x2)
@test_throws DimensionMismatch BLAS.dot(x1,rand(elty, n + 1))
y1 = convert(Matrix{elty}, randn(4,4))
y2 = convert(Matrix{elty}, randn(2,8))
@test_throws DimensionMismatch BLAS.dot(y1, y2)
@test sum(y1[i] * y2[i] for i in 1:16) ≈ BLAS.dot(vec(y1), vec(y2))
else
z1 = convert(Vector{elty}, complex.(randn(n),randn(n)))
z2 = convert(Vector{elty}, complex.(randn(n),randn(n)))
@test BLAS.dotc(z1,z2) ≈ sum(conj(z1).*z2)
@test BLAS.dotu(z1,z2) ≈ sum(z1.*z2)
@test_throws DimensionMismatch BLAS.dotc(z1,rand(elty, n + 1))
@test_throws DimensionMismatch BLAS.dotu(z1,rand(elty, n + 1))
y1 = convert(Matrix{elty}, complex.(randn(4,4),randn(4,4)))
y2 = convert(Matrix{elty}, complex.(randn(2,8),randn(2,8)))
@test_throws DimensionMismatch BLAS.dotc(y1, y2)
@test_throws DimensionMismatch BLAS.dotu(y1, y2)
@test sum(conj(y1[i]) * y2[i] for i in 1:16) ≈ BLAS.dotc(vec(y1), vec(y2))
@test sum(y1[i] * y2[i] for i in 1:16) ≈ BLAS.dotu(vec(y1), vec(y2))
end
end
@testset "iamax" begin
Expand Down
19 changes: 19 additions & 0 deletions stdlib/LinearAlgebra/test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,10 @@ end
@test dot(X, Y) == convert(elty, 35.0)
Z = convert(Vector{Matrix{elty}},[reshape(1:4, 2, 2), fill(1, 2, 2)])
@test dot(Z, Z) == convert(elty, 34.0)
Y2 = convert(Matrix{elty},[1.5 3.5 2.5 4.5])
@test_throws DimensionMismatch dot(X, Y2)
@test_throws DimensionMismatch dot(vec(X), Y2)
@test dot(X, Y) == dot(vec(X), vec(Y2))
end

dot1(x,y) = invoke(dot, Tuple{Any,Any}, x,y)
Expand All @@ -454,6 +458,21 @@ dot2(x,y) = invoke(dot, Tuple{AbstractArray,AbstractArray}, x,y)
end
end
end
for elty in (Float32, Float64, ComplexF32, ComplexF64)
XX = convert(Matrix{elty},[1.0 2.0; 3.0 4.0])
YY = convert(Matrix{elty},[1.5 2.5; 3.5 4.5])
YY2 = convert(Matrix{elty},[1.5 3.5 2.5 4.5])
for X in (copy(XX), view(XX, 1:2, 1:2)), Y in (copy(YY), view(YY, 1:2, 1:2)), Y2 in (copy(YY2), view(YY2, 1:1, 1:4))
@test dot1(X, Y) == convert(elty, 35.0)
@test dot2(X, Y) == convert(elty, 35.0)
@test dot1(X, Y2) == convert(elty, 35.0) # dot1 considers general iterators and cannot check sizes
@test_throws DimensionMismatch dot2(X, Y2)
@test dot1(vec(X), Y2) == convert(elty, 35.0) # dot1 considers general iterators and cannot check sizes
@test_throws DimensionMismatch dot2(vec(X), Y2)
@test dot1(X, Y) == dot1(vec(X), vec(Y2))
@test dot2(X, Y) == dot2(vec(X), vec(Y2))
end
end
end

@testset "Issue 11978" begin
Expand Down
4 changes: 3 additions & 1 deletion stdlib/SparseArrays/src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,9 @@ ilog2(n::Integer) = sizeof(n)<<3 - leading_zeros(n)
# Frobenius dot/inner product: trace(A'B)
function dot(A::AbstractSparseMatrixCSC{T1,S1},B::AbstractSparseMatrixCSC{T2,S2}) where {T1,T2,S1,S2}
m, n = size(A)
size(B) == (m,n) || throw(DimensionMismatch("matrices must have the same dimensions"))
if size(B) != (m,n)
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is axes purposely not used here for some reason?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought SparseMatrixCSC uses definitely classical indexing as in line 208. Thus, I've just used size instead of axes. Should I change that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, the SparseArrays code has not been modified to support arbitrary indices.
I suppose it would be generally less useful than for dense arrays.

throw(DimensionMismatch("The first array has size $(size(A)) which does not match the size of the second, $(size(B)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`."))
end
r = dot(zero(T1), zero(T2))
@inbounds for j = 1:n
ia = getcolptr(A)[j]; ia_nxt = getcolptr(A)[j+1]
Expand Down
12 changes: 9 additions & 3 deletions stdlib/SparseArrays/src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1487,7 +1487,9 @@ end
function dot(x::AbstractVector{Tx}, y::SparseVectorUnion{Ty}) where {Tx<:Number,Ty<:Number}
require_one_based_indexing(x, y)
n = length(x)
length(y) == n || throw(DimensionMismatch())
if axes(x) != axes(y)
throw(DimensionMismatch("The first array has axes $(axes(x)) that do not match the axes of the second, $(axes(y)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`."))
end
nzind = nonzeroinds(y)
nzval = nonzeros(y)
s = dot(zero(Tx), zero(Ty))
Expand All @@ -1500,7 +1502,9 @@ end
function dot(x::SparseVectorUnion{Tx}, y::AbstractVector{Ty}) where {Tx<:Number,Ty<:Number}
require_one_based_indexing(x, y)
n = length(y)
length(x) == n || throw(DimensionMismatch())
if axes(x) != axes(y)
throw(DimensionMismatch("The first array has axes $(axes(x)) that do not match the axes of the second, $(axes(y)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`."))
end
nzind = nonzeroinds(x)
nzval = nonzeros(x)
s = dot(zero(Tx), zero(Ty))
Expand Down Expand Up @@ -1534,7 +1538,9 @@ end
function dot(x::SparseVectorUnion{<:Number}, y::SparseVectorUnion{<:Number})
x === y && return sum(abs2, x)
n = length(x)
length(y) == n || throw(DimensionMismatch())
if axes(x) != axes(y)
throw(DimensionMismatch("The first array has axes $(axes(x)) that do not match the axes of the second, $(axes(y)). You might want to use `dot(vec(x), vec(y))` if `length(x) == length(y)`."))
end

xnzind = nonzeroinds(x)
ynzind = nonzeroinds(y)
Expand Down