Skip to content

Commit

Permalink
Merge pull request #22 from rxin/type
Browse files Browse the repository at this point in the history
HiveTypeCoercion.WidenTypes Void to Boolean support & unit tests
  • Loading branch information
marmbrus committed Jan 26, 2014
2 parents 9bb1979 + 5b11db0 commit 0b31176
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 15 deletions.
31 changes: 16 additions & 15 deletions src/main/scala/catalyst/analysis/HiveTypeCoercion.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import expressions._
import plans.logical._
import rules._
import types._
import catalyst.execution.{HiveUdf, HiveGenericUdf}

/**
* A collection of [[catalyst.rules.Rule Rules]] that can be used to coerce differing types that
Expand Down Expand Up @@ -38,7 +37,7 @@ trait HiveTypeCoercion {
// Leave the same if the dataTypes match.
case Some(newType) if a.dataType == newType.dataType => a
case Some(newType) =>
logger.debug(s"Promoting $a to ${newType} in ${q.simpleString}}")
logger.debug(s"Promoting $a to $newType in ${q.simpleString}}")
newType
}
}
Expand Down Expand Up @@ -89,18 +88,20 @@ trait HiveTypeCoercion {
* - BOOLEAN types cannot be converted to any other type.
*
* Additionally, all types when UNION-ed with strings will be promoted to strings.
* Other string conversions are handled by PromoteStrings
* Other string conversions are handled by PromoteStrings.
*/
object WidenTypes extends Rule[LogicalPlan] {
val integralPrecedence = Seq(NullType, ByteType, ShortType, IntegerType, LongType)
val toDouble = integralPrecedence ++ Seq(NullType, FloatType, DoubleType)
val toFloat = Seq(NullType, ByteType, ShortType, IntegerType) :+ FloatType
val allPromotions: Seq[Seq[DataType]] = integralPrecedence :: toDouble :: toFloat :: Nil
// See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
// The conversion for integral and floating point types have a linear widening hierarchy:
val numericPrecedence =
Seq(NullType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType)
// Boolean is only wider than Void
val booleanPrecedence = Seq(NullType, BooleanType)
val allPromotions: Seq[Seq[DataType]] = numericPrecedence :: booleanPrecedence :: Nil

def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
// Try and find a promotion rule that contains both types in question.
val applicableConversion =
allPromotions.find(p => p.contains(t1) && p.contains(t2))
val applicableConversion = allPromotions.find(p => p.contains(t1) && p.contains(t2))

// If found return the widest common type, otherwise None
applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
Expand All @@ -110,12 +111,12 @@ trait HiveTypeCoercion {
case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
val castedInput = left.output.zip(right.output).map {
// When a string is found on one side, make the other side a string too.
case (l,r) if l.dataType == StringType && r.dataType != StringType =>
case (l, r) if l.dataType == StringType && r.dataType != StringType =>
(l, Alias(Cast(r, StringType), r.name)())
case (l,r) if l.dataType != StringType && r.dataType == StringType =>
case (l, r) if l.dataType != StringType && r.dataType == StringType =>
(Alias(Cast(l, StringType), l.name)(), r)

case (l,r) if l.dataType != r.dataType =>
case (l, r) if l.dataType != r.dataType =>
logger.debug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}")
findTightestCommonType(l.dataType, r.dataType).map { widestType =>
val newLeft =
Expand All @@ -124,22 +125,22 @@ trait HiveTypeCoercion {
if (r.dataType == widestType) r else Alias(Cast(r, widestType), r.name)()

(newLeft, newRight)
}.getOrElse((l,r)) // If there is no applicable conversion, leave expression unchanged.
}.getOrElse((l, r)) // If there is no applicable conversion, leave expression unchanged.
case other => other
}

val (castedLeft, castedRight) = castedInput.unzip

val newLeft =
if(castedLeft.map(_.dataType) != left.output.map(_.dataType)) {
if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) {
logger.debug(s"Widening numeric types in union $castedLeft ${left.output}")
Project(castedLeft, left)
} else {
left
}

val newRight =
if(castedRight.map(_.dataType) != right.output.map(_.dataType)) {
if (castedRight.map(_.dataType) != right.output.map(_.dataType)) {
logger.debug(s"Widening numeric types in union $castedRight ${right.output}")
Project(castedRight, right)
} else {
Expand Down
55 changes: 55 additions & 0 deletions src/test/scala/catalyst/analysis/HiveTypeCoercionSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package catalyst.analysis

import org.scalatest.FunSuite

import catalyst.types._


class HiveTypeCoercionSuite extends FunSuite {

val rules = new HiveTypeCoercion { }
import rules._

test("tightest common bound for numeric and boolean types") {
def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) {
var found = WidenTypes.findTightestCommonType(t1, t2)
assert(found == tightestCommon,
s"Expected $tightestCommon as tightest common type for $t1 and $t2, found $found")
// Test both directions to make sure the widening is symmetric.
found = WidenTypes.findTightestCommonType(t2, t1)
assert(found == tightestCommon,
s"Expected $tightestCommon as tightest common type for $t2 and $t1, found $found")
}

// Boolean
widenTest(NullType, BooleanType, Some(BooleanType))
widenTest(BooleanType, BooleanType, Some(BooleanType))
widenTest(IntegerType, BooleanType, None)
widenTest(LongType, BooleanType, None)

// Integral
widenTest(NullType, ByteType, Some(ByteType))
widenTest(NullType, IntegerType, Some(IntegerType))
widenTest(NullType, LongType, Some(LongType))
widenTest(ShortType, IntegerType, Some(IntegerType))
widenTest(ShortType, LongType, Some(LongType))
widenTest(IntegerType, LongType, Some(LongType))
widenTest(LongType, LongType, Some(LongType))

// Floating point
widenTest(NullType, FloatType, Some(FloatType))
widenTest(NullType, DoubleType, Some(DoubleType))
widenTest(FloatType, DoubleType, Some(DoubleType))
widenTest(FloatType, FloatType, Some(FloatType))
widenTest(DoubleType, DoubleType, Some(DoubleType))

// Integral mixed with floating point.
widenTest(NullType, FloatType, Some(FloatType))
widenTest(NullType, DoubleType, Some(DoubleType))
widenTest(IntegerType, FloatType, Some(FloatType))
widenTest(IntegerType, DoubleType, Some(DoubleType))
widenTest(IntegerType, DoubleType, Some(DoubleType))
widenTest(LongType, FloatType, Some(FloatType))
widenTest(LongType, DoubleType, Some(DoubleType))
}
}

0 comments on commit 0b31176

Please sign in to comment.