Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-18394][SQL] Make an AttributeSet.toSeq output order consistent #18959

Closed
wants to merge 4 commits into from
Closed
Changes from all commits
Commits
File filter...
Filter file types
Jump to…
Jump to file or symbol
Failed to load files and symbols.

Always

Just for now

@@ -121,7 +121,12 @@ class AttributeSet private (val baseSet: Set[AttributeEquals])

// We must force toSeq to not be strict otherwise we end up with a [[Stream]] that captures all
// sorts of things in its closure.
override def toSeq: Seq[Attribute] = baseSet.map(_.a).toArray.toSeq
override def toSeq: Seq[Attribute] = {
// We need to keep a deterministic output order for `baseSet` because this affects a variable
// order in generated code (e.g., `GenerateColumnAccessor`).
// See SPARK-18394 for details.
baseSet.map(_.a).toSeq.sortBy { a => (a.name, a.exprId.id) }
}

override def toString: String = "{" + baseSet.map(_.a).mkString(", ") + "}"

@@ -78,4 +78,44 @@ class AttributeSetSuite extends SparkFunSuite {
assert(aSet == aSet)
assert(aSet == AttributeSet(aUpper :: Nil))
}

test("SPARK-18394 keep a deterministic output order along with attribute names and exprIds") {
// Checks a simple case
val attrSeqA = {
val attr1 = AttributeReference("c1", IntegerType)(exprId = ExprId(1098))
val attr2 = AttributeReference("c2", IntegerType)(exprId = ExprId(107))
val attr3 = AttributeReference("c3", IntegerType)(exprId = ExprId(838))
val attrSetA = AttributeSet(attr1 :: attr2 :: attr3 :: Nil)

val attr4 = AttributeReference("c4", IntegerType)(exprId = ExprId(389))
val attr5 = AttributeReference("c5", IntegerType)(exprId = ExprId(89329))

val attrSetB = AttributeSet(attr4 :: attr5 :: Nil)
(attrSetA ++ attrSetB).toSeq.map(_.name)
}

val attrSeqB = {
val attr1 = AttributeReference("c1", IntegerType)(exprId = ExprId(392))
val attr2 = AttributeReference("c2", IntegerType)(exprId = ExprId(92))
val attr3 = AttributeReference("c3", IntegerType)(exprId = ExprId(87))
val attrSetA = AttributeSet(attr1 :: attr2 :: attr3 :: Nil)

val attr4 = AttributeReference("c4", IntegerType)(exprId = ExprId(9023920))
val attr5 = AttributeReference("c5", IntegerType)(exprId = ExprId(522))
val attrSetB = AttributeSet(attr4 :: attr5 :: Nil)

(attrSetA ++ attrSetB).toSeq.map(_.name)
}

assert(attrSeqA === attrSeqB)

// Checks the same column names having different exprIds
val attr1 = AttributeReference("c", IntegerType)(exprId = ExprId(1098))
val attr2 = AttributeReference("c", IntegerType)(exprId = ExprId(107))
val attrSetA = AttributeSet(attr1 :: attr2 :: Nil)
val attr3 = AttributeReference("c", IntegerType)(exprId = ExprId(389))
val attrSetB = AttributeSet(attr3 :: Nil)

assert((attrSetA ++ attrSetB).toSeq === attr2 :: attr3 :: attr1 :: Nil)
}
}
@@ -162,7 +162,12 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter {
}.head

assert(actualOutputColumns === expectedOutputColumns, "Output columns mismatch")
assert(actualScannedColumns === expectedScannedColumns, "Scanned columns mismatch")

// Scanned columns in `HiveTableScanExec` are generated by the `pruneFilterProject` method
// in `SparkPlanner`. This method internally uses `AttributeSet.toSeq`, in which
// the returned output columns are sorted by the names and expression ids.
assert(actualScannedColumns.sorted === expectedScannedColumns.sorted,
"Scanned columns mismatch")

val actualPartitions = actualPartValues.map(_.asScala.mkString(",")).sorted
val expectedPartitions = expectedPartValues.map(_.mkString(",")).sorted
ProTip! Use n and p to navigate between commits in a pull request.
You can’t perform that action at this time.