From c886be69808df5394cc614ff1bc2cd999de59b09 Mon Sep 17 00:00:00 2001 From: jameswillis Date: Tue, 10 Mar 2026 10:13:44 -0700 Subject: [PATCH] fix: disable TransformNestedUDTParquet in Spark versions >= 4.1 SPARK-48942: nested UDTs crash the vectorized Parquet reader on Spark < 4.1. SPARK-52651 fixes this in Spark 4.1+ by recursively stripping UDTs in ColumnVector, making the TransformNestedUDTParquet workaround unnecessary. Only register the TransformNestedUDTParquet optimizer rule when running on Spark < 4.1. Uses defensive version parsing with Try/getOrElse and .lift() to avoid exceptions on malformed version strings. --- .../apache/sedona/spark/SedonaContext.scala | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala b/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala index b0e46cf6e9e..b13f93594c4 100644 --- a/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala +++ b/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala @@ -42,12 +42,26 @@ class InternalApi( object SedonaContext { - private def customOptimizationsWithSession(sparkSession: SparkSession) = - Seq( - new TransformNestedUDTParquet(sparkSession), + private def customOptimizationsWithSession(sparkSession: SparkSession) = { + val optimizations = Seq( new SpatialFilterPushDownForGeoParquet(sparkSession), new SpatialTemporalFilterPushDownForStacScan(sparkSession)) + val versionParts = + sparkSession.version + .split('.') + .map(s => scala.util.Try(s.takeWhile(_.isDigit).toInt).getOrElse(0)) + val major = versionParts.lift(0).getOrElse(0) + val minor = versionParts.lift(1).getOrElse(0) + if (major < 4 || (major == 4 && minor < 1)) { + // SPARK-48942: nested UDTs crash the vectorized Parquet reader on Spark < 4.1. + // SPARK-52651 fixes this in Spark 4.1+ by recursively stripping UDTs in ColumnVector. + new TransformNestedUDTParquet(sparkSession) +: optimizations + } else { + optimizations + } + } + def create(sqlContext: SQLContext): SQLContext = { create(sqlContext.sparkSession) sqlContext