Skip to content

make float32 bessels type stable #120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/BesselFunctions/U_polynomials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ function besseljy_debye(v, x::T) where T
p2 = v^2 / vmx

Uk_Jn, Uk_Yn = Uk_poly_Jn(p, v, p2, T(x))
return coef_Jn * Uk_Jn, coef_Yn * Uk_Yn
return T(coef_Jn * Uk_Jn), T(coef_Yn * Uk_Yn)
end

# Cutoffs for besseljy_debye expansions
Expand Down Expand Up @@ -102,7 +102,7 @@ function hankel_debye(v, x::T) where T
p2 = v^2 / vmx

_, Uk_Yn = Uk_poly_Hankel(p*im, v, -p2, T(x))
return coef_Yn * Uk_Yn
return T(coef_Yn * Uk_Yn)
end

# Cutoffs for hankel_debye expansions
Expand Down
8 changes: 4 additions & 4 deletions src/BesselFunctions/asymptotics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ function besseljy_large_argument(v, x::T) where T
s3 = CMS * sα
s4 = CPS * cα

return SQ2O2(S) * (s1 + s2) * b, SQ2O2(S) * (s3 - s4) * b
return T(SQ2O2(S) * (s1 + s2) * b), T(SQ2O2(S) * (s3 - s4) * b)
end

# Float64
Expand Down Expand Up @@ -101,11 +101,11 @@ end
function _α_αp_asymptotic(v, x::Float32)
v, x = Float64(v), Float64(x)
if x > 4*v
return _α_αp_poly_5(v, x)
return Float32.(_α_αp_poly_5(v, x))
elseif x > 1.8*v
return _α_αp_poly_10(v, x)
return Float32.(_α_αp_poly_10(v, x))
else
return _α_αp_poly_30(v, x)
return Float32.(_α_αp_poly_30(v, x))
end
end

Expand Down
6 changes: 3 additions & 3 deletions src/BesselFunctions/bessely.jl
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ function bessely_power_series(v, x::T) where T
out2 = zero(S)
a = (x/2)^v
# check for underflow and return limit for small arguments
iszero(a) && return (-T(Inf), a)
iszero(a) && return (-T(Inf), T(a))

b = inv(a)
a /= gamma(v + one(S))
Expand All @@ -434,7 +434,7 @@ function bessely_power_series(v, x::T) where T
b *= -inv((-v + i + one(S)) * (i + one(S))) * t2
end
s, c = sincospi(v)
return (out*c - out2) / s, out
return T((out*c - out2) / s), T(out)
end
bessely_series_cutoff(v, x::Float64) = (x < 7.0) || v > 1.35*x - 4.5
bessely_series_cutoff(v, x::Float32) = (x < 21.0f0) || v > 1.38f0*x - 12.5f0
Expand Down Expand Up @@ -484,7 +484,7 @@ function bessely_chebyshev_low_orders(v, x)
x1 = 2*(x - 6)/13 - 1
v1 = v - 1
v2 = v
a = clenshaw_chebyshev.(x1, bessely_cheb_weights)
a = clenshaw_chebyshev.(x1, map(Base.Fix1(map, typeof(x1)),bessely_cheb_weights))
return clenshaw_chebyshev(v1, a), clenshaw_chebyshev(v2, a)
end

Expand Down
2 changes: 1 addition & 1 deletion src/BesselFunctions/hankel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ function besseljy_positive_args(nu::Real, x::T) where T

# at this point x > 19.0 (for Float64) and fairly close to nu
# shift nu down and use the debye expansion for Hankel function (valid x > nu) then use forward recurrence
nu_shift = ceil(nu) - floor(Int, -1.5 + x + Base.Math._approx_cbrt(-411*x)) + 2
nu_shift = ceil(nu) - floor(Int, -3//2 + x + Base.Math._approx_cbrt(-411*x)) + 2
v2 = maximum((nu - nu_shift, modf(nu)[1] + 1))

Hnu = hankel_debye(v2, x)
Expand Down
64 changes: 36 additions & 28 deletions test/hankel_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,43 @@
# here we will test just a few cases of the overall hankel function
# focusing on negative arguments and reflection

v, x = 1.5, 1.3
@test isapprox(hankelh1(v, x), SpecialFunctions.hankelh1(v, x), rtol=2e-13)
@test isapprox(hankelh2(v, x), SpecialFunctions.hankelh2(v, x), rtol=2e-13)
@test isapprox(besselh(v, 1, x), SpecialFunctions.besselh(v, 1, x), rtol=2e-13)
@test isapprox(besselh(v, 2, x), SpecialFunctions.besselh(v, 2, x), rtol=2e-13)
@testset "$T" for T in (Float32, Float64)
rtol = T == Float64 ? 2e-13 : 2e-6
v, x = T(1.5), T(1.3)
@test isapprox(hankelh1(v, x), SpecialFunctions.hankelh1(v, x); rtol)
@test isapprox(hankelh2(v, x), SpecialFunctions.hankelh2(v, x); rtol)
@test isapprox(besselh(v, 1, x), SpecialFunctions.besselh(v, 1, x); rtol)
@test isapprox(besselh(v, 2, x), SpecialFunctions.besselh(v, 2, x); rtol)
@inferred besselh(v, 2, x)

v, x = T(-2.6), T(9.2)
@test isapprox(hankelh1(v, x), SpecialFunctions.hankelh1(v, x); rtol)
@test isapprox(hankelh2(v, x), SpecialFunctions.hankelh2(v, x); rtol)
@test isapprox(besselh(v, 1, x), SpecialFunctions.besselh(v, 1, x); rtol)
@test isapprox(besselh(v, 2, x), SpecialFunctions.besselh(v, 2, x); rtol)
@inferred besselh(v, 2, x)

v, x = -2.6, 9.2
@test isapprox(hankelh1(v, x), SpecialFunctions.hankelh1(v, x), rtol=2e-13)
@test isapprox(hankelh2(v, x), SpecialFunctions.hankelh2(v, x), rtol=2e-13)
@test isapprox(besselh(v, 1, x), SpecialFunctions.besselh(v, 1, x), rtol=2e-13)
@test isapprox(besselh(v, 2, x), SpecialFunctions.besselh(v, 2, x), rtol=2e-13)
v, x = T(-4.0), T(11.4)
@test isapprox(hankelh1(v, x), SpecialFunctions.hankelh1(v, x); rtol)
@test isapprox(hankelh2(v, x), SpecialFunctions.hankelh2(v, x); rtol)
@test isapprox(besselh(v, 1, x), SpecialFunctions.besselh(v, 1, x); rtol)
@test isapprox(besselh(v, 2, x), SpecialFunctions.besselh(v, 2, x); rtol)
@inferred besselh(v, 2, x)

v, x = -4.0, 11.4
@test isapprox(hankelh1(v, x), SpecialFunctions.hankelh1(v, x), rtol=2e-13)
@test isapprox(hankelh2(v, x), SpecialFunctions.hankelh2(v, x), rtol=2e-13)
@test isapprox(besselh(v, 1, x), SpecialFunctions.besselh(v, 1, x), rtol=2e-13)
@test isapprox(besselh(v, 2, x), SpecialFunctions.besselh(v, 2, x), rtol=2e-13)
v, x = T(14.3), T(29.4)
@test isapprox(hankelh1(v, x), SpecialFunctions.hankelh1(v, x); rtol)
@test isapprox(hankelh2(v, x), SpecialFunctions.hankelh2(v, x); rtol)
@test isapprox(besselh(v, 1, x), SpecialFunctions.besselh(v, 1, x); rtol)
@test isapprox(besselh(v, 2, x), SpecialFunctions.besselh(v, 2, x); rtol)
@inferred besselh(v, 2, x)

v, x = 14.3, 29.4
@test isapprox(hankelh1(v, x), SpecialFunctions.hankelh1(v, x), rtol=2e-13)
@test isapprox(hankelh2(v, x), SpecialFunctions.hankelh2(v, x), rtol=2e-13)
@test isapprox(besselh(v, 1, x), SpecialFunctions.besselh(v, 1, x), rtol=2e-13)
@test isapprox(besselh(v, 2, x), SpecialFunctions.besselh(v, 2, x), rtol=2e-13)
@test isapprox(hankelh1(1:50, T(10)), SpecialFunctions.hankelh1.(1:50, 10.0); rtol)
@test isapprox(hankelh1(T(0.5):T(25.5), T(15)), SpecialFunctions.hankelh1.(0.5:1:25.5, 15.0); rtol)
@test isapprox(hankelh1(1:50, T(100)), SpecialFunctions.hankelh1.(1:50, 100.0); 2*rtol)
@test isapprox(hankelh2(1:50, T(10)), SpecialFunctions.hankelh2.(1:50, 10.0); rtol)
@inferred hankelh2(1:50, T(10))

@test isapprox(hankelh1(1:50, 10.0), SpecialFunctions.hankelh1.(1:50, 10.0), rtol=2e-13)
@test isapprox(hankelh1(0.5:1:25.5, 15.0), SpecialFunctions.hankelh1.(0.5:1:25.5, 15.0), rtol=2e-13)
@test isapprox(hankelh1(1:50, 100.0), SpecialFunctions.hankelh1.(1:50, 100.0), rtol=2e-13)
@test isapprox(hankelh2(1:50, 10.0), SpecialFunctions.hankelh2.(1:50, 10.0), rtol=2e-13)

#test 2 arg version
@test besselh(v, 1, x) == besselh(v, x)
@test besselh(1:50, 1, 10.0) == besselh(1:50, 10.0)
#test 2 arg version
@test besselh(v, 1, x) == besselh(v, x)
@test besselh(1:50, 1, T(10.0)) == besselh(1:50, T(10.0))
end
Loading