Skip to content

Commit

Permalink
[SPARK-22734][ML][PYSPARK] Added Python API for VectorSizeHint.
Browse files Browse the repository at this point in the history
(Please fill in changes proposed in this fix)

Python API for VectorSizeHint Transformer.

(Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests)

doc-tests.

Author: Bago Amirbekian <bago@databricks.com>

Closes #20112 from MrBago/vectorSizeHint-PythonAPI.
  • Loading branch information
MrBago authored and jkbradley committed Dec 30, 2017
1 parent 30fcdc0 commit 8169630
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import org.apache.spark.sql.types.StructType
* VectorAssembler needs size information for its input columns and cannot be used on streaming
* dataframes without this metadata.
*
* Note: VectorSizeHint modifies `inputCol` to include size metadata and does not have an outputCol.
*/
@Experimental
@Since("2.3.0")
Expand Down
79 changes: 79 additions & 0 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
'Tokenizer',
'VectorAssembler',
'VectorIndexer', 'VectorIndexerModel',
'VectorSizeHint',
'VectorSlicer',
'Word2Vec', 'Word2VecModel']

Expand Down Expand Up @@ -3466,6 +3467,84 @@ def selectedFeatures(self):
return self._call_java("selectedFeatures")


@inherit_doc
class VectorSizeHint(JavaTransformer, HasInputCol, HasHandleInvalid, JavaMLReadable,
JavaMLWritable):
"""
.. note:: Experimental
A feature transformer that adds size information to the metadata of a vector column.
VectorAssembler needs size information for its input columns and cannot be used on streaming
dataframes without this metadata.
.. note:: VectorSizeHint modifies `inputCol` to include size metadata and does not have an
outputCol.
>>> from pyspark.ml.linalg import Vectors
>>> from pyspark.ml import Pipeline, PipelineModel
>>> data = [(Vectors.dense([1., 2., 3.]), 4.)]
>>> df = spark.createDataFrame(data, ["vector", "float"])
>>>
>>> sizeHint = VectorSizeHint(inputCol="vector", size=3, handleInvalid="skip")
>>> vecAssembler = VectorAssembler(inputCols=["vector", "float"], outputCol="assembled")
>>> pipeline = Pipeline(stages=[sizeHint, vecAssembler])
>>>
>>> pipelineModel = pipeline.fit(df)
>>> pipelineModel.transform(df).head().assembled
DenseVector([1.0, 2.0, 3.0, 4.0])
>>> vectorSizeHintPath = temp_path + "/vector-size-hint-pipeline"
>>> pipelineModel.save(vectorSizeHintPath)
>>> loadedPipeline = PipelineModel.load(vectorSizeHintPath)
>>> loaded = loadedPipeline.transform(df).head().assembled
>>> expected = pipelineModel.transform(df).head().assembled
>>> loaded == expected
True
.. versionadded:: 2.3.0
"""

size = Param(Params._dummy(), "size", "Size of vectors in column.",
typeConverter=TypeConverters.toInt)

handleInvalid = Param(Params._dummy(), "handleInvalid",
"How to handle invalid vectors in inputCol. Invalid vectors include "
"nulls and vectors with the wrong size. The options are `skip` (filter "
"out rows with invalid vectors), `error` (throw an error) and "
"`optimistic` (do not check the vector size, and keep all rows). "
"`error` by default.",
TypeConverters.toString)

@keyword_only
def __init__(self, inputCol=None, size=None, handleInvalid="error"):
"""
__init__(self, inputCol=None, size=None, handleInvalid="error")
"""
super(VectorSizeHint, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorSizeHint", self.uid)
self._setDefault(handleInvalid="error")
self.setParams(**self._input_kwargs)

@keyword_only
@since("2.3.0")
def setParams(self, inputCol=None, size=None, handleInvalid="error"):
"""
setParams(self, inputCol=None, size=None, handleInvalid="error")
Sets params for this VectorSizeHint.
"""
kwargs = self._input_kwargs
return self._set(**kwargs)

@since("2.3.0")
def getSize(self):
""" Gets size param, the size of vectors in `inputCol`."""
self.getOrDefault(self.size)

@since("2.3.0")
def setSize(self, value):
""" Sets size param, the size of vectors in `inputCol`."""
self._set(size=value)


if __name__ == "__main__":
import doctest
import tempfile
Expand Down

0 comments on commit 8169630

Please sign in to comment.