Skip to content

Commit

Permalink
[SEDONA-506] Lenient mode for RS_ZonalStats and RS_ZonalStatsAll (#1257)
Browse files Browse the repository at this point in the history
* Fix handling of geometries with SRID=0 in some raster functions

* Add a lenient mode for RS_ZonalStats and RS_ZonalStatsAll

* Change the default value of lenient to true for RS_ZonalStats and RS_ZonalStatsAll
  • Loading branch information
Kontinuation authored and jiayuasu committed Apr 28, 2024
1 parent d52dabb commit 385d4ba
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
*/
package org.apache.sedona.common.raster;

import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.math3.stat.StatUtils;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation;
Expand All @@ -36,6 +35,7 @@
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Objects;

public class RasterBandAccessors {

Expand Down Expand Up @@ -94,11 +94,15 @@ public static long getCount(GridCoverage2D raster, int band) {
* @param roi Geometry to define the region of interest
* @param band Band to be used for computation
* @param excludeNoData Specifies whether to exclude no-data value or not
* @param lenient Return null if the raster and roi do not intersect when set to true, otherwise will throw an exception
* @return An array with all the stats for the region
* @throws FactoryException
*/
public static double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi, int band, boolean excludeNoData) throws FactoryException {
List<Object> objects = getStatObjects(raster, roi, band, excludeNoData);
public static double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi, int band, boolean excludeNoData, boolean lenient) throws FactoryException {
List<Object> objects = getStatObjects(raster, roi, band, excludeNoData, lenient);
if (objects == null) {
return null;
}
DescriptiveStatistics stats = (DescriptiveStatistics) objects.get(0);
double[] pixelData = (double[]) objects.get(1);

Expand All @@ -118,6 +122,18 @@ public static double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi, int
return result;
}

/**
* @param raster Raster to use for computing stats
* @param roi Geometry to define the region of interest
* @param band Band to be used for computation
* @param excludeNoData Specifies whether to exclude no-data value or not
* @return An array with all the stats for the region
* @throws FactoryException
*/
public static double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi, int band, boolean excludeNoData) throws FactoryException {
return getZonalStatsAll(raster, roi, band, excludeNoData, true);
}

/**
* @param raster Raster to use for computing stats
* @param roi Geometry to define the region of interest
Expand Down Expand Up @@ -145,12 +161,15 @@ public static double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi) thr
* @param band Band to be used for computation
* @param statType Define the statistic to be computed
* @param excludeNoData Specifies whether to exclude no-data value or not
* @param lenient Return null if the raster and roi do not intersect when set to true, otherwise will throw an exception
* @return A double precision floating point number representing the requested statistic calculated over the specified region.
* @throws FactoryException
*/
public static double getZonalStats(GridCoverage2D raster, Geometry roi, int band, String statType, boolean excludeNoData) throws FactoryException {

List<Object> objects = getStatObjects(raster, roi, band, excludeNoData);
public static Double getZonalStats(GridCoverage2D raster, Geometry roi, int band, String statType, boolean excludeNoData, boolean lenient) throws FactoryException {
List<Object> objects = getStatObjects(raster, roi, band, excludeNoData, lenient);
if (objects == null) {
return null;
}
DescriptiveStatistics stats = (DescriptiveStatistics) objects.get(0);
double[] pixelData = (double[]) objects.get(1);

Expand All @@ -162,7 +181,7 @@ public static double getZonalStats(GridCoverage2D raster, Geometry roi, int band
case "mean":
return stats.getMean();
case "count":
return stats.getN();
return (double) stats.getN();
case "max":
return stats.getMax();
case "min":
Expand All @@ -181,6 +200,10 @@ public static double getZonalStats(GridCoverage2D raster, Geometry roi, int band
}
}

public static Double getZonalStats(GridCoverage2D raster, Geometry roi, int band, String statType, boolean excludeNoData) throws FactoryException {
return getZonalStats(raster, roi, band, statType, excludeNoData, true);
}

/**
* @param raster Raster to use for computing stats
* @param roi Geometry to define the region of interest
Expand All @@ -189,7 +212,7 @@ public static double getZonalStats(GridCoverage2D raster, Geometry roi, int band
* @return A double precision floating point number representing the requested statistic calculated over the specified region. The excludeNoData is set to true.
* @throws FactoryException
*/
public static double getZonalStats(GridCoverage2D raster, Geometry roi, int band, String statType) throws FactoryException {
public static Double getZonalStats(GridCoverage2D raster, Geometry roi, int band, String statType) throws FactoryException {
return getZonalStats(raster, roi, band, statType, true);
}

Expand All @@ -200,7 +223,7 @@ public static double getZonalStats(GridCoverage2D raster, Geometry roi, int band
* @return A double precision floating point number representing the requested statistic calculated over the specified region. The excludeNoData is set to true and band is set to 1.
* @throws FactoryException
*/
public static double getZonalStats(GridCoverage2D raster, Geometry roi, String statType) throws FactoryException {
public static Double getZonalStats(GridCoverage2D raster, Geometry roi, String statType) throws FactoryException {
return getZonalStats(raster, roi, 1, statType, true);
}

Expand All @@ -219,10 +242,11 @@ private static double zonalMode(double[] pixelData) {
* @param roi Geometry to define the region of interest
* @param band Band to be used for computation
* @param excludeNoData Specifies whether to exclude no-data value or not
* @param lenient Return null if the raster and roi do not intersect when set to true, otherwise will throw an exception
* @return an object of DescriptiveStatistics and an array of double with pixel data.
* @throws FactoryException
*/
private static List<Object> getStatObjects(GridCoverage2D raster, Geometry roi, int band, boolean excludeNoData) throws FactoryException {
private static List<Object> getStatObjects(GridCoverage2D raster, Geometry roi, int band, boolean excludeNoData, boolean lenient) throws FactoryException {
RasterUtils.ensureBand(raster, band);

if(RasterAccessors.srid(raster) != roi.getSRID()) {
Expand All @@ -234,7 +258,11 @@ private static List<Object> getStatObjects(GridCoverage2D raster, Geometry roi,

// checking if the raster contains the geometry
if (!RasterPredicates.rsIntersects(raster, roi)) {
throw new IllegalArgumentException("The provided geometry is not intersecting the raster. Please provide a geometry that is in the raster's extent.");
if (lenient) {
return null;
} else {
throw new IllegalArgumentException("The provided geometry is not intersecting the raster. Please provide a geometry that is in the raster's extent.");
}
}

Raster rasterData = RasterUtils.getRaster(raster.getRenderedImage());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,15 @@ public void testZonalStats() throws FactoryException, ParseException, IOExceptio
actual = RasterBandAccessors.getZonalStats(raster, geom, 1, "sd", false);
expected = 92.1327;
assertEquals(expected, actual, FP_TOLERANCE);

geom = Constructors.geomFromWKT("POLYGON ((-77.96672569800863073 37.91971182746296876, -77.9688630154902711 37.89620133516485367, -77.93936803424354309 37.90517806858776595, -77.96672569800863073 37.91971182746296876))", 0);
Double statValue = RasterBandAccessors.getZonalStats(raster, geom, 1, "sum", false, true);
assertNotNull(statValue);

Geometry nonIntersectingGeom = Constructors.geomFromWKT("POLYGON ((-78.22106647832458748 37.76411511479908967, -78.20183062098976734 37.72863564460374874, -78.18088490966962922 37.76753482276972562, -78.22106647832458748 37.76411511479908967))", 0);
statValue = RasterBandAccessors.getZonalStats(raster, nonIntersectingGeom, 1, "sum", false, true);
assertNull(statValue);
assertThrows(IllegalArgumentException.class, () -> RasterBandAccessors.getZonalStats(raster, nonIntersectingGeom, 1, "sum", false, false));
}

@Test
Expand Down Expand Up @@ -161,6 +170,15 @@ public void testZonalStatsAll() throws IOException, FactoryException, ParseExcep
double[] actual = RasterBandAccessors.getZonalStatsAll(raster, geom, 1, false);
double[] expected = new double[] {184792.0, 1.0690406E7, 57.8510, 0.0, 0.0, 92.1327, 8488.4480, 0.0, 255.0};
assertArrayEquals(expected, actual, FP_TOLERANCE);

geom = Constructors.geomFromWKT("POLYGON ((-77.96672569800863073 37.91971182746296876, -77.9688630154902711 37.89620133516485367, -77.93936803424354309 37.90517806858776595, -77.96672569800863073 37.91971182746296876))", 0);
actual = RasterBandAccessors.getZonalStatsAll(raster, geom, 1, false);
assertNotNull(actual);

Geometry nonIntersectingGeom = Constructors.geomFromWKT("POLYGON ((-78.22106647832458748 37.76411511479908967, -78.20183062098976734 37.72863564460374874, -78.18088490966962922 37.76753482276972562, -78.22106647832458748 37.76411511479908967))", 0);
actual = RasterBandAccessors.getZonalStatsAll(raster, nonIntersectingGeom, 1, false, true);
assertNull(actual);
assertThrows(IllegalArgumentException.class, () -> RasterBandAccessors.getZonalStatsAll(raster, nonIntersectingGeom, 1, false, false));
}

@Test
Expand Down
16 changes: 14 additions & 2 deletions docs/api/sql/Raster-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -1095,11 +1095,17 @@ Introduction: This returns a statistic value specified by `statType` over the re

The following conditions will throw an `IllegalArgumentException` if they are not met:

- The provided `raster` and `zone` geometry should intersect.
- The provided `raster` and `zone` geometry should intersect when `lenient` parameter is set to `false`.
- The option provided to `statType` should be valid.

`lenient` parameter is set to `true` by default. The function will return `null` if the `raster` and `zone` geometry do not intersect.

Format:

```
RS_ZonalStats(raster: Raster, zone: Geometry, band: Integer, statType: String, excludeNoData: Boolean, lenient: Boolean)
```

```
RS_ZonalStats(raster: Raster, zone: Geometry, band: Integer, statType: String, excludeNoData: Boolean)
```
Expand Down Expand Up @@ -1157,11 +1163,17 @@ Introduction: Returns an array of statistic values, where each statistic is comp

The following conditions will throw an `IllegalArgumentException` if they are not met:

- The provided `raster` and `zone` geometry should intersect.
- The provided `raster` and `zone` geometry should intersect when `lenient` parameter is set to `false`.
- The option provided to `statType` should be valid.

`lenient` parameter is set to `true` by default. The function will return `null` if the `raster` and `zone` geometry do not intersect.

Format:

```
RS_ZonalStatsAll(raster: Raster, zone: Geometry, band: Integer, excludeNodata: Boolean, lenient: Boolean)
```

```
RS_ZonalStatsAll(raster: Raster, zone: Geometry, band: Integer, excludeNodata: Boolean)
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ case class RS_Count(inputExpressions: Seq[Expression]) extends InferredExpressio
}

case class RS_ZonalStats(inputExpressions: Seq[Expression]) extends InferredExpression(
inferrableFunction5(RasterBandAccessors.getZonalStats), inferrableFunction4(RasterBandAccessors.getZonalStats),
inferrableFunction6(RasterBandAccessors.getZonalStats),
inferrableFunction5(RasterBandAccessors.getZonalStats),
inferrableFunction4(RasterBandAccessors.getZonalStats),
inferrableFunction3(RasterBandAccessors.getZonalStats)
) {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
Expand All @@ -54,7 +56,7 @@ case class RS_ZonalStats(inputExpressions: Seq[Expression]) extends InferredExpr

case class RS_ZonalStatsAll(inputExpressions: Seq[Expression]) extends InferredExpression(
inferrableFunction2(RasterBandAccessors.getZonalStatsAll), inferrableFunction4(RasterBandAccessors.getZonalStatsAll),
inferrableFunction3(RasterBandAccessors.getZonalStatsAll)
inferrableFunction3(RasterBandAccessors.getZonalStatsAll), inferrableFunction5(RasterBandAccessors.getZonalStatsAll)
) {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,10 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen
// Test with a polygon in EPSG:4326
actual = df.selectExpr("RS_ZonalStats(raster, ST_GeomFromWKT('POLYGON ((-77.96672569800863073 37.91971182746296876, -77.9688630154902711 37.89620133516485367, -77.93936803424354309 37.90517806858776595, -77.96672569800863073 37.91971182746296876))'), 1, 'mean', false)").first().get(0)
assertNotNull(actual)

// Test with a polygon that does not intersect the raster in lenient mode
actual = df.selectExpr("RS_ZonalStats(raster, ST_GeomFromWKT('POLYGON ((-78.22106647832458748 37.76411511479908967, -78.20183062098976734 37.72863564460374874, -78.18088490966962922 37.76753482276972562, -78.22106647832458748 37.76411511479908967))'), 1, 'mean', false)").first().get(0)
assertNull(actual)
}

it("Passed RS_ZonalStats - Raster with no data") {
Expand Down Expand Up @@ -1006,6 +1010,10 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen
val actual = df.selectExpr("RS_ZonalStatsAll(raster, geom, 1, true)").first().get(0)
val expected = Seq(184792.0, 1.0690406E7, 57.851021689230684, 0.0, 0.0, 92.13277429243035, 8488.448098819916, 0.0, 255.0)
assertTrue(expected.equals(actual))

// Test with a polygon that does not intersect the raster in lenient mode
val actual2 = df.selectExpr("RS_ZonalStatsAll(raster, ST_GeomFromWKT('POLYGON ((-78.22106647832458748 37.76411511479908967, -78.20183062098976734 37.72863564460374874, -78.18088490966962922 37.76753482276972562, -78.22106647832458748 37.76411511479908967))'), 1, false)").first().get(0)
assertNull(actual2)
}

it("Passed RS_ZonalStatsAll - Raster with no data") {
Expand Down

0 comments on commit 385d4ba

Please sign in to comment.