From 5ec3f1e77bf3d3ca9522bcd6b579698290f8bbca Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 2 Jul 2015 12:08:39 +0800 Subject: [PATCH 1/2] fix PySpark PowerIterationClustering test issue --- python/pyspark/mllib/clustering.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index a3eab635282f6..d279582e230a2 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -282,18 +282,30 @@ class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader): Model produced by [[PowerIterationClustering]]. - >>> data = [(0, 1, 1.0), (0, 2, 1.0), (1, 3, 1.0), (2, 3, 1.0), - ... (0, 3, 1.0), (1, 2, 1.0), (0, 4, 0.1)] + >>> data = [(0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), (1, 3, 1.0), + ... (2, 3, 1.0), (3, 4, 0.1), (4, 5, 1.0), (4, 15, 1.0), (5, 6, 1.0), + ... (6, 7, 1.0), (7, 8, 1.0), (8, 9, 1.0), (9, 10, 1.0), (10, 11, 1.0), + ... (11, 12, 1.0), (12, 13, 1.0), (13, 14, 1.0), (14, 15, 1.0)] >>> rdd = sc.parallelize(data, 2) >>> model = PowerIterationClustering.train(rdd, 2, 100) >>> model.k 2 + >>> result = sorted(model.assignments().collect(), key=lambda x: x.id) + >>> sum([x.cluster != result[3].cluster for x in result if x.id < 3]) + 0 + >>> sum([x.cluster != result[4].cluster for x in result if x.id > 4]) + 0 >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> model.save(sc, path) >>> sameModel = PowerIterationClusteringModel.load(sc, path) >>> sameModel.k 2 + >>> result = sorted(model.assignments().collect(), key=lambda x: x.id) + >>> sum([x.cluster != result[3].cluster for x in result if x.id < 3]) + 0 + >>> sum([x.cluster != result[4].cluster for x in result if x.id > 4]) + 0 >>> from shutil import rmtree >>> try: ... rmtree(path) From 392ae5429f7daa6f5b06daabfde467a596162cfe Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 2 Jul 2015 13:45:33 +0800 Subject: [PATCH 2/2] fix model.assignments output --- python/pyspark/mllib/clustering.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index d279582e230a2..ed4d78a2c6788 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -291,10 +291,10 @@ class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader): >>> model.k 2 >>> result = sorted(model.assignments().collect(), key=lambda x: x.id) - >>> sum([x.cluster != result[3].cluster for x in result if x.id < 3]) - 0 - >>> sum([x.cluster != result[4].cluster for x in result if x.id > 4]) - 0 + >>> result[0].cluster == result[1].cluster == result[2].cluster == result[3].cluster + True + >>> result[4].cluster == result[5].cluster == result[6].cluster == result[7].cluster + True >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> model.save(sc, path) @@ -302,10 +302,10 @@ class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader): >>> sameModel.k 2 >>> result = sorted(model.assignments().collect(), key=lambda x: x.id) - >>> sum([x.cluster != result[3].cluster for x in result if x.id < 3]) - 0 - >>> sum([x.cluster != result[4].cluster for x in result if x.id > 4]) - 0 + >>> result[0].cluster == result[1].cluster == result[2].cluster == result[3].cluster + True + >>> result[4].cluster == result[5].cluster == result[6].cluster == result[7].cluster + True >>> from shutil import rmtree >>> try: ... rmtree(path)