diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index f858650df410d..cee440773592a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -541,7 +541,10 @@ case class Round(child: Expression, scale: Expression) extends Expression with E case t => t } - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegralType) + override def inputTypes: Seq[AbstractDataType] = Seq( + //rely on precedence to implicit cast String into Double + TypeCollection(DoubleType, FloatType, LongType, IntegerType, ShortType, ByteType), + TypeCollection(LongType, IntegerType, ShortType, ByteType)) override def checkInputDataTypes(): TypeCheckResult = { child.dataType match {