diff --git a/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala b/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala index 6b262ed162..e16c5b7665 100644 --- a/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala +++ b/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala @@ -20,6 +20,7 @@ package org.apache.sedona.spark import org.apache.sedona.common.utils.TelemetryCollector import org.apache.sedona.core.serde.SedonaKryoRegistrator +import org.apache.sedona.sql.RasterRegistrator import org.apache.sedona.sql.UDF.UdfRegistrator import org.apache.sedona.sql.UDT.UdtRegistrator import org.apache.spark.serializer.KryoSerializer @@ -57,6 +58,7 @@ object SedonaContext { sparkSession.experimental.extraOptimizations ++= Seq(new SpatialFilterPushDownForGeoParquet(sparkSession)) } addGeoParquetToSupportNestedFilterSources(sparkSession) + RasterRegistrator.registerAll(sparkSession) UdtRegistrator.registerAll() UdfRegistrator.registerAll(sparkSession) sparkSession diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/RasterRegistrator.scala b/spark/common/src/main/scala/org/apache/sedona/sql/RasterRegistrator.scala new file mode 100644 index 0000000000..e3152e40d0 --- /dev/null +++ b/spark/common/src/main/scala/org/apache/sedona/sql/RasterRegistrator.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql + +import org.apache.sedona.sql.UDF.RasterUdafCatalog +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.sedona_sql.UDT.RasterUdtRegistratorWrapper +import org.apache.spark.sql.{SparkSession, functions} +import org.slf4j.{Logger, LoggerFactory} + +object RasterRegistrator { + val logger: Logger = LoggerFactory.getLogger(getClass) + private val gridClassName = "org.geotools.coverage.grid.GridCoverage2D" + + // Helper method to check if GridCoverage2D is available + private def isGeoToolsAvailable: Boolean = { + try { + Class.forName(gridClassName, true, Thread.currentThread().getContextClassLoader) + true + } catch { + case _: ClassNotFoundException => + logger.warn("Geotools was not found on the classpath. Raster operations will not be available.") + false + } + } + + def registerAll(sparkSession: SparkSession): Unit = { + if (isGeoToolsAvailable) { + RasterUdtRegistratorWrapper.registerAll(gridClassName) + sparkSession.udf.register(RasterUdafCatalog.rasterAggregateExpression.getClass.getSimpleName, functions.udaf(RasterUdafCatalog.rasterAggregateExpression)) + } + } + + def dropAll(sparkSession: SparkSession): Unit = { + if (isGeoToolsAvailable) { + sparkSession.sessionState.functionRegistry.dropFunction(FunctionIdentifier(RasterUdafCatalog.rasterAggregateExpression.getClass.getSimpleName)) + } + } +} diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala index 011e89b079..25467f3cdd 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala @@ -22,14 +22,12 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionInfo, Literal} import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.sedona_sql.expressions._ import org.apache.spark.sql.sedona_sql.expressions.collect.ST_Collect import org.apache.spark.sql.sedona_sql.expressions.raster._ -import org.apache.spark.sql.sedona_sql.expressions._ -import org.geotools.coverage.grid.GridCoverage2D import org.locationtech.jts.geom.Geometry import org.locationtech.jts.operation.buffer.BufferParameters -import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag object Catalog { @@ -285,8 +283,6 @@ object Catalog { function[RS_NetCDFInfo]() ) - val rasterAggregateExpression: Aggregator[(GridCoverage2D, Int), ArrayBuffer[BandData], GridCoverage2D] = new RS_Union_Aggr - val aggregateExpressions: Seq[Aggregator[Geometry, Geometry, Geometry]] = Seq( new ST_Union_Aggr, new ST_Envelope_Aggr, diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/RasterUdafCatalog.scala b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/RasterUdafCatalog.scala new file mode 100644 index 0000000000..a3deb50f5d --- /dev/null +++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/RasterUdafCatalog.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.UDF + +import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.sedona_sql.expressions.raster.{BandData, RS_Union_Aggr} +import org.geotools.coverage.grid.GridCoverage2D + +import scala.collection.mutable.ArrayBuffer + +object RasterUdafCatalog { + val rasterAggregateExpression: Aggregator[(GridCoverage2D, Int), ArrayBuffer[BandData], GridCoverage2D] = new RS_Union_Aggr +} diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/UdfRegistrator.scala b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/UdfRegistrator.scala index c8d6590b30..5475568486 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/UdfRegistrator.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/UdfRegistrator.scala @@ -37,7 +37,6 @@ object UdfRegistrator { } Catalog.aggregateExpressions.foreach(f => sparkSession.udf.register(f.getClass.getSimpleName, functions.udaf(f))) // SPARK3 anchor //Catalog.aggregateExpressions_UDAF.foreach(f => sparkSession.udf.register(f.getClass.getSimpleName, f)) // SPARK2 anchor - sparkSession.udf.register(Catalog.rasterAggregateExpression.getClass.getSimpleName, functions.udaf(Catalog.rasterAggregateExpression)) } def dropAll(sparkSession: SparkSession): Unit = { @@ -46,6 +45,5 @@ Catalog.aggregateExpressions.foreach(f => sparkSession.udf.register(f.getClass.g } Catalog.aggregateExpressions.foreach(f => sparkSession.sessionState.functionRegistry.dropFunction(FunctionIdentifier(f.getClass.getSimpleName))) // SPARK3 anchor //Catalog.aggregateExpressions_UDAF.foreach(f => sparkSession.sessionState.functionRegistry.dropFunction(FunctionIdentifier(f.getClass.getSimpleName))) // SPARK2 anchor - sparkSession.sessionState.functionRegistry.dropFunction(FunctionIdentifier(Catalog.rasterAggregateExpression.getClass.getSimpleName)) } } diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala b/spark/common/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala index 91a712fedf..52f7ceb1cd 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala @@ -19,6 +19,7 @@ package org.apache.sedona.sql.utils import org.apache.sedona.spark.SedonaContext +import org.apache.sedona.sql.RasterRegistrator import org.apache.sedona.sql.UDF.UdfRegistrator import org.apache.spark.sql.{SQLContext, SparkSession} @@ -44,5 +45,6 @@ object SedonaSQLRegistrator { def dropAll(sparkSession: SparkSession): Unit = { UdfRegistrator.dropAll(sparkSession) + RasterRegistrator.dropAll(sparkSession) } } diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/RasterUdtRegistratorWrapper.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/RasterUdtRegistratorWrapper.scala new file mode 100644 index 0000000000..b4a4e258af --- /dev/null +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/RasterUdtRegistratorWrapper.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.sedona_sql.UDT + +import org.apache.spark.sql.types.UDTRegistration + +object RasterUdtRegistratorWrapper { + + def registerAll(gridClassName: String): Unit = { + UDTRegistration.register(gridClassName, classOf[RasterUDT].getName) + } +} diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/UdtRegistratorWrapper.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/UdtRegistratorWrapper.scala index 127205faf8..a96d15c008 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/UdtRegistratorWrapper.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/UdtRegistratorWrapper.scala @@ -18,26 +18,14 @@ */ package org.apache.spark.sql.sedona_sql.UDT -import org.slf4j.{Logger, LoggerFactory} import org.apache.spark.sql.types.UDTRegistration import org.locationtech.jts.geom.Geometry import org.locationtech.jts.index.SpatialIndex object UdtRegistratorWrapper { - val logger: Logger = LoggerFactory.getLogger(getClass) - def registerAll(): Unit = { UDTRegistration.register(classOf[Geometry].getName, classOf[GeometryUDT].getName) UDTRegistration.register(classOf[SpatialIndex].getName, classOf[IndexUDT].getName) - // Rasters requires geotools which is optional. - val gridClassName = "org.geotools.coverage.grid.GridCoverage2D" - try { - // Trigger an exception if geotools is not found. - java.lang.Class.forName(gridClassName, true, Thread.currentThread().getContextClassLoader) - UDTRegistration.register(gridClassName, classOf[RasterUDT].getName) - } catch { - case e: ClassNotFoundException => logger.warn("Geotools was not found on the classpath. Raster type will not be registered.") - } } }