Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SEDONA-478] Sedona 1.5.1 context initialization fails without GeoTools coverage #1394

Merged
merged 2 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.sedona.spark

import org.apache.sedona.common.utils.TelemetryCollector
import org.apache.sedona.core.serde.SedonaKryoRegistrator
import org.apache.sedona.sql.RasterRegistrator
import org.apache.sedona.sql.UDF.UdfRegistrator
import org.apache.sedona.sql.UDT.UdtRegistrator
import org.apache.spark.serializer.KryoSerializer
Expand Down Expand Up @@ -57,6 +58,7 @@ object SedonaContext {
sparkSession.experimental.extraOptimizations ++= Seq(new SpatialFilterPushDownForGeoParquet(sparkSession))
}
addGeoParquetToSupportNestedFilterSources(sparkSession)
RasterRegistrator.registerAll(sparkSession)
UdtRegistrator.registerAll()
UdfRegistrator.registerAll(sparkSession)
sparkSession
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sedona.sql

import org.apache.sedona.sql.UDF.RasterUdafCatalog
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.sedona_sql.UDT.RasterUdtRegistratorWrapper
import org.apache.spark.sql.{SparkSession, functions}
import org.slf4j.{Logger, LoggerFactory}

object RasterRegistrator {
val logger: Logger = LoggerFactory.getLogger(getClass)
private val gridClassName = "org.geotools.coverage.grid.GridCoverage2D"

// Helper method to check if GridCoverage2D is available
private def isGeoToolsAvailable: Boolean = {
try {
Class.forName(gridClassName, true, Thread.currentThread().getContextClassLoader)
true
} catch {
case _: ClassNotFoundException =>
logger.warn("Geotools was not found on the classpath. Raster operations will not be available.")
false
}
}

def registerAll(sparkSession: SparkSession): Unit = {
if (isGeoToolsAvailable) {
RasterUdtRegistratorWrapper.registerAll(gridClassName)
sparkSession.udf.register(RasterUdafCatalog.rasterAggregateExpression.getClass.getSimpleName, functions.udaf(RasterUdafCatalog.rasterAggregateExpression))
}
}

def dropAll(sparkSession: SparkSession): Unit = {
if (isGeoToolsAvailable) {
sparkSession.sessionState.functionRegistry.dropFunction(FunctionIdentifier(RasterUdafCatalog.rasterAggregateExpression.getClass.getSimpleName))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,12 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionInfo, Literal}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.sedona_sql.expressions._
import org.apache.spark.sql.sedona_sql.expressions.collect.ST_Collect
import org.apache.spark.sql.sedona_sql.expressions.raster._
import org.apache.spark.sql.sedona_sql.expressions._
import org.geotools.coverage.grid.GridCoverage2D
import org.locationtech.jts.geom.Geometry
import org.locationtech.jts.operation.buffer.BufferParameters

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

object Catalog {
Expand Down Expand Up @@ -285,8 +283,6 @@ object Catalog {
function[RS_NetCDFInfo]()
)

val rasterAggregateExpression: Aggregator[(GridCoverage2D, Int), ArrayBuffer[BandData], GridCoverage2D] = new RS_Union_Aggr

val aggregateExpressions: Seq[Aggregator[Geometry, Geometry, Geometry]] = Seq(
new ST_Union_Aggr,
new ST_Envelope_Aggr,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sedona.sql.UDF

import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.sedona_sql.expressions.raster.{BandData, RS_Union_Aggr}
import org.geotools.coverage.grid.GridCoverage2D

import scala.collection.mutable.ArrayBuffer

object RasterUdafCatalog {
val rasterAggregateExpression: Aggregator[(GridCoverage2D, Int), ArrayBuffer[BandData], GridCoverage2D] = new RS_Union_Aggr
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ object UdfRegistrator {
}
Catalog.aggregateExpressions.foreach(f => sparkSession.udf.register(f.getClass.getSimpleName, functions.udaf(f))) // SPARK3 anchor
//Catalog.aggregateExpressions_UDAF.foreach(f => sparkSession.udf.register(f.getClass.getSimpleName, f)) // SPARK2 anchor
sparkSession.udf.register(Catalog.rasterAggregateExpression.getClass.getSimpleName, functions.udaf(Catalog.rasterAggregateExpression))
}

def dropAll(sparkSession: SparkSession): Unit = {
Expand All @@ -46,6 +45,5 @@ Catalog.aggregateExpressions.foreach(f => sparkSession.udf.register(f.getClass.g
}
Catalog.aggregateExpressions.foreach(f => sparkSession.sessionState.functionRegistry.dropFunction(FunctionIdentifier(f.getClass.getSimpleName))) // SPARK3 anchor
//Catalog.aggregateExpressions_UDAF.foreach(f => sparkSession.sessionState.functionRegistry.dropFunction(FunctionIdentifier(f.getClass.getSimpleName))) // SPARK2 anchor
sparkSession.sessionState.functionRegistry.dropFunction(FunctionIdentifier(Catalog.rasterAggregateExpression.getClass.getSimpleName))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.sedona.sql.utils

import org.apache.sedona.spark.SedonaContext
import org.apache.sedona.sql.RasterRegistrator
import org.apache.sedona.sql.UDF.UdfRegistrator
import org.apache.spark.sql.{SQLContext, SparkSession}

Expand All @@ -44,5 +45,6 @@ object SedonaSQLRegistrator {

def dropAll(sparkSession: SparkSession): Unit = {
UdfRegistrator.dropAll(sparkSession)
RasterRegistrator.dropAll(sparkSession)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.spark.sql.sedona_sql.UDT

import org.apache.spark.sql.types.UDTRegistration

object RasterUdtRegistratorWrapper {

def registerAll(gridClassName: String): Unit = {
UDTRegistration.register(gridClassName, classOf[RasterUDT].getName)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,14 @@
*/
package org.apache.spark.sql.sedona_sql.UDT

import org.slf4j.{Logger, LoggerFactory}
import org.apache.spark.sql.types.UDTRegistration
import org.locationtech.jts.geom.Geometry
import org.locationtech.jts.index.SpatialIndex

object UdtRegistratorWrapper {

val logger: Logger = LoggerFactory.getLogger(getClass)

def registerAll(): Unit = {
UDTRegistration.register(classOf[Geometry].getName, classOf[GeometryUDT].getName)
UDTRegistration.register(classOf[SpatialIndex].getName, classOf[IndexUDT].getName)
// Rasters requires geotools which is optional.
val gridClassName = "org.geotools.coverage.grid.GridCoverage2D"
try {
// Trigger an exception if geotools is not found.
java.lang.Class.forName(gridClassName, true, Thread.currentThread().getContextClassLoader)
UDTRegistration.register(gridClassName, classOf[RasterUDT].getName)
} catch {
case e: ClassNotFoundException => logger.warn("Geotools was not found on the classpath. Raster type will not be registered.")
}
}
}
Loading