diff --git a/src/sparsematrix.jl b/src/sparsematrix.jl index 60968c9e..bff7df5a 100644 --- a/src/sparsematrix.jl +++ b/src/sparsematrix.jl @@ -1653,7 +1653,7 @@ end ## fkeep! and children tril!, triu!, droptol!, dropzeros[!] """ - fkeep!(A::AbstractSparseArray, f) + fkeep!(f, A::AbstractSparseArray) Keep elements of `A` for which test `f` returns `true`. `f`'s signature should be @@ -1673,7 +1673,7 @@ julia> A = sparse(Diagonal([1, 2, 3, 4])) ⋅ ⋅ 3 ⋅ ⋅ ⋅ ⋅ 4 -julia> SparseArrays.fkeep!(A, (i, j, v) -> isodd(v)) +julia> SparseArrays.fkeep!((i, j, v) -> isodd(v), A) 4×4 SparseMatrixCSC{Int64, Int64} with 2 stored entries: 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ @@ -1681,7 +1681,7 @@ julia> SparseArrays.fkeep!(A, (i, j, v) -> isodd(v)) ⋅ ⋅ ⋅ ⋅ ``` """ -function _fkeep!(A::AbstractSparseMatrixCSC, f::F) where F +function _fkeep!(f::F, A::AbstractSparseMatrixCSC) where F<:Function An = size(A, 2) Acolptr = getcolptr(A) Arowval = rowvals(A) @@ -1716,13 +1716,11 @@ function _fkeep!(A::AbstractSparseMatrixCSC, f::F) where F return A end -function _fkeep!_fixed(A::AbstractSparseMatrixCSC, f::F) where F +function _fkeep!_fixed(f::F, A::AbstractSparseMatrixCSC) where F<:Function @inbounds for j in axes(A, 2) for k in getcolptr(A)[j]:getcolptr(A)[j+1]-1 - i = rowvals(A)[k] - x = nonzeros(A)[k] # If this element should be kept, rewrite in new position - if f(Ai, Aj, Ax) + if !f(rowvals(A)[k], j, nonzeros(A)[k]) nonzeros(A)[k] = zero(eltype(A)) end end @@ -1730,12 +1728,19 @@ function _fkeep!_fixed(A::AbstractSparseMatrixCSC, f::F) where F return A end -fkeep!(A::AbstractSparseMatrixCSC, f::F) where F= _is_fixed(A) ? _fkeep!_fixed(A, f) : _fkeep!(A, f) +fkeep!(f::F, A::AbstractSparseMatrixCSC) where F<:Function = _is_fixed(A) ? _fkeep!_fixed(f, A) : _fkeep!(f, A) + +# deprecated syntax +function fkeep!(x::Union{AbstractSparseMatrixCSC,AbstractCompressedVector},f::F) where F<:Function + Base.depwarn("`fkeep!(x, f::Function)` is deprecated, use `fkeep!(f::Function, x)` instead.", :fkeep!) + return fkeep!(f, x) +end + tril!(A::AbstractSparseMatrixCSC, k::Integer = 0) = - fkeep!(A, (i, j, x) -> i + k >= j) + fkeep!((i, j, x) -> i + k >= j, A) triu!(A::AbstractSparseMatrixCSC, k::Integer = 0) = - fkeep!(A, (i, j, x) -> j >= i + k) + fkeep!((i, j, x) -> j >= i + k, A) """ droptol!(A::AbstractSparseMatrixCSC, tol) @@ -1743,7 +1748,7 @@ triu!(A::AbstractSparseMatrixCSC, k::Integer = 0) = Removes stored values from `A` whose absolute value is less than or equal to `tol`. """ droptol!(A::AbstractSparseMatrixCSC, tol) = - fkeep!(A, (i, j, x) -> abs(x) > tol) + fkeep!((i, j, x) -> abs(x) > tol, A) """ dropzeros!(A::AbstractSparseMatrixCSC;) @@ -1754,7 +1759,7 @@ For an out-of-place version, see [`dropzeros`](@ref). For algorithmic information, see `fkeep!`. """ -dropzeros!(A::AbstractSparseMatrixCSC) = _is_fixed(A) ? A : fkeep!(A, (i, j, x) -> _isnotzero(x)) +dropzeros!(A::AbstractSparseMatrixCSC) = _is_fixed(A) ? A : fkeep!((i, j, x) -> _isnotzero(x), A) """ dropzeros(A::AbstractSparseMatrixCSC;) diff --git a/src/sparsevector.jl b/src/sparsevector.jl index f293a210..a9568455 100644 --- a/src/sparsevector.jl +++ b/src/sparsevector.jl @@ -2074,31 +2074,36 @@ function sort(x::AbstractCompressedVector{Tv,Ti}; kws...) where {Tv,Ti} typeof(x)(n,newnzind,newnzvals) end -function fkeep!(x::AbstractCompressedVector, f) - _is_fixed(x) && return x - - nzind = nonzeroinds(x) - nzval = nonzeros(x) - - x_writepos = 1 - @inbounds for xk in 1:nnz(x) - xi = nzind[xk] - xv = nzval[xk] - # If this element should be kept, rewrite in new position - if f(xi, xv) - if x_writepos != xk - nzind[x_writepos] = xi - nzval[x_writepos] = xv +function fkeep!(f, x::AbstractCompressedVector{Tv}) where Tv + if _is_fixed(x) + for i in 1:nnz(x) + if !f(nonzeroinds(x)[i], nonzeros(x)[i]) + nonzeros(x)[i] = zero(Tv) + end + end + else + nzind = nonzeroinds(x) + nzval = nonzeros(x) + + x_writepos = 1 + @inbounds for xk in 1:nnz(x) + xi = nzind[xk] + xv = nzval[xk] + # If this element should be kept, rewrite in new position + if f(xi, xv) + if x_writepos != xk + nzind[x_writepos] = xi + nzval[x_writepos] = xv + end + x_writepos += 1 end - x_writepos += 1 end - end - - # Trim x's storage if necessary - x_nnz = x_writepos - 1 - resize!(nzval, x_nnz) - resize!(nzind, x_nnz) + # Trim x's storage if necessary + x_nnz = x_writepos - 1 + resize!(nzval, x_nnz) + resize!(nzind, x_nnz) + end return x end @@ -2109,7 +2114,7 @@ end Removes stored values from `x` whose absolute value is less than or equal to `tol`. """ -droptol!(x::AbstractCompressedVector, tol) = fkeep!(x, (i, x) -> abs(x) > tol) +droptol!(x::AbstractCompressedVector, tol) = fkeep!((i, x) -> abs(x) > tol, x) """ dropzeros!(x::AbstractCompressedVector) @@ -2119,7 +2124,7 @@ Removes stored numerical zeros from `x`. For an out-of-place version, see [`dropzeros`](@ref). For algorithmic information, see `fkeep!`. """ -dropzeros!(x::AbstractCompressedVector) = _is_fixed(x) ? x : fkeep!(x, (i, x) -> _isnotzero(x)) +dropzeros!(x::AbstractCompressedVector) = _is_fixed(x) ? x : fkeep!((i, x) -> _isnotzero(x), x) """ diff --git a/test/fixed.jl b/test/fixed.jl index f15087df..b9d8bf87 100644 --- a/test/fixed.jl +++ b/test/fixed.jl @@ -1,6 +1,6 @@ using Test, SparseArrays, LinearAlgebra using SparseArrays: AbstractSparseVector, AbstractSparseMatrixCSC, FixedSparseCSC, FixedSparseVector, ReadOnly, - getcolptr, rowvals, nonzeros, nonzeroinds, _is_fixed, fixed, move_fixed + getcolptr, rowvals, nonzeros, nonzeroinds, _is_fixed, fixed, move_fixed, fkeep! @testset "ReadOnly" begin v = randn(100) @@ -124,3 +124,14 @@ end @test b == a end +always_false(x...) = false +@testset "Test fkeep!" begin + for a in [sprandn(10, 10, 0.99) + I, sprandn(10, 0.1) .+ 1] + a = fixed(a) + b = copy(a) + fkeep!(always_false, b) + @test nnz(a) == nnz(b) + @test all(iszero, nonzeros(b)) + + end +end \ No newline at end of file diff --git a/test/sparsevector.jl b/test/sparsevector.jl index 790d4aa3..76d7446b 100644 --- a/test/sparsevector.jl +++ b/test/sparsevector.jl @@ -913,7 +913,7 @@ end @test_throws ArgumentError findmin(x) @test_throws ArgumentError findmax(x) end - + let v = spzeros(3) #Julia #44978 v[1] = 2 @test argmin(v) == 2 @@ -1298,8 +1298,9 @@ end xdrop = copy(x) # This will keep index 1, 3, 4, 7 in xdrop f_drop(i, x) = (abs(x) == 1.) || (i in [1, 7]) - SparseArrays.fkeep!(xdrop, f_drop) + SparseArrays.fkeep!(f_drop, xdrop) @test exact_equal(xdrop, SparseVector(7, [1, 3, 4, 7], [3., -1., 1., 3.])) + @test_deprecated SparseArrays.fkeep!(xdrop, f_drop) end @testset "dropzeros[!] with length=$m" for m in (10, 20, 30)