Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1653,7 +1653,7 @@ end
## fkeep! and children tril!, triu!, droptol!, dropzeros[!]

"""
fkeep!(A::AbstractSparseArray, f)
fkeep!(f, A::AbstractSparseArray)

Keep elements of `A` for which test `f` returns `true`. `f`'s signature should be

Expand All @@ -1673,15 +1673,15 @@ julia> A = sparse(Diagonal([1, 2, 3, 4]))
⋅ ⋅ 3 ⋅
⋅ ⋅ ⋅ 4

julia> SparseArrays.fkeep!(A, (i, j, v) -> isodd(v))
julia> SparseArrays.fkeep!((i, j, v) -> isodd(v), A)
4×4 SparseMatrixCSC{Int64, Int64} with 2 stored entries:
1 ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅
⋅ ⋅ 3 ⋅
⋅ ⋅ ⋅ ⋅
```
"""
function _fkeep!(A::AbstractSparseMatrixCSC, f::F) where F
function _fkeep!(f::F, A::AbstractSparseMatrixCSC) where F<:Function
An = size(A, 2)
Acolptr = getcolptr(A)
Arowval = rowvals(A)
Expand Down Expand Up @@ -1716,34 +1716,39 @@ function _fkeep!(A::AbstractSparseMatrixCSC, f::F) where F
return A
end

function _fkeep!_fixed(A::AbstractSparseMatrixCSC, f::F) where F
function _fkeep!_fixed(f::F, A::AbstractSparseMatrixCSC) where F<:Function
@inbounds for j in axes(A, 2)
for k in getcolptr(A)[j]:getcolptr(A)[j+1]-1
i = rowvals(A)[k]
x = nonzeros(A)[k]
# If this element should be kept, rewrite in new position
if f(Ai, Aj, Ax)
if !f(rowvals(A)[k], j, nonzeros(A)[k])
nonzeros(A)[k] = zero(eltype(A))
end
end
end
return A
end

fkeep!(A::AbstractSparseMatrixCSC, f::F) where F= _is_fixed(A) ? _fkeep!_fixed(A, f) : _fkeep!(A, f)
fkeep!(f::F, A::AbstractSparseMatrixCSC) where F<:Function = _is_fixed(A) ? _fkeep!_fixed(f, A) : _fkeep!(f, A)

# deprecated syntax
function fkeep!(x::Union{AbstractSparseMatrixCSC,AbstractCompressedVector},f::F) where F<:Function
Base.depwarn("`fkeep!(x, f::Function)` is deprecated, use `fkeep!(f::Function, x)` instead.", :fkeep!)
return fkeep!(f, x)
end


tril!(A::AbstractSparseMatrixCSC, k::Integer = 0) =
fkeep!(A, (i, j, x) -> i + k >= j)
fkeep!((i, j, x) -> i + k >= j, A)
triu!(A::AbstractSparseMatrixCSC, k::Integer = 0) =
fkeep!(A, (i, j, x) -> j >= i + k)
fkeep!((i, j, x) -> j >= i + k, A)

"""
droptol!(A::AbstractSparseMatrixCSC, tol)

Removes stored values from `A` whose absolute value is less than or equal to `tol`.
"""
droptol!(A::AbstractSparseMatrixCSC, tol) =
fkeep!(A, (i, j, x) -> abs(x) > tol)
fkeep!((i, j, x) -> abs(x) > tol, A)

"""
dropzeros!(A::AbstractSparseMatrixCSC;)
Expand All @@ -1754,7 +1759,7 @@ For an out-of-place version, see [`dropzeros`](@ref). For
algorithmic information, see `fkeep!`.
"""

dropzeros!(A::AbstractSparseMatrixCSC) = _is_fixed(A) ? A : fkeep!(A, (i, j, x) -> _isnotzero(x))
dropzeros!(A::AbstractSparseMatrixCSC) = _is_fixed(A) ? A : fkeep!((i, j, x) -> _isnotzero(x), A)

"""
dropzeros(A::AbstractSparseMatrixCSC;)
Expand Down
53 changes: 29 additions & 24 deletions src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2074,31 +2074,36 @@ function sort(x::AbstractCompressedVector{Tv,Ti}; kws...) where {Tv,Ti}
typeof(x)(n,newnzind,newnzvals)
end

function fkeep!(x::AbstractCompressedVector, f)
_is_fixed(x) && return x

nzind = nonzeroinds(x)
nzval = nonzeros(x)

x_writepos = 1
@inbounds for xk in 1:nnz(x)
xi = nzind[xk]
xv = nzval[xk]
# If this element should be kept, rewrite in new position
if f(xi, xv)
if x_writepos != xk
nzind[x_writepos] = xi
nzval[x_writepos] = xv
function fkeep!(f, x::AbstractCompressedVector{Tv}) where Tv
if _is_fixed(x)
for i in 1:nnz(x)
if !f(nonzeroinds(x)[i], nonzeros(x)[i])
nonzeros(x)[i] = zero(Tv)
end
end
else
nzind = nonzeroinds(x)
nzval = nonzeros(x)

x_writepos = 1
@inbounds for xk in 1:nnz(x)
xi = nzind[xk]
xv = nzval[xk]
# If this element should be kept, rewrite in new position
if f(xi, xv)
if x_writepos != xk
nzind[x_writepos] = xi
nzval[x_writepos] = xv
end
x_writepos += 1
end
x_writepos += 1
end
end

# Trim x's storage if necessary
x_nnz = x_writepos - 1
resize!(nzval, x_nnz)
resize!(nzind, x_nnz)

# Trim x's storage if necessary
x_nnz = x_writepos - 1
resize!(nzval, x_nnz)
resize!(nzind, x_nnz)
end
return x
end

Expand All @@ -2109,7 +2114,7 @@ end

Removes stored values from `x` whose absolute value is less than or equal to `tol`.
"""
droptol!(x::AbstractCompressedVector, tol) = fkeep!(x, (i, x) -> abs(x) > tol)
droptol!(x::AbstractCompressedVector, tol) = fkeep!((i, x) -> abs(x) > tol, x)

"""
dropzeros!(x::AbstractCompressedVector)
Expand All @@ -2119,7 +2124,7 @@ Removes stored numerical zeros from `x`.
For an out-of-place version, see [`dropzeros`](@ref). For
algorithmic information, see `fkeep!`.
"""
dropzeros!(x::AbstractCompressedVector) = _is_fixed(x) ? x : fkeep!(x, (i, x) -> _isnotzero(x))
dropzeros!(x::AbstractCompressedVector) = _is_fixed(x) ? x : fkeep!((i, x) -> _isnotzero(x), x)


"""
Expand Down
13 changes: 12 additions & 1 deletion test/fixed.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Test, SparseArrays, LinearAlgebra
using SparseArrays: AbstractSparseVector, AbstractSparseMatrixCSC, FixedSparseCSC, FixedSparseVector, ReadOnly,
getcolptr, rowvals, nonzeros, nonzeroinds, _is_fixed, fixed, move_fixed
getcolptr, rowvals, nonzeros, nonzeroinds, _is_fixed, fixed, move_fixed, fkeep!

@testset "ReadOnly" begin
v = randn(100)
Expand Down Expand Up @@ -124,3 +124,14 @@ end
@test b == a
end

always_false(x...) = false
@testset "Test fkeep!" begin
for a in [sprandn(10, 10, 0.99) + I, sprandn(10, 0.1) .+ 1]
a = fixed(a)
b = copy(a)
fkeep!(always_false, b)
@test nnz(a) == nnz(b)
@test all(iszero, nonzeros(b))

end
end
5 changes: 3 additions & 2 deletions test/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,7 @@ end
@test_throws ArgumentError findmin(x)
@test_throws ArgumentError findmax(x)
end

let v = spzeros(3) #Julia #44978
v[1] = 2
@test argmin(v) == 2
Expand Down Expand Up @@ -1298,8 +1298,9 @@ end
xdrop = copy(x)
# This will keep index 1, 3, 4, 7 in xdrop
f_drop(i, x) = (abs(x) == 1.) || (i in [1, 7])
SparseArrays.fkeep!(xdrop, f_drop)
SparseArrays.fkeep!(f_drop, xdrop)
@test exact_equal(xdrop, SparseVector(7, [1, 3, 4, 7], [3., -1., 1., 3.]))
@test_deprecated SparseArrays.fkeep!(xdrop, f_drop)
end

@testset "dropzeros[!] with length=$m" for m in (10, 20, 30)
Expand Down