Skip to content

Commit

Permalink
restrict to CommutativeMulNumber
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Oct 20, 2020
1 parent 4a459e6 commit 84b5f71
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 3 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ version = "0.7.27"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
2 changes: 2 additions & 0 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ if VERSION < v"1.3.0-DEV.142"
import LinearAlgebra: dot
end

# numbers that we know commute under multiplication
const CommutativeMulNumber = Union{Real,Complex}

include("rulesets/Core/core.jl")

Expand Down
14 changes: 11 additions & 3 deletions src/rulesets/Base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ end
##### `*`
#####

function rrule(::typeof(*), A::AbstractMatrix{<:Number}, B::AbstractMatrix{<:Number})
function rrule(
::typeof(*),
A::AbstractMatrix{<:CommutativeMulNumber},
B::AbstractMatrix{<:CommutativeMulNumber},
)
function times_pullback(Ȳ)
return (
NO_FIELDS,
Expand All @@ -36,7 +40,9 @@ function rrule(::typeof(*), A::AbstractMatrix{<:Number}, B::AbstractMatrix{<:Num
return A * B, times_pullback
end

function rrule(::typeof(*), A::Number, B::AbstractArray{<:Number})
function rrule(
::typeof(*), A::CommutativeMulNumber, B::AbstractArray{<:CommutativeMulNumber}
)
function times_pullback(Ȳ)
return (
NO_FIELDS,
Expand All @@ -50,7 +56,9 @@ function rrule(::typeof(*), A::Number, B::AbstractArray{<:Number})
return A * B, times_pullback
end

function rrule(::typeof(*), B::AbstractArray{<:Number}, A::Number)
function rrule(
::typeof(*), B::AbstractArray{<:CommutativeMulNumber}, A::CommutativeMulNumber
)
function times_pullback(Ȳ)
return (
NO_FIELDS,
Expand Down

0 comments on commit 84b5f71

Please sign in to comment.