Skip to content

slow forward differentiation with Mooncake #854

@rveltz

Description

@rveltz

Hi,

I would like to optimize a derivative which is roughly 20x slower than the function to differentiate. Hopefully, there is an easy fix.
I hope my code qualifies as MWE ....

using BenchmarkTools
import AssociatedLegendrePolynomials as ALP
using DifferentiationInterface
import Mooncake

struct LegendreBuffer{T1, T2, T3}
    Λ::T1
    Yₗₘ::T2
    coeff::T3
end

function build_buffer(lmax, θ::T) where T
    Λ = ALP.λlm(0:lmax, 0:lmax, θ)
    Yₗₘ = zeros(T, 45)
    coeff = ALP.LegendreSphereCoeff{T}(lmax)
    work = ALP.Work(coeff, Λ, θ)

    LegendreBuffer(Λ, Yₗₘ, work)
end

function f(θ, Buffer, ::Val{lmax}) where {lmax}
    ALP._legendre_impl!(Buffer.coeff, Buffer.Λ, lmax, lmax, cos(θ))
end

legendre_buffer = build_buffer(8, 0.1)
@benchmark f(0.1, legendre_buffer, Val(8))

gives on julia 1.11

julia> @benchmark f(0.1, legendre_buffer, Val(8))
BenchmarkTools.Trial: 10000 samples with 977 evaluations per sample.
 Range (min … max):  67.468 ns … 98.730 ns  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     67.682 ns              ┊ GC (median):    0.00%
 Time  (mean ± σ):   68.795 ns ±  3.132 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

  █▇▂           ▃▂               ▁                            ▁
  ███▇▇▅▅▃▃▄▄▃▃▆███▇▇▇▇▆▆▆▆▆▆▇▆▇██▇▇▇▆▆▅▅▆▇▇▆▆▆▆▅▅▆▆▆▆▄▆▆▆▅▅▅ █
  67.5 ns      Histogram: log(frequency) by time      82.1 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

Now, here is the slow derivative:

backend = AutoMooncakeForward()
prep = prepare_derivative(f, backend, 0.1, Constant(legendre_buffer), Constant(Val(8)))
@benchmark DifferentiationInterface.derivative(f, $prep, $backend, 0.1, Constant(legendre_buffer), Constant(Val(8)))

which gives

julia> @benchmark DifferentiationInterface.derivative(f, $prep, $backend, 0.1, Constant(legendre_buffer), Constant(Val(8)))
BenchmarkTools.Trial: 10000 samples with 9 evaluations per sample.
 Range (min … max):  2.005 μs … 999.606 μs  ┊ GC (min … max): 0.00% … 99.52%
 Time  (median):     2.250 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   2.495 μs ±  11.746 μs  ┊ GC (mean ± σ):  6.46% ±  1.41%

     █▅▃ ▂▁▂▂▂                                                 
  ▁▄▇███▇█████▇█▅▄▃▃▂▃▂▃▂▃▃▃▃▃▂▂▂▂▁▁▁▂▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  2 μs            Histogram: frequency by time        3.51 μs <

 Memory estimate: 5.64 KiB, allocs estimate: 48.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions