Skip to content

Commit

Permalink
[SPARK-8003][SQL] Added virtual column support to Spark
Browse files Browse the repository at this point in the history
Added virtual column support by adding a new resolution role to the query analyzer. Additional virtual columns can be added by adding case expressions to [the new rule](https://github.com/JDrit/spark/blob/virt_columns/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala#L1026) and my modifying the [logical plan](https://github.com/JDrit/spark/blob/virt_columns/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala#L216) to resolve them.

This also solves [SPARK-8003](https://issues.apache.org/jira/browse/SPARK-8003)

This allows you to perform queries such as:
```sql
select spark__partition__id, count(*) as c from table group by spark__partition__id;
```

Author: Joseph Batchik <josephbatchik@gmail.com>
Author: JD <jd@csh.rit.edu>

Closes #7478 from JDrit/virt_columns and squashes the following commits:

7932bf0 [Joseph Batchik] adding spark__partition__id to hive as well
f8a9c6c [Joseph Batchik] merging in master
e49da48 [JD] fixes for @rxin's suggestions
60e120b [JD] fixing test in merge
4bf8554 [JD] merging in master
c68bc0f [Joseph Batchik] Adding function register ability to SQLContext and adding a function for spark__partition__id()
  • Loading branch information
JDrit authored and rxin committed Jul 28, 2015
1 parent 8d5bb52 commit b88b868
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ object FunctionRegistry {
}

/** See usage above. */
private def expression[T <: Expression](name: String)
def expression[T <: Expression](name: String)
(implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = {

// See if we can find a constructor that accepts Seq[Expression]
Expand Down
11 changes: 10 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import org.apache.spark.SparkContext
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{expression => FunctionExpression, FunctionBuilder}
import org.apache.spark.sql.execution.expressions.SparkPartitionID
import org.apache.spark.sql.SQLConf.SQLConfEntry
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.errors.DialectException
Expand Down Expand Up @@ -140,7 +142,14 @@ class SQLContext(@transient val sparkContext: SparkContext)

// TODO how to handle the temp function per user session?
@transient
protected[sql] lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin
protected[sql] lazy val functionRegistry: FunctionRegistry = {
val reg = FunctionRegistry.builtin
val extendedFunctions = List[(String, (ExpressionInfo, FunctionBuilder))](
FunctionExpression[SparkPartitionID]("spark__partition__id")
)
extendedFunctions.foreach { case(name, (info, fun)) => reg.registerFunction(name, info, fun) }
reg
}

@transient
protected[sql] lazy val analyzer: Analyzer =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.types.{IntegerType, DataType}
/**
* Expression that returns the current partition id of the Spark task.
*/
private[sql] case object SparkPartitionID extends LeafExpression with Nondeterministic {
private[sql] case class SparkPartitionID() extends LeafExpression with Nondeterministic {

override def nullable: Boolean = false

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ object functions {
* @group normal_funcs
* @since 1.4.0
*/
def sparkPartitionId(): Column = execution.expressions.SparkPartitionID
def sparkPartitionId(): Column = execution.expressions.SparkPartitionID()

/**
* Computes the square root of the specified float value.
Expand Down
7 changes: 7 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ class UDFSuite extends QueryTest {
df.selectExpr("count(distinct a)")
}

test("SPARK-8003 spark__partition__id") {
val df = Seq((1, "Tearing down the walls that divide us")).toDF("id", "saying")
df.registerTempTable("tmp_table")
checkAnswer(ctx.sql("select spark__partition__id() from tmp_table").toDF(), Row(0))
ctx.dropTempTable("tmp_table")
}

test("error reporting for incorrect number of arguments") {
val df = ctx.emptyDataFrame
val e = intercept[AnalysisException] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ class NondeterministicSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("SparkPartitionID") {
checkEvaluation(SparkPartitionID, 0)
checkEvaluation(SparkPartitionID(), 0)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ import org.apache.spark.Logging
import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{expression => FunctionExpression, FunctionBuilder}
import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
import org.apache.spark.sql.execution.expressions.SparkPartitionID
import org.apache.spark.sql.SQLConf.SQLConfEntry
import org.apache.spark.sql.SQLConf.SQLConfEntry._
import org.apache.spark.sql.catalyst.{TableIdentifier, ParserDialect}
Expand Down Expand Up @@ -372,8 +375,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {

// Note that HiveUDFs will be overridden by functions registered in this context.
@transient
override protected[sql] lazy val functionRegistry: FunctionRegistry =
new HiveFunctionRegistry(FunctionRegistry.builtin)
override protected[sql] lazy val functionRegistry: FunctionRegistry = {
val reg = new HiveFunctionRegistry(FunctionRegistry.builtin)
val extendedFunctions = List[(String, (ExpressionInfo, FunctionBuilder))](
FunctionExpression[SparkPartitionID]("spark__partition__id")
)
extendedFunctions.foreach { case(name, (info, fun)) => reg.registerFunction(name, info, fun) }
reg
}

/* An analyzer that uses the Hive metastore. */
@transient
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@

package org.apache.spark.sql.hive

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.{Row, QueryTest}

case class FunctionResult(f1: String, f2: String)

class UDFSuite extends QueryTest {

private lazy val ctx = org.apache.spark.sql.hive.test.TestHive
import ctx.implicits._

test("UDF case insensitive") {
ctx.udf.register("random0", () => { Math.random() })
Expand All @@ -33,4 +34,10 @@ class UDFSuite extends QueryTest {
assert(ctx.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0)
assert(ctx.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5)
}

test("SPARK-8003 spark__partition__id") {
val df = Seq((1, "Two Fiiiiive")).toDF("id", "saying")
ctx.registerDataFrameAsTable(df, "test_table")
checkAnswer(ctx.sql("select spark__partition__id() from test_table LIMIT 1").toDF(), Row(0))
}
}

0 comments on commit b88b868

Please sign in to comment.