Skip to content

Commit

Permalink
Simplify rule according to feedback in #531 (#549)
Browse files Browse the repository at this point in the history
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
  • Loading branch information
simsurace and devmotion committed Feb 7, 2024
1 parent 49049b1 commit 16a9828
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,9 @@ end

function ChainRulesCore.rrule(s::Sinus, x::AbstractVector, y::AbstractVector)
d = x - y
sind = sinpi.(d)
abs2_sind_r = abs2.(sind) ./ s.r .^ 2
abs2_sind_r = (sinpi.(d) ./ s.r) .^ 2
val = sum(abs2_sind_r)
gradx = twoπ .* cospi.(d) .* sind ./ s.r .^ 2
gradx = π .* sinpi.(2 .* d) ./ s.r .^ 2
function evaluate_pullback::Any)
= -2Δ .* abs2_sind_r ./ s.r
= ChainRulesCore.Tangent{typeof(s)}(; r=r̄)
Expand All @@ -136,7 +135,7 @@ function ChainRulesCore.rrule(
for j in 1:n, i in 1:n
xi = view(x, i, :)
xj = view(x, j, :)
ds = twoπ .* Δ[i, j] .* sinpi.(xi .- xj) .* cospi.(xi .- xj) ./ d.r .^ 2
ds = π .* Δ[i, j] .* sinpi.(2 .* (xi .- xj)) ./ d.r .^ 2
.-= 2 .* Δ[i, j] .* sinpi.(xi .- xj) .^ 2 ./ d.r .^ 3
x̄[i, :] += ds
x̄[j, :] -= ds
Expand All @@ -147,8 +146,8 @@ function ChainRulesCore.rrule(
xj = view(x, :, j)
ds = twoπ .* Δ[i, j] .* sinpi.(xi .- xj) .* cospi.(xi .- xj) ./ d.r .^ 2
.-= 2 .* Δ[i, j] .* sinpi.(xi .- xj) .^ 2 ./ d.r .^ 3
x̄[:, i] += ds
x̄[:, j] -= ds
x̄[:, i] .+= ds
x̄[:, j] .-= ds
end
end
= ChainRulesCore.Tangent{typeof(d)}(; r=r̄)
Expand All @@ -173,19 +172,19 @@ function ChainRulesCore.rrule(
for j in 1:m, i in 1:n
xi = view(x, i, :)
yj = view(y, j, :)
ds = twoπ .* Δ[i, j] .* sinpi.(xi .- yj) .* cospi.(xi .- yj) ./ d.r .^ 2
ds = π .* Δ[i, j] .* sinpi.(2 .* (xi .- yj)) ./ d.r .^ 2
.-= 2 .* Δ[i, j] .* sinpi.(xi .- yj) .^ 2 ./ d.r .^ 3
x̄[i, :] += ds
ȳ[j, :] -= ds
x̄[i, :] .+= ds
ȳ[j, :] .-= ds
end
elseif dims == 2
for j in 1:m, i in 1:n
xi = view(x, :, i)
yj = view(y, :, j)
ds = twoπ .* Δ[i, j] .* sinpi.(xi .- yj) .* cospi.(xi .- yj) ./ d.r .^ 2
ds = π .* Δ[i, j] .* sinpi.(2 .* (xi .- yj)) ./ d.r .^ 2
.-= 2 .* Δ[i, j] .* sinpi.(xi .- yj) .^ 2 ./ d.r .^ 3
x̄[:, i] += ds
ȳ[:, j] -= ds
x̄[:, i] .+= ds
ȳ[:, j] .-= ds
end
end
= ChainRulesCore.Tangent{typeof(d)}(; r=r̄)
Expand All @@ -208,10 +207,10 @@ function ChainRulesCore.rrule(
for i in 1:n
xi = view(x, :, i)
yi = view(y, :, i)
ds = twoπ .* Δ[i] .* sinpi.(xi .- yi) .* cospi.(xi .- yi) ./ d.r .^ 2
ds = π .* Δ[i] .* sinpi.(2 .* (xi .- yi)) ./ d.r .^ 2
.-= 2 .* Δ[i] .* sinpi.(xi .- yi) .^ 2 ./ d.r .^ 3
x̄[:, i] += ds
ȳ[:, i] -= ds
x̄[:, i] .+= ds
ȳ[:, i] .-= ds
end
= ChainRulesCore.Tangent{typeof(d)}(; r=r̄)
return NoTangent(), d̄, @thunk(project_x(x̄)), @thunk(project_y(ȳ))
Expand Down

0 comments on commit 16a9828

Please sign in to comment.