Skip to content

Commit

Permalink
[SPARK-5817] [SQL] Fix bug of udtf with column names
Browse files Browse the repository at this point in the history
It's a bug while do query like:
```sql
select d from (select explode(array(1,1)) d from src limit 1) t
```
And it will throws exception like:
```
org.apache.spark.sql.AnalysisException: cannot resolve 'd' given input columns _c0; line 1 pos 7
at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$apply$3$$anonfun$apply$1.applyOrElse(CheckAnalysis.scala:48)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$apply$3$$anonfun$apply$1.applyOrElse(CheckAnalysis.scala:45)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:250)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:250)
at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:50)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:249)
at org.apache.spark.sql.catalyst.plans.QueryPlan.org$apache$spark$sql$catalyst$plans$QueryPlan$$transformExpressionUp$1(QueryPlan.scala:103)
at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$2$$anonfun$apply$2.apply(QueryPlan.scala:117)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244)
at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:47)
at scala.collection.TraversableLike$class.map(TraversableLike.scala:244)
at scala.collection.AbstractTraversable.map(Traversable.scala:105)
at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$2.apply(QueryPlan.scala:116)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
```

To solve the bug, it requires code refactoring for UDTF
The major changes are about:
* Simplifying the UDTF development, UDTF will manage the output attribute names any more, instead, the `logical.Generate` will handle that properly.
* UDTF will be asked for the output schema (data types) during the logical plan analyzing.

Author: Cheng Hao <hao.cheng@intel.com>

Closes #4602 from chenghao-intel/explode_bug and squashes the following commits:

c2a5132 [Cheng Hao] add back resolved for Alias
556e982 [Cheng Hao] revert the unncessary change
002c361 [Cheng Hao] change the rule of resolved for Generate
04ae500 [Cheng Hao] add qualifier only for generator output
5ee5d2c [Cheng Hao] prepend the new qualifier
d2e8b43 [Cheng Hao] Update the code as feedback
ca5e7f4 [Cheng Hao] shrink the commits
  • Loading branch information
chenghao-intel authored and marmbrus committed Apr 21, 2015
1 parent 2a24bf9 commit 7662ec2
Show file tree
Hide file tree
Showing 26 changed files with 207 additions and 145 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.util.collection.OpenHashSet
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
Expand Down Expand Up @@ -59,6 +58,7 @@ class Analyzer(
ResolveReferences ::
ResolveGroupingAnalytics ::
ResolveSortReferences ::
ResolveGenerate ::
ImplicitGenerate ::
ResolveFunctions ::
GlobalAggregates ::
Expand Down Expand Up @@ -474,8 +474,59 @@ class Analyzer(
*/
object ImplicitGenerate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Project(Seq(Alias(g: Generator, _)), child) =>
Generate(g, join = false, outer = false, None, child)
case Project(Seq(Alias(g: Generator, name)), child) =>
Generate(g, join = false, outer = false,
qualifier = None, UnresolvedAttribute(name) :: Nil, child)
case Project(Seq(MultiAlias(g: Generator, names)), child) =>
Generate(g, join = false, outer = false,
qualifier = None, names.map(UnresolvedAttribute(_)), child)
}
}

/**
* Resolve the Generate, if the output names specified, we will take them, otherwise
* we will try to provide the default names, which follow the same rule with Hive.
*/
object ResolveGenerate extends Rule[LogicalPlan] {
// Construct the output attributes for the generator,
// The output attribute names can be either specified or
// auto generated.
private def makeGeneratorOutput(
generator: Generator,
generatorOutput: Seq[Attribute]): Seq[Attribute] = {
val elementTypes = generator.elementTypes

if (generatorOutput.length == elementTypes.length) {
generatorOutput.zip(elementTypes).map {
case (a, (t, nullable)) if !a.resolved =>
AttributeReference(a.name, t, nullable)()
case (a, _) => a
}
} else if (generatorOutput.length == 0) {
elementTypes.zipWithIndex.map {
// keep the default column names as Hive does _c0, _c1, _cN
case ((t, nullable), i) => AttributeReference(s"_c$i", t, nullable)()
}
} else {
throw new AnalysisException(
s"""
|The number of aliases supplied in the AS clause does not match
|the number of columns output by the UDTF expected
|${elementTypes.size} aliases but got ${generatorOutput.size}
""".stripMargin)
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case p: Generate if !p.child.resolved || !p.generator.resolved => p
case p: Generate if p.resolved == false =>
// if the generator output names are not specified, we will use the default ones.
Generate(
p.generator,
join = p.join,
outer = p.outer,
p.qualifier,
makeGeneratorOutput(p.generator, p.generatorOutput), p.child)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ trait CheckAnalysis {
throw new AnalysisException(msg)
}

def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = {
exprs.flatMap(_.collect {
case e: Generator => true
}).length >= 1
}

def checkAnalysis(plan: LogicalPlan): Unit = {
// We transform up and order the rules so as to catch the first possible failure instead
// of the result of cascading resolution failures.
Expand Down Expand Up @@ -110,6 +116,12 @@ trait CheckAnalysis {
failAnalysis(
s"unresolved operator ${operator.simpleString}")

case p @ Project(exprs, _) if containsMultipleGenerators(exprs) =>
failAnalysis(
s"""Only a single table generating function is allowed in a SELECT clause, found:
| ${exprs.map(_.prettyString).mkString(",")}""".stripMargin)


case _ => // Analysis successful!
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,12 +284,13 @@ package object dsl {
seed: Int = (math.random * 1000).toInt): LogicalPlan =
Sample(fraction, withReplacement, seed, logicalPlan)

// TODO specify the output column names
def generate(
generator: Generator,
join: Boolean = false,
outer: Boolean = false,
alias: Option[String] = None): LogicalPlan =
Generate(generator, join, outer, None, logicalPlan)
Generate(generator, join = join, outer = outer, alias, Nil, logicalPlan)

def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan =
InsertIntoTable(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,47 +42,30 @@ abstract class Generator extends Expression {

override type EvaluatedType = TraversableOnce[Row]

override lazy val dataType =
ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))))
// TODO ideally we should return the type of ArrayType(StructType),
// however, we don't keep the output field names in the Generator.
override def dataType: DataType = throw new UnsupportedOperationException

override def nullable: Boolean = false

/**
* Should be overridden by specific generators. Called only once for each instance to ensure
* that rule application does not change the output schema of a generator.
* The output element data types in structure of Seq[(DataType, Nullable)]
* TODO we probably need to add more information like metadata etc.
*/
protected def makeOutput(): Seq[Attribute]

private var _output: Seq[Attribute] = null

def output: Seq[Attribute] = {
if (_output == null) {
_output = makeOutput()
}
_output
}
def elementTypes: Seq[(DataType, Boolean)]

/** Should be implemented by child classes to perform specific Generators. */
override def eval(input: Row): TraversableOnce[Row]

/** Overridden `makeCopy` also copies the attributes that are produced by this generator. */
override def makeCopy(newArgs: Array[AnyRef]): this.type = {
val copy = super.makeCopy(newArgs)
copy._output = _output
copy
}
}

/**
* A generator that produces its output using the provided lambda function.
*/
case class UserDefinedGenerator(
schema: Seq[Attribute],
elementTypes: Seq[(DataType, Boolean)],
function: Row => TraversableOnce[Row],
children: Seq[Expression])
extends Generator{

override protected def makeOutput(): Seq[Attribute] = schema
extends Generator {

override def eval(input: Row): TraversableOnce[Row] = {
// TODO(davies): improve this
Expand All @@ -98,30 +81,18 @@ case class UserDefinedGenerator(
/**
* Given an input array produces a sequence of rows for each value in the array.
*/
case class Explode(attributeNames: Seq[String], child: Expression)
case class Explode(child: Expression)
extends Generator with trees.UnaryNode[Expression] {

override lazy val resolved =
child.resolved &&
(child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType])

private lazy val elementTypes = child.dataType match {
override def elementTypes: Seq[(DataType, Boolean)] = child.dataType match {
case ArrayType(et, containsNull) => (et, containsNull) :: Nil
case MapType(kt, vt, valueContainsNull) => (kt, false) :: (vt, valueContainsNull) :: Nil
}

// TODO: Move this pattern into Generator.
protected def makeOutput() =
if (attributeNames.size == elementTypes.size) {
attributeNames.zip(elementTypes).map {
case (n, (t, nullable)) => AttributeReference(n, t, nullable)()
}
} else {
elementTypes.zipWithIndex.map {
case ((t, nullable), i) => AttributeReference(s"c_$i", t, nullable)()
}
}

override def eval(input: Row): TraversableOnce[Row] = {
child.dataType match {
case ArrayType(_, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ case class Alias(child: Expression, name: String)(
extends NamedExpression with trees.UnaryNode[Expression] {

override type EvaluatedType = Any
// Alias(Generator, xx) need to be transformed into Generate(generator, ...)
override lazy val resolved = childrenResolved && !child.isInstanceOf[Generator]

override def eval(input: Row): Any = child.eval(input)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -482,16 +482,16 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] {
object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelper {

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case filter @ Filter(condition,
generate @ Generate(generator, join, outer, alias, grandChild)) =>
case filter @ Filter(condition, g: Generate) =>
// Predicates that reference attributes produced by the `Generate` operator cannot
// be pushed below the operator.
val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition {
conjunct => conjunct.references subsetOf grandChild.outputSet
conjunct => conjunct.references subsetOf g.child.outputSet
}
if (pushDown.nonEmpty) {
val pushDownPredicate = pushDown.reduce(And)
val withPushdown = generate.copy(child = Filter(pushDownPredicate, grandChild))
val withPushdown = Generate(g.generator, join = g.join, outer = g.outer,
g.qualifier, g.generatorOutput, Filter(pushDownPredicate, g.child))
stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown)
} else {
filter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,34 +40,43 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
* output of each into a new stream of rows. This operation is similar to a `flatMap` in functional
* programming with one important additional feature, which allows the input rows to be joined with
* their output.
* @param generator the generator expression
* @param join when true, each output row is implicitly joined with the input tuple that produced
* it.
* @param outer when true, each input row will be output at least once, even if the output of the
* given `generator` is empty. `outer` has no effect when `join` is false.
* @param alias when set, this string is applied to the schema of the output of the transformation
* as a qualifier.
* @param qualifier Qualifier for the attributes of generator(UDTF)
* @param generatorOutput The output schema of the Generator.
* @param child Children logical plan node
*/
case class Generate(
generator: Generator,
join: Boolean,
outer: Boolean,
alias: Option[String],
qualifier: Option[String],
generatorOutput: Seq[Attribute],
child: LogicalPlan)
extends UnaryNode {

protected def generatorOutput: Seq[Attribute] = {
val output = alias
.map(a => generator.output.map(_.withQualifiers(a :: Nil)))
.getOrElse(generator.output)
if (join && outer) {
output.map(_.withNullability(true))
} else {
output
}
override lazy val resolved: Boolean = {
generator.resolved &&
childrenResolved &&
generator.elementTypes.length == generatorOutput.length &&
!generatorOutput.exists(!_.resolved)
}

override def output: Seq[Attribute] =
if (join) child.output ++ generatorOutput else generatorOutput
// we don't want the gOutput to be taken as part of the expressions
// as that will cause exceptions like unresolved attributes etc.
override def expressions: Seq[Expression] = generator :: Nil

def output: Seq[Attribute] = {
val qualified = qualifier.map(q =>
// prepend the new qualifier to the existed one
generatorOutput.map(a => a.withQualifiers(q +: a.qualifiers))
).getOrElse(generatorOutput)

if (join) child.output ++ qualified else qualified
}
}

case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {

assert(!Project(Seq(UnresolvedAttribute("a")), testRelation).resolved)

val explode = Explode(Nil, AttributeReference("a", IntegerType, nullable = true)())
val explode = Explode(AttributeReference("a", IntegerType, nullable = true)())
assert(!Project(Seq(Alias(explode, "explode")()), testRelation).resolved)

assert(!Project(Seq(Alias(Count(Literal(1)), "count")()), testRelation).resolved)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -454,21 +454,21 @@ class FilterPushdownSuite extends PlanTest {
test("generate: predicate referenced no generated column") {
val originalQuery = {
testRelationWithArrayType
.generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr"))
.generate(Explode('c_arr), true, false, Some("arr"))
.where(('b >= 5) && ('a > 6))
}
val optimized = Optimize(originalQuery.analyze)
val correctAnswer = {
testRelationWithArrayType
.where(('b >= 5) && ('a > 6))
.generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr")).analyze
.generate(Explode('c_arr), true, false, Some("arr")).analyze
}

comparePlans(optimized, correctAnswer)
}

test("generate: part of conjuncts referenced generated column") {
val generator = Explode(Seq("c"), 'c_arr)
val generator = Explode('c_arr)
val originalQuery = {
testRelationWithArrayType
.generate(generator, true, false, Some("arr"))
Expand Down Expand Up @@ -499,7 +499,7 @@ class FilterPushdownSuite extends PlanTest {
test("generate: all conjuncts referenced generated column") {
val originalQuery = {
testRelationWithArrayType
.generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr"))
.generate(Explode('c_arr), true, false, Some("arr"))
.where(('c > 6) || ('b > 5)).analyze
}
val optimized = Optimize(originalQuery)
Expand Down
21 changes: 15 additions & 6 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedRelation, ResolvedStar}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, ResolvedStar}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
import org.apache.spark.sql.catalyst.plans.logical._
Expand Down Expand Up @@ -711,12 +711,16 @@ class DataFrame private[sql](
*/
def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = {
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
val attributes = schema.toAttributes

val elementTypes = schema.toAttributes.map { attr => (attr.dataType, attr.nullable) }
val names = schema.toAttributes.map(_.name)

val rowFunction =
f.andThen(_.map(CatalystTypeConverters.convertToCatalyst(_, schema).asInstanceOf[Row]))
val generator = UserDefinedGenerator(attributes, rowFunction, input.map(_.expr))
val generator = UserDefinedGenerator(elementTypes, rowFunction, input.map(_.expr))

Generate(generator, join = true, outer = false, None, logicalPlan)
Generate(generator, join = true, outer = false,
qualifier = None, names.map(UnresolvedAttribute(_)), logicalPlan)
}

/**
Expand All @@ -733,12 +737,17 @@ class DataFrame private[sql](
: DataFrame = {
val dataType = ScalaReflection.schemaFor[B].dataType
val attributes = AttributeReference(outputColumn, dataType)() :: Nil
// TODO handle the metadata?
val elementTypes = attributes.map { attr => (attr.dataType, attr.nullable) }
val names = attributes.map(_.name)

def rowFunction(row: Row): TraversableOnce[Row] = {
f(row(0).asInstanceOf[A]).map(o => Row(CatalystTypeConverters.convertToCatalyst(o, dataType)))
}
val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr :: Nil)
val generator = UserDefinedGenerator(elementTypes, rowFunction, apply(inputColumn).expr :: Nil)

Generate(generator, join = true, outer = false, None, logicalPlan)
Generate(generator, join = true, outer = false,
qualifier = None, names.map(UnresolvedAttribute(_)), logicalPlan)
}

/////////////////////////////////////////////////////////////////////////////
Expand Down
Loading

0 comments on commit 7662ec2

Please sign in to comment.