diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 0f8282d3b2f1f..368a70a384044 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -106,11 +106,11 @@ object RowEncoder { inputObject :: Nil) case d: DecimalType => - StaticInvoke( + CheckOverflow(StaticInvoke( Decimal.getClass, d, "fromDecimal", - inputObject :: Nil) + inputObject :: Nil), d) case StringType => StaticInvoke( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 683fe4a329365..dbaf62825a76d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1238,6 +1238,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkDataset(ds, SpecialCharClass("1", "2")) } } + + test("SPARK-26233: serializer should enforce decimal precision and scale") { + val s = StructType(Seq(StructField("a", StringType), StructField("b", DecimalType(38, 8)))) + val encoder = RowEncoder(s) + implicit val uEnc = encoder + val df = spark.range(2).map(l => Row(l.toString, BigDecimal.valueOf(l + 0.1111))) + checkAnswer(df.groupBy(col("a")).agg(first(col("b"))), + Seq(Row("0", BigDecimal.valueOf(0.1111)), Row("1", BigDecimal.valueOf(1.1111)))) + } } case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String])