diff --git a/src/sparsevector.jl b/src/sparsevector.jl index a9568455..063f6f09 100644 --- a/src/sparsevector.jl +++ b/src/sparsevector.jl @@ -879,6 +879,45 @@ end ### Generic functions operating on AbstractSparseVector +## Explicit efficient comparisons with vectors + +function ==(A::AbstractCompressedVector, + B::AbstractCompressedVector) + # Different sizes are always different + size(A) ≠ size(B) && return false + # Compare nonzero elements + i, j = 1, 1 + @inbounds while i <= nnz(A) && j <= nnz(B) + if nonzeroinds(A)[i] == nonzeroinds(B)[j] + nonzeros(A)[i] == nonzeros(B)[j] || return false + i += 1 + j += 1 + elseif nonzeroinds(A)[i] <= nonzeroinds(B)[j] + iszero(nonzeros(A)[i]) || return false + i += 1 + else # nonzeroinds(A)[i] >= nonzeroinds(B)[j] + iszero(nonzeros(B)[j]) || return false + j += 1 + end + end + + @inbounds for k in i:nnz(A) + iszero(nonzeros(A)[k]) || return false + end + + @inbounds for k in j:nnz(B) + iszero(nonzeros(B)[k]) || return false + end + + return true +end + +==(A::Transpose{<:Any,<:AbstractCompressedVector}, + B::Transpose{<:Any,<:AbstractCompressedVector}) = transpose(A) == transpose(B) + +==(A::Adjoint{<:Any,<:AbstractCompressedVector}, + B::Adjoint{<:Any,<:AbstractCompressedVector}) = adjoint(A) == adjoint(B) + ### getindex function _spgetindex(m::Int, nzind::AbstractVector{Ti}, nzval::AbstractVector{Tv}, i::Integer) where {Tv,Ti} diff --git a/test/higherorderfns.jl b/test/higherorderfns.jl index f2749584..0e06816f 100644 --- a/test/higherorderfns.jl +++ b/test/higherorderfns.jl @@ -652,7 +652,8 @@ end @test ((_, x) -> x).(Int, spzeros(3)) == spzeros(3) @test ((_, _, x) -> x).(Int, Int, spzeros(3)) == spzeros(3) @test ((_, _, _, x) -> x).(Int, Int, Int, spzeros(3)) == spzeros(3) - @test_broken ((_, _, _, _, x) -> x).(Int, Int, Int, Int, spzeros(3)) == spzeros(3) + @test ((_, _, _, _, x) -> x).(Int, Int, Int, Int, spzeros(3)) == spzeros(3) + @test_broken typeof(((_, _, _, _, x) -> x).(Int, Int, Int, Int, spzeros(3))) == typeof(spzeros(3)) end using SparseArrays.HigherOrderFns: SparseVecStyle, SparseMatStyle diff --git a/test/issues.jl b/test/issues.jl index 6a131d41..01b53d85 100644 --- a/test/issues.jl +++ b/test/issues.jl @@ -13,7 +13,7 @@ include("simplesmatrix.jl") @testset "Issue #15" begin s = sparse([1, 2], [1, 2], [10, missing]) d = Matrix(s) - + s2 = sparse(d) @test s2[1, 1] == 10 diff --git a/test/sparsematrix_ops.jl b/test/sparsematrix_ops.jl index e047b156..68d0a8c8 100644 --- a/test/sparsematrix_ops.jl +++ b/test/sparsematrix_ops.jl @@ -530,4 +530,34 @@ Base.transpose(x::Counting) = Counting(transpose(x.elt)) end end + +@testset "Issue #246" begin + for t in [Int, UInt8, Float64] + a = Counting.(sprand(t, 100, 0.5)) + b = Counting.(sprand(t, 100, 0.5)) + + c = if nnz(a) != 0 + c = copy(a) + nonzeros(c)[1] = 0 + c + else + c = copy(a) + push!(nonzeros(c), zero(t)) + push!(nonzerosinds(c), 1) + c + end + d = dropzeros(c) + + for m in [identity, transpose, adjoint] + ma, mb, mc, md = m.([a, b, c, d]) + + resetcounter() + ma == mb + @test getcounter() <= nnz(a) + nnz(b) + + @test (mc == md) == (Array(mc) == Array(md)) + end + end +end + end # module