Skip to content

Commit

Permalink
Add scalar-array / and \
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Sep 28, 2020
1 parent 9663f93 commit cf992ea
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.20"
version = "0.7.21"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
19 changes: 19 additions & 0 deletions src/rulesets/LinearAlgebra/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,25 @@ function rrule(::typeof(*), B::AbstractArray{<:Real}, A::Real)
return A * B, times_pullback
end

#####
##### `\`, `/` matrix-scalar_rule

function rrule(::typeof(/), A::AbstractArray{<:Real}, b::Real)
Y = A/b
function slash_pullback(Ȳ)
return (NO_FIELDS, @thunk(Ȳ/b), @thunk(-dot(Ȳ, Y)/b))
end
return Y, slash_pullback
end

function rrule(::typeof(\), b::Real, A::AbstractArray{<:Real})
Y = b\A
function backslash_pullback(Ȳ)
return (NO_FIELDS, @thunk(-dot(Ȳ, Y)/b), @thunk(Ȳ/b))
end
return Y, backslash_pullback
end


#####
##### `pinv`
Expand Down
7 changes: 7 additions & 0 deletions test/rulesets/LinearAlgebra/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,13 @@
end
end
end
@testset "/ and \\ Scalar-AbstractArray" begin
A = randn(3, 4, 5)
= randn(3, 4, 5)
= randn(3, 4, 5)
rrule_test(/, Ȳ, (A, Ā), (7.2, 2.3))
rrule_test(\, Ȳ, (7.2, 2.3), (A, Ā))
end
@testset "norm" begin
for dims in [(), (5,), (3, 2), (7, 3, 2)]
A = randn(dims...)
Expand Down

0 comments on commit cf992ea

Please sign in to comment.