From 23c531e55727e75c0fbfe80af5dd2c3c9af7420b Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Wed, 7 Feb 2024 13:46:25 +0100 Subject: [PATCH 1/6] Simplify rule according to feedback in #531 --- src/chainrules.jl | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 1f1a9a81b..321cf4f45 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -111,10 +111,9 @@ end function ChainRulesCore.rrule(s::Sinus, x::AbstractVector, y::AbstractVector) d = x - y - sind = sinpi.(d) - abs2_sind_r = abs2.(sind) ./ s.r .^ 2 + abs2_sind_r = (sinpi.(d) ./ s.r) .^ 2 val = sum(abs2_sind_r) - gradx = twoπ .* cospi.(d) .* sind ./ s.r .^ 2 + gradx = π .* cospi.(2 .* d) ./ s.r .^ 2 function evaluate_pullback(Δ::Any) r̄ = -2Δ .* abs2_sind_r ./ s.r s̄ = ChainRulesCore.Tangent{typeof(s)}(; r=r̄) @@ -136,7 +135,7 @@ function ChainRulesCore.rrule( for j in 1:n, i in 1:n xi = view(x, i, :) xj = view(x, j, :) - ds = twoπ .* Δ[i, j] .* sinpi.(xi .- xj) .* cospi.(xi .- xj) ./ d.r .^ 2 + ds = π .* Δ[i, j] .* cospi.(2 .* (xi .- xj)) ./ d.r .^ 2 r̄ .-= 2 .* Δ[i, j] .* sinpi.(xi .- xj) .^ 2 ./ d.r .^ 3 x̄[i, :] += ds x̄[j, :] -= ds @@ -147,8 +146,8 @@ function ChainRulesCore.rrule( xj = view(x, :, j) ds = twoπ .* Δ[i, j] .* sinpi.(xi .- xj) .* cospi.(xi .- xj) ./ d.r .^ 2 r̄ .-= 2 .* Δ[i, j] .* sinpi.(xi .- xj) .^ 2 ./ d.r .^ 3 - x̄[:, i] += ds - x̄[:, j] -= ds + x̄[:, i] .+= ds + x̄[:, j] .-= ds end end d̄ = ChainRulesCore.Tangent{typeof(d)}(; r=r̄) @@ -173,19 +172,19 @@ function ChainRulesCore.rrule( for j in 1:m, i in 1:n xi = view(x, i, :) yj = view(y, j, :) - ds = twoπ .* Δ[i, j] .* sinpi.(xi .- yj) .* cospi.(xi .- yj) ./ d.r .^ 2 + ds = π .* Δ[i, j] .* cospi.(2 .* (xi .- yj)) ./ d.r .^ 2 r̄ .-= 2 .* Δ[i, j] .* sinpi.(xi .- yj) .^ 2 ./ d.r .^ 3 - x̄[i, :] += ds - ȳ[j, :] -= ds + x̄[i, :] .+= ds + ȳ[j, :] .-= ds end elseif dims == 2 for j in 1:m, i in 1:n xi = view(x, :, i) yj = view(y, :, j) - ds = twoπ .* Δ[i, j] .* sinpi.(xi .- yj) .* cospi.(xi .- yj) ./ d.r .^ 2 + ds = π .* Δ[i, j] .* cospi.(2 .* (xi .- yj)) ./ d.r .^ 2 r̄ .-= 2 .* Δ[i, j] .* sinpi.(xi .- yj) .^ 2 ./ d.r .^ 3 - x̄[:, i] += ds - ȳ[:, j] -= ds + x̄[:, i] .+= ds + ȳ[:, j] .-= ds end end d̄ = ChainRulesCore.Tangent{typeof(d)}(; r=r̄) @@ -208,10 +207,10 @@ function ChainRulesCore.rrule( for i in 1:n xi = view(x, :, i) yi = view(y, :, i) - ds = twoπ .* Δ[i] .* sinpi.(xi .- yi) .* cospi.(xi .- yi) ./ d.r .^ 2 + ds = π .* Δ[i] .* cospi.(2 .* (xi .- yi)) ./ d.r .^ 2 r̄ .-= 2 .* Δ[i] .* sinpi.(xi .- yi) .^ 2 ./ d.r .^ 3 - x̄[:, i] += ds - ȳ[:, i] -= ds + x̄[:, i] .+= ds + ȳ[:, i] .-= ds end d̄ = ChainRulesCore.Tangent{typeof(d)}(; r=r̄) return NoTangent(), d̄, @thunk(project_x(x̄)), @thunk(project_y(ȳ)) From 63cc1189e9d4390c6a48e9d2caf45aa2d123b94b Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Wed, 7 Feb 2024 14:43:54 +0100 Subject: [PATCH 2/6] Correct rule Co-authored-by: David Widmann --- src/chainrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 321cf4f45..d9261793b 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -113,7 +113,7 @@ function ChainRulesCore.rrule(s::Sinus, x::AbstractVector, y::AbstractVector) d = x - y abs2_sind_r = (sinpi.(d) ./ s.r) .^ 2 val = sum(abs2_sind_r) - gradx = π .* cospi.(2 .* d) ./ s.r .^ 2 + gradx = π .* sinpi.(2 .* d) ./ s.r .^ 2 function evaluate_pullback(Δ::Any) r̄ = -2Δ .* abs2_sind_r ./ s.r s̄ = ChainRulesCore.Tangent{typeof(s)}(; r=r̄) From d9259c7621e4452cc42480e90717e3a805caae81 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Wed, 7 Feb 2024 14:44:02 +0100 Subject: [PATCH 3/6] Correct rule Co-authored-by: David Widmann --- src/chainrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index d9261793b..e9298d761 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -135,7 +135,7 @@ function ChainRulesCore.rrule( for j in 1:n, i in 1:n xi = view(x, i, :) xj = view(x, j, :) - ds = π .* Δ[i, j] .* cospi.(2 .* (xi .- xj)) ./ d.r .^ 2 + ds = π .* Δ[i, j] .* sinpi.(2 .* (xi .- xj)) ./ d.r .^ 2 r̄ .-= 2 .* Δ[i, j] .* sinpi.(xi .- xj) .^ 2 ./ d.r .^ 3 x̄[i, :] += ds x̄[j, :] -= ds From 6be1104a4430a07cbf20c1c8dbb07a8c1a5bc2ca Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Wed, 7 Feb 2024 14:44:10 +0100 Subject: [PATCH 4/6] Correct rule Co-authored-by: David Widmann --- src/chainrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index e9298d761..5a1ec2688 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -172,7 +172,7 @@ function ChainRulesCore.rrule( for j in 1:m, i in 1:n xi = view(x, i, :) yj = view(y, j, :) - ds = π .* Δ[i, j] .* cospi.(2 .* (xi .- yj)) ./ d.r .^ 2 + ds = π .* Δ[i, j] .* sinpi.(2 .* (xi .- yj)) ./ d.r .^ 2 r̄ .-= 2 .* Δ[i, j] .* sinpi.(xi .- yj) .^ 2 ./ d.r .^ 3 x̄[i, :] .+= ds ȳ[j, :] .-= ds From 1f326f2f32a86f418776dfd8252a48bcece471ed Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Wed, 7 Feb 2024 14:44:19 +0100 Subject: [PATCH 5/6] Correct rule Co-authored-by: David Widmann --- src/chainrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 5a1ec2688..785f684a6 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -181,7 +181,7 @@ function ChainRulesCore.rrule( for j in 1:m, i in 1:n xi = view(x, :, i) yj = view(y, :, j) - ds = π .* Δ[i, j] .* cospi.(2 .* (xi .- yj)) ./ d.r .^ 2 + ds = π .* Δ[i, j] .* sinpi.(2 .* (xi .- yj)) ./ d.r .^ 2 r̄ .-= 2 .* Δ[i, j] .* sinpi.(xi .- yj) .^ 2 ./ d.r .^ 3 x̄[:, i] .+= ds ȳ[:, j] .-= ds From 26a812a23da1403eb04c78a0de4a02a3bfa36f7d Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Wed, 7 Feb 2024 14:44:28 +0100 Subject: [PATCH 6/6] Correct rule Co-authored-by: David Widmann --- src/chainrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 785f684a6..549373876 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -207,7 +207,7 @@ function ChainRulesCore.rrule( for i in 1:n xi = view(x, :, i) yi = view(y, :, i) - ds = π .* Δ[i] .* cospi.(2 .* (xi .- yi)) ./ d.r .^ 2 + ds = π .* Δ[i] .* sinpi.(2 .* (xi .- yi)) ./ d.r .^ 2 r̄ .-= 2 .* Δ[i] .* sinpi.(xi .- yi) .^ 2 ./ d.r .^ 3 x̄[:, i] .+= ds ȳ[:, i] .-= ds