From 385d4ba832f195279417c5c4321f4ac04c817310 Mon Sep 17 00:00:00 2001 From: Kristin Cowalcijk Date: Wed, 28 Feb 2024 11:38:15 +0800 Subject: [PATCH] [SEDONA-506] Lenient mode for RS_ZonalStats and RS_ZonalStatsAll (#1257) * Fix handling of geometries with SRID=0 in some raster functions * Add a lenient mode for RS_ZonalStats and RS_ZonalStatsAll * Change the default value of lenient to true for RS_ZonalStats and RS_ZonalStatsAll --- .../common/raster/RasterBandAccessors.java | 50 +++++++++++++++---- .../raster/RasterBandAccessorsTest.java | 18 +++++++ docs/api/sql/Raster-operators.md | 16 +++++- .../raster/RasterBandAccessors.scala | 6 ++- .../apache/sedona/sql/rasteralgebraTest.scala | 8 +++ 5 files changed, 83 insertions(+), 15 deletions(-) diff --git a/common/src/main/java/org/apache/sedona/common/raster/RasterBandAccessors.java b/common/src/main/java/org/apache/sedona/common/raster/RasterBandAccessors.java index dc5bd38324..880ca803ef 100644 --- a/common/src/main/java/org/apache/sedona/common/raster/RasterBandAccessors.java +++ b/common/src/main/java/org/apache/sedona/common/raster/RasterBandAccessors.java @@ -18,7 +18,6 @@ */ package org.apache.sedona.common.raster; -import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.math3.stat.StatUtils; import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics; import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation; @@ -36,6 +35,7 @@ import java.util.Arrays; import java.util.HashMap; import java.util.List; +import java.util.Objects; public class RasterBandAccessors { @@ -94,11 +94,15 @@ public static long getCount(GridCoverage2D raster, int band) { * @param roi Geometry to define the region of interest * @param band Band to be used for computation * @param excludeNoData Specifies whether to exclude no-data value or not + * @param lenient Return null if the raster and roi do not intersect when set to true, otherwise will throw an exception * @return An array with all the stats for the region * @throws FactoryException */ - public static double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi, int band, boolean excludeNoData) throws FactoryException { - List objects = getStatObjects(raster, roi, band, excludeNoData); + public static double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi, int band, boolean excludeNoData, boolean lenient) throws FactoryException { + List objects = getStatObjects(raster, roi, band, excludeNoData, lenient); + if (objects == null) { + return null; + } DescriptiveStatistics stats = (DescriptiveStatistics) objects.get(0); double[] pixelData = (double[]) objects.get(1); @@ -118,6 +122,18 @@ public static double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi, int return result; } + /** + * @param raster Raster to use for computing stats + * @param roi Geometry to define the region of interest + * @param band Band to be used for computation + * @param excludeNoData Specifies whether to exclude no-data value or not + * @return An array with all the stats for the region + * @throws FactoryException + */ + public static double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi, int band, boolean excludeNoData) throws FactoryException { + return getZonalStatsAll(raster, roi, band, excludeNoData, true); + } + /** * @param raster Raster to use for computing stats * @param roi Geometry to define the region of interest @@ -145,12 +161,15 @@ public static double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi) thr * @param band Band to be used for computation * @param statType Define the statistic to be computed * @param excludeNoData Specifies whether to exclude no-data value or not + * @param lenient Return null if the raster and roi do not intersect when set to true, otherwise will throw an exception * @return A double precision floating point number representing the requested statistic calculated over the specified region. * @throws FactoryException */ - public static double getZonalStats(GridCoverage2D raster, Geometry roi, int band, String statType, boolean excludeNoData) throws FactoryException { - - List objects = getStatObjects(raster, roi, band, excludeNoData); + public static Double getZonalStats(GridCoverage2D raster, Geometry roi, int band, String statType, boolean excludeNoData, boolean lenient) throws FactoryException { + List objects = getStatObjects(raster, roi, band, excludeNoData, lenient); + if (objects == null) { + return null; + } DescriptiveStatistics stats = (DescriptiveStatistics) objects.get(0); double[] pixelData = (double[]) objects.get(1); @@ -162,7 +181,7 @@ public static double getZonalStats(GridCoverage2D raster, Geometry roi, int band case "mean": return stats.getMean(); case "count": - return stats.getN(); + return (double) stats.getN(); case "max": return stats.getMax(); case "min": @@ -181,6 +200,10 @@ public static double getZonalStats(GridCoverage2D raster, Geometry roi, int band } } + public static Double getZonalStats(GridCoverage2D raster, Geometry roi, int band, String statType, boolean excludeNoData) throws FactoryException { + return getZonalStats(raster, roi, band, statType, excludeNoData, true); + } + /** * @param raster Raster to use for computing stats * @param roi Geometry to define the region of interest @@ -189,7 +212,7 @@ public static double getZonalStats(GridCoverage2D raster, Geometry roi, int band * @return A double precision floating point number representing the requested statistic calculated over the specified region. The excludeNoData is set to true. * @throws FactoryException */ - public static double getZonalStats(GridCoverage2D raster, Geometry roi, int band, String statType) throws FactoryException { + public static Double getZonalStats(GridCoverage2D raster, Geometry roi, int band, String statType) throws FactoryException { return getZonalStats(raster, roi, band, statType, true); } @@ -200,7 +223,7 @@ public static double getZonalStats(GridCoverage2D raster, Geometry roi, int band * @return A double precision floating point number representing the requested statistic calculated over the specified region. The excludeNoData is set to true and band is set to 1. * @throws FactoryException */ - public static double getZonalStats(GridCoverage2D raster, Geometry roi, String statType) throws FactoryException { + public static Double getZonalStats(GridCoverage2D raster, Geometry roi, String statType) throws FactoryException { return getZonalStats(raster, roi, 1, statType, true); } @@ -219,10 +242,11 @@ private static double zonalMode(double[] pixelData) { * @param roi Geometry to define the region of interest * @param band Band to be used for computation * @param excludeNoData Specifies whether to exclude no-data value or not + * @param lenient Return null if the raster and roi do not intersect when set to true, otherwise will throw an exception * @return an object of DescriptiveStatistics and an array of double with pixel data. * @throws FactoryException */ - private static List getStatObjects(GridCoverage2D raster, Geometry roi, int band, boolean excludeNoData) throws FactoryException { + private static List getStatObjects(GridCoverage2D raster, Geometry roi, int band, boolean excludeNoData, boolean lenient) throws FactoryException { RasterUtils.ensureBand(raster, band); if(RasterAccessors.srid(raster) != roi.getSRID()) { @@ -234,7 +258,11 @@ private static List getStatObjects(GridCoverage2D raster, Geometry roi, // checking if the raster contains the geometry if (!RasterPredicates.rsIntersects(raster, roi)) { - throw new IllegalArgumentException("The provided geometry is not intersecting the raster. Please provide a geometry that is in the raster's extent."); + if (lenient) { + return null; + } else { + throw new IllegalArgumentException("The provided geometry is not intersecting the raster. Please provide a geometry that is in the raster's extent."); + } } Raster rasterData = RasterUtils.getRaster(raster.getRenderedImage()); diff --git a/common/src/test/java/org/apache/sedona/common/raster/RasterBandAccessorsTest.java b/common/src/test/java/org/apache/sedona/common/raster/RasterBandAccessorsTest.java index 1a39862300..17fedd854c 100644 --- a/common/src/test/java/org/apache/sedona/common/raster/RasterBandAccessorsTest.java +++ b/common/src/test/java/org/apache/sedona/common/raster/RasterBandAccessorsTest.java @@ -114,6 +114,15 @@ public void testZonalStats() throws FactoryException, ParseException, IOExceptio actual = RasterBandAccessors.getZonalStats(raster, geom, 1, "sd", false); expected = 92.1327; assertEquals(expected, actual, FP_TOLERANCE); + + geom = Constructors.geomFromWKT("POLYGON ((-77.96672569800863073 37.91971182746296876, -77.9688630154902711 37.89620133516485367, -77.93936803424354309 37.90517806858776595, -77.96672569800863073 37.91971182746296876))", 0); + Double statValue = RasterBandAccessors.getZonalStats(raster, geom, 1, "sum", false, true); + assertNotNull(statValue); + + Geometry nonIntersectingGeom = Constructors.geomFromWKT("POLYGON ((-78.22106647832458748 37.76411511479908967, -78.20183062098976734 37.72863564460374874, -78.18088490966962922 37.76753482276972562, -78.22106647832458748 37.76411511479908967))", 0); + statValue = RasterBandAccessors.getZonalStats(raster, nonIntersectingGeom, 1, "sum", false, true); + assertNull(statValue); + assertThrows(IllegalArgumentException.class, () -> RasterBandAccessors.getZonalStats(raster, nonIntersectingGeom, 1, "sum", false, false)); } @Test @@ -161,6 +170,15 @@ public void testZonalStatsAll() throws IOException, FactoryException, ParseExcep double[] actual = RasterBandAccessors.getZonalStatsAll(raster, geom, 1, false); double[] expected = new double[] {184792.0, 1.0690406E7, 57.8510, 0.0, 0.0, 92.1327, 8488.4480, 0.0, 255.0}; assertArrayEquals(expected, actual, FP_TOLERANCE); + + geom = Constructors.geomFromWKT("POLYGON ((-77.96672569800863073 37.91971182746296876, -77.9688630154902711 37.89620133516485367, -77.93936803424354309 37.90517806858776595, -77.96672569800863073 37.91971182746296876))", 0); + actual = RasterBandAccessors.getZonalStatsAll(raster, geom, 1, false); + assertNotNull(actual); + + Geometry nonIntersectingGeom = Constructors.geomFromWKT("POLYGON ((-78.22106647832458748 37.76411511479908967, -78.20183062098976734 37.72863564460374874, -78.18088490966962922 37.76753482276972562, -78.22106647832458748 37.76411511479908967))", 0); + actual = RasterBandAccessors.getZonalStatsAll(raster, nonIntersectingGeom, 1, false, true); + assertNull(actual); + assertThrows(IllegalArgumentException.class, () -> RasterBandAccessors.getZonalStatsAll(raster, nonIntersectingGeom, 1, false, false)); } @Test diff --git a/docs/api/sql/Raster-operators.md b/docs/api/sql/Raster-operators.md index 6f78a3e8e8..4e1d57e8ce 100644 --- a/docs/api/sql/Raster-operators.md +++ b/docs/api/sql/Raster-operators.md @@ -1095,11 +1095,17 @@ Introduction: This returns a statistic value specified by `statType` over the re The following conditions will throw an `IllegalArgumentException` if they are not met: - - The provided `raster` and `zone` geometry should intersect. + - The provided `raster` and `zone` geometry should intersect when `lenient` parameter is set to `false`. - The option provided to `statType` should be valid. + `lenient` parameter is set to `true` by default. The function will return `null` if the `raster` and `zone` geometry do not intersect. + Format: +``` +RS_ZonalStats(raster: Raster, zone: Geometry, band: Integer, statType: String, excludeNoData: Boolean, lenient: Boolean) +``` + ``` RS_ZonalStats(raster: Raster, zone: Geometry, band: Integer, statType: String, excludeNoData: Boolean) ``` @@ -1157,11 +1163,17 @@ Introduction: Returns an array of statistic values, where each statistic is comp The following conditions will throw an `IllegalArgumentException` if they are not met: - - The provided `raster` and `zone` geometry should intersect. + - The provided `raster` and `zone` geometry should intersect when `lenient` parameter is set to `false`. - The option provided to `statType` should be valid. + `lenient` parameter is set to `true` by default. The function will return `null` if the `raster` and `zone` geometry do not intersect. + Format: +``` +RS_ZonalStatsAll(raster: Raster, zone: Geometry, band: Integer, excludeNodata: Boolean, lenient: Boolean) +``` + ``` RS_ZonalStatsAll(raster: Raster, zone: Geometry, band: Integer, excludeNodata: Boolean) ``` diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterBandAccessors.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterBandAccessors.scala index 6405e841cf..11b4152401 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterBandAccessors.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterBandAccessors.scala @@ -44,7 +44,9 @@ case class RS_Count(inputExpressions: Seq[Expression]) extends InferredExpressio } case class RS_ZonalStats(inputExpressions: Seq[Expression]) extends InferredExpression( - inferrableFunction5(RasterBandAccessors.getZonalStats), inferrableFunction4(RasterBandAccessors.getZonalStats), + inferrableFunction6(RasterBandAccessors.getZonalStats), + inferrableFunction5(RasterBandAccessors.getZonalStats), + inferrableFunction4(RasterBandAccessors.getZonalStats), inferrableFunction3(RasterBandAccessors.getZonalStats) ) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { @@ -54,7 +56,7 @@ case class RS_ZonalStats(inputExpressions: Seq[Expression]) extends InferredExpr case class RS_ZonalStatsAll(inputExpressions: Seq[Expression]) extends InferredExpression( inferrableFunction2(RasterBandAccessors.getZonalStatsAll), inferrableFunction4(RasterBandAccessors.getZonalStatsAll), - inferrableFunction3(RasterBandAccessors.getZonalStatsAll) + inferrableFunction3(RasterBandAccessors.getZonalStatsAll), inferrableFunction5(RasterBandAccessors.getZonalStatsAll) ) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala b/spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala index c0b1ba56a3..e33d83626d 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala @@ -979,6 +979,10 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen // Test with a polygon in EPSG:4326 actual = df.selectExpr("RS_ZonalStats(raster, ST_GeomFromWKT('POLYGON ((-77.96672569800863073 37.91971182746296876, -77.9688630154902711 37.89620133516485367, -77.93936803424354309 37.90517806858776595, -77.96672569800863073 37.91971182746296876))'), 1, 'mean', false)").first().get(0) assertNotNull(actual) + + // Test with a polygon that does not intersect the raster in lenient mode + actual = df.selectExpr("RS_ZonalStats(raster, ST_GeomFromWKT('POLYGON ((-78.22106647832458748 37.76411511479908967, -78.20183062098976734 37.72863564460374874, -78.18088490966962922 37.76753482276972562, -78.22106647832458748 37.76411511479908967))'), 1, 'mean', false)").first().get(0) + assertNull(actual) } it("Passed RS_ZonalStats - Raster with no data") { @@ -1006,6 +1010,10 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen val actual = df.selectExpr("RS_ZonalStatsAll(raster, geom, 1, true)").first().get(0) val expected = Seq(184792.0, 1.0690406E7, 57.851021689230684, 0.0, 0.0, 92.13277429243035, 8488.448098819916, 0.0, 255.0) assertTrue(expected.equals(actual)) + + // Test with a polygon that does not intersect the raster in lenient mode + val actual2 = df.selectExpr("RS_ZonalStatsAll(raster, ST_GeomFromWKT('POLYGON ((-78.22106647832458748 37.76411511479908967, -78.20183062098976734 37.72863564460374874, -78.18088490966962922 37.76753482276972562, -78.22106647832458748 37.76411511479908967))'), 1, false)").first().get(0) + assertNull(actual2) } it("Passed RS_ZonalStatsAll - Raster with no data") {