Skip to content

Commit

Permalink
Don't thunk if only one adjoint exists (#235)
Browse files Browse the repository at this point in the history
* Don't thunk where only one adjoint exists

* Bump ChainRulesTestUtils compat

* Increment version number
  • Loading branch information
sethaxen committed Jul 13, 2020
1 parent d3cd83e commit 9a364d6
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 40 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.8"
version = "0.7.9"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -12,7 +12,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "0.9"
ChainRulesTestUtils = "0.4.2"
ChainRulesTestUtils = "0.4.2, 0.5"
Compat = "3"
FiniteDifferences = "0.10"
Reexport = "0.2"
Expand Down
8 changes: 4 additions & 4 deletions src/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@

function rrule(::typeof(reshape), A::AbstractArray, dims::Tuple{Vararg{Int}})
function reshape_pullback(Ȳ)
return (NO_FIELDS, @thunk(reshape(Ȳ, dims)), DoesNotExist())
return (NO_FIELDS, reshape(Ȳ, dims), DoesNotExist())
end
return reshape(A, dims), reshape_pullback
end

function rrule(::typeof(reshape), A::AbstractArray, dims::Int...)
function reshape_pullback(Ȳ)
∂A = @thunk(reshape(Ȳ, dims))
∂A = reshape(Ȳ, dims)
return (NO_FIELDS, ∂A, fill(DoesNotExist(), length(dims))...)
end
return reshape(A, dims...), reshape_pullback
Expand Down Expand Up @@ -63,14 +63,14 @@ end

function rrule(::typeof(fill), value::Any, dims::Tuple{Vararg{Int}})
function fill_pullback(Ȳ)
return (NO_FIELDS, @thunk(sum(Ȳ)), DoesNotExist())
return (NO_FIELDS, sum(Ȳ), DoesNotExist())
end
return fill(value, dims), fill_pullback
end

function rrule(::typeof(fill), value::Any, dims::Int...)
function fill_pullback(Ȳ)
return (NO_FIELDS, @thunk(sum(Ȳ)), ntuple(_->DoesNotExist(), length(dims))...)
return (NO_FIELDS, sum(Ȳ), ntuple(_->DoesNotExist(), length(dims))...)
end
return fill(value, dims), fill_pullback
end
4 changes: 2 additions & 2 deletions src/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ function rrule(::typeof(sum), x::AbstractArray{T}; dims=:) where {T<:Number}
y = sum(sum, x; dims=dims)
function sum_pullback(ȳ)
# broadcasting the two works out the size no-matter `dims`
= @thunk broadcast(x, ȳ) do xi, ȳi
= broadcast(x, ȳ) do xi, ȳi
ȳi
end
return (NO_FIELDS, x̄)
Expand Down Expand Up @@ -44,7 +44,7 @@ function rrule(
) where {T<:Union{Real,Complex}}
y = sum(abs2, x; dims=dims)
function sum_abs2_pullback(ȳ)
return (NO_FIELDS, DoesNotExist(), @thunk(2 .* real.(ȳ) .* x))
return (NO_FIELDS, DoesNotExist(), 2 .* real.(ȳ) .* x)
end
return y, sum_abs2_pullback
end
2 changes: 1 addition & 1 deletion src/rulesets/LinearAlgebra/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ function rrule(::typeof(BLAS.asum), n, X, incx)
function asum_pullback(ΔΩ)
# BLAS.scal! requires s has the same eltype as X
s = eltype(X)(real(ΔΩ))
∂X = @thunk scal!(n, s, blascopy!(n, _signcomp.(X), incx, _zeros(X), incx), incx)
∂X = scal!(n, s, blascopy!(n, _signcomp.(X), incx, _zeros(X), incx), incx)
return (NO_FIELDS, DoesNotExist(), ∂X, DoesNotExist())
end
return Ω, asum_pullback
Expand Down
2 changes: 1 addition & 1 deletion src/rulesets/LinearAlgebra/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ function rrule(::typeof(tr), x)
# This should really be a FillArray
# see https://github.com/JuliaDiff/ChainRules.jl/issues/46
function tr_pullback(ΔΩ)
return (NO_FIELDS, @thunk Diagonal(fill(ΔΩ, size(x, 1))))
return (NO_FIELDS, Diagonal(fill(ΔΩ, size(x, 1))))
end
return tr(x), tr_pullback
end
Expand Down
8 changes: 4 additions & 4 deletions src/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using LinearAlgebra.BLAS: gemv, gemv!, gemm!, trsm!, axpy!, ger!
function rrule(::typeof(svd), X::AbstractMatrix{<:Real})
F = svd(X)
function svd_pullback::Composite)
∂X = @thunk(svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.V))
∂X = svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.V)
return (NO_FIELDS, ∂X)
end
return F, svd_pullback
Expand Down Expand Up @@ -73,9 +73,9 @@ function rrule(::typeof(cholesky), X::AbstractMatrix{<:Real})
F = cholesky(X)
function cholesky_pullback::Composite)
∂X = if F.uplo === 'U'
@thunk(chol_blocked_rev.U, F.U, 25, true))
chol_blocked_rev.U, F.U, 25, true)
else
@thunk(chol_blocked_rev.L, F.L, 25, false))
chol_blocked_rev.L, F.L, 25, false)
end
return (NO_FIELDS, ∂X)
end
Expand All @@ -85,7 +85,7 @@ end
function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: Cholesky
function getproperty_cholesky_pullback(Ȳ)
C = Composite{T}
∂F = @thunk if x === :U
∂F = if x === :U
if F.uplo === 'U'
C(U=UpperTriangular(Ȳ),)
else
Expand Down
42 changes: 20 additions & 22 deletions src/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ end

function rrule(::typeof(diag), A::AbstractMatrix)
function diag_pullback(ȳ)
return (NO_FIELDS, @thunk(Diagonal(ȳ)))
return (NO_FIELDS, Diagonal(ȳ))
end
return diag(A), diag_pullback
end
if VERSION v"1.3"
function rrule(::typeof(diag), A::AbstractMatrix, k::Integer)
function diag_pullback(ȳ)
return (NO_FIELDS, @thunk(diagm(size(A)..., k =>)), DoesNotExist())
return (NO_FIELDS, diagm(size(A)..., k => ȳ), DoesNotExist())
end
return diag(A, k), diag_pullback
end
Expand All @@ -48,11 +48,9 @@ function rrule(::typeof(diagm), kv::Pair{<:Integer,<:AbstractVector}...)
end

function _diagm_back(p, ȳ)
return Thunk() do
k, v = p
d = diag(ȳ, k)[1:length(v)] # handle if diagonal was smaller than matrix
return Composite{typeof(p)}(second = d)
end
k, v = p
d = diag(ȳ, k)[1:length(v)] # handle if diagonal was smaller than matrix
return Composite{typeof(p)}(second = d)
end

function rrule(::typeof(*), D::Diagonal{<:Real}, V::AbstractVector{<:Real})
Expand All @@ -73,7 +71,7 @@ end
function rrule(T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo)
Ω = T(A, uplo)
function HermOrSym_pullback(ΔΩ)
return (NO_FIELDS, @thunk(_symherm_back(T, ΔΩ, Ω.uplo)), DoesNotExist())
return (NO_FIELDS, _symherm_back(T, ΔΩ, Ω.uplo), DoesNotExist())
end
return Ω, HermOrSym_pullback
end
Expand Down Expand Up @@ -149,28 +147,28 @@ end
# ✖️✖️✖️TODO: Deal with complex-valued arrays as well
function rrule(::Type{<:Adjoint}, A::AbstractMatrix{<:Real})
function Adjoint_pullback(ȳ)
return (NO_FIELDS, @thunk(adjoint(ȳ)))
return (NO_FIELDS, adjoint(ȳ))
end
return Adjoint(A), Adjoint_pullback
end

function rrule(::Type{<:Adjoint}, A::AbstractVector{<:Real})
function Adjoint_pullback(ȳ)
return (NO_FIELDS, @thunk(vec(adjoint(ȳ))))
return (NO_FIELDS, vec(adjoint(ȳ)))
end
return Adjoint(A), Adjoint_pullback
end

function rrule(::typeof(adjoint), A::AbstractMatrix{<:Real})
function adjoint_pullback(ȳ)
return (NO_FIELDS, @thunk(adjoint(ȳ)))
return (NO_FIELDS, adjoint(ȳ))
end
return adjoint(A), adjoint_pullback
end

function rrule(::typeof(adjoint), A::AbstractVector{<:Real})
function adjoint_pullback(ȳ)
return (NO_FIELDS, @thunk(vec(adjoint(ȳ))))
return (NO_FIELDS, vec(adjoint(ȳ)))
end
return adjoint(A), adjoint_pullback
end
Expand All @@ -181,28 +179,28 @@ end

function rrule(::Type{<:Transpose}, A::AbstractMatrix)
function Transpose_pullback(ȳ)
return (NO_FIELDS, @thunk transpose(ȳ))
return (NO_FIELDS, transpose(ȳ))
end
return Transpose(A), Transpose_pullback
end

function rrule(::Type{<:Transpose}, A::AbstractVector)
function Transpose_pullback(ȳ)
return (NO_FIELDS, @thunk vec(transpose(ȳ)))
return (NO_FIELDS, vec(transpose(ȳ)))
end
return Transpose(A), Transpose_pullback
end

function rrule(::typeof(transpose), A::AbstractMatrix)
function transpose_pullback(ȳ)
return (NO_FIELDS, @thunk transpose(ȳ))
return (NO_FIELDS, transpose(ȳ))
end
return transpose(A), transpose_pullback
end

function rrule(::typeof(transpose), A::AbstractVector)
function transpose_pullback(ȳ)
return (NO_FIELDS, @thunk vec(transpose(ȳ)))
return (NO_FIELDS, vec(transpose(ȳ)))
end
return transpose(A), transpose_pullback
end
Expand All @@ -213,40 +211,40 @@ end

function rrule(::Type{<:UpperTriangular}, A::AbstractMatrix)
function UpperTriangular_pullback(ȳ)
return (NO_FIELDS, @thunk Matrix(ȳ))
return (NO_FIELDS, Matrix(ȳ))
end
return UpperTriangular(A), UpperTriangular_pullback
end

function rrule(::Type{<:LowerTriangular}, A::AbstractMatrix)
function LowerTriangular_pullback(ȳ)
return (NO_FIELDS, @thunk Matrix(ȳ))
return (NO_FIELDS, Matrix(ȳ))
end
return LowerTriangular(A), LowerTriangular_pullback
end

function rrule(::typeof(triu), A::AbstractMatrix, k::Integer)
function triu_pullback(ȳ)
return (NO_FIELDS, @thunk(triu(ȳ, k)), DoesNotExist())
return (NO_FIELDS, triu(ȳ, k), DoesNotExist())
end
return triu(A, k), triu_pullback
end
function rrule(::typeof(triu), A::AbstractMatrix)
function triu_pullback(ȳ)
return (NO_FIELDS, @thunk triu(ȳ))
return (NO_FIELDS, triu(ȳ))
end
return triu(A), triu_pullback
end

function rrule(::typeof(tril), A::AbstractMatrix, k::Integer)
function tril_pullback(ȳ)
return (NO_FIELDS, @thunk(tril(ȳ, k)), DoesNotExist())
return (NO_FIELDS, tril(ȳ, k), DoesNotExist())
end
return tril(A, k), tril_pullback
end
function rrule(::typeof(tril), A::AbstractMatrix)
function tril_pullback(ȳ)
return (NO_FIELDS, @thunk tril(ȳ))
return (NO_FIELDS, tril(ȳ))
end
return tril(A), tril_pullback
end
6 changes: 2 additions & 4 deletions src/rulesets/Statistics/statistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@ function rrule(::typeof(mean), x::AbstractArray{<:Real}; dims=:)
y_sum, sum_pullback = rrule(sum, x; dims=dims)
n = _denom(x, dims)
function mean_pullback(ȳ)
∂x = Thunk() do
_, ∂sum_x = sum_pullback(ȳ)
extern(∂sum_x) / n
end
_, ∂sum_x = sum_pullback(ȳ)
∂x = extern(∂sum_x) / n
return (NO_FIELDS, ∂x)
end
return y_sum / n, mean_pullback
Expand Down

2 comments on commit 9a364d6

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/17870

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.9 -m "<description of version>" 9a364d6419a8cd74b7b449f69b2470f540bc3d9d
git push origin v0.7.9

Please sign in to comment.