Skip to content

Commit

Permalink
Improve type consistency of special functions
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox authored and ararslan committed Feb 10, 2017
1 parent 67301bc commit 9b7abbe
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 59 deletions.
163 changes: 104 additions & 59 deletions src/gamma.jl
Expand Up @@ -2,6 +2,8 @@

using Base.MPFR: ROUNDING_MODE, big_ln2

typealias ComplexOrReal{T} Union{T,Complex{T}}

# Bernoulli numbers B_{2k}, using tabulated numerators and denominators from
# the online encyclopedia of integer sequences. (They actually have data
# up to k=249, but we stop here at k=20.) Used for generating the polygamma
Expand All @@ -15,7 +17,7 @@ using Base.MPFR: ROUNDING_MODE, big_ln2
Compute the digamma function of `x` (the logarithmic derivative of `gamma(x)`).
"""
function digamma(z::Union{Float64,Complex{Float64}})
function digamma(z::ComplexOrReal{Float64})
# Based on eq. (12), without looking at the accompanying source
# code, of: K. S. Kölbig, "Programs for computing the logarithm of
# the gamma function, and the digamma function, for complex
Expand Down Expand Up @@ -55,7 +57,7 @@ end
Compute the trigamma function of `x` (the logarithmic second derivative of `gamma(x)`).
"""
function trigamma(z::Union{Float64,Complex{Float64}})
function trigamma(z::ComplexOrReal{Float64})
# via the derivative of the Kölbig digamma formulation
x = real(z)
if x <= 0 # reflection formula
Expand Down Expand Up @@ -215,8 +217,7 @@ this definition is equivalent to the Hurwitz zeta function
``\\sum_{k=0}^\\infty (k+z)^{-s}``. For ``z=1``, it yields
the Riemann zeta function ``\\zeta(s)``.
"""
function zeta(s::Union{Int,Float64,Complex{Float64}},
z::Union{Float64,Complex{Float64}})
function zeta(s::ComplexOrReal{Float64}, z::ComplexOrReal{Float64})
ζ = zero(promote_type(typeof(s), typeof(z)))

(z == 1 || z == 0) && return oftype(ζ, zeta(s))
Expand Down Expand Up @@ -265,7 +266,8 @@ function zeta(s::Union{Int,Float64,Complex{Float64}},
minus_z = -z
ζ += pow_oftype(ζ, minus_z, minus_s) # ν = 0 term
if xf != z
ζ += pow_oftype(ζ, z - nx, minus_s) # real(z - nx) > 0, so use correct branch cut
ζ += pow_oftype(ζ, z - nx, minus_s)
# real(z - nx) > 0, so use correct branch cut
# otherwise, if xf==z, then the definition skips this term
end
# do loop in different order, depending on the sign of s,
Expand Down Expand Up @@ -318,10 +320,10 @@ end
"""
polygamma(m, x)
Compute the polygamma function of order `m` of argument `x` (the `(m+1)th` derivative of the
Compute the polygamma function of order `m` of argument `x` (the `(m+1)`th derivative of the
logarithm of `gamma(x)`)
"""
function polygamma(m::Integer, z::Union{Float64,Complex{Float64}})
function polygamma(m::Integer, z::ComplexOrReal{Float64})
m == 0 && return digamma(z)
m == 1 && return trigamma(z)

Expand All @@ -339,7 +341,9 @@ function polygamma(m::Integer, z::Union{Float64,Complex{Float64}})
# constants. We throw a DomainError() since the definition is unclear.
real(m) < 0 && throw(DomainError())

s = m+1
s = Float64(m+1)
# It is safe to convert any integer (including `BigInt`) to a float here
# as underflow occurs before precision issues.
if real(z) <= 0 # reflection formula
(zeta(s, 1-z) + signflip(m, cotderiv(m,z))) * (-gamma(s))
else
Expand All @@ -355,32 +359,15 @@ f32(z::Complex) = Complex64(z)
f16(x::Real) = Float16(x)
f16(z::Complex) = Complex32(z)

# If we really cared about single precision, we could make a faster
# Float32 version by truncating the Stirling series at a smaller cutoff.
for (f,T) in ((:f32,Float32),(:f16,Float16))
@eval begin
zeta(s::Integer, z::Union{$T,Complex{$T}}) = $f(zeta(Int(s), f64(z)))
zeta(s::Union{Float64,Complex128}, z::Union{$T,Complex{$T}}) = zeta(s, f64(z))
zeta(s::Number, z::Union{$T,Complex{$T}}) = $f(zeta(f64(s), f64(z)))
polygamma(m::Integer, z::Union{$T,Complex{$T}}) = $f(polygamma(Int(m), f64(z)))
digamma(z::Union{$T,Complex{$T}}) = $f(digamma(f64(z)))
trigamma(z::Union{$T,Complex{$T}}) = $f(trigamma(f64(z)))
end
end

zeta(s::Integer, z::Number) = zeta(Int(s), f64(z))
zeta(s::Number, z::Number) = zeta(f64(s), f64(z))
for f in (:digamma, :trigamma)
@eval begin
$f(z::Number) = $f(f64(z))
end
end
polygamma(m::Integer, z::Number) = polygamma(m, f64(z))
"""
invdigamma(x)
# Inverse digamma function:
# Implementation of fixed point algorithm described in
# "Estimating a Dirichlet distribution" by Thomas P. Minka, 2000
Compute the inverse [`digamma`](@ref) function of `x`.
"""
function invdigamma(y::Float64)
# Implementation of fixed point algorithm described in
# "Estimating a Dirichlet distribution" by Thomas P. Minka, 2000

# Closed form initial estimates
if y >= -2.22
x_old = exp(y) + 0.5
Expand All @@ -402,18 +389,16 @@ function invdigamma(y::Float64)

return x_new
end
invdigamma(x::Float32) = Float32(invdigamma(Float64(x)))

"""
invdigamma(x)
zeta(s)
Compute the inverse [`digamma`](@ref) function of `x`.
Riemann zeta function ``\\zeta(s)``.
"""
invdigamma(x::Real) = invdigamma(Float64(x))
function zeta(s::ComplexOrReal{Float64})
# Riemann zeta function; algorithm is based on specializing the Hurwitz
# zeta function above for z==1.

# Riemann zeta function; algorithm is based on specializing the Hurwitz
# zeta function above for z==1.
function zeta(s::Union{Float64,Complex{Float64}})
# blows up to ±Inf, but get correct sign of imaginary zero
s == 1 && return NaN + zero(s) * imag(s)

Expand Down Expand Up @@ -458,23 +443,18 @@ function zeta(s::Union{Float64,Complex{Float64}})
return ζ
end

zeta(x::Integer) = zeta(Float64(x))
zeta(x::Real) = oftype(float(x),zeta(Float64(x)))

"""
zeta(s)
Riemann zeta function ``\\zeta(s)``.
"""
zeta(z::Complex) = oftype(float(z),zeta(Complex128(z)))

function zeta(x::BigFloat)
z = BigFloat()
ccall((:mpfr_zeta, :libmpfr), Int32, (Ptr{BigFloat}, Ptr{BigFloat}, Int32), &z, &x, ROUNDING_MODE[])
return z
end

function eta(z::Union{Float64,Complex{Float64}})
"""
eta(x)
Dirichlet eta function ``\\eta(s) = \\sum^\\infty_{n=1}(-1)^{n-1}/n^{s}``.
"""
function eta(z::ComplexOrReal{Float64})
δz = 1 - z
if abs(real(δz)) + abs(imag(δz)) < 7e-3 # Taylor expand around z==1
return 0.6931471805599453094172321214581765 *
Expand All @@ -488,17 +468,82 @@ function eta(z::Union{Float64,Complex{Float64}})
return -zeta(z) * expm1(0.6931471805599453094172321214581765*δz)
end
end
eta(x::Integer) = eta(Float64(x))
eta(x::Real) = oftype(float(x),eta(Float64(x)))

"""
eta(x)
Dirichlet eta function ``\\eta(s) = \\sum^\\infty_{n=1}(-1)^{n-1}/n^{s}``.
"""
eta(z::Complex) = oftype(float(z),eta(Complex128(z)))

function eta(x::BigFloat)
x == 1 && return big_ln2()
return -zeta(x) * expm1(big_ln2()*(1-x))
end

# Converting types that we can convert, and not ones we can not
# Float16, and Float32 and their Complex equivalents can be converted to Float64
# and results converted back.
# Otherwise, we need to make things use their own `float` converting methods
# and in those cases, we do not convert back either as we assume
# they also implement their own versions of the functions, with the correct return types.
# This is the case for BitIntegers (which become `Float64` when `float`ed).
# Otherwise, if they do not implement their version of the functions we
# manually throw a `MethodError`.
# This case occurs, when calling `float` on a type does not change its type,
# as it is already a `float`, and would have hit own method, if one had existed.


# If we really cared about single precision, we could make a faster
# Float32 version by truncating the Stirling series at a smaller cutoff.
# and if we really cared about half precision, we could make a faster
# Float16 version, by using a precomputed table look-up.


for T in (Float16, Float32, Float64)
@eval f64(x::Complex{$T}) = Complex128(x)
@eval f64(x::$T) = Float64(x)
end


for f in (:digamma, :trigamma, :zeta, :eta, :invdigamma)
@eval begin
function $f(z::Union{ComplexOrReal{Float16}, ComplexOrReal{Float32}})
oftype(z, $f(f64(z)))
end

function $f(z::Number)
x = float(z)
typeof(x) === typeof(z) && throw(MethodError($f, (z,)))
# There is nothing to fallback to, as this didn't change the argument types
$f(x)
end
end
end


for T1 in (Float16, Float32, Float64), T2 in (Float16, Float32, Float64)
(T1 == T2 == Float64) && continue # Avoid redefining base definition

@eval function zeta(s::ComplexOrReal{$T1}, z::ComplexOrReal{$T2})
ζ = zeta(f64(s), f64(z))
convert(promote_type(typeof(s), typeof(z)), ζ)
end
end


function zeta(s::Number, z::Number)
t = float(s)
x = float(z)
if typeof(t) === typeof(s) && typeof(x) === typeof(z)
# There is nothing to fallback to, since this didn't work
throw(MethodError(zeta,(s,z)))
end
zeta(t, x)
end


function polygamma(m::Integer, z::Union{ComplexOrReal{Float16}, ComplexOrReal{Float32}})
oftype(z, polygamma(m, f64(z)))
end


function polygamma(m::Integer, z::Number)
x = float(z)
typeof(x) === typeof(z) && throw(MethodError(polygamma, (m,z)))
# There is nothing to fallback to, since this didn't work
polygamma(m, x)
end
32 changes: 32 additions & 0 deletions test/runtests.jl
Expand Up @@ -454,3 +454,35 @@ end
@test typeof(SF.erfc(a)) == BigFloat
end
end

@testset "Base Julia issue #17474" begin
@test SF.f64(complex(1f0,1f0)) == complex(1.0, 1.0)
@test SF.f64(1f0) == 1.0

@test typeof(SF.eta(big"2")) == BigFloat
@test typeof(SF.zeta(big"2")) == BigFloat
@test typeof(SF.digamma(big"2")) == BigFloat

@test_throws MethodError SF.trigamma(big"2")
@test_throws MethodError SF.trigamma(big"2.0")
@test_throws MethodError SF.invdigamma(big"2")
@test_throws MethodError SF.invdigamma(big"2.0")

@test_throws MethodError SF.eta(Complex(big"2"))
@test_throws MethodError SF.eta(Complex(big"2.0"))
@test_throws MethodError SF.zeta(Complex(big"2"))
@test_throws MethodError SF.zeta(Complex(big"2.0"))
@test_throws MethodError SF.zeta(1.0,big"2")
@test_throws MethodError SF.zeta(1.0,big"2.0")
@test_throws MethodError SF.zeta(big"1.0",2.0)
@test_throws MethodError SF.zeta(big"1",2.0)


@test typeof(SF.polygamma(3, 0x2)) == Float64
@test typeof(SF.polygamma(big"3", 2f0)) == Float32
@test typeof(SF.zeta(1, 2.0)) == Float64
@test typeof(SF.zeta(1, 2f0)) == Float64 # BitIntegers result in Float64 returns
@test typeof(SF.zeta(2f0, complex(2f0,0f0))) == Complex{Float32}
@test typeof(SF.zeta(complex(1,1), 2f0)) == Complex{Float64}
@test typeof(SF.zeta(complex(1), 2.0)) == Complex{Float64}
end

0 comments on commit 9b7abbe

Please sign in to comment.