diff --git a/Project.toml b/Project.toml index abe71c972..d4bcbd5d8 100644 --- a/Project.toml +++ b/Project.toml @@ -4,13 +4,11 @@ version = "0.9.24" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [compat] BenchmarkTools = "0.5" FiniteDifferences = "0.10" -MuladdMacro = "0.2.1" StaticArrays = "0.11, 0.12" julia = "1" diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 686c556fc..e9ddb3cd1 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -10,10 +10,10 @@ uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232" version = "0.5.10" [[ChainRulesCore]] -deps = ["LinearAlgebra", "MuladdMacro", "SparseArrays"] +deps = ["LinearAlgebra", "SparseArrays"] path = ".." uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.17" +version = "0.9.24" [[Dates]] deps = ["Printf"] @@ -30,20 +30,26 @@ uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" version = "0.8.3" [[Documenter]] -deps = ["Base64", "Dates", "DocStringExtensions", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"] -git-tree-sha1 = "fb1ff838470573adc15c71ba79f8d31328f035da" +deps = ["Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"] +git-tree-sha1 = "a4875e0763112d6d017126f3944f4133abb342ae" uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -version = "0.25.2" +version = "0.25.5" [[DocumenterTools]] deps = ["Base64", "DocStringExtensions", "Documenter", "FileWatching", "LibGit2", "Sass"] -git-tree-sha1 = "6fa30234228d9020cbe31e393e9d183e944845bb" +git-tree-sha1 = "9b40fd93f54ba5ef9d364981124a8ed389fd634e" uuid = "35a29f4d-8980-5a13-9543-d66fff28ecb8" -version = "0.1.7" +version = "0.1.9" [[FileWatching]] uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" +[[IOCapture]] +deps = ["Logging"] +git-tree-sha1 = "377252859f740c217b936cebcd918a44f9b53b59" +uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89" +version = "0.1.1" + [[InteractiveUtils]] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" @@ -75,16 +81,11 @@ uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" [[Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" -[[MuladdMacro]] -git-tree-sha1 = "c6190f9a7fc5d9d5915ab29f2134421b12d24a68" -uuid = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" -version = "0.2.2" - [[Parsers]] deps = ["Dates"] -git-tree-sha1 = "6fa4202675c05ba0f8268a6ddf07606350eda3ce" +git-tree-sha1 = "50c9a9ed8c714945e01cd53a21007ed3865ed714" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "1.0.11" +version = "1.0.15" [[Pkg]] deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 400a34540..5361b9a81 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -2,7 +2,6 @@ module ChainRulesCore using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize! using LinearAlgebra: LinearAlgebra using SparseArrays: SparseVector, SparseMatrixCSC -using MuladdMacro: @muladd export on_new_rule, refresh_rules # generation tools export frule, rrule # core function diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index aeed853a5..b68717693 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -207,28 +207,32 @@ end """ function propagation_expr(Δs, ∂s, _conj = false) # This is basically Δs ⋅ ∂s - ∂s = map(esc, ∂s) - n∂s = length(∂s) - - # Due to bugs in Julia 1.0, we can't use `.+` or `.*` inside expression literals. - ∂_mul_Δs = if _conj - ntuple(i->:(conj($(∂s[i])) * $(Δs[i])), n∂s) - else - ntuple(i->:($(∂s[i]) * $(Δs[i])), n∂s) + _∂s = map(∂s) do ∂s_i + if _conj + :(conj($(esc(∂s_i)))) + else + esc(∂s_i) + end end + n∂s = length(_∂s) + + summed_∂_mul_Δs = if n∂s > 1 + # Explicit multiplication is only performed for the first pair + # of partial and gradient. + init_expr = :((*).($(_∂s[1]), $(Δs[1]))) - # Avoiding the extra `+` operation, it is potentially expensive for vector mode AD. - sumed_∂_mul_Δs = if n∂s > 1 - # we use `@.` to broadcast `*` and `+` - :(@. +($(∂_mul_Δs...))) + # Apply `muladd` iteratively. + foldl(Iterators.drop(zip(_∂s, Δs), 1); init=init_expr) do ex, (∂s_i, Δs_i) + :((muladd).($∂s_i, $Δs_i, $ex)) + end else # Note: we don't want to do broadcasting with only 1 multiply (no `+`), # because some arrays overload multiply with scalar. Avoiding # broadcasting saves compilation time. - ∂_mul_Δs[1] + :($(_∂s[1]) * $(Δs[1])) end - return :(@muladd $sumed_∂_mul_Δs) + return summed_∂_mul_Δs end """