Skip to content

Commit

Permalink
Merge 923ad31 into bbe68cc
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Oct 14, 2020
2 parents bbe68cc + 923ad31 commit eb7aa95
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/rulesets/Base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,14 @@ function rrule(::typeof(\), b::Real, A::AbstractArray{<:Real})
end
return Y, backslash_pullback
end

#####
##### Negation (Unary -)
#####

function rrule(::typeof(-), x::AbstractArray)
function negation_pullback(ȳ)
return NO_FIELDS, InplaceableThunk(@thunk(-ȳ), ā -> _subtract!!(ā, ȳ))
end
return -x, negation_pullback
end
3 changes: 3 additions & 0 deletions src/rulesets/LinearAlgebra/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ _mulsubtrans!!(X::AbstractZero, F::AbstractZero) = X
_mulsubtrans!!(X::AbstractZero, F::AbstractMatrix{<:Real}) = X
_mulsubtrans!!(X::AbstractMatrix{<:Real}, F::AbstractZero) = F

_subtract!!(x, y) = x - y
_subtract!!(x::Array, y::AbstractArray) = x .-= y

# I - X, overwrites X
function _eyesubx!(X::AbstractMatrix)
n, m = size(X)
Expand Down
9 changes: 9 additions & 0 deletions test/rulesets/Base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,13 @@
rrule_test(/, Ȳ, (A, Ā), (7.2, 2.3))
rrule_test(\, Ȳ, (7.2, 2.3), (A, Ā))
end


@testset "negation" begin
A = randn(4, 4)
= randn(4, 4)
= randn(4, 4)
rrule_test(-, Ȳ, (A, Ā))
rrule_test(-, Diagonal(Ȳ), (Diagonal(A), Diagonal(Ā)))
end
end
27 changes: 27 additions & 0 deletions test/rulesets/LinearAlgebra/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
@testset "utils.jl" begin
@testset "_subtract!!" begin
_subtract!! = ChainRules._subtract!!

@testset "Inplace" begin
x = [1, 2, 3]
ret = _subtract!!(x, ones(3))
@test ret === x
@test ret == [0, 1, 2]
end

@testset "Out of place" begin
x = Diagonal([2, 2])
ret = _subtract!!(x, ones(2, 2))
@test ret !== x
@test ret == [1 -1; -1 1]
end

@testset "Currently out of place, but this could change" begin
x = Diagonal([3, 3])
ret = _subtract!!(x, Diagonal([1,1]))
@test ret !== x
@test ret isa Diagonal
@test ret == Diagonal([2, 2])
end
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ println("Testing ChainRules.jl")
include(joinpath("rulesets", "LinearAlgebra", "structured.jl"))
include(joinpath("rulesets", "LinearAlgebra", "factorization.jl"))
include(joinpath("rulesets", "LinearAlgebra", "blas.jl"))
include(joinpath("rulesets", "LinearAlgebra", "utils.jl"))
end

print(" ")
Expand Down

0 comments on commit eb7aa95

Please sign in to comment.