Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SEDONA-506] Lenient mode for RS_ZonalStats and RS_ZonalStatsAll #1257

Merged
merged 3 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
*/
package org.apache.sedona.common.raster;

import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.math3.stat.StatUtils;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation;
Expand All @@ -36,6 +35,7 @@
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Objects;

public class RasterBandAccessors {

Expand Down Expand Up @@ -94,11 +94,15 @@ public static long getCount(GridCoverage2D raster, int band) {
* @param roi Geometry to define the region of interest
* @param band Band to be used for computation
* @param excludeNoData Specifies whether to exclude no-data value or not
* @param lenient Return null if the raster and roi do not intersect when set to true, otherwise will throw an exception
* @return An array with all the stats for the region
* @throws FactoryException
*/
public static double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi, int band, boolean excludeNoData) throws FactoryException {
List<Object> objects = getStatObjects(raster, roi, band, excludeNoData);
public static double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi, int band, boolean excludeNoData, boolean lenient) throws FactoryException {
List<Object> objects = getStatObjects(raster, roi, band, excludeNoData, lenient);
if (objects == null) {
return null;
}
DescriptiveStatistics stats = (DescriptiveStatistics) objects.get(0);
double[] pixelData = (double[]) objects.get(1);

Expand All @@ -118,6 +122,18 @@ public static double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi, int
return result;
}

/**
* @param raster Raster to use for computing stats
* @param roi Geometry to define the region of interest
* @param band Band to be used for computation
* @param excludeNoData Specifies whether to exclude no-data value or not
* @return An array with all the stats for the region
* @throws FactoryException
*/
public static double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi, int band, boolean excludeNoData) throws FactoryException {
return getZonalStatsAll(raster, roi, band, excludeNoData, true);
}

/**
* @param raster Raster to use for computing stats
* @param roi Geometry to define the region of interest
Expand Down Expand Up @@ -145,12 +161,15 @@ public static double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi) thr
* @param band Band to be used for computation
* @param statType Define the statistic to be computed
* @param excludeNoData Specifies whether to exclude no-data value or not
* @param lenient Return null if the raster and roi do not intersect when set to true, otherwise will throw an exception
* @return A double precision floating point number representing the requested statistic calculated over the specified region.
* @throws FactoryException
*/
public static double getZonalStats(GridCoverage2D raster, Geometry roi, int band, String statType, boolean excludeNoData) throws FactoryException {

List<Object> objects = getStatObjects(raster, roi, band, excludeNoData);
public static Double getZonalStats(GridCoverage2D raster, Geometry roi, int band, String statType, boolean excludeNoData, boolean lenient) throws FactoryException {
List<Object> objects = getStatObjects(raster, roi, band, excludeNoData, lenient);
if (objects == null) {
return null;
}
DescriptiveStatistics stats = (DescriptiveStatistics) objects.get(0);
double[] pixelData = (double[]) objects.get(1);

Expand All @@ -162,7 +181,7 @@ public static double getZonalStats(GridCoverage2D raster, Geometry roi, int band
case "mean":
return stats.getMean();
case "count":
return stats.getN();
return (double) stats.getN();
case "max":
return stats.getMax();
case "min":
Expand All @@ -181,6 +200,10 @@ public static double getZonalStats(GridCoverage2D raster, Geometry roi, int band
}
}

public static Double getZonalStats(GridCoverage2D raster, Geometry roi, int band, String statType, boolean excludeNoData) throws FactoryException {
return getZonalStats(raster, roi, band, statType, excludeNoData, true);
}

/**
* @param raster Raster to use for computing stats
* @param roi Geometry to define the region of interest
Expand All @@ -189,7 +212,7 @@ public static double getZonalStats(GridCoverage2D raster, Geometry roi, int band
* @return A double precision floating point number representing the requested statistic calculated over the specified region. The excludeNoData is set to true.
* @throws FactoryException
*/
public static double getZonalStats(GridCoverage2D raster, Geometry roi, int band, String statType) throws FactoryException {
public static Double getZonalStats(GridCoverage2D raster, Geometry roi, int band, String statType) throws FactoryException {
return getZonalStats(raster, roi, band, statType, true);
}

Expand All @@ -200,7 +223,7 @@ public static double getZonalStats(GridCoverage2D raster, Geometry roi, int band
* @return A double precision floating point number representing the requested statistic calculated over the specified region. The excludeNoData is set to true and band is set to 1.
* @throws FactoryException
*/
public static double getZonalStats(GridCoverage2D raster, Geometry roi, String statType) throws FactoryException {
public static Double getZonalStats(GridCoverage2D raster, Geometry roi, String statType) throws FactoryException {
return getZonalStats(raster, roi, 1, statType, true);
}

Expand All @@ -219,10 +242,11 @@ private static double zonalMode(double[] pixelData) {
* @param roi Geometry to define the region of interest
* @param band Band to be used for computation
* @param excludeNoData Specifies whether to exclude no-data value or not
* @param lenient Return null if the raster and roi do not intersect when set to true, otherwise will throw an exception
* @return an object of DescriptiveStatistics and an array of double with pixel data.
* @throws FactoryException
*/
private static List<Object> getStatObjects(GridCoverage2D raster, Geometry roi, int band, boolean excludeNoData) throws FactoryException {
private static List<Object> getStatObjects(GridCoverage2D raster, Geometry roi, int band, boolean excludeNoData, boolean lenient) throws FactoryException {
RasterUtils.ensureBand(raster, band);

if(RasterAccessors.srid(raster) != roi.getSRID()) {
Expand All @@ -234,7 +258,11 @@ private static List<Object> getStatObjects(GridCoverage2D raster, Geometry roi,

// checking if the raster contains the geometry
if (!RasterPredicates.rsIntersects(raster, roi)) {
throw new IllegalArgumentException("The provided geometry is not intersecting the raster. Please provide a geometry that is in the raster's extent.");
if (lenient) {
return null;
} else {
throw new IllegalArgumentException("The provided geometry is not intersecting the raster. Please provide a geometry that is in the raster's extent.");
}
}

Raster rasterData = RasterUtils.getRaster(raster.getRenderedImage());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,15 @@ public void testZonalStats() throws FactoryException, ParseException, IOExceptio
actual = RasterBandAccessors.getZonalStats(raster, geom, 1, "sd", false);
expected = 92.1327;
assertEquals(expected, actual, FP_TOLERANCE);

geom = Constructors.geomFromWKT("POLYGON ((-77.96672569800863073 37.91971182746296876, -77.9688630154902711 37.89620133516485367, -77.93936803424354309 37.90517806858776595, -77.96672569800863073 37.91971182746296876))", 0);
Double statValue = RasterBandAccessors.getZonalStats(raster, geom, 1, "sum", false, true);
assertNotNull(statValue);

Geometry nonIntersectingGeom = Constructors.geomFromWKT("POLYGON ((-78.22106647832458748 37.76411511479908967, -78.20183062098976734 37.72863564460374874, -78.18088490966962922 37.76753482276972562, -78.22106647832458748 37.76411511479908967))", 0);
statValue = RasterBandAccessors.getZonalStats(raster, nonIntersectingGeom, 1, "sum", false, true);
assertNull(statValue);
assertThrows(IllegalArgumentException.class, () -> RasterBandAccessors.getZonalStats(raster, nonIntersectingGeom, 1, "sum", false, false));
}

@Test
Expand Down Expand Up @@ -161,6 +170,15 @@ public void testZonalStatsAll() throws IOException, FactoryException, ParseExcep
double[] actual = RasterBandAccessors.getZonalStatsAll(raster, geom, 1, false);
double[] expected = new double[] {184792.0, 1.0690406E7, 57.8510, 0.0, 0.0, 92.1327, 8488.4480, 0.0, 255.0};
assertArrayEquals(expected, actual, FP_TOLERANCE);

geom = Constructors.geomFromWKT("POLYGON ((-77.96672569800863073 37.91971182746296876, -77.9688630154902711 37.89620133516485367, -77.93936803424354309 37.90517806858776595, -77.96672569800863073 37.91971182746296876))", 0);
actual = RasterBandAccessors.getZonalStatsAll(raster, geom, 1, false);
assertNotNull(actual);

Geometry nonIntersectingGeom = Constructors.geomFromWKT("POLYGON ((-78.22106647832458748 37.76411511479908967, -78.20183062098976734 37.72863564460374874, -78.18088490966962922 37.76753482276972562, -78.22106647832458748 37.76411511479908967))", 0);
actual = RasterBandAccessors.getZonalStatsAll(raster, nonIntersectingGeom, 1, false, true);
assertNull(actual);
assertThrows(IllegalArgumentException.class, () -> RasterBandAccessors.getZonalStatsAll(raster, nonIntersectingGeom, 1, false, false));
}

@Test
Expand Down
16 changes: 14 additions & 2 deletions docs/api/sql/Raster-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -1095,11 +1095,17 @@ Introduction: This returns a statistic value specified by `statType` over the re

The following conditions will throw an `IllegalArgumentException` if they are not met:

- The provided `raster` and `zone` geometry should intersect.
- The provided `raster` and `zone` geometry should intersect when `lenient` parameter is set to `false`.
- The option provided to `statType` should be valid.

`lenient` parameter is set to `true` by default. The function will return `null` if the `raster` and `zone` geometry do not intersect.

Format:

```
RS_ZonalStats(raster: Raster, zone: Geometry, band: Integer, statType: String, excludeNoData: Boolean, lenient: Boolean)
```

```
RS_ZonalStats(raster: Raster, zone: Geometry, band: Integer, statType: String, excludeNoData: Boolean)
```
Expand Down Expand Up @@ -1157,11 +1163,17 @@ Introduction: Returns an array of statistic values, where each statistic is comp

The following conditions will throw an `IllegalArgumentException` if they are not met:

- The provided `raster` and `zone` geometry should intersect.
- The provided `raster` and `zone` geometry should intersect when `lenient` parameter is set to `false`.
- The option provided to `statType` should be valid.

`lenient` parameter is set to `true` by default. The function will return `null` if the `raster` and `zone` geometry do not intersect.

Format:

```
RS_ZonalStatsAll(raster: Raster, zone: Geometry, band: Integer, excludeNodata: Boolean, lenient: Boolean)
```

```
RS_ZonalStatsAll(raster: Raster, zone: Geometry, band: Integer, excludeNodata: Boolean)
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ case class RS_Count(inputExpressions: Seq[Expression]) extends InferredExpressio
}

case class RS_ZonalStats(inputExpressions: Seq[Expression]) extends InferredExpression(
inferrableFunction5(RasterBandAccessors.getZonalStats), inferrableFunction4(RasterBandAccessors.getZonalStats),
inferrableFunction6(RasterBandAccessors.getZonalStats),
inferrableFunction5(RasterBandAccessors.getZonalStats),
inferrableFunction4(RasterBandAccessors.getZonalStats),
inferrableFunction3(RasterBandAccessors.getZonalStats)
) {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
Expand All @@ -54,7 +56,7 @@ case class RS_ZonalStats(inputExpressions: Seq[Expression]) extends InferredExpr

case class RS_ZonalStatsAll(inputExpressions: Seq[Expression]) extends InferredExpression(
inferrableFunction2(RasterBandAccessors.getZonalStatsAll), inferrableFunction4(RasterBandAccessors.getZonalStatsAll),
inferrableFunction3(RasterBandAccessors.getZonalStatsAll)
inferrableFunction3(RasterBandAccessors.getZonalStatsAll), inferrableFunction5(RasterBandAccessors.getZonalStatsAll)
) {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,10 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen
// Test with a polygon in EPSG:4326
actual = df.selectExpr("RS_ZonalStats(raster, ST_GeomFromWKT('POLYGON ((-77.96672569800863073 37.91971182746296876, -77.9688630154902711 37.89620133516485367, -77.93936803424354309 37.90517806858776595, -77.96672569800863073 37.91971182746296876))'), 1, 'mean', false)").first().get(0)
assertNotNull(actual)

// Test with a polygon that does not intersect the raster in lenient mode
actual = df.selectExpr("RS_ZonalStats(raster, ST_GeomFromWKT('POLYGON ((-78.22106647832458748 37.76411511479908967, -78.20183062098976734 37.72863564460374874, -78.18088490966962922 37.76753482276972562, -78.22106647832458748 37.76411511479908967))'), 1, 'mean', false)").first().get(0)
assertNull(actual)
}

it("Passed RS_ZonalStats - Raster with no data") {
Expand Down Expand Up @@ -1019,6 +1023,10 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen
val actual = df.selectExpr("RS_ZonalStatsAll(raster, geom, 1, true)").first().get(0)
val expected = Seq(184792.0, 1.0690406E7, 57.851021689230684, 0.0, 0.0, 92.13277429243035, 8488.448098819916, 0.0, 255.0)
assertTrue(expected.equals(actual))

// Test with a polygon that does not intersect the raster in lenient mode
val actual2 = df.selectExpr("RS_ZonalStatsAll(raster, ST_GeomFromWKT('POLYGON ((-78.22106647832458748 37.76411511479908967, -78.20183062098976734 37.72863564460374874, -78.18088490966962922 37.76753482276972562, -78.22106647832458748 37.76411511479908967))'), 1, false)").first().get(0)
assertNull(actual2)
}

it("Passed RS_ZonalStatsAll - Raster with no data") {
Expand Down
Loading