diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index 63205b5bf..80711b4ec 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -48,7 +48,7 @@ export tensor, ⊗, compose using Compat using ChainRulesCore: ChainRulesCore, Tangent, ZeroTangent, NoTangent -using ChainRulesCore: @thunk, InplaceableThunk +using ChainRulesCore: @thunk, InplaceableThunk, ProjectTo, unthunk using CompositionsBase using Distances using FillArrays diff --git a/src/chainrules.jl b/src/chainrules.jl index d31ec97d1..1f1a9a81b 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -112,15 +112,113 @@ end function ChainRulesCore.rrule(s::Sinus, x::AbstractVector, y::AbstractVector) d = x - y sind = sinpi.(d) - abs2_sind_r = abs2.(sind) ./ s.r + abs2_sind_r = abs2.(sind) ./ s.r .^ 2 val = sum(abs2_sind_r) - gradx = twoπ .* cospi.(d) .* sind ./ (s.r .^ 2) + gradx = twoπ .* cospi.(d) .* sind ./ s.r .^ 2 function evaluate_pullback(Δ::Any) - return (r=-2Δ .* abs2_sind_r,), Δ * gradx, -Δ * gradx + r̄ = -2Δ .* abs2_sind_r ./ s.r + s̄ = ChainRulesCore.Tangent{typeof(s)}(; r=r̄) + return s̄, Δ * gradx, -Δ * gradx end return val, evaluate_pullback end +function ChainRulesCore.rrule( + ::typeof(Distances.pairwise), d::Sinus, x::AbstractMatrix; dims=2 +) + project_x = ProjectTo(x) + function pairwise_pullback(z̄) + Δ = unthunk(z̄) + n = size(x, dims) + x̄ = collect(zero(x)) + r̄ = zero(d.r) + if dims == 1 + for j in 1:n, i in 1:n + xi = view(x, i, :) + xj = view(x, j, :) + ds = twoπ .* Δ[i, j] .* sinpi.(xi .- xj) .* cospi.(xi .- xj) ./ d.r .^ 2 + r̄ .-= 2 .* Δ[i, j] .* sinpi.(xi .- xj) .^ 2 ./ d.r .^ 3 + x̄[i, :] += ds + x̄[j, :] -= ds + end + elseif dims == 2 + for j in 1:n, i in 1:n + xi = view(x, :, i) + xj = view(x, :, j) + ds = twoπ .* Δ[i, j] .* sinpi.(xi .- xj) .* cospi.(xi .- xj) ./ d.r .^ 2 + r̄ .-= 2 .* Δ[i, j] .* sinpi.(xi .- xj) .^ 2 ./ d.r .^ 3 + x̄[:, i] += ds + x̄[:, j] -= ds + end + end + d̄ = ChainRulesCore.Tangent{typeof(d)}(; r=r̄) + return NoTangent(), d̄, @thunk(project_x(x̄)) + end + return Distances.pairwise(d, x; dims), pairwise_pullback +end + +function ChainRulesCore.rrule( + ::typeof(Distances.pairwise), d::Sinus, x::AbstractMatrix, y::AbstractMatrix; dims=2 +) + project_x = ProjectTo(x) + project_y = ProjectTo(y) + function pairwise_pullback(z̄) + Δ = unthunk(z̄) + n = size(x, dims) + m = size(y, dims) + x̄ = collect(zero(x)) + ȳ = collect(zero(y)) + r̄ = zero(d.r) + if dims == 1 + for j in 1:m, i in 1:n + xi = view(x, i, :) + yj = view(y, j, :) + ds = twoπ .* Δ[i, j] .* sinpi.(xi .- yj) .* cospi.(xi .- yj) ./ d.r .^ 2 + r̄ .-= 2 .* Δ[i, j] .* sinpi.(xi .- yj) .^ 2 ./ d.r .^ 3 + x̄[i, :] += ds + ȳ[j, :] -= ds + end + elseif dims == 2 + for j in 1:m, i in 1:n + xi = view(x, :, i) + yj = view(y, :, j) + ds = twoπ .* Δ[i, j] .* sinpi.(xi .- yj) .* cospi.(xi .- yj) ./ d.r .^ 2 + r̄ .-= 2 .* Δ[i, j] .* sinpi.(xi .- yj) .^ 2 ./ d.r .^ 3 + x̄[:, i] += ds + ȳ[:, j] -= ds + end + end + d̄ = ChainRulesCore.Tangent{typeof(d)}(; r=r̄) + return NoTangent(), d̄, @thunk(project_x(x̄)), @thunk(project_y(ȳ)) + end + return Distances.pairwise(d, x, y; dims), pairwise_pullback +end + +function ChainRulesCore.rrule( + ::typeof(Distances.colwise), d::Sinus, x::AbstractMatrix, y::AbstractMatrix +) + project_x = ProjectTo(x) + project_y = ProjectTo(y) + function colwise_pullback(z̄) + Δ = unthunk(z̄) + n = size(x, 2) + x̄ = collect(zero(x)) + ȳ = collect(zero(y)) + r̄ = zero(d.r) + for i in 1:n + xi = view(x, :, i) + yi = view(y, :, i) + ds = twoπ .* Δ[i] .* sinpi.(xi .- yi) .* cospi.(xi .- yi) ./ d.r .^ 2 + r̄ .-= 2 .* Δ[i] .* sinpi.(xi .- yi) .^ 2 ./ d.r .^ 3 + x̄[:, i] += ds + ȳ[:, i] -= ds + end + d̄ = ChainRulesCore.Tangent{typeof(d)}(; r=r̄) + return NoTangent(), d̄, @thunk(project_x(x̄)), @thunk(project_y(ȳ)) + end + return Distances.colwise(d, x, y), colwise_pullback +end + ## Reverse Rules for matrix wrappers function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix) diff --git a/test/Project.toml b/test/Project.toml index 4df8b9cdb..7ef690d1d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,7 @@ [deps] AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" diff --git a/test/basekernels/periodic.jl b/test/basekernels/periodic.jl index 046687962..540947b1b 100644 --- a/test/basekernels/periodic.jl +++ b/test/basekernels/periodic.jl @@ -15,8 +15,6 @@ TestUtils.test_interface(PeriodicKernel(; r=[0.9, 0.9]), ColVecs{Float64}) TestUtils.test_interface(PeriodicKernel(; r=[0.8, 0.7]), RowVecs{Float64}) - # test_ADs(r->PeriodicKernel(r =exp.(r)), log.(r), ADs = [:ForwardDiff, :ReverseDiff]) - # Undefined adjoint for Sinus metric, and failing randomly for ForwardDiff and ReverseDiff - @test_broken false + test_ADs(r -> PeriodicKernel(; r=exp.(r)), log.(r)) test_params(k, (r,)) end diff --git a/test/chainrules.jl b/test/chainrules.jl index 5c3c5766b..f2a4ae3a5 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -19,4 +19,23 @@ compare_gradient(:Zygote, [x, y]) do xy KernelFunctions.Sinus(r)(xy[1], xy[2]) end + @testset "rrules for Sinus(r=$r)" for r in (rand(3),) + dist = KernelFunctions.Sinus(r) + @testset "$type" for type in (Vector, SVector{3}) + test_rrule(dist, type(rand(3)), type(rand(3))) + end + @testset "$type1, $type2" for type1 in (Matrix, SMatrix{3,2}), + type2 in (Matrix, SMatrix{3,4}) + + test_rrule(Distances.pairwise, dist, type1(rand(3, 2)); fkwargs=(dims=2,)) + test_rrule( + Distances.pairwise, + dist, + type1(rand(3, 2)), + type2(rand(3, 4)); + fkwargs=(dims=2,), + ) + test_rrule(Distances.colwise, dist, type1(rand(3, 2)), type1(rand(3, 2))) + end + end end diff --git a/test/runtests.jl b/test/runtests.jl index caf43cb91..e054b992a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,7 @@ using KernelFunctions using AxisArrays +using ChainRulesCore +using ChainRulesTestUtils using Distances using Documenter using Functors: functor