diff --git a/Project.toml b/Project.toml index 5635552b1..b258fe502 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.26.1" +version = "1.27.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/SparseArrays/sparsematrix.jl b/src/rulesets/SparseArrays/sparsematrix.jl index 22a5d0366..51b41421c 100644 --- a/src/rulesets/SparseArrays/sparsematrix.jl +++ b/src/rulesets/SparseArrays/sparsematrix.jl @@ -10,16 +10,42 @@ function rrule(::typeof(sparse), I::AbstractVector, J::AbstractVector, V::Abstra return sparse(I, J, V, m, n, combine), sparse_pullback end -function rrule(::Type{T}, A::AbstractMatrix) where T <: SparseMatrixCSC +function rrule(::Type{T}, A::AbstractMatrix) where T <: AbstractSparseMatrix function sparse_pullback(Ω̄) return NoTangent(), Ω̄ end return T(A), sparse_pullback end -function rrule(::Type{T}, v::AbstractVector) where T <: SparseVector +function rrule(::Type{T}, v::AbstractVector) where T <: AbstractSparseVector function sparse_pullback(Ω̄) return NoTangent(), Ω̄ end return T(v), sparse_pullback end + +function rrule(::typeof(findnz), A::AbstractSparseMatrix) + I, J, V = findnz(A) + m, n = size(A) + + function findnz_pullback(Δ) + _, _, V̄ = unthunk(Δ) + V̄ isa AbstractZero && return (NoTangent(), V̄) + return NoTangent(), sparse(I, J, V̄, m, n) + end + + return (I, J, V), findnz_pullback +end + +function rrule(::typeof(findnz), v::AbstractSparseVector) + I, V = findnz(v) + n = length(v) + + function findnz_pullback(Δ) + _, V̄ = unthunk(Δ) + V̄ isa AbstractZero && return (NoTangent(), V̄) + return NoTangent(), sparsevec(I, V̄, n) + end + + return (I, V), findnz_pullback +end diff --git a/test/rulesets/SparseArrays/sparsematrix.jl b/test/rulesets/SparseArrays/sparsematrix.jl index 3e239cb1c..a11a1e963 100644 --- a/test/rulesets/SparseArrays/sparsematrix.jl +++ b/test/rulesets/SparseArrays/sparsematrix.jl @@ -9,11 +9,27 @@ end @testset "SparseMatrixCSC(A)" begin A = rand(5, 3) test_rrule(SparseMatrixCSC, A) - test_rrule(SparseMatrixCSC{Float32,Int}, A, rtol=1e-5) + test_rrule(SparseMatrixCSC{Float32,Int}, A, rtol=1e-4) end @testset "SparseVector(v)" begin v = rand(5) test_rrule(SparseVector, v) - test_rrule(SparseVector{Float32}, Float32.(v), rtol=1e-5) + test_rrule(SparseVector{Float32}, Float32.(v), rtol=1e-4) +end + +@testset "findnz" begin + A = sprand(5, 5, 0.5) + dA = similar(A) + rand!(dA.nzval) + I, J, V = findnz(A) + V̄ = rand!(similar(V)) + test_rrule(findnz, A ⊢ dA, output_tangent=(zeros(length(I)), zeros(length(J)), V̄)) + + v = sprand(5, 0.5) + dv = similar(v) + rand!(dv.nzval) + I, V = findnz(v) + V̄ = rand!(similar(V)) + test_rrule(findnz, v ⊢ dv, output_tangent=(zeros(length(I)), V̄)) end