diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 976fa57cb98d5..100480724626d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -131,20 +131,22 @@ trait HiveTypeCoercion { // Don't propagate types from unresolved children. case q: LogicalPlan if !q.childrenResolved => q - case q: LogicalPlan => q transformExpressions { - case a: AttributeReference => - q.inputSet.find(_.exprId == a.exprId) match { - // This can happen when a Attribute reference is born in a non-leaf node, for example - // due to a call to an external script like in the Transform operator. - // TODO: Perhaps those should actually be aliases? - case None => a - // Leave the same if the dataTypes match. - case Some(newType) if a.dataType == newType.dataType => a - case Some(newType) => - logDebug(s"Promoting $a to $newType in ${q.simpleString}}") - newType - } - } + case q: LogicalPlan => + val inputMap = q.inputSet.toAttributeMap(a => a) + q transformExpressions { + case a: AttributeReference => + inputMap.get(a) match { + // This can happen when a Attribute reference is born in a non-leaf node, for example + // due to a call to an external script like in the Transform operator. + // TODO: Perhaps those should actually be aliases? + case None => a + // Leave the same if the dataTypes match. + case Some(newType) if a.dataType == newType.dataType => a + case Some(newType) => + logDebug(s"Promoting $a to $newType in ${q.simpleString}}") + newType + } + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index 5345696570b41..6be12fd4090a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -111,6 +111,14 @@ class AttributeSet private (val baseSet: Set[AttributeEquals]) def intersect(other: AttributeSet): AttributeSet = new AttributeSet(baseSet.intersect(other.baseSet)) + /** + * Returns a new [[AttributeMap]] that uses [[Attribute.exprId]] as key. The value of this map is + * [[(Attribute, A)]] where type [[A]] is given by the parameter function [[f]]. + */ + def toAttributeMap[A](f: (Attribute) => A): AttributeMap[A] = { + AttributeMap(this.toSeq.map(a => (a, f(a)))) + } + override def foreach[U](f: (Attribute) => U): Unit = baseSet.map(_.a).foreach(f) // We must force toSeq to not be strict otherwise we end up with a [[Stream]] that captures all