diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala index 1fe3cfc74c76d..f5947d61fe349 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.types.StructType * VectorAssembler needs size information for its input columns and cannot be used on streaming * dataframes without this metadata. * + * Note: VectorSizeHint modifies `inputCol` to include size metadata and does not have an outputCol. */ @Experimental @Since("2.3.0") diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 608f2a5715497..5094324e5c1fe 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -57,6 +57,7 @@ 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'VectorIndexerModel', + 'VectorSizeHint', 'VectorSlicer', 'Word2Vec', 'Word2VecModel'] @@ -3466,6 +3467,84 @@ def selectedFeatures(self): return self._call_java("selectedFeatures") +@inherit_doc +class VectorSizeHint(JavaTransformer, HasInputCol, HasHandleInvalid, JavaMLReadable, + JavaMLWritable): + """ + .. note:: Experimental + + A feature transformer that adds size information to the metadata of a vector column. + VectorAssembler needs size information for its input columns and cannot be used on streaming + dataframes without this metadata. + + .. note:: VectorSizeHint modifies `inputCol` to include size metadata and does not have an + outputCol. + + >>> from pyspark.ml.linalg import Vectors + >>> from pyspark.ml import Pipeline, PipelineModel + >>> data = [(Vectors.dense([1., 2., 3.]), 4.)] + >>> df = spark.createDataFrame(data, ["vector", "float"]) + >>> + >>> sizeHint = VectorSizeHint(inputCol="vector", size=3, handleInvalid="skip") + >>> vecAssembler = VectorAssembler(inputCols=["vector", "float"], outputCol="assembled") + >>> pipeline = Pipeline(stages=[sizeHint, vecAssembler]) + >>> + >>> pipelineModel = pipeline.fit(df) + >>> pipelineModel.transform(df).head().assembled + DenseVector([1.0, 2.0, 3.0, 4.0]) + >>> vectorSizeHintPath = temp_path + "/vector-size-hint-pipeline" + >>> pipelineModel.save(vectorSizeHintPath) + >>> loadedPipeline = PipelineModel.load(vectorSizeHintPath) + >>> loaded = loadedPipeline.transform(df).head().assembled + >>> expected = pipelineModel.transform(df).head().assembled + >>> loaded == expected + True + + .. versionadded:: 2.3.0 + """ + + size = Param(Params._dummy(), "size", "Size of vectors in column.", + typeConverter=TypeConverters.toInt) + + handleInvalid = Param(Params._dummy(), "handleInvalid", + "How to handle invalid vectors in inputCol. Invalid vectors include " + "nulls and vectors with the wrong size. The options are `skip` (filter " + "out rows with invalid vectors), `error` (throw an error) and " + "`optimistic` (do not check the vector size, and keep all rows). " + "`error` by default.", + TypeConverters.toString) + + @keyword_only + def __init__(self, inputCol=None, size=None, handleInvalid="error"): + """ + __init__(self, inputCol=None, size=None, handleInvalid="error") + """ + super(VectorSizeHint, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorSizeHint", self.uid) + self._setDefault(handleInvalid="error") + self.setParams(**self._input_kwargs) + + @keyword_only + @since("2.3.0") + def setParams(self, inputCol=None, size=None, handleInvalid="error"): + """ + setParams(self, inputCol=None, size=None, handleInvalid="error") + Sets params for this VectorSizeHint. + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + @since("2.3.0") + def getSize(self): + """ Gets size param, the size of vectors in `inputCol`.""" + self.getOrDefault(self.size) + + @since("2.3.0") + def setSize(self, value): + """ Sets size param, the size of vectors in `inputCol`.""" + self._set(size=value) + + if __name__ == "__main__": import doctest import tempfile