diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index a1265294a1e9e..6b5f639496527 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -469,10 +469,11 @@ def _resetUid(self, newUid): Changes the uid of this instance. This updates both the stored uid and the parent uid of params and param maps. This is used by persistence (loading). - :param newUid: new uid to use + :param newUid: new uid to use, which is converted to unicode :return: same instance, but with the uid and Param.parent values updated, including within param maps """ + newUid = unicode(newUid) self.uid = newUid newDefaultParamMap = dict() newParamMap = dict() diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 0b0ad2377fd7d..d6a0983daf83c 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -618,6 +618,8 @@ def test_linear_regression(self): lr_path = path + "/lr" lr.save(lr_path) lr2 = LinearRegression.load(lr_path) + self.assertEqual(lr.uid, lr2.uid) + self.assertEqual(type(lr.uid), type(lr2.uid)) self.assertEqual(lr2.uid, lr2.maxIter.parent, "Loaded LinearRegression instance uid (%s) did not match Param's uid (%s)" % (lr2.uid, lr2.maxIter.parent)) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 9dfcef0e40d67..841bfb47e1b9d 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -21,6 +21,7 @@ if sys.version > '3': basestring = str + unicode = str from pyspark import SparkContext, since from pyspark.mllib.common import inherit_doc @@ -67,10 +68,10 @@ def __repr__(self): @classmethod def _randomUID(cls): """ - Generate a unique id for the object. The default implementation + Generate a unique unicode id for the object. The default implementation concatenates the class name, "_", and 12 random hex chars. """ - return cls.__name__ + "_" + uuid.uuid4().hex[12:] + return unicode(cls.__name__ + "_" + uuid.uuid4().hex[12:]) @inherit_doc diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index cd0e5b80d5559..e386c7095f63f 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -254,7 +254,7 @@ def __init__(self, java_model=None): """ super(JavaModel, self).__init__(java_model) if java_model is not None: - self.uid = java_model.uid() + self._resetUid(java_model.uid()) def copy(self, extra=None): """