Skip to content

Commit

Permalink
[SPARK-10355] [ML] [PySpark] Add Python API for SQLTransformer
Browse files Browse the repository at this point in the history
Add Python API for SQLTransformer

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #8527 from yanboliang/spark-10355.
  • Loading branch information
yanboliang authored and mengxr committed Aug 31, 2015
1 parent fe16fd0 commit 52ea399
Showing 1 changed file with 54 additions and 3 deletions.
57 changes: 54 additions & 3 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@

__all__ = ['Binarizer', 'Bucketizer', 'DCT', 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel',
'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer',
'StandardScaler', 'StandardScalerModel', 'StringIndexer', 'StringIndexerModel',
'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'Word2Vec', 'Word2VecModel',
'PCA', 'PCAModel', 'RFormula', 'RFormulaModel']
'SQLTransformer', 'StandardScaler', 'StandardScalerModel', 'StringIndexer',
'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'Word2Vec',
'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel']


@inherit_doc
Expand Down Expand Up @@ -743,6 +743,57 @@ def getPattern(self):
return self.getOrDefault(self.pattern)


@inherit_doc
class SQLTransformer(JavaTransformer):
"""
Implements the transforms which are defined by SQL statement.
Currently we only support SQL syntax like 'SELECT ... FROM __THIS__'
where '__THIS__' represents the underlying table of the input dataset.
>>> df = sqlContext.createDataFrame([(0, 1.0, 3.0), (2, 2.0, 5.0)], ["id", "v1", "v2"])
>>> sqlTrans = SQLTransformer(
... statement="SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__")
>>> sqlTrans.transform(df).head()
Row(id=0, v1=1.0, v2=3.0, v3=4.0, v4=3.0)
"""

# a placeholder to make it appear in the generated doc
statement = Param(Params._dummy(), "statement", "SQL statement")

@keyword_only
def __init__(self, statement=None):
"""
__init__(self, statement=None)
"""
super(SQLTransformer, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.SQLTransformer", self.uid)
self.statement = Param(self, "statement", "SQL statement")
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)

@keyword_only
def setParams(self, statement=None):
"""
setParams(self, statement=None)
Sets params for this SQLTransformer.
"""
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)

def setStatement(self, value):
"""
Sets the value of :py:attr:`statement`.
"""
self._paramMap[self.statement] = value
return self

def getStatement(self):
"""
Gets the value of statement or its default value.
"""
return self.getOrDefault(self.statement)


@inherit_doc
class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol):
"""
Expand Down

0 comments on commit 52ea399

Please sign in to comment.