Skip to content

Commit

Permalink
Add distances, divergences, and deviations from StatsBase
Browse files Browse the repository at this point in the history
  • Loading branch information
ararslan committed Oct 27, 2016
1 parent f6cee36 commit 94bb3e2
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 9 deletions.
13 changes: 12 additions & 1 deletion src/Distances.jl
Expand Up @@ -32,6 +32,7 @@ export
CorrDist,
ChiSqDist,
KLDivergence,
GenKLDivergence,
JSDivergence,
RenyiDivergence,
SpanNormDist,
Expand All @@ -46,6 +47,11 @@ export
BhattacharyyaDist,
HellingerDist,

MeanAbsDeviation,
MeanSqDeviation,
RMSDeviation,
NormRMSDeviation,

# convenient functions
euclidean,
sqeuclidean,
Expand All @@ -61,6 +67,7 @@ export
corr_dist,
chisq_dist,
kl_divergence,
gkl_divergence,
js_divergence,
renyi_divergence,
spannorm_dist,
Expand All @@ -73,7 +80,11 @@ export
sqmahalanobis,
mahalanobis,
bhattacharyya,
hellinger
hellinger,

msd,
rmsd,
nrmsd

include("common.jl")
include("generic.jl")
Expand Down
36 changes: 33 additions & 3 deletions src/metrics.jl
Expand Up @@ -24,6 +24,7 @@ type CorrDist <: SemiMetric end

type ChiSqDist <: SemiMetric end
type KLDivergence <: PreMetric end
type GenKLDivergence <: PreMetric end

immutable RenyiDivergence{T <: Real} <: PreMetric
p::T # order of power mean (order of divergence - 1)
Expand All @@ -37,10 +38,10 @@ immutable RenyiDivergence{T <: Real} <: PreMetric
is_zero = q zero(T)
is_one = q one(T)
is_inf = isinf(q)

# Only positive Rényi divergences are defined
!is_zero && q < zero(T) && throw(ArgumentError("Order of Rényi divergence not legal, $(q) < 0."))

new(q - 1, !(is_zero || is_one || is_inf), is_zero, is_one, is_inf)
end
end
Expand All @@ -50,8 +51,15 @@ type JSDivergence <: SemiMetric end

type SpanNormDist <: SemiMetric end

# Deviations are handled separately from the other distances/divergences and
# are excluded from `UnionMetrics`
type MeanAbsDeviation <: Metric end
type MeanSqDeviation <: SemiMetric end
type RMSDeviation <: Metric end
type NormRMSDeviation <: Metric end


typealias UnionMetrics Union{Euclidean, SqEuclidean, Chebyshev, Cityblock, Minkowski, Hamming, Jaccard, RogersTanimoto, CosineDist, CorrDist, ChiSqDist, KLDivergence, RenyiDivergence, JSDivergence, SpanNormDist}
typealias UnionMetrics Union{Euclidean, SqEuclidean, Chebyshev, Cityblock, Minkowski, Hamming, Jaccard, RogersTanimoto, CosineDist, CorrDist, ChiSqDist, KLDivergence, RenyiDivergence, JSDivergence, SpanNormDist, GenKLDivergence}

###########################################################
#
Expand Down Expand Up @@ -163,6 +171,11 @@ chisq_dist(a::AbstractArray, b::AbstractArray) = evaluate(ChiSqDist(), a, b)
@inline eval_reduce(::KLDivergence, s1, s2) = s1 + s2
kl_divergence(a::AbstractArray, b::AbstractArray) = evaluate(KLDivergence(), a, b)

# GenKLDivergence
@inline eval_op(::GenKLDivergence, ai, bi) = ai > 0 ? ai * log(ai / bi) - ai + bi : bi
@inline eval_reduce(::GenKLDivergence, s1, s2) = s1 + s2
gkl_divergence(a::AbstractArray, b::AbstractArray) = evaluate(GenKLDivergence(), a, b)

# RenyiDivergence
function eval_start{T<:AbstractFloat}(::RenyiDivergence, a::AbstractArray{T}, b::AbstractArray{T})
zero(T), zero(T)
Expand Down Expand Up @@ -289,6 +302,23 @@ end
end
rogerstanimoto{T <: Bool}(a::AbstractArray{T}, b::AbstractArray{T}) = evaluate(RogersTanimoto(), a, b)

# Deviations

evaluate(::MeanAbsDeviation, a, b) = cityblock(a, b) / length(a)

evaluate(::MeanSqDeviation, a, b) = sqeuclidean(a, b) / length(a)
msd(a, b) = evaluate(MeanSqDeviation(), a, b)

evaluate(::RMSDeviation, a, b) = sqrt(evaluate(MeanSqDeviation(), a, b))
rmsd(a, b) = evaluate(RMSDeviation(), a, b)

function evaluate(::NormRMSDeviation, a, b)
amin, amax = extrema(a)
return evaluate(RMSDeviation(), a, b) / (amax - amin)
end
nrmsd(a, b) = evaluate(NormRMSDeviation(), a, b)


###########################################################
#
# Special method
Expand Down
15 changes: 10 additions & 5 deletions test/test_dists.jl
Expand Up @@ -63,7 +63,12 @@ for (x, y) in (([4., 5., 6., 7.], [3., 9., 8., 1.]),
@test spannorm_dist(x, x) == 0.
@test spannorm_dist(x, y) == maximum(x - vec(y)) - minimum(x - vec(y))

@test gkl_divergence(x, y) sum(i -> x[i] * log(x[i] / y[i]) - x[i] + y[i], 1:length(x))

@test evaluate(MeanAbsDeviation(), x, y) mean(Float64[abs(x[i] - y[i]) for i in 1:length(x)])
@test msd(x, y) mean(Float64[abs2(x[i] - y[i]) for i in 1:length(x)])
@test rmsd(x, y) sqrt(msd(x, y))
@test nrmsd(x, y) sqrt(msd(x, y)) / (maximum(x) - minimum(x))

w = ones(4)
@test sqeuclidean(x, y) wsqeuclidean(x, y, w)
Expand Down Expand Up @@ -101,7 +106,7 @@ w = rand(size(a))
p = r = rand(12)
p[p .< 0.3] = 0.0
scale = sum(p) / sum(r)
r /= sum(r)
r /= sum(r)
p /= sum(p)
q = rand(12)
q /= sum(q)
Expand All @@ -119,14 +124,14 @@ end
@test renyi_divergence(p, p, rand()) 0
@test renyi_divergence(p, p, 1.0 + rand()) 0
@test renyi_divergence(p, p, Inf) 0
@test renyi_divergence(p, r, 0) -log(scale)
@test renyi_divergence(p, r, 1) -log(scale)
@test renyi_divergence(p, r, rand()) -log(scale)
@test renyi_divergence(p, r, 0) -log(scale)
@test renyi_divergence(p, r, 1) -log(scale)
@test renyi_divergence(p, r, rand()) -log(scale)
@test renyi_divergence(p, r, Inf) -log(scale)
@test isinf(renyi_divergence([0.0, 0.5, 0.5], [0.0, 1.0, 0.0], Inf))
@test renyi_divergence([0.0, 1.0, 0.0], [0.0, 0.5, 0.5], Inf) log(2.0)
@test renyi_divergence(p, q, 1) kl_divergence(p, q)

pm = (p + q) / 2
jsv = kl_divergence(p, pm) / 2 + kl_divergence(q, pm) / 2
@test js_divergence(p, p) 0.0
Expand Down

0 comments on commit 94bb3e2

Please sign in to comment.