-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-9654][ML][PYSPARK] Add IndexToString to PySpark #7976
Changes from all commits
1dc4579
0445fcc
af2f869
510bce5
c6da160
9f5af3a
7b3b5ca
244e083
e95b61b
b1795aa
ab90dcd
43ae197
c400e16
64de5c9
2316a90
15390bb
28afcfd
f19445d
51ae7ee
ed0ca91
8fca8b3
3ef852f
41d0d27
cd5d418
4f56b17
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,10 +27,11 @@ | |
from pyspark.mllib.linalg import _convert_to_vector | ||
|
||
__all__ = ['Binarizer', 'Bucketizer', 'DCT', 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', | ||
'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer', | ||
'SQLTransformer', 'StandardScaler', 'StandardScalerModel', 'StringIndexer', | ||
'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'Word2Vec', | ||
'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel', 'StopWordsRemover'] | ||
'IndexToString', 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', | ||
'RegexTokenizer', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', | ||
'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', | ||
'Word2Vec', 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel', | ||
'StopWordsRemover'] | ||
|
||
|
||
@inherit_doc | ||
|
@@ -902,6 +903,11 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol): | |
>>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]), | ||
... key=lambda x: x[0]) | ||
[(0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)] | ||
>>> inverter = IndexToString(inputCol="indexed", outputCol="label2", labels=model.labels()) | ||
>>> itd = inverter.transform(td) | ||
>>> sorted(set([(i[0], str(i[1])) for i in itd.select(itd.id, itd.label2).collect()]), | ||
... key=lambda x: x[0]) | ||
[(0, 'a'), (1, 'b'), (2, 'c'), (3, 'a'), (4, 'a'), (5, 'c')] | ||
""" | ||
|
||
@keyword_only | ||
|
@@ -931,6 +937,66 @@ class StringIndexerModel(JavaModel): | |
""" | ||
Model fitted by StringIndexer. | ||
""" | ||
@property | ||
def labels(self): | ||
""" | ||
Ordered list of labels, corresponding to indices to be assigned. | ||
""" | ||
return self._java_obj.labels | ||
|
||
|
||
@inherit_doc | ||
class IndexToString(JavaTransformer, HasInputCol, HasOutputCol): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use inherit_doc tag |
||
""" | ||
.. note:: Experimental | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: line break |
||
|
||
A :py:class:`Transformer` that maps a column of string indices back to a new column of | ||
corresponding string values using either the ML attributes of the input column, or if | ||
provided using the labels supplied by the user. | ||
All original columns are kept during transformation. | ||
See L{StringIndexer} for converting strings into indices. | ||
""" | ||
|
||
# a placeholder to make the labels show up in generated doc | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. insert newline above |
||
labels = Param(Params._dummy(), "labels", | ||
"Optional array of labels to be provided by the user, if not supplied or " + | ||
"empty, column metadata is read for labels") | ||
|
||
@keyword_only | ||
def __init__(self, inputCol=None, outputCol=None, labels=None): | ||
""" | ||
__init__(self, inputCol=None, outputCol=None, labels=None) | ||
""" | ||
super(IndexToString, self).__init__() | ||
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IndexToString", | ||
self.uid) | ||
self.labels = Param(self, "labels", | ||
"Optional array of labels to be provided by the user, if not " + | ||
"supplied or empty, column metadata is read for labels") | ||
kwargs = self.__init__._input_kwargs | ||
self.setParams(**kwargs) | ||
|
||
@keyword_only | ||
def setParams(self, inputCol=None, outputCol=None, labels=None): | ||
""" | ||
setParams(self, inputCol=None, outputCol=None, labels=None) | ||
Sets params for this IndexToString. | ||
""" | ||
kwargs = self.setParams._input_kwargs | ||
return self._set(**kwargs) | ||
|
||
def setLabels(self, value): | ||
""" | ||
Sets the value of :py:attr:`labels`. | ||
""" | ||
self._paramMap[self.labels] = value | ||
return self | ||
|
||
def getLabels(self): | ||
""" | ||
Gets the value of :py:attr:`labels` or its default value. | ||
""" | ||
return self.getOrDefault(self.labels) | ||
|
||
|
||
class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
copy Scala doc: "Ordered list of labels, corresponding to indices to be assigned"