Skip to content

Commit

Permalink
Fix periodic kernel AD (#531)
Browse files Browse the repository at this point in the history
  • Loading branch information
simsurace committed Feb 7, 2024
1 parent 3b7a2df commit 49049b1
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ export tensor, ⊗, compose

using Compat
using ChainRulesCore: ChainRulesCore, Tangent, ZeroTangent, NoTangent
using ChainRulesCore: @thunk, InplaceableThunk
using ChainRulesCore: @thunk, InplaceableThunk, ProjectTo, unthunk
using CompositionsBase
using Distances
using FillArrays
Expand Down
104 changes: 101 additions & 3 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,113 @@ end
function ChainRulesCore.rrule(s::Sinus, x::AbstractVector, y::AbstractVector)
d = x - y
sind = sinpi.(d)
abs2_sind_r = abs2.(sind) ./ s.r
abs2_sind_r = abs2.(sind) ./ s.r .^ 2
val = sum(abs2_sind_r)
gradx = twoπ .* cospi.(d) .* sind ./ (s.r .^ 2)
gradx = twoπ .* cospi.(d) .* sind ./ s.r .^ 2
function evaluate_pullback::Any)
return (r=-2Δ .* abs2_sind_r,), Δ * gradx, -Δ * gradx
= -2Δ .* abs2_sind_r ./ s.r
= ChainRulesCore.Tangent{typeof(s)}(; r=r̄)
return s̄, Δ * gradx, -Δ * gradx
end
return val, evaluate_pullback
end

function ChainRulesCore.rrule(
::typeof(Distances.pairwise), d::Sinus, x::AbstractMatrix; dims=2
)
project_x = ProjectTo(x)
function pairwise_pullback(z̄)
Δ = unthunk(z̄)
n = size(x, dims)
= collect(zero(x))
= zero(d.r)
if dims == 1
for j in 1:n, i in 1:n
xi = view(x, i, :)
xj = view(x, j, :)
ds = twoπ .* Δ[i, j] .* sinpi.(xi .- xj) .* cospi.(xi .- xj) ./ d.r .^ 2
.-= 2 .* Δ[i, j] .* sinpi.(xi .- xj) .^ 2 ./ d.r .^ 3
x̄[i, :] += ds
x̄[j, :] -= ds
end
elseif dims == 2
for j in 1:n, i in 1:n
xi = view(x, :, i)
xj = view(x, :, j)
ds = twoπ .* Δ[i, j] .* sinpi.(xi .- xj) .* cospi.(xi .- xj) ./ d.r .^ 2
.-= 2 .* Δ[i, j] .* sinpi.(xi .- xj) .^ 2 ./ d.r .^ 3
x̄[:, i] += ds
x̄[:, j] -= ds
end
end
= ChainRulesCore.Tangent{typeof(d)}(; r=r̄)
return NoTangent(), d̄, @thunk(project_x(x̄))
end
return Distances.pairwise(d, x; dims), pairwise_pullback
end

function ChainRulesCore.rrule(
::typeof(Distances.pairwise), d::Sinus, x::AbstractMatrix, y::AbstractMatrix; dims=2
)
project_x = ProjectTo(x)
project_y = ProjectTo(y)
function pairwise_pullback(z̄)
Δ = unthunk(z̄)
n = size(x, dims)
m = size(y, dims)
= collect(zero(x))
= collect(zero(y))
= zero(d.r)
if dims == 1
for j in 1:m, i in 1:n
xi = view(x, i, :)
yj = view(y, j, :)
ds = twoπ .* Δ[i, j] .* sinpi.(xi .- yj) .* cospi.(xi .- yj) ./ d.r .^ 2
.-= 2 .* Δ[i, j] .* sinpi.(xi .- yj) .^ 2 ./ d.r .^ 3
x̄[i, :] += ds
ȳ[j, :] -= ds
end
elseif dims == 2
for j in 1:m, i in 1:n
xi = view(x, :, i)
yj = view(y, :, j)
ds = twoπ .* Δ[i, j] .* sinpi.(xi .- yj) .* cospi.(xi .- yj) ./ d.r .^ 2
.-= 2 .* Δ[i, j] .* sinpi.(xi .- yj) .^ 2 ./ d.r .^ 3
x̄[:, i] += ds
ȳ[:, j] -= ds
end
end
= ChainRulesCore.Tangent{typeof(d)}(; r=r̄)
return NoTangent(), d̄, @thunk(project_x(x̄)), @thunk(project_y(ȳ))
end
return Distances.pairwise(d, x, y; dims), pairwise_pullback
end

function ChainRulesCore.rrule(
::typeof(Distances.colwise), d::Sinus, x::AbstractMatrix, y::AbstractMatrix
)
project_x = ProjectTo(x)
project_y = ProjectTo(y)
function colwise_pullback(z̄)
Δ = unthunk(z̄)
n = size(x, 2)
= collect(zero(x))
= collect(zero(y))
= zero(d.r)
for i in 1:n
xi = view(x, :, i)
yi = view(y, :, i)
ds = twoπ .* Δ[i] .* sinpi.(xi .- yi) .* cospi.(xi .- yi) ./ d.r .^ 2
.-= 2 .* Δ[i] .* sinpi.(xi .- yi) .^ 2 ./ d.r .^ 3
x̄[:, i] += ds
ȳ[:, i] -= ds
end
= ChainRulesCore.Tangent{typeof(d)}(; r=r̄)
return NoTangent(), d̄, @thunk(project_x(x̄)), @thunk(project_y(ȳ))
end
return Distances.colwise(d, x, y), colwise_pullback
end

## Reverse Rules for matrix wrappers

function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix)
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
[deps]
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Expand Down
4 changes: 1 addition & 3 deletions test/basekernels/periodic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
TestUtils.test_interface(PeriodicKernel(; r=[0.9, 0.9]), ColVecs{Float64})
TestUtils.test_interface(PeriodicKernel(; r=[0.8, 0.7]), RowVecs{Float64})

# test_ADs(r->PeriodicKernel(r =exp.(r)), log.(r), ADs = [:ForwardDiff, :ReverseDiff])
# Undefined adjoint for Sinus metric, and failing randomly for ForwardDiff and ReverseDiff
@test_broken false
test_ADs(r -> PeriodicKernel(; r=exp.(r)), log.(r))
test_params(k, (r,))
end
19 changes: 19 additions & 0 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,23 @@
compare_gradient(:Zygote, [x, y]) do xy
KernelFunctions.Sinus(r)(xy[1], xy[2])
end
@testset "rrules for Sinus(r=$r)" for r in (rand(3),)
dist = KernelFunctions.Sinus(r)
@testset "$type" for type in (Vector, SVector{3})
test_rrule(dist, type(rand(3)), type(rand(3)))
end
@testset "$type1, $type2" for type1 in (Matrix, SMatrix{3,2}),
type2 in (Matrix, SMatrix{3,4})

test_rrule(Distances.pairwise, dist, type1(rand(3, 2)); fkwargs=(dims=2,))
test_rrule(
Distances.pairwise,
dist,
type1(rand(3, 2)),
type2(rand(3, 4));
fkwargs=(dims=2,),
)
test_rrule(Distances.colwise, dist, type1(rand(3, 2)), type1(rand(3, 2)))
end
end
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using KernelFunctions
using AxisArrays
using ChainRulesCore
using ChainRulesTestUtils
using Distances
using Documenter
using Functors: functor
Expand Down

0 comments on commit 49049b1

Please sign in to comment.