Skip to content

Commit

Permalink
[SPARK-13032][ML][PYSPARK] PySpark support model export/import and ta…
Browse files Browse the repository at this point in the history
…ke LinearRegression as example

* Implement ```MLWriter/MLWritable/MLReader/MLReadable``` for PySpark.
* Making ```LinearRegression``` to support ```save/load``` as example. After this merged, the work for other transformers/estimators will be easy, then we can list and distribute the tasks to the community.

cc mengxr jkbradley

Author: Yanbo Liang <ybliang8@gmail.com>
Author: Joseph K. Bradley <joseph@databricks.com>

Closes #10469 from yanboliang/spark-11939.
  • Loading branch information
yanboliang authored and jkbradley committed Jan 29, 2016
1 parent 55561e7 commit e51b6ea
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 29 deletions.
24 changes: 24 additions & 0 deletions python/pyspark/ml/param/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,3 +314,27 @@ def _copyValues(self, to, extra=None):
if p in paramMap and to.hasParam(p.name):
to._set(**{p.name: paramMap[p]})
return to

def _resetUid(self, newUid):
"""
Changes the uid of this instance. This updates both
the stored uid and the parent uid of params and param maps.
This is used by persistence (loading).
:param newUid: new uid to use
:return: same instance, but with the uid and Param.parent values
updated, including within param maps
"""
self.uid = newUid
newDefaultParamMap = dict()
newParamMap = dict()
for param in self.params:
newParam = copy.copy(param)
newParam.parent = newUid
if param in self._defaultParamMap:
newDefaultParamMap[newParam] = self._defaultParamMap[param]
if param in self._paramMap:
newParamMap[newParam] = self._paramMap[param]
param.parent = newUid
self._defaultParamMap = newDefaultParamMap
self._paramMap = newParamMap
return self
30 changes: 25 additions & 5 deletions python/pyspark/ml/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
import warnings

from pyspark import since
from pyspark.ml.util import keyword_only
from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.ml.param.shared import *
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.mllib.common import inherit_doc


Expand All @@ -35,7 +35,7 @@
@inherit_doc
class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept,
HasStandardization, HasSolver, HasWeightCol):
HasStandardization, HasSolver, HasWeightCol, MLWritable, MLReadable):
"""
Linear regression.
Expand Down Expand Up @@ -68,6 +68,25 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
Traceback (most recent call last):
...
TypeError: Method setParams forces keyword arguments.
>>> import os, tempfile
>>> path = tempfile.mkdtemp()
>>> lr_path = path + "/lr"
>>> lr.save(lr_path)
>>> lr2 = LinearRegression.load(lr_path)
>>> lr2.getMaxIter()
5
>>> model_path = path + "/lr_model"
>>> model.save(model_path)
>>> model2 = LinearRegressionModel.load(model_path)
>>> model.coefficients[0] == model2.coefficients[0]
True
>>> model.intercept == model2.intercept
True
>>> from shutil import rmtree
>>> try:
... rmtree(path)
... except OSError:
... pass
.. versionadded:: 1.4.0
"""
Expand Down Expand Up @@ -106,7 +125,7 @@ def _create_model(self, java_model):
return LinearRegressionModel(java_model)


class LinearRegressionModel(JavaModel):
class LinearRegressionModel(JavaModel, MLWritable, MLReadable):
"""
Model fitted by LinearRegression.
Expand Down Expand Up @@ -821,9 +840,10 @@ def predict(self, features):

if __name__ == "__main__":
import doctest
import pyspark.ml.regression
from pyspark.context import SparkContext
from pyspark.sql import SQLContext
globs = globals().copy()
globs = pyspark.ml.regression.__dict__.copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
sc = SparkContext("local[2]", "ml.regression tests")
Expand Down
36 changes: 30 additions & 6 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,22 @@
else:
import unittest

from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
from pyspark.sql import DataFrame, SQLContext, Row
from pyspark.sql.functions import rand
from shutil import rmtree
import tempfile

from pyspark.ml import Estimator, Model, Pipeline, Transformer
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.feature import *
from pyspark.ml.param import Param, Params
from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
from pyspark.ml.util import keyword_only
from pyspark.ml import Estimator, Model, Pipeline, Transformer
from pyspark.ml.feature import *
from pyspark.ml.regression import LinearRegression
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator, CrossValidatorModel
from pyspark.ml.util import keyword_only
from pyspark.mllib.linalg import DenseVector
from pyspark.sql import DataFrame, SQLContext, Row
from pyspark.sql.functions import rand
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase


class MockDataset(DataFrame):
Expand Down Expand Up @@ -405,6 +409,26 @@ def test_fit_maximize_metric(self):
self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1")


class PersistenceTest(PySparkTestCase):

def test_linear_regression(self):
lr = LinearRegression(maxIter=1)
path = tempfile.mkdtemp()
lr_path = path + "/lr"
lr.save(lr_path)
lr2 = LinearRegression.load(lr_path)
self.assertEqual(lr2.uid, lr2.maxIter.parent,
"Loaded LinearRegression instance uid (%s) did not match Param's uid (%s)"
% (lr2.uid, lr2.maxIter.parent))
self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter],
"Loaded LinearRegression instance default params did not match " +
"original defaults")
try:
rmtree(path)
except OSError:
pass


if __name__ == "__main__":
from pyspark.ml.tests import *
if xmlrunner:
Expand Down
142 changes: 141 additions & 1 deletion python/pyspark/ml/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,27 @@
# limitations under the License.
#

from functools import wraps
import sys
import uuid
from functools import wraps

if sys.version > '3':
basestring = str

from pyspark import SparkContext, since
from pyspark.mllib.common import inherit_doc


def _jvm():
"""
Returns the JVM view associated with SparkContext. Must be called
after SparkContext is initialized.
"""
jvm = SparkContext._jvm
if jvm:
return jvm
else:
raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?")


def keyword_only(func):
Expand Down Expand Up @@ -52,3 +71,124 @@ def _randomUID(cls):
concatenates the class name, "_", and 12 random hex chars.
"""
return cls.__name__ + "_" + uuid.uuid4().hex[12:]


@inherit_doc
class JavaMLWriter(object):
"""
.. note:: Experimental
Utility class that can save ML instances through their Scala implementation.
.. versionadded:: 2.0.0
"""

def __init__(self, instance):
instance._transfer_params_to_java()
self._jwrite = instance._java_obj.write()

def save(self, path):
"""Save the ML instance to the input path."""
if not isinstance(path, basestring):
raise TypeError("path should be a basestring, got type %s" % type(path))
self._jwrite.save(path)

def overwrite(self):
"""Overwrites if the output path already exists."""
self._jwrite.overwrite()
return self

def context(self, sqlContext):
"""Sets the SQL context to use for saving."""
self._jwrite.context(sqlContext._ssql_ctx)
return self


@inherit_doc
class MLWritable(object):
"""
.. note:: Experimental
Mixin for ML instances that provide JavaMLWriter.
.. versionadded:: 2.0.0
"""

def write(self):
"""Returns an JavaMLWriter instance for this ML instance."""
return JavaMLWriter(self)

def save(self, path):
"""Save this ML instance to the given path, a shortcut of `write().save(path)`."""
self.write().save(path)


@inherit_doc
class JavaMLReader(object):
"""
.. note:: Experimental
Utility class that can load ML instances through their Scala implementation.
.. versionadded:: 2.0.0
"""

def __init__(self, clazz):
self._clazz = clazz
self._jread = self._load_java_obj(clazz).read()

def load(self, path):
"""Load the ML instance from the input path."""
if not isinstance(path, basestring):
raise TypeError("path should be a basestring, got type %s" % type(path))
java_obj = self._jread.load(path)
instance = self._clazz()
instance._java_obj = java_obj
instance._resetUid(java_obj.uid())
instance._transfer_params_from_java()
return instance

def context(self, sqlContext):
"""Sets the SQL context to use for loading."""
self._jread.context(sqlContext._ssql_ctx)
return self

@classmethod
def _java_loader_class(cls, clazz):
"""
Returns the full class name of the Java ML instance. The default
implementation replaces "pyspark" by "org.apache.spark" in
the Python full class name.
"""
java_package = clazz.__module__.replace("pyspark", "org.apache.spark")
return ".".join([java_package, clazz.__name__])

@classmethod
def _load_java_obj(cls, clazz):
"""Load the peer Java object of the ML instance."""
java_class = cls._java_loader_class(clazz)
java_obj = _jvm()
for name in java_class.split("."):
java_obj = getattr(java_obj, name)
return java_obj


@inherit_doc
class MLReadable(object):
"""
.. note:: Experimental
Mixin for instances that provide JavaMLReader.
.. versionadded:: 2.0.0
"""

@classmethod
def read(cls):
"""Returns an JavaMLReader instance for this class."""
return JavaMLReader(cls)

@classmethod
def load(cls, path):
"""Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
return cls.read().load(path)
33 changes: 16 additions & 17 deletions python/pyspark/ml/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,10 @@
from pyspark.sql import DataFrame
from pyspark.ml.param import Params
from pyspark.ml.pipeline import Estimator, Transformer, Model
from pyspark.ml.util import _jvm
from pyspark.mllib.common import inherit_doc, _java2py, _py2java


def _jvm():
"""
Returns the JVM view associated with SparkContext. Must be called
after SparkContext is initialized.
"""
jvm = SparkContext._jvm
if jvm:
return jvm
else:
raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?")


@inherit_doc
class JavaWrapper(Params):
"""
Expand Down Expand Up @@ -159,15 +148,24 @@ class JavaModel(Model, JavaTransformer):

__metaclass__ = ABCMeta

def __init__(self, java_model):
def __init__(self, java_model=None):
"""
Initialize this instance with a Java model object.
Subclasses should call this constructor, initialize params,
and then call _transformer_params_from_java.
This instance can be instantiated without specifying java_model,
it will be assigned after that, but this scenario only used by
:py:class:`JavaMLReader` to load models. This is a bit of a
hack, but it is easiest since a proper fix would require
MLReader (in pyspark.ml.util) to depend on these wrappers, but
these wrappers depend on pyspark.ml.util (both directly and via
other ML classes).
"""
super(JavaModel, self).__init__()
self._java_obj = java_model
self.uid = java_model.uid()
if java_model is not None:
self._java_obj = java_model
self.uid = java_model.uid()

def copy(self, extra=None):
"""
Expand All @@ -182,8 +180,9 @@ def copy(self, extra=None):
if extra is None:
extra = dict()
that = super(JavaModel, self).copy(extra)
that._java_obj = self._java_obj.copy(self._empty_java_param_map())
that._transfer_params_to_java()
if self._java_obj is not None:
that._java_obj = self._java_obj.copy(self._empty_java_param_map())
that._transfer_params_to_java()
return that

def _call_java(self, name, *args):
Expand Down

0 comments on commit e51b6ea

Please sign in to comment.