Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-9654][ML][PYSPARK] Add IndexToString to PySpark #7976

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
1dc4579
SPARK-9654 Add string indexer inverse in PySpark
holdenk Aug 5, 2015
0445fcc
doc fix
holdenk Aug 5, 2015
af2f869
Don't changge the base class init, fill out the doctest for the invert.
holdenk Aug 6, 2015
510bce5
remove extra blank line
holdenk Aug 6, 2015
c6da160
get rid of unicude specificers in doctest
holdenk Aug 6, 2015
9f5af3a
Deal with the difference between 2.X and 3.X with the output by just …
holdenk Aug 6, 2015
7b3b5ca
Use the standard constructor method for the StringIndexInverse
holdenk Aug 12, 2015
244e083
Update for index to string changeover
holdenk Aug 14, 2015
e95b61b
Move the property on to the model, remove references to old class name
holdenk Aug 14, 2015
b1795aa
CR feedback
holdenk Aug 18, 2015
ab90dcd
switch link to pydoc style
holdenk Aug 18, 2015
43ae197
Merge in master
holdenk Aug 18, 2015
c400e16
remove getLabels function (CR feedback) now that labels is public.
holdenk Aug 18, 2015
64de5c9
Some CR feedback
holdenk Aug 28, 2015
2316a90
Use None instead of empty array
holdenk Aug 28, 2015
15390bb
merge in master
holdenk Sep 1, 2015
28afcfd
Some CR feedback (note: still sorting our one of the params)
holdenk Sep 1, 2015
f19445d
Change description text
holdenk Sep 1, 2015
51ae7ee
merge in master
holdenk Sep 1, 2015
ed0ca91
moar merge
holdenk Sep 1, 2015
8fca8b3
punctuation
holdenk Sep 1, 2015
3ef852f
remove unrelated change
holdenk Sep 1, 2015
41d0d27
long line fix
holdenk Sep 1, 2015
cd5d418
Add missing period
holdenk Sep 9, 2015
4f56b17
Fix link to transformer class, copy scala doc for labels
holdenk Sep 9, 2015
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
* [[StringIndexerModel.transform]] would return the input dataset unmodified.
* This is a temporary fix for the case when target labels do not exist during prediction.
*
* @param labels Ordered list of labels, corresponding to indices to be assigned
* @param labels Ordered list of labels, corresponding to indices to be assigned.
*/
@Experimental
class StringIndexerModel (
Expand Down
74 changes: 70 additions & 4 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@
from pyspark.mllib.linalg import _convert_to_vector

__all__ = ['Binarizer', 'Bucketizer', 'DCT', 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel',
'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer',
'SQLTransformer', 'StandardScaler', 'StandardScalerModel', 'StringIndexer',
'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'Word2Vec',
'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel', 'StopWordsRemover']
'IndexToString', 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion',
'RegexTokenizer', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel',
'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer',
'Word2Vec', 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel',
'StopWordsRemover']


@inherit_doc
Expand Down Expand Up @@ -902,6 +903,11 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol):
>>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]),
... key=lambda x: x[0])
[(0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)]
>>> inverter = IndexToString(inputCol="indexed", outputCol="label2", labels=model.labels())
>>> itd = inverter.transform(td)
>>> sorted(set([(i[0], str(i[1])) for i in itd.select(itd.id, itd.label2).collect()]),
... key=lambda x: x[0])
[(0, 'a'), (1, 'b'), (2, 'c'), (3, 'a'), (4, 'a'), (5, 'c')]
"""

@keyword_only
Expand Down Expand Up @@ -931,6 +937,66 @@ class StringIndexerModel(JavaModel):
"""
Model fitted by StringIndexer.
"""
@property
def labels(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy Scala doc: "Ordered list of labels, corresponding to indices to be assigned"

"""
Ordered list of labels, corresponding to indices to be assigned.
"""
return self._java_obj.labels


@inherit_doc
class IndexToString(JavaTransformer, HasInputCol, HasOutputCol):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use inherit_doc tag

"""
.. note:: Experimental
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: line break


A :py:class:`Transformer` that maps a column of string indices back to a new column of
corresponding string values using either the ML attributes of the input column, or if
provided using the labels supplied by the user.
All original columns are kept during transformation.
See L{StringIndexer} for converting strings into indices.
"""

# a placeholder to make the labels show up in generated doc
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

insert newline above

labels = Param(Params._dummy(), "labels",
"Optional array of labels to be provided by the user, if not supplied or " +
"empty, column metadata is read for labels")

@keyword_only
def __init__(self, inputCol=None, outputCol=None, labels=None):
"""
__init__(self, inputCol=None, outputCol=None, labels=None)
"""
super(IndexToString, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IndexToString",
self.uid)
self.labels = Param(self, "labels",
"Optional array of labels to be provided by the user, if not " +
"supplied or empty, column metadata is read for labels")
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)

@keyword_only
def setParams(self, inputCol=None, outputCol=None, labels=None):
"""
setParams(self, inputCol=None, outputCol=None, labels=None)
Sets params for this IndexToString.
"""
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)

def setLabels(self, value):
"""
Sets the value of :py:attr:`labels`.
"""
self._paramMap[self.labels] = value
return self

def getLabels(self):
"""
Gets the value of :py:attr:`labels` or its default value.
"""
return self.getOrDefault(self.labels)


class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol):
Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/ml/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ def _fit(self, dataset):
class JavaTransformer(Transformer, JavaWrapper):
"""
Base class for :py:class:`Transformer`s that wrap Java/Scala
implementations.
implementations. Subclasses should ensure they have the transformer Java object
available as _java_obj.
"""

__metaclass__ = ABCMeta
Expand Down