Skip to content

Commit

Permalink
[SPARK-14605][ML][PYTHON] Changed Python to use unicode UIDs for spar…
Browse files Browse the repository at this point in the history
…k.ml Identifiable

## What changes were proposed in this pull request?

Python spark.ml Identifiable classes use UIDs of type str, but they should use unicode (in Python 2.x) to match Java. This could be a problem if someone created a class in Java with odd unicode characters, saved it, and loaded it in Python.

This PR: Use unicode everywhere in Python.

## How was this patch tested?

Updated persistence unit test to check uid type

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #12368 from jkbradley/python-uid-unicode.
  • Loading branch information
jkbradley committed Apr 16, 2016
1 parent 9f678e9 commit 36da5e3
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 4 deletions.
3 changes: 2 additions & 1 deletion python/pyspark/ml/param/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,10 +485,11 @@ def _resetUid(self, newUid):
Changes the uid of this instance. This updates both
the stored uid and the parent uid of params and param maps.
This is used by persistence (loading).
:param newUid: new uid to use
:param newUid: new uid to use, which is converted to unicode
:return: same instance, but with the uid and Param.parent values
updated, including within param maps
"""
newUid = unicode(newUid)
self.uid = newUid
newDefaultParamMap = dict()
newParamMap = dict()
Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,8 @@ def test_linear_regression(self):
lr_path = path + "/lr"
lr.save(lr_path)
lr2 = LinearRegression.load(lr_path)
self.assertEqual(lr.uid, lr2.uid)
self.assertEqual(type(lr.uid), type(lr2.uid))
self.assertEqual(lr2.uid, lr2.maxIter.parent,
"Loaded LinearRegression instance uid (%s) did not match Param's uid (%s)"
% (lr2.uid, lr2.maxIter.parent))
Expand Down
5 changes: 3 additions & 2 deletions python/pyspark/ml/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

if sys.version > '3':
basestring = str
unicode = str

from pyspark import SparkContext, since
from pyspark.mllib.common import inherit_doc
Expand Down Expand Up @@ -67,10 +68,10 @@ def __repr__(self):
@classmethod
def _randomUID(cls):
"""
Generate a unique id for the object. The default implementation
Generate a unique unicode id for the object. The default implementation
concatenates the class name, "_", and 12 random hex chars.
"""
return cls.__name__ + "_" + uuid.uuid4().hex[12:]
return unicode(cls.__name__ + "_" + uuid.uuid4().hex[12:])


@inherit_doc
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/ml/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def __init__(self, java_model=None):
"""
super(JavaModel, self).__init__(java_model)
if java_model is not None:
self.uid = java_model.uid()
self._resetUid(java_model.uid())

def copy(self, extra=None):
"""
Expand Down

0 comments on commit 36da5e3

Please sign in to comment.