From 840a19324f54a1077d59a7cd6f8e911f57505370 Mon Sep 17 00:00:00 2001 From: Ajay Saini Date: Tue, 8 Aug 2017 16:24:23 -0700 Subject: [PATCH 01/11] Pipeline persistence commit with tests. --- python/pyspark/ml/pipeline.py | 158 ++++++++++++++++++++++++++++++++-- python/pyspark/ml/tests.py | 35 +++++++- 2 files changed, 185 insertions(+), 8 deletions(-) diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index a8dc76b846c24..79633532b243e 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -16,6 +16,7 @@ # import sys +import os if sys.version > '3': basestring = str @@ -23,7 +24,8 @@ from pyspark import since, keyword_only, SparkContext from pyspark.ml.base import Estimator, Model, Transformer from pyspark.ml.param import Param, Params -from pyspark.ml.util import JavaMLWriter, JavaMLReader, MLReadable, MLWritable +from pyspark.ml.util import JavaMLWriter, JavaMLReader, MLReadable, MLWritable, \ + DefaultParamsWriter, DefaultParamsReader from pyspark.ml.wrapper import JavaParams from pyspark.ml.common import inherit_doc @@ -130,13 +132,20 @@ def copy(self, extra=None): @since("2.0.0") def write(self): """Returns an MLWriter instance for this ML instance.""" - return JavaMLWriter(self) + allStagesAreJava = True + stages = self.getStages() + for stage in stages: + if not isinstance(stage, JavaMLWritable): + allStagesAreJava = False + if allStagesAreJava: + return JavaMLWriter(self) + return PipelineWriter(self) @classmethod @since("2.0.0") def read(cls): """Returns an MLReader instance for this class.""" - return JavaMLReader(cls) + return PipelineReader(cls) @classmethod def _from_java(cls, java_stage): @@ -171,6 +180,76 @@ def _to_java(self): return _java_obj +@inherit_doc +class PipelineWriter(MLWriter): + """ + (Private) Specialization of :py:class:`MLWriter` for :py:class:`Pipeline` types + """ + + def __init__(self, instance): + super(PipelineWriter, self).__init__() + self.instance = instance + + def saveImpl(self, path): + stages = self.instance.getStages() + SharedReadWrite.validateStages(stages) + SharedReadWrite.saveImpl(self.instance, stages, self.sc, path) + + +@inherit_doc +class PipelineReader(MLReader): + """ + (Private) Specialization of :py:class:`MLReader` for :py:class:`Pipeline` types + """ + + def __init__(self, cls): + super(PipelineReader, self).__init__() + self.cls = cls + + def load(self, path): + metadata = DefaultParamsReader.loadMetadata(path, self.sc) + if 'savedAsPython' not in metadata['paramMap']: + return JavaMLReader(self.cls).load(path) + else: + uid, stages = SharedReadWrite.load(metadata, self.sc, path) + return Pipeline(stages=stages)._resetUid(uid) + + +@inherit_doc +class PipelineModelWriter(MLWriter): + """ + (Private) Specialization of :py:class:`MLWriter` for :py:class:`PipelineModel` types + """ + + def __init__(self, instance): + super(PipelineModelWriter, self).__init__() + self.instance = instance + + def saveImpl(self, path): + stages = self.instance.stages + SharedReadWrite.validateStages(stages) + SharedReadWrite.saveImpl(self.instance, stages, self.sc, path) + + +@inherit_doc +class PipelineModelReader(MLReader): + """ + (Private) Specialization of :py:class:`MLReader` for :py:class:`PipelineModel` types + """ + + def __init__(self, cls): + super(PipelineModelReader, self).__init__() + self.cls = cls + + def load(self, path): + metadata = DefaultParamsReader.loadMetadata(path, self.sc) + if 'savedAsPython' not in metadata['paramMap']: + return JavaMLReader(self.cls).load(path) + else: + uid, stages = SharedReadWrite.load(metadata, self.sc, path) + return PipelineModel(stages=stages)._resetUid(uid) + + @inherit_doc class PipelineModel(Model, MLReadable, MLWritable): """ @@ -204,13 +283,20 @@ def copy(self, extra=None): @since("2.0.0") def write(self): """Returns an MLWriter instance for this ML instance.""" - return JavaMLWriter(self) + allStagesAreJava = True + stages = self.stages + for stage in stages: + if not isinstance(stage, JavaMLWritable): + allStagesAreJava = False + if allStagesAreJava: + return JavaMLWriter(self) + return PipelineModelWriter(self) @classmethod @since("2.0.0") def read(cls): """Returns an MLReader instance for this class.""" - return JavaMLReader(cls) + return PipelineModelReader(cls) @classmethod def _from_java(cls, java_stage): @@ -242,3 +328,65 @@ def _to_java(self): JavaParams._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages) return _java_obj + + +@inherit_doc +class SharedReadWrite(): + """ + Functions for :py:class:`MLReader` and :py:class:`MLWriter` shared between + :py:class:'Pipeline' and :py:class'PipelineModel' + + .. versionadded:: 2.3.0 + """ + + @staticmethod + def validateStages(stages): + """ + Check that all stages are Writable + """ + for stage in stages: + if not isinstance(stage, MLWritable): + raise ValueError("Pipeline write will fail on this pipline " + + "because stage %s of type %s is not MLWritable", + stage.uid, type(stage)) + + @staticmethod + def saveImpl(instance, stages, sc, path): + """ + Save metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel` + - save metadata to path/metadata + - save stages to stages/IDX_UID + """ + stageUids = map(lambda x: x.uid, stages) + jsonParams = {'stageUids': stageUids, 'savedAsPython': True} + DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap=jsonParams) + stagesDir = os.path.join(path, "stages") + for index, stage in enumerate(stages): + stage.write().save(SharedReadWrite + .getStagePath(stage.uid, index, len(stages), stagesDir)) + + @staticmethod + def load(metadata, sc, path): + """ + Load metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel` + + :return: (UID, list of stages) + """ + stagesDir = os.path.join(path, "stages") + stageUids = metadata['paramMap']['stageUids'] + stages = [] + for index, stageUid in enumerate(stageUids): + stagePath = SharedReadWrite.getStagePath(stageUid, index, len(stageUids), stagesDir) + stage = DefaultParamsReader.loadParamsInstance(stagePath, sc) + stages.append(stage) + return (metadata['uid'], stages) + + @staticmethod + def getStagePath(stageUid, stageIdx, numStages, stagesDir): + """ + Get path for saving the given stage. + """ + stageIdxDigits = len(str(numStages)) + stageDir = str(stageIdxDigits) + "_" + stageUid + stagePath = os.path.join(stagesDir, stageDir) + return stagePath diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 6aecc7fe87074..0495973d2f625 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -123,7 +123,7 @@ def _transform(self, dataset): return dataset -class MockUnaryTransformer(UnaryTransformer): +class MockUnaryTransformer(UnaryTransformer, DefaultParamsReadable, DefaultParamsWritable): shift = Param(Params._dummy(), "shift", "The amount by which to shift " + "data in a DataFrame", @@ -150,7 +150,7 @@ def outputDataType(self): def validateInputType(self, inputType): if inputType != DoubleType(): raise TypeError("Bad input type: {}. ".format(inputType) + - "Requires Integer.") + "Requires Double.") class MockEstimator(Estimator, HasFake): @@ -1063,7 +1063,7 @@ def _compare_pipelines(self, m1, m2): """ self.assertEqual(m1.uid, m2.uid) self.assertEqual(type(m1), type(m2)) - if isinstance(m1, JavaParams): + if isinstance(m1, JavaParams) or isinstance(m1, Transformer): self.assertEqual(len(m1.params), len(m2.params)) for p in m1.params: self._compare_params(m1, m2, p) @@ -1142,6 +1142,35 @@ def test_nested_pipeline_persistence(self): except OSError: pass + def test_python_transformer_pipeline_persistence(self): + """ + Pipeline[MockUnaryTransformer, Binarizer] + """ + temp_path = tempfile.mkdtemp() + + try: + df = self.spark.range(0, 10).toDF('input') + tf = MockUnaryTransformer(shiftVal=2)\ + .setInputCol("input").setOutputCol("shiftedInput") + tf2 = Binarizer(threshold=6, inputCol="shiftedInput", outputCol="binarized") + pl = Pipeline(stages=[tf, tf2]) + model = pl.fit(df) + + pipeline_path = temp_path + "/pipeline" + pl.save(pipeline_path) + loaded_pipeline = Pipeline.load(pipeline_path) + self._compare_pipelines(pl, loaded_pipeline) + + model_path = temp_path + "/pipeline-model" + model.save(model_path) + loaded_model = PipelineModel.load(model_path) + self._compare_pipelines(model, loaded_model) + finally: + try: + rmtree(temp_path) + except OSError: + pass + def test_onevsrest(self): temp_path = tempfile.mkdtemp() df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), From 85a98d6fa45e7f6bb32af9b949913e70b32f8ca9 Mon Sep 17 00:00:00 2001 From: Ajay Saini Date: Tue, 8 Aug 2017 16:30:33 -0700 Subject: [PATCH 02/11] Fixed import --- python/pyspark/ml/pipeline.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 79633532b243e..e3e99209960f6 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -24,8 +24,7 @@ from pyspark import since, keyword_only, SparkContext from pyspark.ml.base import Estimator, Model, Transformer from pyspark.ml.param import Param, Params -from pyspark.ml.util import JavaMLWriter, JavaMLReader, MLReadable, MLWritable, \ - DefaultParamsWriter, DefaultParamsReader +from pyspark.ml.util import * from pyspark.ml.wrapper import JavaParams from pyspark.ml.common import inherit_doc From 0eb0494e6ac8c54458ff6477f0cbe1591d50a69a Mon Sep 17 00:00:00 2001 From: Ajay Saini Date: Tue, 8 Aug 2017 16:58:00 -0700 Subject: [PATCH 03/11] Fixed python 3 issue with updating a dictionary --- python/pyspark/ml/util.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 67772910c0d38..0c4496b67ab49 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -419,7 +419,8 @@ def _get_metadata_to_save(instance, sc, extraMetadata=None, paramMap=None): basicMetadata = {"class": cls, "timestamp": long(round(time.time() * 1000)), "sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams} if extraMetadata is not None: - basicMetadata.update(extraMetadata) + for key, value in extraMetadata: + basicMetadata[key] = value return json.dumps(basicMetadata, separators=[',', ':']) From ba4402cac561714850ac2eabc38c9891af6d1960 Mon Sep 17 00:00:00 2001 From: Ajay Saini Date: Tue, 8 Aug 2017 16:59:17 -0700 Subject: [PATCH 04/11] Fixed dictionary update --- python/pyspark/ml/util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 0c4496b67ab49..e92cdeefd7b26 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -419,8 +419,8 @@ def _get_metadata_to_save(instance, sc, extraMetadata=None, paramMap=None): basicMetadata = {"class": cls, "timestamp": long(round(time.time() * 1000)), "sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams} if extraMetadata is not None: - for key, value in extraMetadata: - basicMetadata[key] = value + for key in extraMetadata: + basicMetadata[key] = extraMetadata[key] return json.dumps(basicMetadata, separators=[',', ':']) From 22ebe3eadf86a2c62dcb4c14f048b40b41fc7569 Mon Sep 17 00:00:00 2001 From: Ajay Saini Date: Tue, 8 Aug 2017 17:25:52 -0700 Subject: [PATCH 05/11] Fixed map serialization issue in python 3 --- python/pyspark/ml/pipeline.py | 2 +- python/pyspark/ml/util.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index e3e99209960f6..8d1f27c7f563e 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -356,7 +356,7 @@ def saveImpl(instance, stages, sc, path): - save metadata to path/metadata - save stages to stages/IDX_UID """ - stageUids = map(lambda x: x.uid, stages) + stageUids = [stage.uid for stage in stages] jsonParams = {'stageUids': stageUids, 'savedAsPython': True} DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap=jsonParams) stagesDir = os.path.join(path, "stages") diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index e92cdeefd7b26..67772910c0d38 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -419,8 +419,7 @@ def _get_metadata_to_save(instance, sc, extraMetadata=None, paramMap=None): basicMetadata = {"class": cls, "timestamp": long(round(time.time() * 1000)), "sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams} if extraMetadata is not None: - for key in extraMetadata: - basicMetadata[key] = extraMetadata[key] + basicMetadata.update(extraMetadata) return json.dumps(basicMetadata, separators=[',', ':']) From 6a094f02b789a99b55223c0f2668a560c12b3e5d Mon Sep 17 00:00:00 2001 From: Ajay Saini Date: Tue, 8 Aug 2017 17:32:07 -0700 Subject: [PATCH 06/11] Removed extra space --- python/pyspark/ml/pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 8d1f27c7f563e..73702099f8532 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -369,7 +369,7 @@ def load(metadata, sc, path): """ Load metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel` - :return: (UID, list of stages) + :return: (UID, list of stages) """ stagesDir = os.path.join(path, "stages") stageUids = metadata['paramMap']['stageUids'] From 4d2caf89b00256b07717afe99090fd24e49bc9f3 Mon Sep 17 00:00:00 2001 From: Ajay Saini Date: Wed, 9 Aug 2017 11:52:15 -0700 Subject: [PATCH 07/11] Removed duplicated java stage check logic --- python/pyspark/ml/pipeline.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 73702099f8532..b2b8ceb0daa76 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -131,11 +131,7 @@ def copy(self, extra=None): @since("2.0.0") def write(self): """Returns an MLWriter instance for this ML instance.""" - allStagesAreJava = True - stages = self.getStages() - for stage in stages: - if not isinstance(stage, JavaMLWritable): - allStagesAreJava = False + allStagesAreJava = SharedReadWrite.checkStagesForJava(self.getStages()) if allStagesAreJava: return JavaMLWriter(self) return PipelineWriter(self) @@ -282,11 +278,7 @@ def copy(self, extra=None): @since("2.0.0") def write(self): """Returns an MLWriter instance for this ML instance.""" - allStagesAreJava = True - stages = self.stages - for stage in stages: - if not isinstance(stage, JavaMLWritable): - allStagesAreJava = False + allStagesAreJava = SharedReadWrite.checkStagesForJava(self.stages) if allStagesAreJava: return JavaMLWriter(self) return PipelineModelWriter(self) @@ -338,6 +330,16 @@ class SharedReadWrite(): .. versionadded:: 2.3.0 """ + @staticmethod + def checkStagesForJava(stages): + allStagesAreJava = True + stages = self.stages + for stage in stages: + if not isinstance(stage, JavaMLWritable): + allStagesAreJava = False + break + return allStagesAreJava + @staticmethod def validateStages(stages): """ From cdcd1cc5c6d9cd9e819cd3ce7882870ad88e46b0 Mon Sep 17 00:00:00 2001 From: Ajay Saini Date: Wed, 9 Aug 2017 11:54:01 -0700 Subject: [PATCH 08/11] Fixed small bug --- python/pyspark/ml/pipeline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index b2b8ceb0daa76..e9473a3f03053 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -333,7 +333,6 @@ class SharedReadWrite(): @staticmethod def checkStagesForJava(stages): allStagesAreJava = True - stages = self.stages for stage in stages: if not isinstance(stage, JavaMLWritable): allStagesAreJava = False From cf1a08d3e95ceb35c8d435e721440610100e2fb9 Mon Sep 17 00:00:00 2001 From: Ajay Saini Date: Thu, 10 Aug 2017 17:02:26 -0700 Subject: [PATCH 09/11] Fixed based on comments --- python/pyspark/ml/pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index e9473a3f03053..1bd798c133db5 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -325,7 +325,7 @@ def _to_java(self): class SharedReadWrite(): """ Functions for :py:class:`MLReader` and :py:class:`MLWriter` shared between - :py:class:'Pipeline' and :py:class'PipelineModel' + :py:class:`Pipeline` and :py:class`PipelineModel` .. versionadded:: 2.3.0 """ @@ -387,6 +387,6 @@ def getStagePath(stageUid, stageIdx, numStages, stagesDir): Get path for saving the given stage. """ stageIdxDigits = len(str(numStages)) - stageDir = str(stageIdxDigits) + "_" + stageUid + stageDir = str(stageIdx).zfill(stageIdxDigits) + "_" + stageUid stagePath = os.path.join(stagesDir, stageDir) return stagePath From 18c902c36bf6550e69a7d01d0ec70a5c226fa473 Mon Sep 17 00:00:00 2001 From: Ajay Saini Date: Fri, 11 Aug 2017 11:55:56 -0700 Subject: [PATCH 10/11] Fixed based on comments --- python/pyspark/ml/pipeline.py | 40 ++++++++++++++++------------------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 1bd798c133db5..09e0748ffbb3b 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -131,7 +131,7 @@ def copy(self, extra=None): @since("2.0.0") def write(self): """Returns an MLWriter instance for this ML instance.""" - allStagesAreJava = SharedReadWrite.checkStagesForJava(self.getStages()) + allStagesAreJava = PipelineSharedReadWrite.checkStagesForJava(self.getStages()) if allStagesAreJava: return JavaMLWriter(self) return PipelineWriter(self) @@ -187,8 +187,8 @@ def __init__(self, instance): def saveImpl(self, path): stages = self.instance.getStages() - SharedReadWrite.validateStages(stages) - SharedReadWrite.saveImpl(self.instance, stages, self.sc, path) + PipelineSharedReadWrite.validateStages(stages) + PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path) @inherit_doc @@ -203,10 +203,10 @@ def __init__(self, cls): def load(self, path): metadata = DefaultParamsReader.loadMetadata(path, self.sc) - if 'savedAsPython' not in metadata['paramMap']: + if 'language' not in metadata['paramMap'] or metadata['paramMap']['language'] != 'Python': return JavaMLReader(self.cls).load(path) else: - uid, stages = SharedReadWrite.load(metadata, self.sc, path) + uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path) return Pipeline(stages=stages)._resetUid(uid) @@ -222,8 +222,8 @@ def __init__(self, instance): def saveImpl(self, path): stages = self.instance.stages - SharedReadWrite.validateStages(stages) - SharedReadWrite.saveImpl(self.instance, stages, self.sc, path) + PipelineSharedReadWrite.validateStages(stages) + PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path) @inherit_doc @@ -238,10 +238,10 @@ def __init__(self, cls): def load(self, path): metadata = DefaultParamsReader.loadMetadata(path, self.sc) - if 'savedAsPython' not in metadata['paramMap']: + if 'language' not in metadata['paramMap'] or metadata['paramMap']['language'] != 'Python': return JavaMLReader(self.cls).load(path) else: - uid, stages = SharedReadWrite.load(metadata, self.sc, path) + uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path) return PipelineModel(stages=stages)._resetUid(uid) @@ -278,7 +278,7 @@ def copy(self, extra=None): @since("2.0.0") def write(self): """Returns an MLWriter instance for this ML instance.""" - allStagesAreJava = SharedReadWrite.checkStagesForJava(self.stages) + allStagesAreJava = PipelineSharedReadWrite.checkStagesForJava(self.stages) if allStagesAreJava: return JavaMLWriter(self) return PipelineModelWriter(self) @@ -322,22 +322,17 @@ def _to_java(self): @inherit_doc -class SharedReadWrite(): +class PipelineSharedReadWrite(): """ Functions for :py:class:`MLReader` and :py:class:`MLWriter` shared between - :py:class:`Pipeline` and :py:class`PipelineModel` + :py:class:`Pipeline` and :py:class:`PipelineModel` .. versionadded:: 2.3.0 """ @staticmethod def checkStagesForJava(stages): - allStagesAreJava = True - for stage in stages: - if not isinstance(stage, JavaMLWritable): - allStagesAreJava = False - break - return allStagesAreJava + return all(isinstance(stage, JavaMLWritable) for stage in stages) @staticmethod def validateStages(stages): @@ -346,7 +341,7 @@ def validateStages(stages): """ for stage in stages: if not isinstance(stage, MLWritable): - raise ValueError("Pipeline write will fail on this pipline " + + raise ValueError("Pipeline write will fail on this pipeline " + "because stage %s of type %s is not MLWritable", stage.uid, type(stage)) @@ -358,11 +353,11 @@ def saveImpl(instance, stages, sc, path): - save stages to stages/IDX_UID """ stageUids = [stage.uid for stage in stages] - jsonParams = {'stageUids': stageUids, 'savedAsPython': True} + jsonParams = {'stageUids': stageUids, 'language': 'Python'} DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap=jsonParams) stagesDir = os.path.join(path, "stages") for index, stage in enumerate(stages): - stage.write().save(SharedReadWrite + stage.write().save(PipelineSharedReadWrite .getStagePath(stage.uid, index, len(stages), stagesDir)) @staticmethod @@ -376,7 +371,8 @@ def load(metadata, sc, path): stageUids = metadata['paramMap']['stageUids'] stages = [] for index, stageUid in enumerate(stageUids): - stagePath = SharedReadWrite.getStagePath(stageUid, index, len(stageUids), stagesDir) + stagePath = \ + PipelineSharedReadWrite.getStagePath(stageUid, index, len(stageUids), stagesDir) stage = DefaultParamsReader.loadParamsInstance(stagePath, sc) stages.append(stage) return (metadata['uid'], stages) From 2b63eeaaf071b81af26eaccf4cffd5ac9104ee7d Mon Sep 17 00:00:00 2001 From: Ajay Saini Date: Fri, 11 Aug 2017 12:26:06 -0700 Subject: [PATCH 11/11] Marked PipelineReadWrite class as developer API --- python/pyspark/ml/pipeline.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 09e0748ffbb3b..097530230cbca 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -324,6 +324,8 @@ def _to_java(self): @inherit_doc class PipelineSharedReadWrite(): """ + .. note:: DeveloperApi + Functions for :py:class:`MLReader` and :py:class:`MLWriter` shared between :py:class:`Pipeline` and :py:class:`PipelineModel`