Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Half-integer order for besselk #25

Closed
cgeoga opened this issue Jul 26, 2022 · 12 comments · Fixed by #47
Closed

Half-integer order for besselk #25

cgeoga opened this issue Jul 26, 2022 · 12 comments · Fixed by #47

Comments

@cgeoga
Copy link
Contributor

cgeoga commented Jul 26, 2022

What do you think about something like this to handle half-integer orders of besselk? It uses the fact that the asymptotic expansion terminates and is exact for half-integer orders.

function _besselk_vhalfint(v, x) 
  v<zero(v)  && return _besselk_vhalfint(-v, x)  
  (v == 1/2) && return sqrt(pi/(2*x))*exp(-x)
  (v == 3/2) && return sqrt(pi/(2*x))*exp(-x)*(one(x) + inv(x))
  b0 = sqrt(pi/(2*x))*exp(-x)                    # v=1/2
  b1 = sqrt(pi/(2*x))*exp(-x)*(one(x) + inv(x))  # v=3/2
  _v = 3/2
  twodx = 2/x
  b2 = twodx*_v*b1 + b0
  for _ in 1:Int(floor(v-2)) # standard up-recurrence
    _v += 1
    b0  = b1
    b1  = b2
    b2  = muladd(twodx*_v, b1, b0)
  end
  b2
end

You can't pass AD through this w.r.t. v, of course, and this was in my experience the hardest derivative to get and is why I ended up coding up the very expensive Temme routine, which is by far the slowest routine for us. But if you wanted something in the mean time, this has to be pretty competitively fast except for outrageously large v.

@heltonmc
Copy link
Member

I'm not for sure if there is another function I'm missing but that allocates because it is just misspelled (missing an _ in front) so it can't infer the correct type of that return. Using @code_warntype is always helpful for those kind of weird things.

One thing I would be concerned about is just the amount of branches. For the v<zero(v) I would probably just say v = abs(v) since K_{-v}(x) = K_{v}(x)?

Though there should be a way to eliminate those branches for constants like that... see comment JuliaMath/SpecialFunctions.jl#178 (comment) and related notes

@cgeoga
Copy link
Contributor Author

cgeoga commented Jul 26, 2022

Okay, well that's embarrassing about the _. I actually did look at the output of @code_warntype, but I did not understand it well enough to see that that was the problem. Oof. I've edited the top post so that code works correctly on a copy+paste. Also, makes sense about abs(v), although -v would give the same output in this case.

With regard to the branches, I'm not sure how they could really be removed and if there's something in that issue explaining how I don't think I understand it. Even with manual methods, how would you eventually avoid somehow checking if v is one of the two base cases?

In any case, I know you're working on bessely. Just bringing this up because stat people love half-order arguments and so a special branch like this that uses the extra fast option would probably be appealing.

@oscardssmith
Copy link
Member

I believe it should be something like the following (I probably have the details slightly wrong). The key is that you use the for loop running 0 times instead of the if statement.

function _besselk_vhalfint(v, x) 
  v = abs(v)
  invx = inv(x)
  b0 = b1 = sqrt(invx*(pi/2))*exp(-x) 
  twodx = 2*invx
  for _v in (1/2) : 1 : (v-1)
    b0, b1 = b1, muladd(b1, twodx*_v, b0)
  end
  b1
end

@heltonmc
Copy link
Member

Haha ya I was about to comment the same thing.... You can completely avoid these branches.

function besslk_halfint(nu, x)
    nu = abs(nu)
    k0 = sqrt(pi/(2*x))*exp(-x) 
    k1 = k0*(one(x) + inv(x))
    
    k2 = k1
    x2 = 2 / x
    arr = range(start=1.5, stop=nu, step=1)
    for n in arr
        a = x2 * n
        k2 = muladd(a, k1, k0)
        k0 = k1
        k1 = k2
    end
    return k0
end

@oscardssmith
Copy link
Member

oh, and your version is even right!

@cgeoga
Copy link
Contributor Author

cgeoga commented Jul 26, 2022

Oh wow, look at that! That's very clever to make k0 the return object and loop over that range. It's impressive that the compiler can optimize that so well that it beats the branches. TIL!

@heltonmc
Copy link
Member

And you will probably want to combine Oscar's version with mine to avoid the excess divisions. You'll also want to make sqrt(pi/2) a constant then you can solve this with a single division. I've found the branches hard to measure in microbenchmarks until you piece together the full function. Often times when using @benchmark on the individual functions it will completely eliminate those branches that it may not be able to within the main function.

In regards to the constant-propagation that was mentioned in the other thread. I haven't looked into that, but I think @oscardssmith would better be able to answer that...

@oscardssmith
Copy link
Member

branches are expensive. The for loop is 1 branch per iteration, so this version is strictly better. What about constant-prop?

@heltonmc
Copy link
Member

JuliaMath/SpecialFunctions.jl#178 (comment) this comment is what we are referring too

@cgeoga
Copy link
Contributor Author

cgeoga commented Jul 26, 2022

Amazing, thank you both. I'll tinker with this and try to come back with a refined version.

@cgeoga
Copy link
Contributor Author

cgeoga commented Jul 27, 2022

Okay, so here's a weird one: when I use those fancier range iterators it just destroys performance for me on

julia> versioninfo()
Julia Version 1.7.1
Commit ac5cc99908 (2021-12-22 19:35 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: 11th Gen Intel(R) Core(TM) i5-11600K @ 3.90GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-12.0.1 (ORCJIT, icelake-client)
Environment:
  JULIA_HOME = /home/cg/Scratch/julia-1.7.1/bin/

But I think I have a modification that works very well for me and still addresses your helpful points about performance:

const SQRT_PID2 = sqrt(pi/2)

function besk_halfint(v, x)
  v       = abs(v)
  invx    = inv(x)
  b0 = b1 = SQRT_PID2*sqrt(invx)*exp(-x) 
  twodx   = 2*invx
  for _v in (1/2) : 1 : (v-1)
    b0, b1 = b1, muladd(b1, twodx*_v, b0)
  end
  b1
end

function besk_halfint2(v, x)
  v       = abs(v)
  invx    = inv(x)
  b0 = b1 = SQRT_PID2*sqrt(invx)*exp(-x) 
  twodx   = 2*invx
  _v      = convert(eltype(v), 1/2)
  while _v < v
    b0, b1 = b1, muladd(b1, twodx*_v, b0)
    _v    += one(eltype(_v))
  end
  b1
end

With benchmark timings:

v = 21.5
x = 5.5
@benchmark besk_halfint($v, $x)
BenchmarkTools.Trial: 10000 samples with 842 evaluations.
 Range (min … max):  144.914 ns … 220.671 ns  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     146.314 ns               ┊ GC (median):    0.00%
 Time  (mean ± σ):   148.091 ns ±   4.831 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

   ▄▇███▇▆▅▄▄▃▂                       ▁ ▁▂▃▄▄▃▂▁▁               ▂
  ▇█████████████▇▆▆▆▇▇▆▅▇▇▇▆█▇▇▆▇▇▇█▇██████████████▇█▇▇▇▆▅▆▆▇▆▆ █
  145 ns        Histogram: log(frequency) by time        160 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

Using the range also has the same performance properties. With the simple loop in besk_halfint2, I get:

BenchmarkTools.Trial: 10000 samples with 998 evaluations.
 Range (min … max):  17.895 ns … 23.917 ns  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     18.024 ns              ┊ GC (median):    0.00%
 Time  (mean ± σ):   18.158 ns ±  0.415 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

    ▂▆▅▇█▂                                                     
  ▂▅██████▄▂▂▂▂▂▂▁▂▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▃▃▃▃▃▃▂ ▃
  17.9 ns         Histogram: frequency by time        19.3 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

Hard to imagine doing much better than that. But I've said that before and been very wrong as evidenced by a lot of things in this package....so, thoughts?

I should say, I have an Intel i5-11600K CPU, which does have AVX-512, and I've always been afraid that somehow it's actually been killing performance for me that the compiler is trying to use it but doing a bad job or something. So maybe this is a me problem that I should sort out independently of this issue.

@heltonmc
Copy link
Member

I can also say that I got a similar benchmark CPU: Intel(R) Core(TM) i7-8700K CPU @ 3.70GHz which is interesting...

Also, you'll probably want to hardcode those constants and do them in extended precision so you can get correctly rounded results

const SQRT_PID2(::Type{Float64}) = 1.2533141373155003
function besk_halfint2(v::T, x) where T
    v       = abs(v)
    invx    = inv(x)
    b0 = b1 = SQRT_PID2(T)*sqrt(invx)*exp(-x) 
    twodx   = 2*invx
    _v      = T(1/2)
    while _v < v
      b0, b1 = b1, muladd(b1, twodx*_v, b0)
      _v    += one(T)
    end
    b1
end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants