Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -538,5 +538,10 @@ object Catalog extends AbstractCatalog with Logging {
// are only constructed when registerAll is called and Spark is set up. This lets the
// categorization invariant test access `Catalog.expressions` without bootstrapping Spark.
lazy val aggregateExpressions: Seq[Aggregator[Geometry, _, _]] =
Seq(new ST_Envelope_Aggr, new ST_Intersection_Aggr, new ST_Union_Aggr(), new ST_Collect_Agg())
Seq(
new ST_Envelope_Aggr,
new ST_Extent,
new ST_Intersection_Aggr,
new ST_Union_Aggr(),
new ST_Collect_Agg())
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.spark.sql.sedona_sql.expressions

import org.apache.sedona.common.Functions
import org.apache.sedona.common.geometryObjects.Box2D
import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
Expand Down Expand Up @@ -164,6 +165,51 @@ private[apache] class ST_Envelope_Aggr
def zero: Option[EnvelopeBuffer] = None
}

/**
* Return the planar bounding box (Box2D) of all geometries in the given column. Returns NULL when
* the input contains no rows or all rows are null/empty geometries. Mirrors PostGIS ST_Extent.
*
* ST_Envelope_Aggr is left untouched (returns a polygon Geometry) for backwards compatibility.
*/
private[apache] class ST_Extent extends Aggregator[Geometry, Option[EnvelopeBuffer], Box2D] {

val outputSerde: ExpressionEncoder[Box2D] = ExpressionEncoder[Box2D]()

def reduce(buffer: Option[EnvelopeBuffer], input: Geometry): Option[EnvelopeBuffer] = {
if (input == null || input.isEmpty) return buffer
val env = input.getEnvelopeInternal
val envBuffer = EnvelopeBuffer(env.getMinX, env.getMaxX, env.getMinY, env.getMaxY)
buffer match {
case Some(b) => Some(b.merge(envBuffer))
case None => Some(envBuffer)
}
}

def merge(
buffer1: Option[EnvelopeBuffer],
buffer2: Option[EnvelopeBuffer]): Option[EnvelopeBuffer] = {
(buffer1, buffer2) match {
case (Some(b1), Some(b2)) => Some(b1.merge(b2))
case (Some(_), None) => buffer1
case (None, Some(_)) => buffer2
case (None, None) => None
}
}

def finish(reduction: Option[EnvelopeBuffer]): Box2D = {
reduction match {
case Some(b) => new Box2D(b.minX, b.minY, b.maxX, b.maxY)
case None => null
}
}

def bufferEncoder: Encoder[Option[EnvelopeBuffer]] = Encoders.product[Option[EnvelopeBuffer]]

def outputEncoder: ExpressionEncoder[Box2D] = outputSerde

def zero: Option[EnvelopeBuffer] = None
}

/**
* Return the polygon intersection of all Polygon in the given column
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.apache.sedona.sql

import org.apache.sedona.common.geometryObjects.Box2D
import org.apache.spark.sql.DataFrame
import org.locationtech.jts.geom.{Coordinate, Geometry, GeometryFactory, Polygon}

Expand Down Expand Up @@ -73,6 +74,48 @@ class aggregateFunctionTestScala extends TestBaseScala {
assert(env.getMaxY == 4.0)
}

it("Passed ST_Extent") {
val df = sparkSession.sql(
"SELECT ST_GeomFromWKT(wkt) AS geom FROM VALUES ('POINT (1 2)'), ('POINT (4 5)'), ('LINESTRING (-3 0, 0 0)') AS t(wkt)")
df.createOrReplaceTempView("t")
val bbox =
sparkSession.sql("SELECT ST_Extent(geom) AS bbox FROM t").take(1)(0).getAs[Box2D](0)
assert(bbox.getXMin == -3.0)
assert(bbox.getYMin == 0.0)
assert(bbox.getXMax == 4.0)
assert(bbox.getYMax == 5.0)
}

it("ST_Extent returns null over zero rows") {
val emptyDf = sparkSession.sql(
"SELECT ST_GeomFromWKT(wkt) AS geom FROM VALUES (NULL) AS t(wkt) WHERE wkt IS NOT NULL")
emptyDf.createOrReplaceTempView("empty_extent")
val result = sparkSession.sql("SELECT ST_Extent(geom) FROM empty_extent")
assert(result.take(1)(0).get(0) == null)
}

it("ST_Extent returns null when all inputs are null or empty") {
val nullDf = sparkSession.sql(
"SELECT ST_GeomFromWKT(wkt) AS geom FROM VALUES (CAST(NULL AS STRING)), ('POINT EMPTY'), ('POLYGON EMPTY') AS t(wkt)")
nullDf.createOrReplaceTempView("null_extent")
val result = sparkSession.sql("SELECT ST_Extent(geom) FROM null_extent")
assert(result.take(1)(0).get(0) == null)
}

it("ST_Extent ignores null and empty rows mixed with valid geometries") {
val mixedDf = sparkSession.sql(
"SELECT ST_GeomFromWKT(wkt) AS geom FROM VALUES (CAST(NULL AS STRING)), ('POINT EMPTY'), ('POINT (10 20)'), ('POINT (-5 -5)') AS t(wkt)")
mixedDf.createOrReplaceTempView("mixed_extent")
val bbox = sparkSession
.sql("SELECT ST_Extent(geom) FROM mixed_extent")
.take(1)(0)
.getAs[Box2D](0)
assert(bbox.getXMin == -5.0)
assert(bbox.getYMin == -5.0)
assert(bbox.getXMax == 10.0)
assert(bbox.getYMax == 20.0)
}

it("Passed ST_Union_aggr") {

var polygonCsvDf = sparkSession.read
Expand Down
Loading