Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-5579][SQL][DataFrame] Support for project/filter using SQL expressions #4348

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,14 @@ class LogisticRegressionModel private[ml] (
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
val scoreFunction: Vector => Double = (v) => {
val scoreFunction = udf((v: Vector) => {
val margin = BLAS.dot(v, weights)
1.0 / (1.0 + math.exp(-margin))
}
} : Double)
val t = map(threshold)
val predictFunction: Double => Double = (score) => { if (score > t) 1.0 else 0.0 }
val predictFunction = udf((score: Double) => { if (score > t) 1.0 else 0.0 } : Double)
dataset
.select($"*", callUDF(scoreFunction, col(map(featuresCol))).as(map(scoreCol)))
.select($"*", callUDF(predictFunction, col(map(scoreCol))).as(map(predictionCol)))
.select($"*", scoreFunction(col(map(featuresCol))).as(map(scoreCol)))
.select($"*", predictFunction(col(map(scoreCol))).as(map(predictionCol)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,8 @@ class StandardScalerModel private[ml] (
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
val scale: (Vector) => Vector = (v) => {
scaler.transform(v)
}
dataset.select($"*", callUDF(scale, col(map(inputCol))).as(map(outputCol)))
val scale = udf((v: Vector) => { scaler.transform(v) } : Vector)
dataset.select($"*", scale(col(map(inputCol))).as(map(outputCol)))
}

private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,22 +126,20 @@ class ALSModel private[ml] (
val map = this.paramMap ++ paramMap
val users = userFactors.toDataFrame("id", "features")
val items = itemFactors.toDataFrame("id", "features")
val predict: (Seq[Float], Seq[Float]) => Float = (userFeatures, itemFeatures) => {

// Register a UDF for DataFrame, and then
// create a new column named map(predictionCol) by running the predict UDF.
val predict = udf((userFeatures: Seq[Float], itemFeatures: Seq[Float]) => {
if (userFeatures != null && itemFeatures != null) {
blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1)
} else {
Float.NaN
}
}
val inputColumns = dataset.schema.fieldNames
val prediction = callUDF(predict, users("features"), items("features")).as(map(predictionCol))
val outputColumns = inputColumns.map(f => dataset(f)) :+ prediction
} : Float)
dataset
.join(users, dataset(map(userCol)) === users("id"), "left")
.join(items, dataset(map(itemCol)) === items("id"), "left")
.select(outputColumns: _*)
// TODO: Just use a dataset("*")
// .select(dataset("*"), prediction)
.select(dataset("*"), predict(users("features"), items("features")).as(map(predictionCol)))
}

override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
Expand Down
5 changes: 2 additions & 3 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2126,10 +2126,9 @@ def sort(self, *cols):
"""
if not cols:
raise ValueError("should sort by at least one column")
jcols = ListConverter().convert([_to_java_column(c) for c in cols[1:]],
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@davies take a look at the Python changes.

self._sc._gateway._gateway_client)
jdf = self._jdf.sort(_to_java_column(cols[0]),
self._sc._jvm.Dsl.toColumns(jcols))
jdf = self._jdf.sort(self._sc._jvm.Dsl.toColumns(jcols))
return DataFrame(jdf, self.sql_ctx)

sortBy = sort
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@ import org.apache.spark.sql.types._
* for a SQL like language should checkout the HiveQL support in the sql/hive sub-project.
*/
class SqlParser extends AbstractSparkSQLParser {

def parseExpression(input: String): Expression = {
// Initialize the Keywords.
lexical.initialize(reservedWords)
phrase(expression)(new lexical.Scanner(input)) match {
case Success(plan, _) => plan
case failureOrError => sys.error(failureOrError.toString)
}
}

// Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword`
// properties via reflection the class in runtime for constructing the SqlLexical object
protected val ABS = Keyword("ABS")
Expand Down
27 changes: 23 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ private[sql] object DataFrame {
* }}}
*
* Note that the [[Column]] type can also be manipulated through its various functions.
* {{
* {{{
* // The following creates a new column that increases everybody's age by 10.
* people("age") + 10 // in Scala
* }}
* }}}
*
* A more concrete example:
* {{{
Expand Down Expand Up @@ -173,7 +173,7 @@ trait DataFrame extends RDDApi[Row] {
* }}}
*/
@scala.annotation.varargs
def sort(sortExpr: Column, sortExprs: Column*): DataFrame
def sort(sortExprs: Column*): DataFrame

/**
* Returns a new [[DataFrame]] sorted by the given expressions.
Expand All @@ -187,7 +187,7 @@ trait DataFrame extends RDDApi[Row] {
* This is an alias of the `sort` function.
*/
@scala.annotation.varargs
def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame
def orderBy(sortExprs: Column*): DataFrame

/**
* Selects column based on the column name and return it as a [[Column]].
Expand Down Expand Up @@ -236,6 +236,17 @@ trait DataFrame extends RDDApi[Row] {
@scala.annotation.varargs
def select(col: String, cols: String*): DataFrame

/**
* Selects a set of SQL expressions. This is a variant of `select` that accepts
* SQL expressions.
*
* {{{
* df.selectExpr("colA", "colB as newName", "abs(colC)")
* }}}
*/
@scala.annotation.varargs
def selectExpr(exprs: String*): DataFrame

/**
* Filters rows using the given condition.
* {{{
Expand All @@ -247,6 +258,14 @@ trait DataFrame extends RDDApi[Row] {
*/
def filter(condition: Column): DataFrame

/**
* Filters rows using the given SQL expression.
* {{{
* peopleDf.filter("age > 15")
* }}}
*/
def filter(conditionExpr: String): DataFrame

/**
* Filters rows using the given condition. This is an alias for `filter`.
* {{{
Expand Down
24 changes: 17 additions & 7 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedRelation}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
Expand Down Expand Up @@ -124,11 +124,11 @@ private[sql] class DataFrameImpl protected[sql](
}

override def sort(sortCol: String, sortCols: String*): DataFrame = {
orderBy(apply(sortCol), sortCols.map(apply) :_*)
sort((sortCol +: sortCols).map(apply) :_*)
}

override def sort(sortExpr: Column, sortExprs: Column*): DataFrame = {
val sortOrder: Seq[SortOrder] = (sortExpr +: sortExprs).map { col =>
override def sort(sortExprs: Column*): DataFrame = {
val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
col.expr match {
case expr: SortOrder =>
expr
Expand All @@ -143,8 +143,8 @@ private[sql] class DataFrameImpl protected[sql](
sort(sortCol, sortCols :_*)
}

override def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame = {
sort(sortExpr, sortExprs :_*)
override def orderBy(sortExprs: Column*): DataFrame = {
sort(sortExprs :_*)
}

override def col(colName: String): Column = colName match {
Expand Down Expand Up @@ -179,10 +179,20 @@ private[sql] class DataFrameImpl protected[sql](
select((col +: cols).map(Column(_)) :_*)
}

override def selectExpr(exprs: String*): DataFrame = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this one could be merged into select(), column is also a valid expression

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not if it has space ... it will just fail

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should work in these cases with this implementation.

select('*', 'a', '`the name`', 'a + 1', 'min(b) * 3')

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea - but asking users to wrap a column name in backticks in strings is fairly annoying.

select(exprs.map { expr =>
Column(new SqlParser().parseExpression(expr))
} :_*)
}

override def filter(condition: Column): DataFrame = {
Filter(condition.expr, logicalPlan)
}

override def filter(conditionExpr: String): DataFrame = {
filter(Column(new SqlParser().parseExpression(conditionExpr)))
}

override def where(condition: Column): DataFrame = {
filter(condition)
}
Expand Down Expand Up @@ -329,7 +339,7 @@ private[sql] class DataFrameImpl protected[sql](

override def save(path: String): Unit = {
val dataSourceName = sqlContext.conf.defaultDataSourceName
save(dataSourceName, ("path" -> path))
save(dataSourceName, "path" -> path)
}

override def save(
Expand Down