From 7bb6ed16e2ac58efe35b5bf7a273766183c2c953 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Wed, 1 Oct 2025 19:23:17 -0400 Subject: [PATCH] [SPARK-52614][SQL][4.0] Support RowEncoder inside Product Encoder This fixes support for using a RowEncoder inside a ProductEncoder. The current does a dataType check on a path when contructing the RowEncoder deserializer. But this is not safe and if the RowEncoder is used inside a ProductEncoder, it will throw because the path Expression is unresolved. The check was introduced in https://github.com/apache/spark/pull/49785 Yes, it makes it possible to use RowEncoder in more cases. Existing and new unit tests. No Closes #51319 from eejbyfeldt/SPARK-52614. Authored-by: Emil Ejbyfeldt Signed-off-by: Herman van Hovell --- .../catalyst/DeserializerBuildHelper.scala | 28 ++++++++----------- .../encoders/ExpressionEncoderSuite.scala | 16 +++++++++++ 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index 5d1bbb024074c..9dcaba8c2bc46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -241,19 +241,12 @@ object DeserializerBuildHelper { val walkedTypePath = WalkedTypePath().recordRoot(enc.clsTag.runtimeClass.getName) // Assumes we are deserializing the first column of a row. val input = GetColumnByOrdinal(0, enc.dataType) - enc match { - case AgnosticEncoders.RowEncoder(fields) => - val children = fields.zipWithIndex.map { case (f, i) => - createDeserializer(f.enc, GetStructField(input, i), walkedTypePath) - } - CreateExternalRow(children, enc.schema) - case _ => - val deserializer = createDeserializer( - enc, - upCastToExpectedType(input, enc.dataType, walkedTypePath), - walkedTypePath) - expressionWithNullSafety(deserializer, enc.nullable, walkedTypePath) - } + val deserializer = createDeserializer( + enc, + upCastToExpectedType(input, enc.dataType, walkedTypePath), + walkedTypePath, + isTopLevel = true) + expressionWithNullSafety(deserializer, enc.nullable, walkedTypePath) } /** @@ -265,11 +258,13 @@ object DeserializerBuildHelper { * external representation. * @param path The expression which can be used to extract serialized value. * @param walkedTypePath The paths from top to bottom to access current field when deserializing. + * @param isTopLevel true if we are creating a deserializer for the top level value. */ private def createDeserializer( enc: AgnosticEncoder[_], path: Expression, - walkedTypePath: WalkedTypePath): Expression = enc match { + walkedTypePath: WalkedTypePath, + isTopLevel: Boolean = false): Expression = enc match { case ae: AgnosticExpressionPathEncoder[_] => ae.fromCatalyst(path) case _ if isNativeEncoder(enc) => @@ -408,13 +403,12 @@ object DeserializerBuildHelper { NewInstance(cls, arguments, Nil, propagateNull = false, dt, outerPointerGetter)) case AgnosticEncoders.RowEncoder(fields) => - val isExternalRow = !path.dataType.isInstanceOf[StructType] val convertedFields = fields.zipWithIndex.map { case (f, i) => val newTypePath = walkedTypePath.recordField( f.enc.clsTag.runtimeClass.getName, f.name) val deserializer = createDeserializer(f.enc, GetStructField(path, i), newTypePath) - if (isExternalRow) { + if (!isTopLevel) { exprs.If( Invoke(path, "isNullAt", BooleanType, exprs.Literal(i) :: Nil), exprs.Literal.create(null, externalDataTypeFor(f.enc)), @@ -460,7 +454,7 @@ object DeserializerBuildHelper { Literal.create(provider(), ObjectType(classOf[Codec[_, _]])), "decode", dataTypeForClass(tag.runtimeClass), - createDeserializer(encoder, path, walkedTypePath) :: Nil) + createDeserializer(encoder, path, walkedTypePath, isTopLevel) :: Nil) } private def deserializeArray( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 1b5f1b109c45e..3d738fe985dd1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -659,6 +659,22 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes assert(fromRow(toRow(new Wrapper(Row(9L, "x")))) == new Wrapper(Row(9L, "x"))) } + test("SPARK-52614: transforming encoder row encoder in product encoder") { + val schema = new StructType().add("a", LongType).add("b", StringType) + val wrapperEncoder = TransformingEncoder( + classTag[Wrapper[Row]], + RowEncoder.encoderFor(schema), + new WrapperCodecProvider[Row]) + val encoder = ExpressionEncoder(ProductEncoder( + classTag[V[Wrapper[Row]]], + Seq(EncoderField("v", wrapperEncoder, nullable = false, Metadata.empty)), + None)) + .resolveAndBind() + val toRow = encoder.createSerializer() + val fromRow = encoder.createDeserializer() + assert(fromRow(toRow(V(new Wrapper(Row(9L, "x"))))) == V(new Wrapper(Row(9L, "x")))) + } + // below tests are related to SPARK-49960 and TransformingEncoder usage test("""Encoder with OptionEncoder of transformation""".stripMargin) { type T = Option[V[V[Int]]]