Skip to content

Commit

Permalink
[SPARK-13302][PYSPARK][TESTS] Move the temp file creation and cleanup…
Browse files Browse the repository at this point in the history
… outside of the doctests

Some of the new doctests in ml/clustering.py have a lot of setup code, move the setup code to the general test init to keep the doctest more example-style looking.
In part this is a follow up to #10999
Note that the same pattern is followed in regression & recommendation - might as well clean up all three at the same time.

Author: Holden Karau <holden@us.ibm.com>

Closes #11197 from holdenk/SPARK-13302-cleanup-doctests-in-ml-clustering.
  • Loading branch information
holdenk authored and srowen committed Feb 20, 2016
1 parent dfb2ae2 commit 9ca79c1
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 33 deletions.
25 changes: 14 additions & 11 deletions python/pyspark/ml/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,25 +70,18 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol
True
>>> rows[2].prediction == rows[3].prediction
True
>>> import os, tempfile
>>> path = tempfile.mkdtemp()
>>> kmeans_path = path + "/kmeans"
>>> kmeans_path = temp_path + "/kmeans"
>>> kmeans.save(kmeans_path)
>>> kmeans2 = KMeans.load(kmeans_path)
>>> kmeans2.getK()
2
>>> model_path = path + "/kmeans_model"
>>> model_path = temp_path + "/kmeans_model"
>>> model.save(model_path)
>>> model2 = KMeansModel.load(model_path)
>>> model.clusterCenters()[0] == model2.clusterCenters()[0]
array([ True, True], dtype=bool)
>>> model.clusterCenters()[1] == model2.clusterCenters()[1]
array([ True, True], dtype=bool)
>>> from shutil import rmtree
>>> try:
... rmtree(path)
... except OSError:
... pass
.. versionadded:: 1.5.0
"""
Expand Down Expand Up @@ -310,7 +303,17 @@ def _create_model(self, java_model):
sqlContext = SQLContext(sc)
globs['sc'] = sc
globs['sqlContext'] = sqlContext
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
sc.stop()
import tempfile
temp_path = tempfile.mkdtemp()
globs['temp_path'] = temp_path
try:
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
sc.stop()
finally:
from shutil import rmtree
try:
rmtree(temp_path)
except OSError:
pass
if failure_count:
exit(-1)
25 changes: 14 additions & 11 deletions python/pyspark/ml/recommendation.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,12 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
Row(user=1, item=0, prediction=2.6258413791656494)
>>> predictions[2]
Row(user=2, item=0, prediction=-1.5018409490585327)
>>> import os, tempfile
>>> path = tempfile.mkdtemp()
>>> als_path = path + "/als"
>>> als_path = temp_path + "/als"
>>> als.save(als_path)
>>> als2 = ALS.load(als_path)
>>> als.getMaxIter()
5
>>> model_path = path + "/als_model"
>>> model_path = temp_path + "/als_model"
>>> model.save(model_path)
>>> model2 = ALSModel.load(model_path)
>>> model.rank == model2.rank
Expand All @@ -98,11 +96,6 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
True
>>> sorted(model.itemFactors.collect()) == sorted(model2.itemFactors.collect())
True
>>> from shutil import rmtree
>>> try:
... rmtree(path)
... except OSError:
... pass
.. versionadded:: 1.4.0
"""
Expand Down Expand Up @@ -340,7 +333,17 @@ def itemFactors(self):
sqlContext = SQLContext(sc)
globs['sc'] = sc
globs['sqlContext'] = sqlContext
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
sc.stop()
import tempfile
temp_path = tempfile.mkdtemp()
globs['temp_path'] = temp_path
try:
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
sc.stop()
finally:
from shutil import rmtree
try:
rmtree(temp_path)
except OSError:
pass
if failure_count:
exit(-1)
25 changes: 14 additions & 11 deletions python/pyspark/ml/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,25 +68,18 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
Traceback (most recent call last):
...
TypeError: Method setParams forces keyword arguments.
>>> import os, tempfile
>>> path = tempfile.mkdtemp()
>>> lr_path = path + "/lr"
>>> lr_path = temp_path + "/lr"
>>> lr.save(lr_path)
>>> lr2 = LinearRegression.load(lr_path)
>>> lr2.getMaxIter()
5
>>> model_path = path + "/lr_model"
>>> model_path = temp_path + "/lr_model"
>>> model.save(model_path)
>>> model2 = LinearRegressionModel.load(model_path)
>>> model.coefficients[0] == model2.coefficients[0]
True
>>> model.intercept == model2.intercept
True
>>> from shutil import rmtree
>>> try:
... rmtree(path)
... except OSError:
... pass
.. versionadded:: 1.4.0
"""
Expand Down Expand Up @@ -850,7 +843,17 @@ def predict(self, features):
sqlContext = SQLContext(sc)
globs['sc'] = sc
globs['sqlContext'] = sqlContext
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
sc.stop()
import tempfile
temp_path = tempfile.mkdtemp()
globs['temp_path'] = temp_path
try:
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
sc.stop()
finally:
from shutil import rmtree
try:
rmtree(temp_path)
except OSError:
pass
if failure_count:
exit(-1)

0 comments on commit 9ca79c1

Please sign in to comment.