Skip to content

Commit

Permalink
Slightly improve PropagateTypes.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Jun 29, 2015
1 parent 660c6ce commit 913f6ad
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,20 +131,22 @@ trait HiveTypeCoercion {
// Don't propagate types from unresolved children.
case q: LogicalPlan if !q.childrenResolved => q

case q: LogicalPlan => q transformExpressions {
case a: AttributeReference =>
q.inputSet.find(_.exprId == a.exprId) match {
// This can happen when a Attribute reference is born in a non-leaf node, for example
// due to a call to an external script like in the Transform operator.
// TODO: Perhaps those should actually be aliases?
case None => a
// Leave the same if the dataTypes match.
case Some(newType) if a.dataType == newType.dataType => a
case Some(newType) =>
logDebug(s"Promoting $a to $newType in ${q.simpleString}}")
newType
}
}
case q: LogicalPlan =>
val inputMap = q.inputSet.toAttributeMap(a => a)
q transformExpressions {
case a: AttributeReference =>
inputMap.get(a) match {
// This can happen when a Attribute reference is born in a non-leaf node, for example
// due to a call to an external script like in the Transform operator.
// TODO: Perhaps those should actually be aliases?
case None => a
// Leave the same if the dataTypes match.
case Some(newType) if a.dataType == newType.dataType => a
case Some(newType) =>
logDebug(s"Promoting $a to $newType in ${q.simpleString}}")
newType
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,14 @@ class AttributeSet private (val baseSet: Set[AttributeEquals])
def intersect(other: AttributeSet): AttributeSet =
new AttributeSet(baseSet.intersect(other.baseSet))

/**
* Returns a new [[AttributeMap]] that uses [[Attribute.exprId]] as key. The value of this map is
* [[(Attribute, A)]] where type [[A]] is given by the parameter function [[f]].
*/
def toAttributeMap[A](f: (Attribute) => A): AttributeMap[A] = {
AttributeMap(this.toSeq.map(a => (a, f(a))))
}

override def foreach[U](f: (Attribute) => U): Unit = baseSet.map(_.a).foreach(f)

// We must force toSeq to not be strict otherwise we end up with a [[Stream]] that captures all
Expand Down

0 comments on commit 913f6ad

Please sign in to comment.