Skip to content

Commit

Permalink
Gradient for dot(x,A,y) (#261)
Browse files Browse the repository at this point in the history
* rrule for 3-arg dot, take 1

* fix complex

* Update test/rulesets/LinearAlgebra/dense.jl

Co-authored-by: willtebbutt <wt0881@my.bristol.ac.uk>

* add method for Diagonal

* restrict to AbstractVector{<:Number} etc

* add PermutedDimsArray tests

* v0.7.30

Co-authored-by: Michael Abbott <me@escbook>
Co-authored-by: willtebbutt <wt0881@my.bristol.ac.uk>
  • Loading branch information
3 people committed Oct 23, 2020
1 parent 27516a9 commit cf36ac6
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.29"
version = "0.7.30"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
29 changes: 29 additions & 0 deletions src/rulesets/LinearAlgebra/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,35 @@ function rrule(::typeof(dot), x, y)
return dot(x, y), dot_pullback
end

function frule((_, Δx, ΔA, Δy), ::typeof(dot), x::AbstractVector{<:Number}, A::AbstractMatrix{<:Number}, y::AbstractVector{<:Number})
return dot(x, A, y), dot(Δx, A, y) + dot(x, ΔA, y) + dot(x, A, Δy)
end

function rrule(::typeof(dot), x::AbstractVector{<:Number}, A::AbstractMatrix{<:Number}, y::AbstractVector{<:Number})
Ay = A * y
z = adjoint(x) * Ay
function dot_pullback(ΔΩ)
dx = @thunk conj(ΔΩ) .* Ay
dA = @thunk ΔΩ .* x .* adjoint(y)
dy = @thunk ΔΩ .* (adjoint(A) * x)
return (NO_FIELDS, dx, dA, dy)
end
dot_pullback(::Zero) = (NO_FIELDS, Zero(), Zero(), Zero())
return z, dot_pullback
end

function rrule(::typeof(dot), x::AbstractVector{<:Number}, A::Diagonal{<:Number}, y::AbstractVector{<:Number})
z = dot(x,A,y)
function dot_pullback(ΔΩ)
dx = @thunk conj(ΔΩ) .* A.diag .* y # A*y is this broadcast, can be fused
dA = @thunk Diagonal(ΔΩ .* x .* conj(y)) # calculate N not N^2 elements
dy = @thunk ΔΩ .* conj.(A.diag) .* x
return (NO_FIELDS, dx, dA, dy)
end
dot_pullback(::Zero) = (NO_FIELDS, Zero(), Zero(), Zero())
return z, dot_pullback
end

#####
##### `cross`
#####
Expand Down
17 changes: 17 additions & 0 deletions test/rulesets/LinearAlgebra/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,23 @@
frule_test(dot, (x, ẋ), (y, ẏ))
rrule_test(dot, randn(T), (x, x̄), (y, ȳ))
end
@testset "3-arg dot, Array{$T}" for T in (Float64, ComplexF64)
M, N = 3, 4
x, A, y = randn(T, M), randn(T, M, N), randn(T, N)
ẋ, Adot, ẏ = randn(T, M), randn(T, M, N), randn(T, N)
x̄, Abar, ȳ = similar(x), similar(A), similar(y)
frule_test(dot, (x, ẋ), (A, Adot), (y, ẏ))
rrule_test(dot, randn(T), (x, x̄), (A, Abar), (y, ȳ))
end
permuteddimsarray(A) = PermutedDimsArray(A, (2,1))
@testset "3-arg dot, $F{$T}" for T in (Float32, ComplexF32), F in (adjoint, permuteddimsarray)
M, N = 3, 4
x, A, y = rand(T, M), F(rand(T, N, M)), rand(T, N)
ẋ, Adot, ẏ = rand(T, M), F(rand(T, N, M)), rand(T, N)
x̄, Abar, ȳ = similar(x), F(rand(T, N, M)), similar(y)
frule_test(dot, (x, ẋ), (A, Adot), (y, ẏ); rtol=1f-3)
rrule_test(dot, float(rand(T)), (x, x̄), (A, Abar), (y, ȳ); rtol=1f-3)
end
end
@testset "cross" begin
@testset "frule" begin
Expand Down
7 changes: 6 additions & 1 deletion test/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@
comp = Composite{typeof(res)}(; diag=10*res.diag) # this is the structure of Diagonal
@test pb(comp) == (NO_FIELDS, [10, 40])
end

@testset "dot(x, ::Diagonal, y)" begin
N = 4
x, d, y = randn(ComplexF64, N), randn(ComplexF64, N), randn(ComplexF64, N)
D = Diagonal(d)
rrule_test(dot, rand(ComplexF64), (x,similar(x)), (D,similar(D)), (y,similar(y)))
end
@testset "::Diagonal * ::AbstractVector" begin
N = 3
rrule_test(
Expand Down

2 comments on commit cf36ac6

@willtebbutt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/23507

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.30 -m "<description of version>" cf36ac6a2a245adbd565a7815c3e33e05e2adc65
git push origin v0.7.30

Please sign in to comment.