From 0cf2566cf5b4ef60f413ba9c0c1153fe05f7a5a3 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 24 Dec 2015 17:14:29 +0800 Subject: [PATCH 1/9] PySpark support model export/import and take LinearRegression as example --- python/pyspark/ml/regression.py | 34 ++++++- python/pyspark/ml/util.py | 159 +++++++++++++++++++++++++++++++- python/pyspark/ml/wrapper.py | 13 +-- 3 files changed, 188 insertions(+), 18 deletions(-) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 74a2248ed07c8..a3fa48ea20898 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -18,9 +18,9 @@ import warnings from pyspark import since -from pyspark.ml.util import keyword_only -from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * +from pyspark.ml.util import * +from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.mllib.common import inherit_doc @@ -35,7 +35,8 @@ @inherit_doc class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept, - HasStandardization, HasSolver, HasWeightCol): + HasStandardization, HasSolver, HasWeightCol, MLWritable, + EstimatorMLReadable): """ Linear regression. @@ -68,6 +69,28 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction Traceback (most recent call last): ... TypeError: Method setParams forces keyword arguments. + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> lr_path = path + "/lr" + >>> lr.save(lr_path) + >>> lr2 = LinearRegression.load(lr_path) + >>> model2 = lr2.fit(df) + >>> abs(model.coefficients[0] - model2.coefficients[0]) < 0.001 + True + >>> abs(model.intercept - model2.intercept) < 0.001 + True + >>> model_path = path + "/model" + >>> model.save(model_path) + >>> model3 = LinearRegressionModel.load(model_path) + >>> abs(model.coefficients[0] - model3.coefficients[0]) < 0.001 + True + >>> abs(model.intercept - model3.intercept) < 0.001 + True + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except OSError: + ... pass .. versionadded:: 1.4.0 """ @@ -106,7 +129,7 @@ def _create_model(self, java_model): return LinearRegressionModel(java_model) -class LinearRegressionModel(JavaModel): +class LinearRegressionModel(JavaModel, MLWritable, TransformerMLReadable): """ Model fitted by LinearRegression. @@ -821,9 +844,10 @@ def predict(self, features): if __name__ == "__main__": import doctest + import pyspark.ml.regression from pyspark.context import SparkContext from pyspark.sql import SQLContext - globs = globals().copy() + globs = pyspark.ml.regression.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext("local[2]", "ml.regression tests") diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index cee9d67b05325..6ee676d52fe1b 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -15,8 +15,27 @@ # limitations under the License. # -from functools import wraps +import sys import uuid +from functools import wraps + +if sys.version > '3': + basestring = str + +from pyspark import SparkContext, since +from pyspark.mllib.common import inherit_doc + + +def _jvm(): + """ + Returns the JVM view associated with SparkContext. Must be called + after SparkContext is initialized. + """ + jvm = SparkContext._jvm + if jvm: + return jvm + else: + raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?") def keyword_only(func): @@ -52,3 +71,141 @@ def _randomUID(cls): concatenates the class name, "_", and 12 random hex chars. """ return cls.__name__ + "_" + uuid.uuid4().hex[12:] + + +@inherit_doc +class MLWriter(object): + """ + Abstract class for utility classes that can save ML instances. + + .. versionadded:: 2.0.0 + """ + + def __init__(self, instance): + self._jwrite = instance._java_obj.write() + + @since("2.0.0") + def save(self, path): + """Saves the ML instances to the input path.""" + self._jwrite.save(path) + + @since("2.0.0") + def overwrite(self): + """Overwrites if the output path already exists.""" + self._jwrite.overwrite() + return self + + @since("2.0.0") + def context(self, sqlContext): + """Sets the SQL context to use for saving.""" + self._jwrite.context(sqlContext._ssql_ctx) + return self + + +@inherit_doc +class MLWritable(object): + """ + Mixin for ML instances that provide MLWriter through their Scala + implementation. + + .. versionadded:: 2.0.0 + """ + + @since("2.0.0") + def write(self): + """Returns an MLWriter instance for this ML instance.""" + return MLWriter(self) + + @since("2.0.0") + def save(self, path): + """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" + if not isinstance(path, basestring): + raise TypeError("path should be a basestring, got type %s" % type(path)) + self._java_obj.save(path) + + +@inherit_doc +class MLReader(object): + """ + Abstract class for utility classes that can load ML instances. + + .. versionadded:: 2.0.0 + """ + + def __init__(self, instance): + self._instance = instance + self._jread = instance._java_obj.read() + + @since("2.0.0") + def load(self, path): + """Loads the ML component from the input path.""" + self._instance.load(path) + + @since("2.0.0") + def context(self, sqlContext): + """Sets the SQL context to use for loading.""" + self._jread.context(sqlContext._ssql_ctx) + return self + + +@inherit_doc +class MLReadable(object): + """ + Mixin for objects that provide MLReader using its Scala implementation. + + .. versionadded:: 2.0.0 + """ + + @classmethod + def _java_loader_class(cls): + """ + Returns the full class name of the Java loader. The default + implementation replaces "pyspark" by "org.apache.spark" in + the Python full class name. + """ + java_package = cls.__module__.replace("pyspark", "org.apache.spark") + return ".".join([java_package, cls.__name__]) + + @classmethod + def _load_java(cls, path): + """ + Load a Java model from the given path. + """ + java_class = cls._java_loader_class() + java_obj = _jvm() + for name in java_class.split("."): + java_obj = getattr(java_obj, name) + return java_obj.load(path) + + @classmethod + @since("2.0.0") + def read(self): + """Returns an MLReader instance for this class.""" + return MLReader(self) + + +@inherit_doc +class TransformerMLReadable(MLReadable): + + @classmethod + @since("2.0.0") + def load(cls, path): + """Load a model from the given path.""" + java_obj = cls._load_java(path) + new_instance = cls(java_obj) + new_instance._transfer_params_from_java() + return new_instance + + +@inherit_doc +class EstimatorMLReadable(MLReadable): + + @classmethod + @since("2.0.0") + def load(cls, path): + """Load a model from the given path.""" + java_obj = cls._load_java(path) + new_instance = cls() + new_instance._java_obj = java_obj + new_instance._transfer_params_from_java() + return new_instance diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index dd1d4b076eddd..f2b5471551067 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -21,21 +21,10 @@ from pyspark.sql import DataFrame from pyspark.ml.param import Params from pyspark.ml.pipeline import Estimator, Transformer, Model +from pyspark.ml.util import _jvm from pyspark.mllib.common import inherit_doc, _java2py, _py2java -def _jvm(): - """ - Returns the JVM view associated with SparkContext. Must be called - after SparkContext is initialized. - """ - jvm = SparkContext._jvm - if jvm: - return jvm - else: - raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?") - - @inherit_doc class JavaWrapper(Params): """ From 61324d3aa526ff9fc2456bbc81e7a11a3319d1ad Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 25 Jan 2016 18:47:46 +0800 Subject: [PATCH 2/9] Address comments --- python/pyspark/ml/regression.py | 6 +++--- python/pyspark/ml/util.py | 15 +++++++++++---- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index a3fa48ea20898..b42651be10777 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -79,12 +79,12 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction True >>> abs(model.intercept - model2.intercept) < 0.001 True - >>> model_path = path + "/model" + >>> model_path = path + "/lr_model" >>> model.save(model_path) >>> model3 = LinearRegressionModel.load(model_path) - >>> abs(model.coefficients[0] - model3.coefficients[0]) < 0.001 + >>> model.coefficients[0] == model3.coefficients[0] True - >>> abs(model.intercept - model3.intercept) < 0.001 + >>> model.intercept == model3.intercept True >>> from shutil import rmtree >>> try: diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 6ee676d52fe1b..fa5e179b19798 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -76,7 +76,9 @@ def _randomUID(cls): @inherit_doc class MLWriter(object): """ - Abstract class for utility classes that can save ML instances. + .. note:: Experimental + + Utility class that can save ML instances. .. versionadded:: 2.0.0 """ @@ -105,6 +107,8 @@ def context(self, sqlContext): @inherit_doc class MLWritable(object): """ + .. note:: Experimental + Mixin for ML instances that provide MLWriter through their Scala implementation. @@ -127,19 +131,20 @@ def save(self, path): @inherit_doc class MLReader(object): """ - Abstract class for utility classes that can load ML instances. + .. note:: Experimental + + Utility class that can load ML instances. .. versionadded:: 2.0.0 """ def __init__(self, instance): - self._instance = instance self._jread = instance._java_obj.read() @since("2.0.0") def load(self, path): """Loads the ML component from the input path.""" - self._instance.load(path) + self._jread.load(path) @since("2.0.0") def context(self, sqlContext): @@ -151,6 +156,8 @@ def context(self, sqlContext): @inherit_doc class MLReadable(object): """ + .. note:: Experimental + Mixin for objects that provide MLReader using its Scala implementation. .. versionadded:: 2.0.0 From 63db658d21dbf205623ca030cbc70f847342ebb1 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 26 Jan 2016 18:05:21 +0800 Subject: [PATCH 3/9] Combine Estimator & Transformer MLReadable --- python/pyspark/ml/regression.py | 5 ++-- python/pyspark/ml/util.py | 44 +++++++++++---------------------- python/pyspark/ml/wrapper.py | 12 ++++++--- 3 files changed, 25 insertions(+), 36 deletions(-) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index b42651be10777..5d0506e47cd0f 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -35,8 +35,7 @@ @inherit_doc class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept, - HasStandardization, HasSolver, HasWeightCol, MLWritable, - EstimatorMLReadable): + HasStandardization, HasSolver, HasWeightCol, MLWritable, MLReadable): """ Linear regression. @@ -129,7 +128,7 @@ def _create_model(self, java_model): return LinearRegressionModel(java_model) -class LinearRegressionModel(JavaModel, MLWritable, TransformerMLReadable): +class LinearRegressionModel(JavaModel, MLWritable, MLReadable): """ Model fitted by LinearRegression. diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index fa5e179b19798..8a97c1af6b0d0 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -139,12 +139,17 @@ class MLReader(object): """ def __init__(self, instance): + self._instance = instance self._jread = instance._java_obj.read() @since("2.0.0") def load(self, path): """Loads the ML component from the input path.""" - self._jread.load(path) + java_obj = self._jread.load(path) + self._instance._java_obj = java_obj + self._instance.uid = java_obj.uid() + self._instance._transfer_params_from_java(True) + return self._instance @since("2.0.0") def context(self, sqlContext): @@ -174,45 +179,26 @@ def _java_loader_class(cls): return ".".join([java_package, cls.__name__]) @classmethod - def _load_java(cls, path): + def _load_java_obj(cls): """ - Load a Java model from the given path. + Load the peer Java object. """ java_class = cls._java_loader_class() java_obj = _jvm() for name in java_class.split("."): java_obj = getattr(java_obj, name) - return java_obj.load(path) + return java_obj @classmethod @since("2.0.0") - def read(self): + def read(cls): """Returns an MLReader instance for this class.""" - return MLReader(self) - - -@inherit_doc -class TransformerMLReadable(MLReadable): - - @classmethod - @since("2.0.0") - def load(cls, path): - """Load a model from the given path.""" - java_obj = cls._load_java(path) - new_instance = cls(java_obj) - new_instance._transfer_params_from_java() - return new_instance - - -@inherit_doc -class EstimatorMLReadable(MLReadable): + instance = cls() + instance._java_obj = cls._load_java_obj() + return MLReader(instance) @classmethod @since("2.0.0") def load(cls, path): - """Load a model from the given path.""" - java_obj = cls._load_java(path) - new_instance = cls() - new_instance._java_obj = java_obj - new_instance._transfer_params_from_java() - return new_instance + """Reads an ML instance from the input path, a shortcut of `read().load(path)`.""" + return cls.read().load(path) diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index f2b5471551067..80054978a7ead 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -71,13 +71,16 @@ def _transfer_params_to_java(self): pair = self._make_java_param_pair(param, paramMap[param]) self._java_obj.set(pair) - def _transfer_params_from_java(self): + def _transfer_params_from_java(self, withParent=False): """ Transforms the embedded params from the companion Java object. """ sc = SparkContext._active_spark_context + parent = self._java_obj.uid() for param in self.params: if self._java_obj.hasParam(param.name): + if withParent: + param.parent = parent java_param = self._java_obj.getParam(param.name) value = _java2py(sc, self._java_obj.getOrDefault(java_param)) self._paramMap[param] = value @@ -148,15 +151,16 @@ class JavaModel(Model, JavaTransformer): __metaclass__ = ABCMeta - def __init__(self, java_model): + def __init__(self, java_model=None): """ Initialize this instance with a Java model object. Subclasses should call this constructor, initialize params, and then call _transformer_params_from_java. """ super(JavaModel, self).__init__() - self._java_obj = java_model - self.uid = java_model.uid() + if java_model is not None: + self._java_obj = java_model + self.uid = java_model.uid() def copy(self, extra=None): """ From 0ccb13082d4edc187af051492252b7866cbeb84f Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 26 Jan 2016 18:16:52 +0800 Subject: [PATCH 4/9] MLWritable should _transfer_params_to_java --- python/pyspark/ml/regression.py | 2 ++ python/pyspark/ml/util.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 5d0506e47cd0f..22de81e33cde6 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -73,6 +73,8 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction >>> lr_path = path + "/lr" >>> lr.save(lr_path) >>> lr2 = LinearRegression.load(lr_path) + >>> lr2.getOrDefault(lr2.getParam("maxIter")) + 5 >>> model2 = lr2.fit(df) >>> abs(model.coefficients[0] - model2.coefficients[0]) < 0.001 True diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 8a97c1af6b0d0..6e6fbda6c60c7 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -84,6 +84,7 @@ class MLWriter(object): """ def __init__(self, instance): + instance._transfer_params_to_java() self._jwrite = instance._java_obj.write() @since("2.0.0") @@ -125,7 +126,7 @@ def save(self, path): """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" if not isinstance(path, basestring): raise TypeError("path should be a basestring, got type %s" % type(path)) - self._java_obj.save(path) + self.write().save(path) @inherit_doc From e9ea63d7f82d9a87126dd517569234f4ebfe36b6 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 26 Jan 2016 18:38:13 +0800 Subject: [PATCH 5/9] update docs --- python/pyspark/ml/util.py | 4 +--- python/pyspark/ml/wrapper.py | 5 +++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 6e6fbda6c60c7..59cbdcc7fe528 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -181,9 +181,7 @@ def _java_loader_class(cls): @classmethod def _load_java_obj(cls): - """ - Load the peer Java object. - """ + """Load the peer Java object.""" java_class = cls._java_loader_class() java_obj = _jvm() for name in java_class.split("."): diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 80054978a7ead..7a5654ea9186e 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -175,8 +175,9 @@ def copy(self, extra=None): if extra is None: extra = dict() that = super(JavaModel, self).copy(extra) - that._java_obj = self._java_obj.copy(self._empty_java_param_map()) - that._transfer_params_to_java() + if self._java_obj is not None: + that._java_obj = self._java_obj.copy(self._empty_java_param_map()) + that._transfer_params_to_java() return that def _call_java(self, name, *args): From 62e31b47d6d038d4db2cf3977f7976af4ac5d55e Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 27 Jan 2016 11:46:10 +0800 Subject: [PATCH 6/9] Make MLReadable general and not specific to Java wrappers --- python/pyspark/ml/util.py | 39 +++++++++++++++++++-------------------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 59cbdcc7fe528..6230a27125f96 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -141,7 +141,8 @@ class MLReader(object): def __init__(self, instance): self._instance = instance - self._jread = instance._java_obj.read() + self._instance._java_obj = self._load_java_obj(self._instance) + self._jread = self._instance._java_obj.read() @since("2.0.0") def load(self, path): @@ -158,43 +159,41 @@ def context(self, sqlContext): self._jread.context(sqlContext._ssql_ctx) return self - -@inherit_doc -class MLReadable(object): - """ - .. note:: Experimental - - Mixin for objects that provide MLReader using its Scala implementation. - - .. versionadded:: 2.0.0 - """ - @classmethod - def _java_loader_class(cls): + def _java_loader_class(cls, instance): """ Returns the full class name of the Java loader. The default implementation replaces "pyspark" by "org.apache.spark" in the Python full class name. """ - java_package = cls.__module__.replace("pyspark", "org.apache.spark") - return ".".join([java_package, cls.__name__]) + java_package = instance.__module__.replace("pyspark", "org.apache.spark") + return ".".join([java_package, instance.__class__.__name__]) @classmethod - def _load_java_obj(cls): + def _load_java_obj(cls, instance): """Load the peer Java object.""" - java_class = cls._java_loader_class() + java_class = cls._java_loader_class(instance) java_obj = _jvm() for name in java_class.split("."): java_obj = getattr(java_obj, name) return java_obj + +@inherit_doc +class MLReadable(object): + """ + .. note:: Experimental + + Mixin for objects that provide MLReader using its Scala implementation. + + .. versionadded:: 2.0.0 + """ + @classmethod @since("2.0.0") def read(cls): """Returns an MLReader instance for this class.""" - instance = cls() - instance._java_obj = cls._load_java_obj() - return MLReader(instance) + return MLReader(cls()) @classmethod @since("2.0.0") From bbd032ffe8a41de9502418f1890f49f93bdc78b2 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 27 Jan 2016 15:08:37 +0800 Subject: [PATCH 7/9] address comments --- python/pyspark/ml/regression.py | 13 ++++------- python/pyspark/ml/util.py | 40 +++++++++++++-------------------- python/pyspark/ml/wrapper.py | 3 +++ 3 files changed, 23 insertions(+), 33 deletions(-) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 22de81e33cde6..20dc6c2db91f3 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -73,19 +73,14 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction >>> lr_path = path + "/lr" >>> lr.save(lr_path) >>> lr2 = LinearRegression.load(lr_path) - >>> lr2.getOrDefault(lr2.getParam("maxIter")) + >>> lr2.getMaxIter() 5 - >>> model2 = lr2.fit(df) - >>> abs(model.coefficients[0] - model2.coefficients[0]) < 0.001 - True - >>> abs(model.intercept - model2.intercept) < 0.001 - True >>> model_path = path + "/lr_model" >>> model.save(model_path) - >>> model3 = LinearRegressionModel.load(model_path) - >>> model.coefficients[0] == model3.coefficients[0] + >>> model2 = LinearRegressionModel.load(model_path) + >>> model.coefficients[0] == model2.coefficients[0] True - >>> model.intercept == model3.intercept + >>> model.intercept == model2.intercept True >>> from shutil import rmtree >>> try: diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 6230a27125f96..09ca7cc8b3399 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -74,11 +74,11 @@ def _randomUID(cls): @inherit_doc -class MLWriter(object): +class JavaMLWriter(object): """ .. note:: Experimental - Utility class that can save ML instances. + Utility class that can save ML instances through their Scala implementation. .. versionadded:: 2.0.0 """ @@ -87,18 +87,15 @@ def __init__(self, instance): instance._transfer_params_to_java() self._jwrite = instance._java_obj.write() - @since("2.0.0") def save(self, path): - """Saves the ML instances to the input path.""" + """Save the ML instance to the input path.""" self._jwrite.save(path) - @since("2.0.0") def overwrite(self): """Overwrites if the output path already exists.""" self._jwrite.overwrite() return self - @since("2.0.0") def context(self, sqlContext): """Sets the SQL context to use for saving.""" self._jwrite.context(sqlContext._ssql_ctx) @@ -110,18 +107,15 @@ class MLWritable(object): """ .. note:: Experimental - Mixin for ML instances that provide MLWriter through their Scala - implementation. + Mixin for ML instances that provide JavaMLWriter. .. versionadded:: 2.0.0 """ - @since("2.0.0") def write(self): - """Returns an MLWriter instance for this ML instance.""" - return MLWriter(self) + """Returns an JavaMLWriter instance for this ML instance.""" + return JavaMLWriter(self) - @since("2.0.0") def save(self, path): """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" if not isinstance(path, basestring): @@ -130,11 +124,11 @@ def save(self, path): @inherit_doc -class MLReader(object): +class JavaMLReader(object): """ .. note:: Experimental - Utility class that can load ML instances. + Utility class that can load ML instances through their Scala implementation. .. versionadded:: 2.0.0 """ @@ -144,16 +138,14 @@ def __init__(self, instance): self._instance._java_obj = self._load_java_obj(self._instance) self._jread = self._instance._java_obj.read() - @since("2.0.0") def load(self, path): - """Loads the ML component from the input path.""" + """Load the ML instance from the input path.""" java_obj = self._jread.load(path) self._instance._java_obj = java_obj self._instance.uid = java_obj.uid() self._instance._transfer_params_from_java(True) return self._instance - @since("2.0.0") def context(self, sqlContext): """Sets the SQL context to use for loading.""" self._jread.context(sqlContext._ssql_ctx) @@ -162,7 +154,7 @@ def context(self, sqlContext): @classmethod def _java_loader_class(cls, instance): """ - Returns the full class name of the Java loader. The default + Returns the full class name of the Java ML instance. The default implementation replaces "pyspark" by "org.apache.spark" in the Python full class name. """ @@ -171,7 +163,7 @@ def _java_loader_class(cls, instance): @classmethod def _load_java_obj(cls, instance): - """Load the peer Java object.""" + """Load the peer Java object of the ML instance.""" java_class = cls._java_loader_class(instance) java_obj = _jvm() for name in java_class.split("."): @@ -184,19 +176,19 @@ class MLReadable(object): """ .. note:: Experimental - Mixin for objects that provide MLReader using its Scala implementation. + Mixin for instances that provide JavaMLReader. .. versionadded:: 2.0.0 """ @classmethod - @since("2.0.0") def read(cls): - """Returns an MLReader instance for this class.""" - return MLReader(cls()) + """Returns an JavaMLReader instance for this class.""" + return JavaMLReader(cls()) @classmethod - @since("2.0.0") def load(cls, path): """Reads an ML instance from the input path, a shortcut of `read().load(path)`.""" + if not isinstance(path, basestring): + raise TypeError("path should be a basestring, got type %s" % type(path)) return cls.read().load(path) diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 7a5654ea9186e..249d6ffbc4b96 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -156,6 +156,9 @@ def __init__(self, java_model=None): Initialize this instance with a Java model object. Subclasses should call this constructor, initialize params, and then call _transformer_params_from_java. + This instance can be instantiated without specifying java_model, + it will be assigned after that, but this scenario only used by + :py:class:`JavaMLReader` to load models. """ super(JavaModel, self).__init__() if java_model is not None: From 08b97607d8fa8385e05b6b490e23788fa5008520 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 27 Jan 2016 21:57:39 -0800 Subject: [PATCH 8/9] added unit test for persistence to check for param UIDs carefully, and fixed current issues --- python/pyspark/ml/param/__init__.py | 24 ++++++++++++++++++ python/pyspark/ml/tests.py | 39 ++++++++++++++++++++++++----- python/pyspark/ml/util.py | 36 +++++++++++++------------- python/pyspark/ml/wrapper.py | 11 +++++--- 4 files changed, 82 insertions(+), 28 deletions(-) diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 3da36d32c5af0..b412e9cd6c856 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -314,3 +314,27 @@ def _copyValues(self, to, extra=None): if p in paramMap and to.hasParam(p.name): to._set(**{p.name: paramMap[p]}) return to + + def _resetUid(self, newUid): + """ + Changes the uid of this instance. This updates both + the stored uid and the parent uid of params and param maps. + This is used by persistence (loading). + :param newUid: new uid to use + :return: same instance, but with the uid and Param.parent values + updated, including within param maps + """ + self.uid = newUid + newDefaultParamMap = dict() + newParamMap = dict() + for param in self.params: + newParam = copy.copy(param) + newParam.parent = newUid + if param in self._defaultParamMap: + newDefaultParamMap[newParam] = self._defaultParamMap[param] + if param in self._paramMap: + newParamMap[newParam] = self._paramMap[param] + param.parent = newUid + self._defaultParamMap = newDefaultParamMap + self._paramMap = newParamMap + return self diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index c45a159c460f3..559e720c42388 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -34,18 +34,22 @@ else: import unittest -from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase -from pyspark.sql import DataFrame, SQLContext, Row -from pyspark.sql.functions import rand +from shutil import rmtree +import tempfile + +from pyspark.ml import Estimator, Model, Pipeline, Transformer from pyspark.ml.classification import LogisticRegression from pyspark.ml.evaluation import RegressionEvaluator +from pyspark.ml.feature import * from pyspark.ml.param import Param, Params from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed -from pyspark.ml.util import keyword_only -from pyspark.ml import Estimator, Model, Pipeline, Transformer -from pyspark.ml.feature import * +from pyspark.ml.regression import LinearRegression from pyspark.ml.tuning import ParamGridBuilder, CrossValidator, CrossValidatorModel +from pyspark.ml.util import keyword_only from pyspark.mllib.linalg import DenseVector +from pyspark.sql import DataFrame, SQLContext, Row +from pyspark.sql.functions import rand +from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase class MockDataset(DataFrame): @@ -405,6 +409,29 @@ def test_fit_maximize_metric(self): self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") +class PersistenceTest(PySparkTestCase): + + def test_linear_regression(self): + lr = LinearRegression(maxIter = 1) + path = tempfile.mkdtemp() + lr_path = path + "/lr" + lr.save(lr_path) + lr2 = LinearRegression.load(lr_path) + lr2.getMaxIter() + print lr.uid + print lr2.uid + print lr2.maxIter.parent + self.assertEqual(lr2.uid, lr2.maxIter.parent, + "Loaded LinearRegression instance uid (%s) did not match Param's uid (%s)" + % (lr2.uid, lr2.maxIter.parent)) + self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter], + "Loaded LinearRegression instance defaults did not match original defaults") + try: + rmtree(path) + except OSError: + pass + + if __name__ == "__main__": from pyspark.ml.tests import * if xmlrunner: diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 09ca7cc8b3399..d7a813f56cd57 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -89,6 +89,8 @@ def __init__(self, instance): def save(self, path): """Save the ML instance to the input path.""" + if not isinstance(path, basestring): + raise TypeError("path should be a basestring, got type %s" % type(path)) self._jwrite.save(path) def overwrite(self): @@ -118,8 +120,6 @@ def write(self): def save(self, path): """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" - if not isinstance(path, basestring): - raise TypeError("path should be a basestring, got type %s" % type(path)) self.write().save(path) @@ -133,18 +133,20 @@ class JavaMLReader(object): .. versionadded:: 2.0.0 """ - def __init__(self, instance): - self._instance = instance - self._instance._java_obj = self._load_java_obj(self._instance) - self._jread = self._instance._java_obj.read() + def __init__(self, clazz): + self._clazz = clazz + self._jread = self._load_java_obj(clazz).read() def load(self, path): """Load the ML instance from the input path.""" + if not isinstance(path, basestring): + raise TypeError("path should be a basestring, got type %s" % type(path)) java_obj = self._jread.load(path) - self._instance._java_obj = java_obj - self._instance.uid = java_obj.uid() - self._instance._transfer_params_from_java(True) - return self._instance + instance = self._clazz() + instance._java_obj = java_obj + instance._resetUid(java_obj.uid()) + instance._transfer_params_from_java() + return instance def context(self, sqlContext): """Sets the SQL context to use for loading.""" @@ -152,19 +154,19 @@ def context(self, sqlContext): return self @classmethod - def _java_loader_class(cls, instance): + def _java_loader_class(cls, clazz): """ Returns the full class name of the Java ML instance. The default implementation replaces "pyspark" by "org.apache.spark" in the Python full class name. """ - java_package = instance.__module__.replace("pyspark", "org.apache.spark") - return ".".join([java_package, instance.__class__.__name__]) + java_package = clazz.__module__.replace("pyspark", "org.apache.spark") + return ".".join([java_package, clazz.__name__]) @classmethod - def _load_java_obj(cls, instance): + def _load_java_obj(cls, clazz): """Load the peer Java object of the ML instance.""" - java_class = cls._java_loader_class(instance) + java_class = cls._java_loader_class(clazz) java_obj = _jvm() for name in java_class.split("."): java_obj = getattr(java_obj, name) @@ -184,11 +186,9 @@ class MLReadable(object): @classmethod def read(cls): """Returns an JavaMLReader instance for this class.""" - return JavaMLReader(cls()) + return JavaMLReader(cls) @classmethod def load(cls, path): """Reads an ML instance from the input path, a shortcut of `read().load(path)`.""" - if not isinstance(path, basestring): - raise TypeError("path should be a basestring, got type %s" % type(path)) return cls.read().load(path) diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 249d6ffbc4b96..67be49a18f7b1 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -71,7 +71,7 @@ def _transfer_params_to_java(self): pair = self._make_java_param_pair(param, paramMap[param]) self._java_obj.set(pair) - def _transfer_params_from_java(self, withParent=False): + def _transfer_params_from_java(self): """ Transforms the embedded params from the companion Java object. """ @@ -79,8 +79,6 @@ def _transfer_params_from_java(self, withParent=False): parent = self._java_obj.uid() for param in self.params: if self._java_obj.hasParam(param.name): - if withParent: - param.parent = parent java_param = self._java_obj.getParam(param.name) value = _java2py(sc, self._java_obj.getOrDefault(java_param)) self._paramMap[param] = value @@ -156,9 +154,14 @@ def __init__(self, java_model=None): Initialize this instance with a Java model object. Subclasses should call this constructor, initialize params, and then call _transformer_params_from_java. + This instance can be instantiated without specifying java_model, it will be assigned after that, but this scenario only used by - :py:class:`JavaMLReader` to load models. + :py:class:`JavaMLReader` to load models. This is a bit of a + hack, but it is easiest since a proper fix would require + MLReader (in pyspark.ml.util) to depend on these wrappers, but + these wrappers depend on pyspark.ml.util (both directly and via + other ML classes). """ super(JavaModel, self).__init__() if java_model is not None: From 7334be978468cee2bb4b9e5b13e296392080246b Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 28 Jan 2016 16:21:33 +0800 Subject: [PATCH 9/9] fix typos & docs --- python/pyspark/ml/param/__init__.py | 2 +- python/pyspark/ml/tests.py | 13 +++++-------- python/pyspark/ml/wrapper.py | 1 - 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index b412e9cd6c856..ea86d6aeb8b31 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -317,7 +317,7 @@ def _copyValues(self, to, extra=None): def _resetUid(self, newUid): """ - Changes the uid of this instance. This updates both + Changes the uid of this instance. This updates both the stored uid and the parent uid of params and param maps. This is used by persistence (loading). :param newUid: new uid to use diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 559e720c42388..54806ee336666 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -412,20 +412,17 @@ def test_fit_maximize_metric(self): class PersistenceTest(PySparkTestCase): def test_linear_regression(self): - lr = LinearRegression(maxIter = 1) + lr = LinearRegression(maxIter=1) path = tempfile.mkdtemp() lr_path = path + "/lr" lr.save(lr_path) lr2 = LinearRegression.load(lr_path) - lr2.getMaxIter() - print lr.uid - print lr2.uid - print lr2.maxIter.parent self.assertEqual(lr2.uid, lr2.maxIter.parent, - "Loaded LinearRegression instance uid (%s) did not match Param's uid (%s)" - % (lr2.uid, lr2.maxIter.parent)) + "Loaded LinearRegression instance uid (%s) did not match Param's uid (%s)" + % (lr2.uid, lr2.maxIter.parent)) self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter], - "Loaded LinearRegression instance defaults did not match original defaults") + "Loaded LinearRegression instance default params did not match " + + "original defaults") try: rmtree(path) except OSError: diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 67be49a18f7b1..d4d48eb2150e3 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -76,7 +76,6 @@ def _transfer_params_from_java(self): Transforms the embedded params from the companion Java object. """ sc = SparkContext._active_spark_context - parent = self._java_obj.uid() for param in self.params: if self._java_obj.hasParam(param.name): java_param = self._java_obj.getParam(param.name)