Skip to content

Commit

Permalink
sparsevec test
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Feb 18, 2022
1 parent 3d98d9d commit 5d5741c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/rulesets/SparseArrays/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ function rrule(::typeof(findnz), v::AbstractSparseVector)
function findnz_pullback(Δ)
_, V̄ = unthunk(Δ)
isa AbstractZero && return (NoTangent(), V̄)
return NoTangent(), sparse(I, V̄, n)
return NoTangent(), sparsevec(I, V̄, n)
end

return (I, V), findnz_pullback
Expand Down
7 changes: 7 additions & 0 deletions test/rulesets/SparseArrays/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,11 @@ end
I, J, V = findnz(A)
= rand!(similar(V))
test_rrule(findnz, A dA, output_tangent=(zeros(length(I)), zeros(length(J)), V̄))

v = sprand(5, 0.5)
dv = similar(v)
rand!(dv.nzval)
I, V = findnz(v)
= rand!(similar(V))
test_rrule(findnz, v dv, output_tangent=(zeros(length(I)), V̄))
end

0 comments on commit 5d5741c

Please sign in to comment.