Skip to content

Commit

Permalink
[SEDONA-478] Sedona 1.5.1 context initialization fails without GeoToo…
Browse files Browse the repository at this point in the history
…ls coverage (#1394)

* Make Raster optional

* UDT registrator is not accessible in Spark 3.1 and before
  • Loading branch information
jiayuasu committed May 1, 2024
1 parent 872ecf3 commit 883fd8e
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.sedona.spark

import org.apache.sedona.common.utils.TelemetryCollector
import org.apache.sedona.core.serde.SedonaKryoRegistrator
import org.apache.sedona.sql.RasterRegistrator
import org.apache.sedona.sql.UDF.UdfRegistrator
import org.apache.sedona.sql.UDT.UdtRegistrator
import org.apache.spark.serializer.KryoSerializer
Expand Down Expand Up @@ -57,6 +58,7 @@ object SedonaContext {
sparkSession.experimental.extraOptimizations ++= Seq(new SpatialFilterPushDownForGeoParquet(sparkSession))
}
addGeoParquetToSupportNestedFilterSources(sparkSession)
RasterRegistrator.registerAll(sparkSession)
UdtRegistrator.registerAll()
UdfRegistrator.registerAll(sparkSession)
sparkSession
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sedona.sql

import org.apache.sedona.sql.UDF.RasterUdafCatalog
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.sedona_sql.UDT.RasterUdtRegistratorWrapper
import org.apache.spark.sql.{SparkSession, functions}
import org.slf4j.{Logger, LoggerFactory}

object RasterRegistrator {
val logger: Logger = LoggerFactory.getLogger(getClass)
private val gridClassName = "org.geotools.coverage.grid.GridCoverage2D"

// Helper method to check if GridCoverage2D is available
private def isGeoToolsAvailable: Boolean = {
try {
Class.forName(gridClassName, true, Thread.currentThread().getContextClassLoader)
true
} catch {
case _: ClassNotFoundException =>
logger.warn("Geotools was not found on the classpath. Raster operations will not be available.")
false
}
}

def registerAll(sparkSession: SparkSession): Unit = {
if (isGeoToolsAvailable) {
RasterUdtRegistratorWrapper.registerAll(gridClassName)
sparkSession.udf.register(RasterUdafCatalog.rasterAggregateExpression.getClass.getSimpleName, functions.udaf(RasterUdafCatalog.rasterAggregateExpression))
}
}

def dropAll(sparkSession: SparkSession): Unit = {
if (isGeoToolsAvailable) {
sparkSession.sessionState.functionRegistry.dropFunction(FunctionIdentifier(RasterUdafCatalog.rasterAggregateExpression.getClass.getSimpleName))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,12 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionInfo, Literal}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.sedona_sql.expressions._
import org.apache.spark.sql.sedona_sql.expressions.collect.ST_Collect
import org.apache.spark.sql.sedona_sql.expressions.raster._
import org.apache.spark.sql.sedona_sql.expressions._
import org.geotools.coverage.grid.GridCoverage2D
import org.locationtech.jts.geom.Geometry
import org.locationtech.jts.operation.buffer.BufferParameters

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

object Catalog {
Expand Down Expand Up @@ -285,8 +283,6 @@ object Catalog {
function[RS_NetCDFInfo]()
)

val rasterAggregateExpression: Aggregator[(GridCoverage2D, Int), ArrayBuffer[BandData], GridCoverage2D] = new RS_Union_Aggr

val aggregateExpressions: Seq[Aggregator[Geometry, Geometry, Geometry]] = Seq(
new ST_Union_Aggr,
new ST_Envelope_Aggr,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sedona.sql.UDF

import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.sedona_sql.expressions.raster.{BandData, RS_Union_Aggr}
import org.geotools.coverage.grid.GridCoverage2D

import scala.collection.mutable.ArrayBuffer

object RasterUdafCatalog {
val rasterAggregateExpression: Aggregator[(GridCoverage2D, Int), ArrayBuffer[BandData], GridCoverage2D] = new RS_Union_Aggr
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ object UdfRegistrator {
}
Catalog.aggregateExpressions.foreach(f => sparkSession.udf.register(f.getClass.getSimpleName, functions.udaf(f))) // SPARK3 anchor
//Catalog.aggregateExpressions_UDAF.foreach(f => sparkSession.udf.register(f.getClass.getSimpleName, f)) // SPARK2 anchor
sparkSession.udf.register(Catalog.rasterAggregateExpression.getClass.getSimpleName, functions.udaf(Catalog.rasterAggregateExpression))
}

def dropAll(sparkSession: SparkSession): Unit = {
Expand All @@ -46,6 +45,5 @@ Catalog.aggregateExpressions.foreach(f => sparkSession.udf.register(f.getClass.g
}
Catalog.aggregateExpressions.foreach(f => sparkSession.sessionState.functionRegistry.dropFunction(FunctionIdentifier(f.getClass.getSimpleName))) // SPARK3 anchor
//Catalog.aggregateExpressions_UDAF.foreach(f => sparkSession.sessionState.functionRegistry.dropFunction(FunctionIdentifier(f.getClass.getSimpleName))) // SPARK2 anchor
sparkSession.sessionState.functionRegistry.dropFunction(FunctionIdentifier(Catalog.rasterAggregateExpression.getClass.getSimpleName))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.sedona.sql.utils

import org.apache.sedona.spark.SedonaContext
import org.apache.sedona.sql.RasterRegistrator
import org.apache.sedona.sql.UDF.UdfRegistrator
import org.apache.spark.sql.{SQLContext, SparkSession}

Expand All @@ -44,5 +45,6 @@ object SedonaSQLRegistrator {

def dropAll(sparkSession: SparkSession): Unit = {
UdfRegistrator.dropAll(sparkSession)
RasterRegistrator.dropAll(sparkSession)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.spark.sql.sedona_sql.UDT

import org.apache.spark.sql.types.UDTRegistration

object RasterUdtRegistratorWrapper {

def registerAll(gridClassName: String): Unit = {
UDTRegistration.register(gridClassName, classOf[RasterUDT].getName)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,14 @@
*/
package org.apache.spark.sql.sedona_sql.UDT

import org.slf4j.{Logger, LoggerFactory}
import org.apache.spark.sql.types.UDTRegistration
import org.locationtech.jts.geom.Geometry
import org.locationtech.jts.index.SpatialIndex

object UdtRegistratorWrapper {

val logger: Logger = LoggerFactory.getLogger(getClass)

def registerAll(): Unit = {
UDTRegistration.register(classOf[Geometry].getName, classOf[GeometryUDT].getName)
UDTRegistration.register(classOf[SpatialIndex].getName, classOf[IndexUDT].getName)
// Rasters requires geotools which is optional.
val gridClassName = "org.geotools.coverage.grid.GridCoverage2D"
try {
// Trigger an exception if geotools is not found.
java.lang.Class.forName(gridClassName, true, Thread.currentThread().getContextClassLoader)
UDTRegistration.register(gridClassName, classOf[RasterUDT].getName)
} catch {
case e: ClassNotFoundException => logger.warn("Geotools was not found on the classpath. Raster type will not be registered.")
}
}
}

0 comments on commit 883fd8e

Please sign in to comment.