Skip to content

Commit

Permalink
Merge pull request #125 from bsipocz/GMM_to_GaussianMixture
Browse files Browse the repository at this point in the history
GMM to GaussianMixture
  • Loading branch information
bsipocz committed Nov 17, 2018
2 parents 75f9537 + 7a12575 commit c7a99c9
Show file tree
Hide file tree
Showing 11 changed files with 66 additions and 41 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,5 @@ book_figures/*/*.pkl
# pyenv
.python-version

.coverage
.pytest_cache
12 changes: 8 additions & 4 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@ addons:

env:
global:
- PYTHON_VERSION=3.6
- NUMPY_VERSION=1.12
- PYTHON_VERSION=3.7
- ASTROPY_VERSION=stable
- CONDA_DEPENDENCIES='scipy scikit-learn<0.19 nose matplotlib pymc'
- CONDA_DEPENDENCIES='scipy scikit-learn nose matplotlib pymc'
- PIP_DEPENDENCIES='astroML_addons'
- DEBUG=true

Expand All @@ -30,12 +29,17 @@ matrix:

- env: NUMPY_VERSION=1.14

- env: NUMPY_VERSION=1.13
- env: PYTHON_VERSION=3.6 NUMPY_VERSION=1.13

- env: PYTHON_VERSION=3.6 NUMPY_VERSION=1.12

- env: PYTHON_VERSION=3.5 NUMPY_VERSION=1.11

- env: PYTHON_VERSION=3.4 NUMPY_VERSION=1.10 SCIKIT_LEARN_VERSION=0.18

- env: PYTHON_VERSION=3.4 CONDA_DEPENDENCIES='scipy matplotlib nose pymc' NUMPY_VERSION=1.9
PIP_DEPENDENCIES='astropy==2.0.9 scikit-learn==0.18'

- env: PYTHON_VERSION=2.7

install:
Expand Down
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ API Changes and Other Changes

- Removed deprecated KDE class. [#119]

- Switched to use the updated scikit-learn API for GaussianMixture. This
change depends on scikit-learn 0.18+. [#125]

0.3.1
=====
Expand Down
4 changes: 2 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,12 @@ The core ``astroML`` package requires the following:
- Python_ version 2.6-2.7 and 3.3+
- Numpy_ >= 1.4
- Scipy_ >= 0.7
- Scikit-learn_ >= 0.10
- Scikit-learn_ >= 0.18
- Matplotlib_ >= 0.99
- AstroPy_ > 0.2.5
AstroPy is required to read Flexible Image Transport
System (FITS) files, which are used by several datasets.

This configuration matches the Ubuntu 10.04 LTS release from April 2010,
with the addition of scikit-learn.

Expand Down
27 changes: 19 additions & 8 deletions astroML/classification/gmm_bayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
This implements generative classification based on mixtures of gaussians
to model the probability density of each class.
"""

import warnings
import numpy as np
from sklearn.mixture import GMM

from sklearn.naive_bayes import BaseNB
from sklearn.mixture import GaussianMixture


class GMMBayes(BaseNB):
"""GMM Bayes Classifier
"""GaussianMixture Bayes Classifier
This is a generalization to the Naive Bayes classifier: rather than
modeling the distribution of each class with axis-aligned gaussians,
Expand All @@ -20,11 +23,12 @@ class GMMBayes(BaseNB):
Parameters
----------
n_components : int or list
number of components to use in the gmm. If specified as a list, it
must match the number of class labels
other keywords are passed directly to GMM
number of components to use in the GaussianMixture. If specified as
a list, it must match the number of class labels. Default is 1.
**kwargs : dict, optional
other keywords are passed directly to GaussianMixture
"""

def __init__(self, n_components=1, **kwargs):
self.n_components = np.atleast_1d(n_components)
self.kwargs = kwargs
Expand Down Expand Up @@ -54,12 +58,19 @@ def fit(self, X, y):
n_comp = np.zeros(len(self.classes_), dtype=int) + self.n_components

for i, y_i in enumerate(unique_y):
self.gmms_[i] = GMM(n_comp[i], **self.kwargs).fit(X[y == y_i])
if n_comp[i] > X[y == y_i].shape[0]:
warnstr = ("Expected n_samples >= n_components but got "
"n_samples={0}, n_components={1}, "
"n_components set to {0}.")
warnings.warn(warnstr.format(X[y == y_i].shape[0], n_comp[i]))
n_comp[i] = y_i
self.gmms_[i] = GaussianMixture(n_comp[i], **self.kwargs).fit(X[y == y_i])
self.class_prior_[i] = np.float(np.sum(y == y_i)) / n_samples

return self

def _joint_log_likelihood(self, X):

X = np.asarray(np.atleast_2d(X))
logprobs = np.array([g.score(X) for g in self.gmms_]).T
logprobs = np.array([g.score_samples(X) for g in self.gmms_]).T
return logprobs + np.log(self.class_prior_)
9 changes: 3 additions & 6 deletions astroML/clustering/mst_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@

from scipy import sparse
from sklearn.neighbors import kneighbors_graph
from sklearn.mixture import GMM

try:
from scipy.sparse.csgraph import \
minimum_spanning_tree, connected_components
except:
from scipy.sparse.csgraph import (
minimum_spanning_tree, connected_components)
except ImportError:
raise ValueError("scipy v0.11 or greater required "
"for minimum spanning tree")

Expand Down Expand Up @@ -176,8 +175,6 @@ def get_graph_segments(X, G):
if (X.ndim != 2) or (X.shape[1] != 2):
raise ValueError('shape of X should be (n_samples, 2)')

n_samples = X.shape[0]

G = sparse.coo_matrix(G)
A = X[G.row].T
B = X[G.col].T
Expand Down
23 changes: 14 additions & 9 deletions astroML/density_estimation/gauss_mixture.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from sklearn.mixture import GMM
from sklearn.mixture import GaussianMixture


class GaussianMixture1D(object):
Expand All @@ -18,29 +18,34 @@ class GaussianMixture1D(object):
def __init__(self, means=0, sigmas=1, weights=1):
data = np.array([t for t in np.broadcast(means, sigmas, weights)])

self._gmm = GMM(data.shape[0])
self._gmm.fit = None # disable fit method for safety
components = data.shape[0]
self._gmm = GaussianMixture(components, covariance_type='spherical')

self._gmm.means_ = data[:, :1]
self._gmm.covars_ = data[:, 1:2] ** 2
self._gmm.weights = data[:, 2] / data[:, 2].sum()
self._gmm.weights_ = data[:, 2] / data[:, 2].sum()
self._gmm.covariances_ = data[:, 1] ** 2

self._gmm.precisions_cholesky_ = 1 / np.sqrt(self._gmm.covariances_)

self._gmm.fit = None # disable fit method for safety

def sample(self, size):
"""Random sample"""
return self._gmm.sample(size)

def pdf(self, x):
"""Compute probability distribution"""
# logprob, responsibilities = self._gmm.eval(x)

if x.ndim == 1:
x = x[:, np.newaxis]
logprob, responsibilities = self._gmm.score_samples(x)
logprob = self._gmm.score_samples(x)
return np.exp(logprob)

def pdf_individual(self, x):
"""Compute probability distribution of each component"""
# logprob, responsibilities = self._gmm.eval(x)

if x.ndim == 1:
x = x[:, np.newaxis]
logprob, responsibilities = self._gmm.score_samples(x)
logprob = self._gmm.score_samples(x)
responsibilities = self._gmm.predict_proba(x)
return responsibilities * np.exp(logprob[:, np.newaxis])
5 changes: 4 additions & 1 deletion astroML/density_estimation/tests/test_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ def test_gaussian1d():
gauss = GaussianMixture1D(means=means, sigmas=sigmas, weights=weights)
y = gauss.pdf(x)

# Check whether sampling works
gauss.sample(10)

dx = x[1] - x[0]
integral = np.sum(y*dx)

assert np.round(integral, 1) == 1
assert_allclose(integral, 1., atol=0.02)
19 changes: 10 additions & 9 deletions astroML/density_estimation/xdeconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import numpy as np
from scipy import linalg

from sklearn.mixture import GMM
from sklearn.mixture import GaussianMixture
from ..utils import logsumexp, log_multivariate_gaussian, check_random_state


Expand All @@ -27,7 +27,7 @@ class XDGMM(object):
----------
n_components: integer
number of gaussian components to fit to the data
n_iter: integer (optional)
max_iter: integer (optional)
number of EM iterations to perform (default=100)
tol: float (optional)
stopping criterion for EM iterations (default=1E-5)
Expand All @@ -36,10 +36,10 @@ class XDGMM(object):
-----
This implementation follows Bovy et al. arXiv 0905.2979
"""
def __init__(self, n_components, n_iter=100, tol=1E-5, verbose=False,
def __init__(self, n_components, max_iter=100, tol=1E-5, verbose=False,
random_state = None):
self.n_components = n_components
self.n_iter = n_iter
self.max_iter = max_iter
self.tol = tol
self.verbose = verbose
self.random_state = random_state
Expand Down Expand Up @@ -73,17 +73,18 @@ def fit(self, X, Xerr, R=None):
# assume full covariances of data
assert Xerr.shape == (n_samples, n_features, n_features)

# initialize components via a few steps of GMM
# initialize components via a few steps of GaussianMixture
# this doesn't take into account errors, but is a fast first-guess
gmm = GMM(self.n_components, n_iter=10, covariance_type='full',
random_state=self.random_state).fit(X)
gmm = GaussianMixture(self.n_components, max_iter=10,
covariance_type='full',
random_state=self.random_state).fit(X)
self.mu = gmm.means_
self.alpha = gmm.weights_
self.V = gmm.covars_
self.V = gmm.covariances_

logL = self.logL(X, Xerr)

for i in range(self.n_iter):
for i in range(self.max_iter):
t0 = time()
self._EMstep(X, Xerr)
logL_next = self.logL(X, Xerr)
Expand Down
2 changes: 1 addition & 1 deletion doc/user_guide/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ The core ``astroML`` package requires the following:
- `Python <http://python.org>`_ version 2.6 - 2.7 or 3.3+
- `Numpy <http://numpy.scipy.org/>`_ >= 1.4
- `Scipy <http://www.scipy.org/>`_ >= 0.7
- `scikit-learn <http://scikit-learn.org/>`_ >= 0.10
- `scikit-learn <http://scikit-learn.org/>`_ >= 0.18
- `matplotlib <http://matplotlib.org/>`_ >= 0.99
- `astropy <http://www.astropy.org/>`_ >= 0.2.5
AstroPy is required to read Flexible Image Transport
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import astroML
VERSION = astroML.__version__

install_requires = ['scikit-learn>=0.10',
install_requires = ['scikit-learn>=0.18',
'numpy>=1.4',
'scipy>=0.7',
'matplotlib>=0.99',
Expand Down

0 comments on commit c7a99c9

Please sign in to comment.