Skip to content

Commit

Permalink
Resolving timezone aware expressions with time zone when resolving in…
Browse files Browse the repository at this point in the history
…line table.
  • Loading branch information
viirya committed Mar 1, 2017
1 parent 0fe8020 commit b635db5
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class Analyzer(
GlobalAggregates ::
ResolveAggregateFunctions ::
TimeWindowing ::
ResolveInlineTables ::
ResolveInlineTables(conf) ::
TypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ package org.apache.spark.sql.catalyst.analysis

import scala.util.control.NonFatal

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.catalyst.{CatalystConf, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{Cast, TimeZoneAwareExpression}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.{StructField, StructType}

/**
* An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]].
*/
object ResolveInlineTables extends Rule[LogicalPlan] {
case class ResolveInlineTables(conf: CatalystConf) extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case table: UnresolvedInlineTable if table.expressionsResolved =>
validateInputDimension(table)
Expand Down Expand Up @@ -95,10 +95,14 @@ object ResolveInlineTables extends Rule[LogicalPlan] {
InternalRow.fromSeq(row.zipWithIndex.map { case (e, ci) =>
val targetType = fields(ci).dataType
try {
if (e.dataType.sameType(targetType)) {
e.eval()
val castedExpr = if (e.dataType.sameType(targetType)) {
e
} else {
Cast(e, targetType).eval()
Cast(e, targetType)
}
castedExpr match {
case te: TimeZoneAwareExpression => te.withTimeZone(conf.sessionLocalTimeZone).eval()
case _ => castedExpr.eval()
}
} catch {
case NonFatal(ex) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,82 +20,92 @@ package org.apache.spark.sql.catalyst.analysis
import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Literal, Rand}
import org.apache.spark.sql.catalyst.expressions.{Cast, Literal, Rand}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.types.{LongType, NullType}
import org.apache.spark.sql.types.{LongType, NullType, TimestampType}

/**
* Unit tests for [[ResolveInlineTables]]. Note that there are also test cases defined in
* end-to-end tests (in sql/core module) for verifying the correct error messages are shown
* in negative cases.
*/
class ResolveInlineTablesSuite extends PlanTest with BeforeAndAfter {
class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter {

private def lit(v: Any): Literal = Literal(v)

test("validate inputs are foldable") {
ResolveInlineTables.validateInputEvaluable(
ResolveInlineTables(conf).validateInputEvaluable(
UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)))))

// nondeterministic (rand) should not work
intercept[AnalysisException] {
ResolveInlineTables.validateInputEvaluable(
ResolveInlineTables(conf).validateInputEvaluable(
UnresolvedInlineTable(Seq("c1"), Seq(Seq(Rand(1)))))
}

// aggregate should not work
intercept[AnalysisException] {
ResolveInlineTables.validateInputEvaluable(
ResolveInlineTables(conf).validateInputEvaluable(
UnresolvedInlineTable(Seq("c1"), Seq(Seq(Count(lit(1))))))
}

// unresolved attribute should not work
intercept[AnalysisException] {
ResolveInlineTables.validateInputEvaluable(
ResolveInlineTables(conf).validateInputEvaluable(
UnresolvedInlineTable(Seq("c1"), Seq(Seq(UnresolvedAttribute("A")))))
}
}

test("validate input dimensions") {
ResolveInlineTables.validateInputDimension(
ResolveInlineTables(conf).validateInputDimension(
UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2)))))

// num alias != data dimension
intercept[AnalysisException] {
ResolveInlineTables.validateInputDimension(
ResolveInlineTables(conf).validateInputDimension(
UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)), Seq(lit(2)))))
}

// num alias == data dimension, but data themselves are inconsistent
intercept[AnalysisException] {
ResolveInlineTables.validateInputDimension(
ResolveInlineTables(conf).validateInputDimension(
UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(21), lit(22)))))
}
}

test("do not fire the rule if not all expressions are resolved") {
val table = UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(UnresolvedAttribute("A"))))
assert(ResolveInlineTables(table) == table)
assert(ResolveInlineTables(conf)(table) == table)
}

test("convert") {
val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
val converted = ResolveInlineTables.convert(table)
val converted = ResolveInlineTables(conf).convert(table)

assert(converted.output.map(_.dataType) == Seq(LongType))
assert(converted.data.size == 2)
assert(converted.data(0).getLong(0) == 1L)
assert(converted.data(1).getLong(0) == 2L)
}

test("convert TimeZoneAwareExpression") {
val table = UnresolvedInlineTable(Seq("c1"),
Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType))))
val converted = ResolveInlineTables(conf).convert(table)
val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType)
.withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long]
assert(converted.output.map(_.dataType) == Seq(TimestampType))
assert(converted.data.size == 1)
assert(converted.data(0).getLong(0) == correct)
}

test("nullability inference in convert") {
val table1 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
val converted1 = ResolveInlineTables.convert(table1)
val converted1 = ResolveInlineTables(conf).convert(table1)
assert(!converted1.schema.fields(0).nullable)

val table2 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(Literal(null, NullType))))
val converted2 = ResolveInlineTables.convert(table2)
val converted2 = ResolveInlineTables(conf).convert(table2)
assert(converted2.schema.fields(0).nullable)
}
}
11 changes: 11 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2586,4 +2586,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
assert(!jobStarted.get(), "Command should not trigger a Spark job.")
}

test("string to timestamp in inline table definition") {
sql(
"""
|CREATE TEMPORARY VIEW table_4(timestamp_col_3)
|AS VALUES TIMESTAMP('1991-12-06 00:00:00.0')
""".stripMargin)
checkAnswer(
sql("SELECT timestamp_col_3 FROM table_4"),
Row(java.sql.Timestamp.valueOf("1991-12-06 00:00:00")) :: Nil)
}
}

0 comments on commit b635db5

Please sign in to comment.