Skip to content

Commit

Permalink
Merge 3638892 into e9c2bb2
Browse files Browse the repository at this point in the history
  • Loading branch information
cossio committed May 19, 2020
2 parents e9c2bb2 + 3638892 commit 2b07bce
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 15 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
*.jl.cov
*.jl.mem

Manifest.toml
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "StatsFuns"
uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
version = "0.9.4"
version = "0.9.5"

[deps]
Rmath = "79098fc4-a85e-5d69-aa6a-4863f24498fa"
Expand Down
1 change: 1 addition & 0 deletions src/StatsFuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ export
log1pmx, # log(1 + x) - x
logmxp1, # log(x) - x + 1
logaddexp, # log(exp(x) + exp(y))
logsubexp, # log(abs(e^x - e^y))
logsumexp, # log(sum(exp(x)))
softmax, # exp(x_i) / sum(exp(x)), for i
softmax!, # inplace softmax
Expand Down
31 changes: 20 additions & 11 deletions src/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,20 @@ julia> StatsFuns.xlogx(0)
0.0
```
"""
xlogx(x::Real) = x > zero(x) ? x * log(x) : zero(log(x))
function xlogx(x)
result = x * log(x)
ifelse(iszero(x), zero(result), result)
end

"""
xlogy(x::Real, y::Real)
Return `x * log(y)` for `y > 0` with correct limit at `x = 0`.
"""
xlogy(x::T, y::T) where {T<:Real} = x > zero(T) ? x * log(y) : zero(log(x))
xlogy(x::Real, y::Real) = xlogy(promote(x, y)...)
function xlogy(x, y)
result = x * log(y)
ifelse(iszero(x), zero(result), result)
end

# The following bounds are precomputed versions of the following abstract
# function, but the implicit interface for AbstractFloat doesn't uniformly
Expand Down Expand Up @@ -196,16 +201,20 @@ end
Return `log(exp(x) + exp(y))`, avoiding intermediate overflow/undeflow, and handling non-finite values.
"""
function logaddexp(x::T, y::T) where T<:Real
# x or y is NaN => NaN
# x or y is +Inf => +Inf
# x or y is -Inf => other value
isfinite(x) && isfinite(y) || return max(x,y)
x > y ? x + log1p(exp(y - x)) : y + log1p(exp(x - y))
function logaddexp(x::Real, y::Real)
# ensure Δ = 0 if x = y = Inf
Δ = ifelse(x == y, zero(x - y), abs(x - y))
max(x, y) + log1pexp(-Δ)
end
logaddexp(x::Real, y::Real) = logaddexp(promote(x, y)...)

Base.@deprecate logsumexp(x::Real, y::Real) logaddexp(x,y)
Base.@deprecate logsumexp(x::Real, y::Real) logaddexp(x, y)

"""
logsubexp(x, y)
Return `log(abs(e^x - e^y))`, preserving numerical accuracy.
"""
logsubexp(x::Real, y::Real) = max(x, y) + log1mexp(-abs(x - y))

"""
logsumexp(X)
Expand Down
26 changes: 23 additions & 3 deletions test/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ using StatsFuns, Test
@testset "xlogx & xlogy" begin
@test iszero(xlogx(0))
@test xlogx(2) 2.0 * log(2.0)
@test_throws DomainError xlogx(-1)

@test iszero(xlogy(0, 1))
@test xlogy(2, 3) 2.0 * log(3.0)
@test_throws DomainError xlogy(1, -1)
end

@testset "logistic & logit" begin
Expand Down Expand Up @@ -88,15 +90,33 @@ end
([-Inf, Inf], Inf),
([-Inf, 9.0], 9.0),
([Inf, 9.0], Inf),
([NaN, 9.0], NaN), # NaN propagation
([NaN, Inf], NaN), # NaN propagation
([NaN, -Inf], NaN), # NaN propagation
([0, 0], log(2.0))] # non-float arguments
for (arguments, result) in cases
@test logaddexp(arguments...) result
@test logsumexp(arguments) result
end
end

@test isnan(logsubexp(Inf, Inf))
@test isnan(logsubexp(-Inf, -Inf))
@test logsubexp(Inf, 9.0) Inf
@test logsubexp(-Inf, 9.0) 9.0
@test logsubexp(1f2, 1f2) -Inf32
@test logsubexp(0, 0) -Inf
@test logsubexp(3, 2) 2.541324854612918108978

# NaN propagation
@test isnan(logaddexp(NaN, 9.0))
@test isnan(logaddexp(NaN, Inf))
@test isnan(logaddexp(NaN, -Inf))

@test isnan(logsubexp(NaN, 9.0))
@test isnan(logsubexp(NaN, Inf))
@test isnan(logsubexp(NaN, -Inf))

@test isnan(logsumexp([NaN, 9.0]))
@test isnan(logsumexp([NaN, Inf]))
@test isnan(logsumexp([NaN, -Inf]))
end

@testset "softmax" begin
Expand Down

0 comments on commit 2b07bce

Please sign in to comment.