From 753cf98d04ab1452f07e12541e6ac972db31ccfb Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Mon, 5 Sep 2022 22:00:07 -0400 Subject: [PATCH 1/6] add == for vectors Closes #246 --- src/sparsevector.jl | 39 +++++++++++++++++++++++++++++++++++++++ test/sparsematrix_ops.jl | 16 ++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/src/sparsevector.jl b/src/sparsevector.jl index a9568455..063f6f09 100644 --- a/src/sparsevector.jl +++ b/src/sparsevector.jl @@ -879,6 +879,45 @@ end ### Generic functions operating on AbstractSparseVector +## Explicit efficient comparisons with vectors + +function ==(A::AbstractCompressedVector, + B::AbstractCompressedVector) + # Different sizes are always different + size(A) ≠ size(B) && return false + # Compare nonzero elements + i, j = 1, 1 + @inbounds while i <= nnz(A) && j <= nnz(B) + if nonzeroinds(A)[i] == nonzeroinds(B)[j] + nonzeros(A)[i] == nonzeros(B)[j] || return false + i += 1 + j += 1 + elseif nonzeroinds(A)[i] <= nonzeroinds(B)[j] + iszero(nonzeros(A)[i]) || return false + i += 1 + else # nonzeroinds(A)[i] >= nonzeroinds(B)[j] + iszero(nonzeros(B)[j]) || return false + j += 1 + end + end + + @inbounds for k in i:nnz(A) + iszero(nonzeros(A)[k]) || return false + end + + @inbounds for k in j:nnz(B) + iszero(nonzeros(B)[k]) || return false + end + + return true +end + +==(A::Transpose{<:Any,<:AbstractCompressedVector}, + B::Transpose{<:Any,<:AbstractCompressedVector}) = transpose(A) == transpose(B) + +==(A::Adjoint{<:Any,<:AbstractCompressedVector}, + B::Adjoint{<:Any,<:AbstractCompressedVector}) = adjoint(A) == adjoint(B) + ### getindex function _spgetindex(m::Int, nzind::AbstractVector{Ti}, nzval::AbstractVector{Tv}, i::Integer) where {Tv,Ti} diff --git a/test/sparsematrix_ops.jl b/test/sparsematrix_ops.jl index e047b156..f9a4e52d 100644 --- a/test/sparsematrix_ops.jl +++ b/test/sparsematrix_ops.jl @@ -530,4 +530,20 @@ Base.transpose(x::Counting) = Counting(transpose(x.elt)) end end + +@testset "Issue #246" begin + for t in [Int, UInt8, Float64] + a = Counting.(sprand(t, 100, 0.5)) + b = Counting.(sprand(t, 100, 0.5)) + for m in [identity, transpose, adjoint] + ma = m(a) + mb = m(b) + + resetcounter() + ma == mb + @test getcounter() <= nnz(a) + nnz(b) + end + end +end + end # module From a0790ebddad6b905a1103b2f6267d6e46e9035ca Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Mon, 5 Sep 2022 22:21:22 -0400 Subject: [PATCH 2/6] i seem to have fixed a test??? --- test/higherorderfns.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/higherorderfns.jl b/test/higherorderfns.jl index f2749584..033dcd84 100644 --- a/test/higherorderfns.jl +++ b/test/higherorderfns.jl @@ -652,7 +652,7 @@ end @test ((_, x) -> x).(Int, spzeros(3)) == spzeros(3) @test ((_, _, x) -> x).(Int, Int, spzeros(3)) == spzeros(3) @test ((_, _, _, x) -> x).(Int, Int, Int, spzeros(3)) == spzeros(3) - @test_broken ((_, _, _, _, x) -> x).(Int, Int, Int, Int, spzeros(3)) == spzeros(3) + @test ((_, _, _, _, x) -> x).(Int, Int, Int, Int, spzeros(3)) == spzeros(3) end using SparseArrays.HigherOrderFns: SparseVecStyle, SparseMatStyle From 3feb0c54520daecb96febe70a00394f442da12f2 Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Mon, 5 Sep 2022 22:24:37 -0400 Subject: [PATCH 3/6] make better broken test --- test/issues.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/issues.jl b/test/issues.jl index 6a131d41..01b53d85 100644 --- a/test/issues.jl +++ b/test/issues.jl @@ -13,7 +13,7 @@ include("simplesmatrix.jl") @testset "Issue #15" begin s = sparse([1, 2], [1, 2], [10, missing]) d = Matrix(s) - + s2 = sparse(d) @test s2[1, 1] == 10 From f7cc5c9103e456b77551902311bcb016a344a1ad Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Mon, 5 Sep 2022 22:25:15 -0400 Subject: [PATCH 4/6] make better broken test --- test/higherorderfns.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/higherorderfns.jl b/test/higherorderfns.jl index 033dcd84..0e06816f 100644 --- a/test/higherorderfns.jl +++ b/test/higherorderfns.jl @@ -653,6 +653,7 @@ end @test ((_, _, x) -> x).(Int, Int, spzeros(3)) == spzeros(3) @test ((_, _, _, x) -> x).(Int, Int, Int, spzeros(3)) == spzeros(3) @test ((_, _, _, _, x) -> x).(Int, Int, Int, Int, spzeros(3)) == spzeros(3) + @test_broken typeof(((_, _, _, _, x) -> x).(Int, Int, Int, Int, spzeros(3))) == typeof(spzeros(3)) end using SparseArrays.HigherOrderFns: SparseVecStyle, SparseMatStyle From 8670a27be01a80683d4f857ec1a3b24054916fbe Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Wed, 7 Sep 2022 12:39:15 -0400 Subject: [PATCH 5/6] improv test --- test/sparsematrix_ops.jl | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/test/sparsematrix_ops.jl b/test/sparsematrix_ops.jl index f9a4e52d..69b4afbd 100644 --- a/test/sparsematrix_ops.jl +++ b/test/sparsematrix_ops.jl @@ -4,7 +4,7 @@ module SparseTests using Test using SparseArrays -using SparseArrays: getcolptr, nonzeroinds, _show_with_braille_patterns, _isnotzero +using SparseArrays: getcolptr, nonzeros, nonzeroinds, _show_with_braille_patterns, _isnotzero using LinearAlgebra using Printf: @printf # for debug using Random @@ -535,13 +535,27 @@ end for t in [Int, UInt8, Float64] a = Counting.(sprand(t, 100, 0.5)) b = Counting.(sprand(t, 100, 0.5)) + + c = if nnz(a) != 0 + c = copy(a) + nonzeros(c)[1] = 0 + c + else + c = copy(a) + push!(nonzeros(c), zero(t)) + push!(nonzerosinds(c), 1) + c + end + d = dropzeros(c) + for m in [identity, transpose, adjoint] - ma = m(a) - mb = m(b) + ma, mb, mc, md = m.([a, b, c, d]) resetcounter() ma == mb @test getcounter() <= nnz(a) + nnz(b) + + @test (mc == md) == (Array(mc) == Array(md)) end end end From e4cece195715916855a2f9a43f6b238d0d26db7d Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Tue, 20 Sep 2022 16:59:33 -0400 Subject: [PATCH 6/6] Update test/sparsematrix_ops.jl Co-authored-by: Daniel Karrasch --- test/sparsematrix_ops.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/sparsematrix_ops.jl b/test/sparsematrix_ops.jl index 69b4afbd..68d0a8c8 100644 --- a/test/sparsematrix_ops.jl +++ b/test/sparsematrix_ops.jl @@ -4,7 +4,7 @@ module SparseTests using Test using SparseArrays -using SparseArrays: getcolptr, nonzeros, nonzeroinds, _show_with_braille_patterns, _isnotzero +using SparseArrays: getcolptr, nonzeroinds, _show_with_braille_patterns, _isnotzero using LinearAlgebra using Printf: @printf # for debug using Random