Skip to content

Commit

Permalink
[MRG+1] Add DBSCAN support for additional metric params (scikit-learn…
Browse files Browse the repository at this point in the history
…#8139)

* Add DBSCAN support for additional metric params
  • Loading branch information
naoyak authored and Sundrique committed Jun 14, 2017
1 parent 7541652 commit 146c829
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 3 deletions.
19 changes: 16 additions & 3 deletions sklearn/cluster/dbscan_.py
Expand Up @@ -20,7 +20,7 @@
from ._dbscan_inner import dbscan_inner


def dbscan(X, eps=0.5, min_samples=5, metric='minkowski',
def dbscan(X, eps=0.5, min_samples=5, metric='minkowski', metric_params=None,
algorithm='auto', leaf_size=30, p=2, sample_weight=None, n_jobs=1):
"""Perform DBSCAN clustering from vector array or distance matrix.
Expand Down Expand Up @@ -50,6 +50,11 @@ def dbscan(X, eps=0.5, min_samples=5, metric='minkowski',
must be square. X may be a sparse matrix, in which case only "nonzero"
elements may be considered neighbors for DBSCAN.
metric_params : dict, optional
Additional keyword arguments for the metric function.
.. versionadded:: 0.19
algorithm : {'auto', 'ball_tree', 'kd_tree', 'brute'}, optional
The algorithm to be used by the NearestNeighbors module
to compute pointwise distances and find nearest neighbors.
Expand Down Expand Up @@ -130,7 +135,8 @@ def dbscan(X, eps=0.5, min_samples=5, metric='minkowski',
else:
neighbors_model = NearestNeighbors(radius=eps, algorithm=algorithm,
leaf_size=leaf_size,
metric=metric, p=p,
metric=metric,
metric_params=metric_params, p=p,
n_jobs=n_jobs)
neighbors_model.fit(X)
# This has worst case O(n^2) memory complexity
Expand Down Expand Up @@ -184,6 +190,11 @@ class DBSCAN(BaseEstimator, ClusterMixin):
.. versionadded:: 0.17
metric *precomputed* to accept precomputed sparse matrix.
metric_params : dict, optional
Additional keyword arguments for the metric function.
.. versionadded:: 0.19
algorithm : {'auto', 'ball_tree', 'kd_tree', 'brute'}, optional
The algorithm to be used by the NearestNeighbors module
to compute pointwise distances and find nearest neighbors.
Expand Down Expand Up @@ -237,10 +248,12 @@ class DBSCAN(BaseEstimator, ClusterMixin):
"""

def __init__(self, eps=0.5, min_samples=5, metric='euclidean',
algorithm='auto', leaf_size=30, p=None, n_jobs=1):
metric_params=None, algorithm='auto', leaf_size=30, p=None,
n_jobs=1):
self.eps = eps
self.min_samples = min_samples
self.metric = metric
self.metric_params = metric_params
self.algorithm = algorithm
self.leaf_size = leaf_size
self.p = p
Expand Down
28 changes: 28 additions & 0 deletions sklearn/cluster/tests/test_dbscan.py
Expand Up @@ -133,6 +133,34 @@ def test_dbscan_callable():
assert_equal(n_clusters_2, n_clusters)


def test_dbscan_metric_params():
# Tests that DBSCAN works with the metrics_params argument.
eps = 0.8
min_samples = 10
p = 1

# Compute DBSCAN with metric_params arg
db = DBSCAN(metric='minkowski', metric_params={'p': p}, eps=eps,
min_samples=min_samples, algorithm='ball_tree').fit(X)
core_sample_1, labels_1 = db.core_sample_indices_, db.labels_

# Test that sample labels are the same as passing Minkowski 'p' directly
db = DBSCAN(metric='minkowski', eps=eps, min_samples=min_samples,
algorithm='ball_tree', p=p).fit(X)
core_sample_2, labels_2 = db.core_sample_indices_, db.labels_

assert_array_equal(core_sample_1, core_sample_2)
assert_array_equal(labels_1, labels_2)

# Minkowski with p=1 should be equivalent to Manhattan distance
db = DBSCAN(metric='manhattan', eps=eps, min_samples=min_samples,
algorithm='ball_tree').fit(X)
core_sample_3, labels_3 = db.core_sample_indices_, db.labels_

assert_array_equal(core_sample_1, core_sample_3)
assert_array_equal(labels_1, labels_3)


def test_dbscan_balltree():
# Tests the DBSCAN algorithm with balltree for neighbor calculation.
eps = 0.8
Expand Down

0 comments on commit 146c829

Please sign in to comment.