diff --git a/src/chainrules.jl b/src/chainrules.jl index 1f1a9a81b..549373876 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -111,10 +111,9 @@ end function ChainRulesCore.rrule(s::Sinus, x::AbstractVector, y::AbstractVector) d = x - y - sind = sinpi.(d) - abs2_sind_r = abs2.(sind) ./ s.r .^ 2 + abs2_sind_r = (sinpi.(d) ./ s.r) .^ 2 val = sum(abs2_sind_r) - gradx = twoπ .* cospi.(d) .* sind ./ s.r .^ 2 + gradx = π .* sinpi.(2 .* d) ./ s.r .^ 2 function evaluate_pullback(Δ::Any) r̄ = -2Δ .* abs2_sind_r ./ s.r s̄ = ChainRulesCore.Tangent{typeof(s)}(; r=r̄) @@ -136,7 +135,7 @@ function ChainRulesCore.rrule( for j in 1:n, i in 1:n xi = view(x, i, :) xj = view(x, j, :) - ds = twoπ .* Δ[i, j] .* sinpi.(xi .- xj) .* cospi.(xi .- xj) ./ d.r .^ 2 + ds = π .* Δ[i, j] .* sinpi.(2 .* (xi .- xj)) ./ d.r .^ 2 r̄ .-= 2 .* Δ[i, j] .* sinpi.(xi .- xj) .^ 2 ./ d.r .^ 3 x̄[i, :] += ds x̄[j, :] -= ds @@ -147,8 +146,8 @@ function ChainRulesCore.rrule( xj = view(x, :, j) ds = twoπ .* Δ[i, j] .* sinpi.(xi .- xj) .* cospi.(xi .- xj) ./ d.r .^ 2 r̄ .-= 2 .* Δ[i, j] .* sinpi.(xi .- xj) .^ 2 ./ d.r .^ 3 - x̄[:, i] += ds - x̄[:, j] -= ds + x̄[:, i] .+= ds + x̄[:, j] .-= ds end end d̄ = ChainRulesCore.Tangent{typeof(d)}(; r=r̄) @@ -173,19 +172,19 @@ function ChainRulesCore.rrule( for j in 1:m, i in 1:n xi = view(x, i, :) yj = view(y, j, :) - ds = twoπ .* Δ[i, j] .* sinpi.(xi .- yj) .* cospi.(xi .- yj) ./ d.r .^ 2 + ds = π .* Δ[i, j] .* sinpi.(2 .* (xi .- yj)) ./ d.r .^ 2 r̄ .-= 2 .* Δ[i, j] .* sinpi.(xi .- yj) .^ 2 ./ d.r .^ 3 - x̄[i, :] += ds - ȳ[j, :] -= ds + x̄[i, :] .+= ds + ȳ[j, :] .-= ds end elseif dims == 2 for j in 1:m, i in 1:n xi = view(x, :, i) yj = view(y, :, j) - ds = twoπ .* Δ[i, j] .* sinpi.(xi .- yj) .* cospi.(xi .- yj) ./ d.r .^ 2 + ds = π .* Δ[i, j] .* sinpi.(2 .* (xi .- yj)) ./ d.r .^ 2 r̄ .-= 2 .* Δ[i, j] .* sinpi.(xi .- yj) .^ 2 ./ d.r .^ 3 - x̄[:, i] += ds - ȳ[:, j] -= ds + x̄[:, i] .+= ds + ȳ[:, j] .-= ds end end d̄ = ChainRulesCore.Tangent{typeof(d)}(; r=r̄) @@ -208,10 +207,10 @@ function ChainRulesCore.rrule( for i in 1:n xi = view(x, :, i) yi = view(y, :, i) - ds = twoπ .* Δ[i] .* sinpi.(xi .- yi) .* cospi.(xi .- yi) ./ d.r .^ 2 + ds = π .* Δ[i] .* sinpi.(2 .* (xi .- yi)) ./ d.r .^ 2 r̄ .-= 2 .* Δ[i] .* sinpi.(xi .- yi) .^ 2 ./ d.r .^ 3 - x̄[:, i] += ds - ȳ[:, i] -= ds + x̄[:, i] .+= ds + ȳ[:, i] .-= ds end d̄ = ChainRulesCore.Tangent{typeof(d)}(; r=r̄) return NoTangent(), d̄, @thunk(project_x(x̄)), @thunk(project_y(ȳ))