-
Notifications
You must be signed in to change notification settings - Fork 96
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* bregman * squash bugs * add commas * fix typo * fix other typo * fix bregman test function * fix other sqeuclidean call: bregman * fix colwise test * fix colwise test again * move \del * add Fs * this build actually passes * remove faulty pairwise test * foo-bar * add back premetric checks * modernize * docs + coverage * cache size * new tests + type signature * unindent * suppress some unnecessary output * rand fix * Add Bregman to README
- Loading branch information
1 parent
b05f5c8
commit 0257f33
Showing
4 changed files
with
84 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Bregman divergence | ||
|
||
""" | ||
Implements the Bregman divergence, a friendly introduction to which can be found | ||
[here](http://mark.reid.name/blog/meet-the-bregman-divergences.html). | ||
Bregman divergences are a minimal implementation of the "mean-minimizer" property. | ||
It is assumed that the (convex differentiable) function F maps vectors (of any type or size) to real numbers. | ||
The inner product used is `Base.dot`, but one can be passed in either by defining `inner` or by | ||
passing in a keyword argument. If an analytic gradient isn't available, Julia offers a suite | ||
of good automatic differentiation packages. | ||
function evaluate(dist::Bregman, p::AbstractVector, q::AbstractVector) | ||
""" | ||
struct Bregman{T1 <: Function, T2 <: Function, T3 <: Function} <: PreMetric | ||
F::T1 | ||
∇::T2 | ||
inner::T3 | ||
end | ||
|
||
# Default costructor. | ||
Bregman(F, ∇) = Bregman(F, ∇, LinearAlgebra.dot) | ||
|
||
# Evaluation fuction | ||
function evaluate(dist::Bregman, p::AbstractVector, q::AbstractVector) | ||
# Create cache vals. | ||
FP_val = dist.F(p); | ||
FQ_val = dist.F(q); | ||
DQ_val = dist.∇(q); | ||
p_size = size(p); | ||
# Check F codomain. | ||
if !(isa(FP_val, Real) && isa(FQ_val, Real)) | ||
throw(ArgumentError("F Codomain Error: F doesn't map the vectors to real numbers")) | ||
end | ||
# Check vector size. | ||
if !(p_size == size(q)) | ||
throw(DimensionMismatch("The vector p ($(size(p))) and q ($(size(q))) are different sizes.")) | ||
end | ||
# Check gradient size. | ||
if !(size(DQ_val) == p_size) | ||
throw(DimensionMismatch("The gradient result is not the same size as p and q")) | ||
end | ||
# Return the Bregman divergence. | ||
return FP_val - FQ_val - dist.inner(DQ_val, p-q); | ||
end | ||
|
||
# Convenience function. | ||
bregman(F, ∇, x, y; inner = LinearAlgebra.dot) = evaluate(Bregman(F, ∇, inner), x, y) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters