Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-13306] [SQL] uncorrelated scalar subquery #11190

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ atomExpression
| whenExpression
| (functionName LPAREN) => function
| tableOrColumn
| (LPAREN KW_SELECT) => subQueryExpression
-> ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP) subQueryExpression)
| LPAREN! expression RPAREN!
;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
UnresolvedAttribute(nameParts :+ cleanIdentifier(attr))
case other => UnresolvedExtractValue(other, Literal(cleanIdentifier(attr)))
}
case Token("TOK_SUBQUERY_EXPR", Token("TOK_SUBQUERY_OP", Nil) :: subquery :: Nil) =>
ScalarSubquery(nodeToPlan(subquery))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might sound excedingly dumb but I cannot find ScalarSubquery or SubqueryExpression. Are they already in the code base? Or did you create branch on top of another branch?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind I just found the other PR...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed a file, sorry


/* Stars (*) */
case Token("TOK_ALLCOLREF", Nil) => UnresolvedStar(None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class Analyzer(
ResolveGenerate ::
ResolveFunctions ::
ResolveAliases ::
ResolveSubquery ::
ResolveWindowOrder ::
ResolveWindowFrame ::
ResolveNaturalJoin ::
Expand Down Expand Up @@ -120,7 +121,14 @@ class Analyzer(
withAlias.getOrElse(relation)
}
substituted.getOrElse(u)
case other =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quick comment on why this isn't in ResolveSubquery

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

// This can't be done in ResolveSubquery because that does not know the CTE.
other transformExpressions {
case e: SubqueryExpression =>
e.withNewPlan(substituteCTE(e.query, cteRelations))
}
}

}
}

Expand Down Expand Up @@ -693,6 +701,30 @@ class Analyzer(
}
}

/**
* This rule resolve subqueries inside expressions.
*
* Note: CTE are handled in CTESubstitution.
*/
object ResolveSubquery extends Rule[LogicalPlan] with PredicateHelper {

private def hasSubquery(e: Expression): Boolean = {
e.find(_.isInstanceOf[SubqueryExpression]).isDefined
}

private def hasSubquery(q: LogicalPlan): Boolean = {
q.expressions.exists(hasSubquery)
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case q: LogicalPlan if q.childrenResolved && hasSubquery(q) =>
q transformExpressions {
case e: SubqueryExpression if !e.query.resolved =>
e.withNewPlan(execute(e.query))
}
}
}

/**
* Turns projections that contain aggregate expressions into aggregations.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery}
import org.apache.spark.sql.types.DataType

/**
* An interface for subquery that is used in expressions.
*/
abstract class SubqueryExpression extends LeafExpression {

/**
* The logical plan of the query.
*/
def query: LogicalPlan
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an helper function used in Analyzer and Optimizer, or we need to do type conversion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the base class for both logical plan and physical plan, kind of weird. This is to make the generateTreeString works in QueryPlan

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Analyzer and Optimizer only applies to logical plan right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes


/**
* Either a logical plan or a physical plan. The generated tree string (explain output) uses this
* field to explain the subquery.
*/
def plan: QueryPlan[_]

/**
* Updates the query with new logical plan.
*/
def withNewPlan(plan: LogicalPlan): SubqueryExpression
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scala doc

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't this be just in the logical plan itself?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be copy(), but I did not figure out how to make copy() work for different kind of SubqueryExpression.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think you can just remove this and move it into the logical subquery expression, since it's only used for logical plan anyway?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then should we have LogicalSubqueryExpression ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant ScalarSubquery. That's already the one isn't it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will have ExistsSubquery, InSubquery shortly (or next release).

}

/**
* A subquery that will return only one row and one column.
*
* This will be converted into [[execution.ScalarSubquery]] during physical planning.
*
* Note: `exprId` is used to have unique name in explain string output.
*/
case class ScalarSubquery(
query: LogicalPlan,
exprId: ExprId = NamedExpression.newExprId)
extends SubqueryExpression with Unevaluable {

override def plan: LogicalPlan = Subquery(toString, query)

override lazy val resolved: Boolean = query.resolved

override def dataType: DataType = query.schema.fields.head.dataType

override def checkInputDataTypes(): TypeCheckResult = {
if (query.schema.length != 1) {
TypeCheckResult.TypeCheckFailure("Scalar subquery must return only one column, but got " +
query.schema.length.toString)
} else {
TypeCheckResult.TypeCheckSuccess
}
}

override def foldable: Boolean = false
override def nullable: Boolean = true

override def withNewPlan(plan: LogicalPlan): ScalarSubquery = ScalarSubquery(plan, exprId)

override def toString: String = s"subquery#${exprId.id}"

// TODO: support sql()
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,19 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
Batch("Decimal Optimizations", FixedPoint(100),
DecimalAggregates) ::
Batch("LocalRelation", FixedPoint(100),
ConvertToLocalRelation) :: Nil
ConvertToLocalRelation) ::
Batch("Subquery", Once,
OptimizeSubqueries) :: Nil
}

/**
* Optimize all the subqueries inside expression.
*/
object OptimizeSubqueries extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case subquery: SubqueryExpression =>
subquery.withNewPlan(Optimizer.this.execute(subquery.query))
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.Subquery
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types.{DataType, StructType}

Expand Down Expand Up @@ -226,4 +227,9 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
protected def statePrefix = if (missingInput.nonEmpty && children.nonEmpty) "!" else ""

override def simpleString: String = statePrefix + super.simpleString

override def treeChildren: Seq[PlanType] = {
val subqueries = expressions.flatMap(_.collect {case e: SubqueryExpression => e})
children ++ subqueries.map(e => e.plan.asInstanceOf[PlanType])
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,11 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
}
}

/**
* All the nodes that will be used to generate tree string.
*/
protected def treeChildren: Seq[BaseType] = children

/**
* Appends the string represent of this node and its children to the given StringBuilder.
*
Expand All @@ -470,9 +475,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
builder.append(simpleString)
builder.append("\n")

if (children.nonEmpty) {
children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder))
children.last.generateTreeString(depth + 1, lastChildren :+ true, builder)
if (treeChildren.nonEmpty) {
treeChildren.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder))
treeChildren.last.generateTreeString(depth + 1, lastChildren :+ true, builder)
}

builder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.BooleanType
import org.apache.spark.unsafe.types.CalendarInterval

class CatalystQlSuite extends PlanTest {
Expand Down Expand Up @@ -201,4 +202,10 @@ class CatalystQlSuite extends PlanTest {
parser.parsePlan("select sum(product + 1) over (partition by (product + (1)) order by 2) " +
"from windowData")
}

test("subquery") {
parser.parsePlan("select (select max(b) from s) ss from t")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only thing we are testing here is that things don't go really really wrong. I'd prefer it if we test the plan as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since plan checking is too easy to break, I added test for plan, finally remove them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok that makes sense.

parser.parsePlan("select * from t where a = (select b from s)")
parser.parsePlan("select * from t group by g having a > (select b from s)")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,17 @@ class AnalysisErrorSuite extends AnalysisTest {

val dateLit = Literal.create(null, DateType)

errorTest(
"scalar subquery with 2 columns",
testRelation.select(
(ScalarSubquery(testRelation.select('a, dateLit.as('b))) + Literal(1)).as('a)),
"Scalar subquery must return only one column, but got 2" :: Nil)

errorTest(
"scalar subquery with no column",
testRelation.select(ScalarSubquery(LocalRelation()).as('a)),
"Scalar subquery must return only one column, but got 0" :: Nil)

errorTest(
"single invalid type, single arg",
testRelation.select(TestFunction(dateLit :: Nil, IntegerType :: Nil).as('a)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,7 @@ class SQLContext private[sql](
@transient
protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] {
val batches = Seq(
Batch("Subquery", Once, PlanSubqueries(self)),
Batch("Add exchange", Once, EnsureRequirements(self)),
Batch("Whole stage codegen", Once, CollapseCodegenStages(self))
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution
import java.util.concurrent.atomic.AtomicBoolean

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.concurrent.duration._

import org.apache.spark.Logging
import org.apache.spark.rdd.{RDD, RDDOperationScope}
Expand All @@ -31,6 +33,7 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetric}
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.ThreadUtils

/**
* The base class for physical operators.
Expand Down Expand Up @@ -112,16 +115,58 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
final def execute(): RDD[InternalRow] = {
RDDOperationScope.withScope(sparkContext, nodeName, false, true) {
prepare()
waitForSubqueries()
doExecute()
}
}

// All the subqueries and their Future of results.
@transient private val queryResults = ArrayBuffer[(ScalarSubquery, Future[Array[InternalRow]])]()

/**
* Collects all the subqueries and create a Future to take the first two rows of them.
*/
protected def prepareSubqueries(): Unit = {
val allSubqueries = expressions.flatMap(_.collect {case e: ScalarSubquery => e})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could move this into QueryPlan, see my previous comment in QueryPlan.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has a little bit difference than that, I'd like to duplicate it here.

allSubqueries.asInstanceOf[Seq[ScalarSubquery]].foreach { e =>
val futureResult = Future {
// We only need the first row, try to take two rows so we can throw an exception if there
// are more than one rows returned.
e.executedPlan.executeTake(2)
}(SparkPlan.subqueryExecutionContext)
queryResults += e -> futureResult
}
}

/**
* Waits for all the subqueries to finish and updates the results.
*/
protected def waitForSubqueries(): Unit = {
// fill in the result of subqueries
queryResults.foreach {
case (e, futureResult) =>
val rows = Await.result(futureResult, Duration.Inf)
if (rows.length > 1) {
sys.error(s"more than one row returned by a subquery used as an expression:\n${e.plan}")
}
if (rows.length == 1) {
assert(rows(0).numFields == 1, "Analyzer should make sure this only returns one column")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The analyzer checks this right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind.

e.updateResult(rows(0).get(0, e.dataType))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we replace the ScalarSubqueries with Literals in the the expression tree? That way we don't need state in ScalarSubquery and make CG easier...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ScalarSubqueries could be class member of the current plan, the filed could be immutable, we could not replace it.

} else {
// There is no rows returned, the result should be null.
e.updateResult(null)
}
}
queryResults.clear()
}

/**
* Prepare a SparkPlan for execution. It's idempotent.
*/
final def prepare(): Unit = {
if (prepareCalled.compareAndSet(false, true)) {
doPrepare()
prepareSubqueries()
children.foreach(_.prepare())
}
}
Expand Down Expand Up @@ -231,6 +276,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
}
}

object SparkPlan {
private[execution] val subqueryExecutionContext = ExecutionContext.fromExecutorService(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What threadpool are broadcasts done on? Should it be the same?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be refactored later, use the same thread pool for all of them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BroadcastHashJoin defines a ThreadPool for broadcasting. I am moving that as part of #11083 into exchange.scala. We could use that one.

ThreadUtils.newDaemonCachedThreadPool("subquery", 16))
}

private[sql] trait LeafNode extends SparkPlan {
override def children: Seq[SparkPlan] = Nil
override def producedAttributes: AttributeSet = outputSet
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,10 @@ trait CodegenSupport extends SparkPlan {
/**
* Returns Java source code to process the rows from upstream.
*/
def produce(ctx: CodegenContext, parent: CodegenSupport): String = {
final def produce(ctx: CodegenContext, parent: CodegenSupport): String = {
this.parent = parent
ctx.freshNamePrefix = variablePrefix
waitForSubqueries()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed? shouldn't SparkPlan.execute already call waitForSubqueries?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed for whole stage codegen, those operator will not call execute().

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok got it. this is fairly hacky ...

doProduce(ctx)
}

Expand All @@ -101,7 +102,7 @@ trait CodegenSupport extends SparkPlan {
/**
* Consume the columns generated from current SparkPlan, call it's parent.
*/
def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = {
final def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = {
if (input != null) {
assert(input.length == output.length)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,18 @@ case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPl

protected override def doExecute(): RDD[InternalRow] = child.execute()
}

/**
* A plan as subquery.
*
* This is used to generate tree string for SparkScalarSubquery.
*/
case class Subquery(name: String, child: SparkPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = child.outputPartitioning
override def outputOrdering: Seq[SortOrder] = child.outputOrdering

protected override def doExecute(): RDD[InternalRow] = {
throw new UnsupportedOperationException
}
}
Loading