Skip to content

Commit

Permalink
Improve accuracy of logistic (#94)
Browse files Browse the repository at this point in the history
Co-authored-by: John Myles White <johnmyleswhite@fb.com>
  • Loading branch information
johnmyleswhite and John Myles White committed May 10, 2020
1 parent 8d9b975 commit e9c2bb2
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
28 changes: 28 additions & 0 deletions src/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,20 @@ Return `x * log(y)` for `y > 0` with correct limit at `x = 0`.
xlogy(x::T, y::T) where {T<:Real} = x > zero(T) ? x * log(y) : zero(log(x))
xlogy(x::Real, y::Real) = xlogy(promote(x, y)...)

# The following bounds are precomputed versions of the following abstract
# function, but the implicit interface for AbstractFloat doesn't uniformly
# enforce that all floating point types implement nextfloat and prevfloat.
# @inline function _logistic_bounds(x::AbstractFloat)
# (
# logit(nextfloat(zero(float(x)))),
# logit(prevfloat(one(float(x)))),
# )
# end

@inline _logistic_bounds(x::Float16) = (Float16(-16.64), Float16(7.625))
@inline _logistic_bounds(x::Float32) = (-103.27893f0, 16.635532f0)
@inline _logistic_bounds(x::Float64) = (-744.4400719213812, 36.7368005696771)

"""
logistic(x::Real)
Expand All @@ -33,6 +47,20 @@ Its inverse is the [`logit`](@ref) function.
"""
logistic(x::Real) = inv(exp(-x) + one(x))

function logistic(x::Union{Float16, Float32, Float64})
e = exp(x)
lower, upper = _logistic_bounds(x)
ifelse(
x < lower,
zero(x),
ifelse(
x > upper,
one(x),
e / (one(x) + e)
)
)
end

"""
logit(p::Real)
Expand Down
4 changes: 4 additions & 0 deletions test/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ end

@testset "logistic & logit" begin
@test logistic(2) 1.0 / (1.0 + exp(-2.0))
@test logistic(-750.0) === 0.0
@test logistic(-740.0) > 0.0
@test logistic(+36.0) < 1.0
@test logistic(+750.0) === 1.0
@test iszero(logit(0.5))
@test logit(logistic(2)) 2.0
end
Expand Down

0 comments on commit e9c2bb2

Please sign in to comment.