Skip to content

Commit

Permalink
add kullback leibler divergence
Browse files Browse the repository at this point in the history
  • Loading branch information
alexandrebarachant committed Jul 27, 2015
1 parent b45f358 commit 417d0f7
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ install:
- virtualenv --system-site-packages testvenv
- source testvenv/bin/activate
- pip install pip --upgrade
- pip install coverage python-coveralls
- pip install coverage python-coveralls joblib
- pip install nose scikit-learn --upgrade
- python setup.py build install
# command to run tests
Expand Down
33 changes: 31 additions & 2 deletions pyriemann/utils/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,30 @@
# distances
###############################################################

def distance_kullback(A,B):
"""Return the Kullback leibler divergence between
two covariance matrices A and B :
:param A: First covariance matrix
:param B: Second covariance matrix
:returns: Kullback leibler divergence between A and B
"""
dim = A.shape[0]
logdet = numpy.log(numpy.linalg.det(B) / numpy.linalg.det(A))
kl = numpy.trace(numpy.dot(numpy.linalg.inv(B), A)) - dim + logdet
return 0.5 * kl


def distance_kullback_right(A, B):
"""wrapper for right kullblack leibler div."""
return distance_kullback(B, A)


def distance_kullback_sym(A, B):
"""Symetrized kullback leibler divergence."""
return distance_kullback(A, B) + distance_kullback_right(A, B)


def distance_euclid(A, B):
"""Return the Euclidean distance (Froebenius norm) between
Expand Down Expand Up @@ -78,14 +102,19 @@ def distance(A, B, metric='riemann'):
:param A: First covariance matrix
:param B: Second covariance matrix
:param metric: the metric (Default value 'riemann'), can be : 'riemann' , 'logeuclid' , 'euclid' , 'logdet'
:param metric: the metric (Default value 'riemann'), can be : 'riemann' ,
'logeuclid' , 'euclid' , 'logdet', 'kullback', 'kullback_right',
'kullback_sym'.
:returns: the distance between A and B
"""
distance_methods = {'riemann': distance_riemann,
'logeuclid': distance_logeuclid,
'euclid': distance_euclid,
'logdet': distance_logdet}
'logdet': distance_logdet,
'kullback': distance_kullback,
'kullback_right': distance_kullback_right,
'kullback_sym': distance_kullback_sym}
if len(A.shape) == 3:
d = numpy.empty((len(A), 1))
for i in range(len(A)):
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
numpy
scipy
scikit-learn
pandas
pandas
joblib
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@
author_email='alexandre.barachant@gmail.com',
license='BSD (3-clause)',
packages=find_packages(),
install_requires=['numpy','scipy','scikit-learn'],
install_requires=['numpy','scipy','scikit-learn', 'joblib'],
zip_safe=False)
38 changes: 28 additions & 10 deletions tests/test_utils_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,54 +6,72 @@
distance_euclid,
distance_logeuclid,
distance_logdet,
distance_kullback,
distance_kullback_right,
distance_kullback_sym,
distance)


def test_distance_riemann():
"""Test riemannian distance"""
A = 2*np.eye(3)
B = 2*np.eye(3)
assert_array_almost_equal(distance_riemann(A,B),0)


def test_distance_kullback():
"""Test kullback divergence"""
A = 2*np.eye(3)
B = 2*np.eye(3)
assert_array_almost_equal(distance_kullback(A,B),0)
assert_array_almost_equal(distance_kullback_right(A,B),0)
assert_array_almost_equal(distance_kullback_sym(A,B),0)

def test_distance_euclid():
"""Test euclidean distance"""
A = 2*np.eye(3)
B = 2*np.eye(3)
assert_equal(distance_euclid(A,B),0)

def test_distance_logeuclid():
"""Test logeuclid distance"""
A = 2*np.eye(3)
B = 2*np.eye(3)
assert_equal(distance_logeuclid(A,B),0)

def test_distance_logdet():
"""Test logdet distance"""
A = 2*np.eye(3)
B = 2*np.eye(3)
assert_equal(distance_logdet(A,B),0)

def test_distance_generic_riemann():
"""Test riemannian distance for generic function"""
A = 2*np.eye(3)
B = 2*np.eye(3)
assert_equal(distance(A,B,metric='riemann'),distance_riemann(A,B))

def test_distance_generic_euclid():
"""Test euclidean distance for generic function"""
A = 2*np.eye(3)
B = 2*np.eye(3)
assert_equal(distance(A,B,metric='euclid'),distance_euclid(A,B))

def test_distance_generic_logdet():
"""Test logdet distance for generic function"""
A = 2*np.eye(3)
B = 2*np.eye(3)
assert_equal(distance(A,B,metric='logdet'),distance_logdet(A,B))

def test_distance_generic_logeuclid():
"""Test logeuclid distance for generic function"""
A = 2*np.eye(3)
B = 2*np.eye(3)
assert_equal(distance(A,B,metric='logeuclid'),distance_logeuclid(A,B))


def test_distance_generic_kullback():
"""Test logeuclid distance for generic function"""
A = 2*np.eye(3)
B = 2*np.eye(3)
assert_equal(distance(A,B,metric='kullback'),distance_kullback(A,B))
assert_equal(distance(A,B,metric='kullback_right'),distance_kullback_right(A,B))
assert_equal(distance(A,B,metric='kullback_sym'),distance_kullback_sym(A,B))

0 comments on commit 417d0f7

Please sign in to comment.