Skip to content

Commit

Permalink
Bregman Divergence (#99)
Browse files Browse the repository at this point in the history
* bregman

* squash bugs

* add commas

* fix typo

* fix other typo

* fix bregman test function

* fix other sqeuclidean call: bregman

* fix colwise test

* fix colwise test again

* move \del

* add Fs

* this build actually passes

* remove faulty pairwise test

* foo-bar

* add back premetric checks

* modernize

* docs + coverage

* cache size

* new tests + type signature

* unindent

* suppress some unnecessary output

* rand fix

* Add Bregman to README
  • Loading branch information
Arnav Sood authored and KristofferC committed Jul 3, 2018
1 parent b05f5c8 commit 0257f33
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 2 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ This package also provides optimized functions to compute column-wise and pairwi
* Root mean squared deviation
* Normalized root mean squared deviation
* Bray-Curtis dissimilarity
* Bregman divergence

For ``Euclidean distance``, ``Squared Euclidean distance``, ``Cityblock distance``, ``Minkowski distance``, and ``Hamming distance``, a weighted version is also provided.

Expand Down Expand Up @@ -163,6 +164,7 @@ Each distance corresponds to a distance type. The type name and the correspondin
| WeightedCityblock | `wcityblock(x, y, w)` | `sum(abs(x - y) .* w)` |
| WeightedMinkowski | `wminkowski(x, y, w, p)` | `sum(abs(x - y).^p .* w) ^ (1/p)` |
| WeightedHamming | `whamming(x, y, w)` | `sum((x .!= y) .* w)` |
| Bregman | `bregman(F, ∇, x, y; inner = LinearAlgebra.dot)` | `F(x) - F(y) - inner(∇(y), x - y)` |

**Note:** The formulas above are using *Julia*'s functions. These formulas are mainly for conveying the math concepts in a concise way. The actual implementation may use a faster way. The arguments `x` and `y` are arrays of real numbers; `k` and `l` are arrays of distinct elements of any kind; a and b are arrays of Bools; and finally, `p` and `q` are arrays forming a discrete probability distribution and are therefore both expected to sum to one.

Expand Down
3 changes: 3 additions & 0 deletions src/Distances.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ export
MeanSqDeviation,
RMSDeviation,
NormRMSDeviation,
Bregman,

# convenient functions
euclidean,
Expand Down Expand Up @@ -84,6 +85,7 @@ export
mahalanobis,
bhattacharyya,
hellinger,
bregman,

haversine,

Expand All @@ -99,5 +101,6 @@ include("wmetrics.jl")
include("haversine.jl")
include("mahalanobis.jl")
include("bhattacharyya.jl")
include("bregman.jl")

end # module end
48 changes: 48 additions & 0 deletions src/bregman.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Bregman divergence

"""
Implements the Bregman divergence, a friendly introduction to which can be found
[here](http://mark.reid.name/blog/meet-the-bregman-divergences.html).
Bregman divergences are a minimal implementation of the "mean-minimizer" property.
It is assumed that the (convex differentiable) function F maps vectors (of any type or size) to real numbers.
The inner product used is `Base.dot`, but one can be passed in either by defining `inner` or by
passing in a keyword argument. If an analytic gradient isn't available, Julia offers a suite
of good automatic differentiation packages.
function evaluate(dist::Bregman, p::AbstractVector, q::AbstractVector)
"""
struct Bregman{T1 <: Function, T2 <: Function, T3 <: Function} <: PreMetric
F::T1
::T2
inner::T3
end

# Default costructor.
Bregman(F, ∇) = Bregman(F, ∇, LinearAlgebra.dot)

# Evaluation fuction
function evaluate(dist::Bregman, p::AbstractVector, q::AbstractVector)
# Create cache vals.
FP_val = dist.F(p);
FQ_val = dist.F(q);
DQ_val = dist.(q);
p_size = size(p);
# Check F codomain.
if !(isa(FP_val, Real) && isa(FQ_val, Real))
throw(ArgumentError("F Codomain Error: F doesn't map the vectors to real numbers"))
end
# Check vector size.
if !(p_size == size(q))
throw(DimensionMismatch("The vector p ($(size(p))) and q ($(size(q))) are different sizes."))
end
# Check gradient size.
if !(size(DQ_val) == p_size)
throw(DimensionMismatch("The gradient result is not the same size as p and q"))
end
# Return the Bregman divergence.
return FP_val - FQ_val - dist.inner(DQ_val, p-q);
end

# Convenience function.
bregman(F, ∇, x, y; inner = LinearAlgebra.dot) = evaluate(Bregman(F, ∇, inner), x, y)
33 changes: 31 additions & 2 deletions test/test_dists.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ function test_metricity(dist, x, y, z)
if isa(dist, PreMetric)
# Unfortunately small non-zero numbers (~10^-16) are appearing
# in our tests due to accumulating floating point rounding errors.
# We either need to allow small errors in our tests or change the
# We either need to allow small errors in our tests or change the
# way we do accumulations...
@test evaluate(dist, x, x) + one(eltype(x)) one(eltype(x))
@test evaluate(dist, y, y) + one(eltype(y)) one(eltype(y))
Expand Down Expand Up @@ -59,6 +59,8 @@ end

test_metricity(BhattacharyyaDist(), x, y, z)
test_metricity(HellingerDist(), x, y, z)
test_metricity(Bregman(x -> sqeuclidean(x, zero(x)), x -> 2*x), x, y, z);


x₁ = rand(T, 2)
x₂ = rand(T, 2)
Expand Down Expand Up @@ -276,6 +278,9 @@ end # testset
@test_throws DimensionMismatch colwise!(mat23, Euclidean(), mat23, mat23)
@test_throws DimensionMismatch colwise!(mat23, Euclidean(), mat23, q)
@test_throws DimensionMismatch colwise!(mat23, Euclidean(), mat23, mat22)
@test_throws DimensionMismatch colwise!(mat23, Bregman(x -> sqeuclidean(x, zero(x)), x -> 2*x), mat23, mat22)
@test_throws DimensionMismatch evaluate(Bregman(x -> sqeuclidean(x, zero(x)), x -> 2*x), [1, 2, 3], [1, 2])
@test_throws DimensionMismatch evaluate(Bregman(x -> sqeuclidean(x, zero(x)), x -> [1, 2]), [1, 2, 3], [1, 2, 3])
end # testset

@testset "mahalanobis" begin
Expand Down Expand Up @@ -382,6 +387,7 @@ end
test_colwise(Chebyshev(), X, Y, T)
test_colwise(Minkowski(2.5), X, Y, T)
test_colwise(Hamming(), A, B, T)
test_colwise(Bregman(x -> sqeuclidean(x, zero(x)), x -> 2*x), X, Y, T);

test_colwise(CosineDist(), X, Y, T)
test_colwise(CorrDist(), X, Y, T)
Expand Down Expand Up @@ -416,7 +422,6 @@ end
test_colwise(Mahalanobis(Q), X, Y, T)
end


function test_pairwise(dist, x, y, T)
@testset "Pairwise test for $(typeof(dist))" begin
nx = size(x, 2)
Expand Down Expand Up @@ -472,6 +477,7 @@ end
test_pairwise(BhattacharyyaDist(), X, Y, T)
test_pairwise(HellingerDist(), X, Y, T)
test_pairwise(BrayCurtis(), X, Y, T)
test_pairwise(Bregman(x -> sqeuclidean(x, zero(x)), x -> 2*x), X, Y, T)

w = rand(m)

Expand Down Expand Up @@ -503,3 +509,26 @@ end
@test pd[1, 1] == 0
@test pd[2, 2] == 0
end

@testset "Bregman Divergence" begin
# Some basic tests.
@test_throws ArgumentError bregman(x -> x, x -> 2*x, [1, 2, 3], [1, 2, 3])
# Test if Bregman() correctly implements the gkl divergence between two random vectors.
F(p) = LinearAlgebra.dot(p, log.(p));
(p) = map(x -> log(x) + 1, p)
testDist = Bregman(F, ∇)
p = rand(4)
q = rand(4)
p = p/sum(p);
q = q/sum(q);
@test evaluate(testDist, p, q) gkl_divergence(p, q)
# Test if Bregman() correctly implements the squared euclidean dist. between them.
@test bregman(x -> norm(x)^2, x -> 2*x, p, q) sqeuclidean(p, q)
# Test if Bregman() correctly implements the IS distance.
F(p) = -1 * sum(log.(p))
(p) = map(x -> -1 * x^(-1), p)
function ISdist(p::AbstractVector, q::AbstractVector)
return sum([p[i]/q[i] - log(p[i]/q[i]) - 1 for i in 1:length(p)])
end
@test bregman(F, ∇, p, q) ISdist(p, q)
end

0 comments on commit 0257f33

Please sign in to comment.