Skip to content
Permalink
Browse files

Additional unit tests for birch example reproducibility

  • Loading branch information...
rth committed Oct 1, 2017
1 parent fe645a4 commit 24b1f6e25012d36b122bf027737b0d21a31476a0
Showing with 40 additions and 8 deletions.
  1. +1 −0 circle.yml
  2. +5 −5 examples/python/birch_cluster_hierarchy.py
  3. +34 −3 freediscovery/tests/test_cluster.py
@@ -21,6 +21,7 @@ dependencies:
# run sphinx to build the doc.
override:
- source build_tools/circle/install.sh
- source activate testenv && py.test -sv .
- source activate testenv && freediscovery run -y --cache-dir ../freediscovery_shared:
background: true
- sleep 20
@@ -55,10 +55,10 @@
# We have a hierarchy 2 levels deep, with 78 sub-clusters and a total
# of 1000 samples.
#
# For instance, let's consider the subcluster with ``cluster_id=12``.
# For instance, let's consider the subcluster with ``cluster_id=34``.
# We can access it by id with the flattened representation of the hierarchy,

sc = htree.flatten()[12]
sc = htree.flatten()[34]
print(sc)

###############################################################################
@@ -78,14 +78,14 @@
###############################################################################
#
# For instance, we can select only subclusters that are one level deep
# (this includes ``cluster_id=12``) and compute their centroids,
# (this includes ``cluster_id=34``) and compute their centroids,

htree_depth_1 = [sc for sc in htree.flatten() if sc.current_depth == 1]

for sc in htree_depth_1:
sc['centroid'] = X[sc['document_id_accumulated'], :].mean(axis=0)

print('Centroid for cluster_id=12:\n', htree.flatten()[12])
print('Centroid for cluster_id=34:\n', htree.flatten()[34]['centroid'])


###############################################################################
@@ -123,4 +123,4 @@ def max_tree_depth(self):

print('Tree depth from the root node:', htree_new.max_tree_depth)

print('Tree depth from cluster_id=12:', htree_new.flatten()[12].max_tree_depth)
print('Tree depth from cluster_id=34:', htree_new.flatten()[34].max_tree_depth)
@@ -4,15 +4,18 @@

import numpy as np
from unittest import SkipTest
from numpy.testing import assert_allclose, assert_equal
from numpy.testing import assert_allclose, assert_equal, assert_array_equal
import pytest

from sklearn.preprocessing import normalize
from sklearn.exceptions import NotFittedError
from sklearn.datasets import make_blobs

from freediscovery.cluster import select_top_words
from freediscovery.cluster.hierarchy import _check_birch_tree_consistency
from freediscovery.cluster import compute_optimal_sampling, centroid_similarity
from freediscovery.cluster import Birch, birch_hierarchy_wrapper
from sklearn.preprocessing import normalize
from sklearn.exceptions import NotFittedError



NCLUSTERS = 2
@@ -86,6 +89,34 @@ def test_birch_hierarchy_validation():
birch_hierarchy_wrapper("some other object")


@pytest.mark.parametrize('example_id', [12, 34])
def test_birch_example_reproducibility(example_id):
# check reproducibility of the Birch example
rng = np.random.RandomState(42)

X, y = make_blobs(n_samples=1000, n_features=10, random_state=rng)

cluster_model = Birch(threshold=0.9, branching_factor=20,
compute_sample_indices=True)
cluster_model.fit(X)
#assert len(cluster_model.root_.subclusters_[1].child_.subclusters_) == 3

htree, n_subclusters = birch_hierarchy_wrapper(cluster_model)

assert htree.tree_size == n_subclusters

# same random seed as in the birch hierarchy example
assert htree.tree_size == 78
sc = htree.flatten()[example_id]
if example_id == 34:
# this is true in both cases, but example_id fails on circle ci
assert sc.current_depth == 1
assert len(sc.children) == 3

assert_array_equal([sc['cluster_id'] for sc in htree.flatten()],
np.arange(htree.tree_size))


def test_denrogram_children():
# temporary solution for
# https://stackoverflow.com/questions/40239956/node-indexing-in-hierarachical-clustering-dendrograms

0 comments on commit 24b1f6e

Please sign in to comment.
You can’t perform that action at this time.