Skip to content

Commit

Permalink
[SPARK-19852][PYSPARK][ML] Update Python API for StringIndexer setHan…
Browse files Browse the repository at this point in the history
…dleInvalid

This PR reflect the changes made in SPARK-17498 on pyspark to support a new option
'keep' in StringIndexer to handle unseen labels

Signed-off-by: VinceShieh <vincent.xie@intel.com>
  • Loading branch information
VinceShieh committed Mar 10, 2017
1 parent d809cee commit d94dc68
Showing 1 changed file with 28 additions and 2 deletions.
30 changes: 28 additions & 2 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -1917,8 +1917,7 @@ def mean(self):


@inherit_doc
class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, JavaMLReadable,
JavaMLWritable):
class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
A label indexer that maps a string column of labels to an ML column of label indices.
If the input column is numeric, we cast it to string and index the string values.
Expand All @@ -1936,6 +1935,14 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
>>> sorted(set([(i[0], str(i[1])) for i in itd.select(itd.id, itd.label2).collect()]),
... key=lambda x: x[0])
[(0, 'a'), (1, 'b'), (2, 'c'), (3, 'a'), (4, 'a'), (5, 'c')]
>>> testData2 = sc.parallelize([Row(id=0, label="a"), Row(id=1, label="d"),
... Row(id=2, label="e")], 2)
>>> dfKeep= spark.createDataFrame(testData2)
>>> tdKeep = stringIndexer.setHandleInvalid("keep").fit(stringIndDf).transform(dfKeep)
>>> itdKeep = inverter.transform(tdKeep)
>>> sorted(set([(i[0], str(i[1])) for i in itdKeep.select(itdKeep.id, itdKeep.label2).collect()]),
... key=lambda x: x[0])
[(0, 'a'), (6, 'd'), (6, 'e')]
>>> stringIndexerPath = temp_path + "/string-indexer"
>>> stringIndexer.save(stringIndexerPath)
>>> loadedIndexer = StringIndexer.load(stringIndexerPath)
Expand All @@ -1955,6 +1962,11 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
.. versionadded:: 1.4.0
"""

handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle unseen labels. " +
"Options are 'skip' (filter out rows with unseen labels), " +
"error (throw an error), or 'keep' (put unseen labels in a special " +
"additional bucket, at index numLabels).",
typeConverter=TypeConverters.toString)
@keyword_only
def __init__(self, inputCol=None, outputCol=None, handleInvalid="error"):
"""
Expand All @@ -1979,6 +1991,20 @@ def setParams(self, inputCol=None, outputCol=None, handleInvalid="error"):
def _create_model(self, java_model):
return StringIndexerModel(java_model)

@since("2.2.0")
def setHandleInvalid(self, value):
"""
Sets the value of :py:attr:`handleInvalid`.
"""
return self._set(handleInvalid=value)

@since("2.2.0")
def getHandleInvalid(self):
"""
Gets the value of :py:attr:`handleInvalid` or its default value.
"""
return self.getOrDefault(self.handleInvalid)


class StringIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable):
"""
Expand Down

0 comments on commit d94dc68

Please sign in to comment.