From 94bb3e2a50386c24bf54839d9d9449111e5e72e9 Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Thu, 27 Oct 2016 14:53:47 -0700 Subject: [PATCH] Add distances, divergences, and deviations from StatsBase --- src/Distances.jl | 13 ++++++++++++- src/metrics.jl | 36 +++++++++++++++++++++++++++++++++--- test/test_dists.jl | 15 ++++++++++----- 3 files changed, 55 insertions(+), 9 deletions(-) diff --git a/src/Distances.jl b/src/Distances.jl index 6b4aa42..88d6497 100644 --- a/src/Distances.jl +++ b/src/Distances.jl @@ -32,6 +32,7 @@ export CorrDist, ChiSqDist, KLDivergence, + GenKLDivergence, JSDivergence, RenyiDivergence, SpanNormDist, @@ -46,6 +47,11 @@ export BhattacharyyaDist, HellingerDist, + MeanAbsDeviation, + MeanSqDeviation, + RMSDeviation, + NormRMSDeviation, + # convenient functions euclidean, sqeuclidean, @@ -61,6 +67,7 @@ export corr_dist, chisq_dist, kl_divergence, + gkl_divergence, js_divergence, renyi_divergence, spannorm_dist, @@ -73,7 +80,11 @@ export sqmahalanobis, mahalanobis, bhattacharyya, - hellinger + hellinger, + + msd, + rmsd, + nrmsd include("common.jl") include("generic.jl") diff --git a/src/metrics.jl b/src/metrics.jl index fde6e48..394453b 100644 --- a/src/metrics.jl +++ b/src/metrics.jl @@ -24,6 +24,7 @@ type CorrDist <: SemiMetric end type ChiSqDist <: SemiMetric end type KLDivergence <: PreMetric end +type GenKLDivergence <: PreMetric end immutable RenyiDivergence{T <: Real} <: PreMetric p::T # order of power mean (order of divergence - 1) @@ -37,10 +38,10 @@ immutable RenyiDivergence{T <: Real} <: PreMetric is_zero = q ≈ zero(T) is_one = q ≈ one(T) is_inf = isinf(q) - + # Only positive Rényi divergences are defined !is_zero && q < zero(T) && throw(ArgumentError("Order of Rényi divergence not legal, $(q) < 0.")) - + new(q - 1, !(is_zero || is_one || is_inf), is_zero, is_one, is_inf) end end @@ -50,8 +51,15 @@ type JSDivergence <: SemiMetric end type SpanNormDist <: SemiMetric end +# Deviations are handled separately from the other distances/divergences and +# are excluded from `UnionMetrics` +type MeanAbsDeviation <: Metric end +type MeanSqDeviation <: SemiMetric end +type RMSDeviation <: Metric end +type NormRMSDeviation <: Metric end + -typealias UnionMetrics Union{Euclidean, SqEuclidean, Chebyshev, Cityblock, Minkowski, Hamming, Jaccard, RogersTanimoto, CosineDist, CorrDist, ChiSqDist, KLDivergence, RenyiDivergence, JSDivergence, SpanNormDist} +typealias UnionMetrics Union{Euclidean, SqEuclidean, Chebyshev, Cityblock, Minkowski, Hamming, Jaccard, RogersTanimoto, CosineDist, CorrDist, ChiSqDist, KLDivergence, RenyiDivergence, JSDivergence, SpanNormDist, GenKLDivergence} ########################################################### # @@ -163,6 +171,11 @@ chisq_dist(a::AbstractArray, b::AbstractArray) = evaluate(ChiSqDist(), a, b) @inline eval_reduce(::KLDivergence, s1, s2) = s1 + s2 kl_divergence(a::AbstractArray, b::AbstractArray) = evaluate(KLDivergence(), a, b) +# GenKLDivergence +@inline eval_op(::GenKLDivergence, ai, bi) = ai > 0 ? ai * log(ai / bi) - ai + bi : bi +@inline eval_reduce(::GenKLDivergence, s1, s2) = s1 + s2 +gkl_divergence(a::AbstractArray, b::AbstractArray) = evaluate(GenKLDivergence(), a, b) + # RenyiDivergence function eval_start{T<:AbstractFloat}(::RenyiDivergence, a::AbstractArray{T}, b::AbstractArray{T}) zero(T), zero(T) @@ -289,6 +302,23 @@ end end rogerstanimoto{T <: Bool}(a::AbstractArray{T}, b::AbstractArray{T}) = evaluate(RogersTanimoto(), a, b) +# Deviations + +evaluate(::MeanAbsDeviation, a, b) = cityblock(a, b) / length(a) + +evaluate(::MeanSqDeviation, a, b) = sqeuclidean(a, b) / length(a) +msd(a, b) = evaluate(MeanSqDeviation(), a, b) + +evaluate(::RMSDeviation, a, b) = sqrt(evaluate(MeanSqDeviation(), a, b)) +rmsd(a, b) = evaluate(RMSDeviation(), a, b) + +function evaluate(::NormRMSDeviation, a, b) + amin, amax = extrema(a) + return evaluate(RMSDeviation(), a, b) / (amax - amin) +end +nrmsd(a, b) = evaluate(NormRMSDeviation(), a, b) + + ########################################################### # # Special method diff --git a/test/test_dists.jl b/test/test_dists.jl index 9f63fa3..27c7f52 100644 --- a/test/test_dists.jl +++ b/test/test_dists.jl @@ -63,7 +63,12 @@ for (x, y) in (([4., 5., 6., 7.], [3., 9., 8., 1.]), @test spannorm_dist(x, x) == 0. @test spannorm_dist(x, y) == maximum(x - vec(y)) - minimum(x - vec(y)) + @test gkl_divergence(x, y) ≈ sum(i -> x[i] * log(x[i] / y[i]) - x[i] + y[i], 1:length(x)) + @test evaluate(MeanAbsDeviation(), x, y) ≈ mean(Float64[abs(x[i] - y[i]) for i in 1:length(x)]) + @test msd(x, y) ≈ mean(Float64[abs2(x[i] - y[i]) for i in 1:length(x)]) + @test rmsd(x, y) ≈ sqrt(msd(x, y)) + @test nrmsd(x, y) ≈ sqrt(msd(x, y)) / (maximum(x) - minimum(x)) w = ones(4) @test sqeuclidean(x, y) ≈ wsqeuclidean(x, y, w) @@ -101,7 +106,7 @@ w = rand(size(a)) p = r = rand(12) p[p .< 0.3] = 0.0 scale = sum(p) / sum(r) -r /= sum(r) +r /= sum(r) p /= sum(p) q = rand(12) q /= sum(q) @@ -119,14 +124,14 @@ end @test renyi_divergence(p, p, rand()) ≈ 0 @test renyi_divergence(p, p, 1.0 + rand()) ≈ 0 @test renyi_divergence(p, p, Inf) ≈ 0 -@test renyi_divergence(p, r, 0) ≈ -log(scale) -@test renyi_divergence(p, r, 1) ≈ -log(scale) -@test renyi_divergence(p, r, rand()) ≈ -log(scale) +@test renyi_divergence(p, r, 0) ≈ -log(scale) +@test renyi_divergence(p, r, 1) ≈ -log(scale) +@test renyi_divergence(p, r, rand()) ≈ -log(scale) @test renyi_divergence(p, r, Inf) ≈ -log(scale) @test isinf(renyi_divergence([0.0, 0.5, 0.5], [0.0, 1.0, 0.0], Inf)) @test renyi_divergence([0.0, 1.0, 0.0], [0.0, 0.5, 0.5], Inf) ≈ log(2.0) @test renyi_divergence(p, q, 1) ≈ kl_divergence(p, q) - + pm = (p + q) / 2 jsv = kl_divergence(p, pm) / 2 + kl_divergence(q, pm) / 2 @test js_divergence(p, p) ≈ 0.0