/
metrics.py
118 lines (95 loc) · 4.05 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#!/usr/bin/python
# -*- coding: utf-8 -*-
##
# metrics.py: Metrics for use with SciPy and sklearn functions.
##
# © 2012 Chris Ferrie (csferrie@gmail.com) and
# Christopher E. Granade (cgranade@gmail.com)
#
# This file is a part of the Qinfer project.
# Licensed under the AGPL version 3.
##
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
##
## FEATURES ###################################################################
from __future__ import absolute_import
from __future__ import division
## ALL ########################################################################
# We use __all__ to restrict what globals are visible to external modules.
__all__ = [
'weighted_pairwise_distances'
]
## IMPORTS ####################################################################
import numpy as np
import scipy.linalg as la
import warnings
from qinfer.utils import outer_product
try:
import sklearn
import sklearn.metrics
import sklearn.metrics.pairwise
except ImportError:
try:
import logging
logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
logger.info("Could not import scikit-learn. Clustering metrics are disabled.")
except:
pass
sklearn = None
## FUNCTIONS ##################################################################
def rescaled_distance_mtx(p, q):
r"""
Given two particle updaters for the same model, returns a matrix
:math:`\matr{d}` with elements
.. math::
\matr{d}_{i,j} = \left\Vert \sqrt{\matr{Q}} \cdot
(\vec{x}_{p, i} - \vec{x}_{q, j}) \right\Vert_2,
where :math:`\matr{Q}` is the scale matrix of the model,
:math:`\vec{x}_{p,i}` is the :math:`i`th particle of ``p``, and where
:math:`\vec{x}_{q,i}` is the :math:`i`th particle of ``q`.
:param qinfer.smc.SMCUpdater p: SMC updater for the distribution
:math:`p(\vec{x})`.
:param qinfer.smc.SMCUpdater q: SMC updater for the distribution
:math:`q(\vec{x})`.
Either or both of ``p`` or ``q`` can simply be the locations array for
an :ref:`SMCUpdater`.
"""
# TODO: check that models are actually the same!
p_locs = p.particle_locations if isinstance(p, qinfer.smc.SMCUpdater) else p
q_locs = q.particle_locations if isinstance(q, qinfer.smc.SMCUpdater) else q
# Because the modelparam axis is last in each of the three cases, we're
# good as far as broadcasting goes.
delta = np.sqrt(p.model.Q) * (
p_locs[:, np.newaxis, :] -
q_locs[np.newaxis, :, :]
)
return np.sqrt(np.sum(delta**2, axis=-1))
def weighted_pairwise_distances(X, w, metric='euclidean', w_pow=0.5):
r"""
Given a feature matrix ``X`` with weights ``w``, calculates the modified
distance metric :math:`\tilde{d}(p, q) = d(p, q) / (w(p) w(q) N^2)^p`, where
:math:`N` is the length of ``X``. This metric is such that "heavy" feature
vectors are considered to be closer to each other than "light" feature
vectors, and are hence correspondingly less likely to be considered part of
the same cluster.
"""
if sklearn is None:
raise ImportError("This function requires scikit-learn.")
base_metric = sklearn.metrics.pairwise.pairwise_distances(X, metric=metric)
N = w.shape[0]
w_matrix = outer_product(w) * N**2
return base_metric / (w_matrix ** w_pow)
## FINAL IMPORTS ##############################################################
import qinfer.smc