Skip to content

Commit

Permalink
Total variation distance (#118)
Browse files Browse the repository at this point in the history
* Add total variation distance

* Add TotalVariation to README and benchmarks
  • Loading branch information
devmotion authored and KristofferC committed Jan 22, 2019
1 parent 1d43a80 commit 785aaab
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 1 deletion.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ This package also provides optimized functions to compute column-wise and pairwi
* Euclidean distance
* Squared Euclidean distance
* Cityblock distance
* Total variation distance
* Jaccard distance
* Rogers-Tanimoto distance
* Chebyshev distance
Expand Down Expand Up @@ -136,6 +137,7 @@ Each distance corresponds to a distance type. The type name and the correspondin
| Euclidean | `euclidean(x, y)` | `sqrt(sum((x - y) .^ 2))` |
| SqEuclidean | `sqeuclidean(x, y)` | `sum((x - y).^2)` |
| Cityblock | `cityblock(x, y)` | `sum(abs(x - y))` |
| TotalVariation | `totalvariation(x, y)` | `sum(abs(x - y)) / 2` |
| Chebyshev | `chebyshev(x, y)` | `max(abs(x - y))` |
| Minkowski | `minkowski(x, y, p)` | `sum(abs(x - y).^p) ^ (1/p)` |
| Hamming | `hamming(k, l)` | `sum(k .!= l)` |
Expand Down
1 change: 1 addition & 0 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ function create_distances(w, Q)
SqEuclidean(),
Euclidean(),
Cityblock(),
TotalVariation(),
Chebyshev(),
Minkowski(3.0),
Hamming(),
Expand Down
1 change: 1 addition & 0 deletions benchmark/print_table.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ order = [
:SqEuclidean,
:Euclidean,
:Cityblock,
:TotalVariation,
:Chebyshev,
:Minkowski,
:Hamming,
Expand Down
2 changes: 2 additions & 0 deletions src/Distances.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ export
Euclidean,
SqEuclidean,
Cityblock,
TotalVariation,
Chebyshev,
Minkowski,
Jaccard,
Expand Down Expand Up @@ -61,6 +62,7 @@ export
euclidean,
sqeuclidean,
cityblock,
totalvariation,
jaccard,
braycurtis,
rogerstanimoto,
Expand Down
10 changes: 9 additions & 1 deletion src/metrics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ struct SqEuclidean <: SemiMetric
end
struct Chebyshev <: Metric end
struct Cityblock <: Metric end
struct TotalVariation <: Metric end
struct Jaccard <: Metric end
struct RogersTanimoto <: Metric end

Expand Down Expand Up @@ -99,7 +100,7 @@ struct RMSDeviation <: Metric end
struct NormRMSDeviation <: Metric end


const UnionMetrics = Union{Euclidean,SqEuclidean,Chebyshev,Cityblock,Minkowski,Hamming,Jaccard,RogersTanimoto,CosineDist,CorrDist,ChiSqDist,KLDivergence,RenyiDivergence,BrayCurtis,JSDivergence,SpanNormDist,GenKLDivergence}
const UnionMetrics = Union{Euclidean,SqEuclidean,Chebyshev,Cityblock,TotalVariation,Minkowski,Hamming,Jaccard,RogersTanimoto,CosineDist,CorrDist,ChiSqDist,KLDivergence,RenyiDivergence,BrayCurtis,JSDivergence,SpanNormDist,GenKLDivergence}

"""
Euclidean([thresh])
Expand Down Expand Up @@ -219,6 +220,13 @@ euclidean(a::Number, b::Number) = evaluate(Euclidean(), a, b)
cityblock(a::AbstractArray, b::AbstractArray) = evaluate(Cityblock(), a, b)
cityblock(a::T, b::T) where {T <: Number} = evaluate(Cityblock(), a, b)

# Total variation
@inline eval_op(::TotalVariation, ai, bi) = abs(ai - bi)
@inline eval_reduce(::TotalVariation, s1, s2) = s1 + s2
eval_end(::TotalVariation, s) = s / 2
totalvariation(a::AbstractArray, b::AbstractArray) = evaluate(TotalVariation(), a, b)
totalvariation(a::T, b::T) where {T <: Number} = evaluate(TotalVariation(), a, b)

# Chebyshev
@inline eval_op(::Chebyshev, ai, bi) = abs(ai - bi)
@inline eval_reduce(::Chebyshev, s1, s2) = max(s1, s2)
Expand Down
7 changes: 7 additions & 0 deletions test/test_dists.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ end
test_metricity(SqEuclidean(), x, y, z)
test_metricity(Euclidean(), x, y, z)
test_metricity(Cityblock(), x, y, z)
test_metricity(TotalVariation(), x, y, z)
test_metricity(Chebyshev(), x, y, z)
test_metricity(Minkowski(2.5), x, y, z)

Expand Down Expand Up @@ -122,6 +123,7 @@ end

@test euclidean(a, b) == 1.0
@test cityblock(a, b) == 1.0
@test totalvariation(a, b) == 0.5
@test chebyshev(a, b) == 1.0
@test minkowski(a, b, 2) == 1.0
@test hamming(a, b) == 1
Expand All @@ -140,6 +142,7 @@ end
@test euclidean(x, y) == sqrt(57.0)
@test jaccard(x, y) == 13.0 / 28
@test cityblock(x, y) == 13.0
@test totalvariation(x, y) == 6.5
@test chebyshev(x, y) == 6.0
@test braycurtis(x, y) == 1.0 - (30.0 / 43.0)
@test minkowski(x, y, 2) == sqrt(57.0)
Expand Down Expand Up @@ -242,6 +245,8 @@ end #testset
@test isa(euclidean(a, b), T)
@test cityblock(a, b) == 0.0
@test isa(cityblock(a, b), T)
@test totalvariation(a, b) == 0.0
@test isa(totalvariation(a, b), T)
@test chebyshev(a, b) == 0.0
@test isa(chebyshev(a, b), T)
@test braycurtis(a, b) == 0.0
Expand Down Expand Up @@ -385,6 +390,7 @@ end
test_colwise(SqEuclidean(), X, Y, T)
test_colwise(Euclidean(), X, Y, T)
test_colwise(Cityblock(), X, Y, T)
test_colwise(TotalVariation(), X, Y, T)
test_colwise(Chebyshev(), X, Y, T)
test_colwise(Minkowski(2.5), X, Y, T)
test_colwise(Hamming(), A, B, T)
Expand Down Expand Up @@ -459,6 +465,7 @@ end
test_pairwise(SqEuclidean(), X, Y, T)
test_pairwise(Euclidean(), X, Y, T)
test_pairwise(Cityblock(), X, Y, T)
test_pairwise(TotalVariation(), X, Y, T)
test_pairwise(Chebyshev(), X, Y, T)
test_pairwise(Minkowski(2.5), X, Y, T)
test_pairwise(Hamming(), A, B, T)
Expand Down

0 comments on commit 785aaab

Please sign in to comment.