Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SEDONA-509] Add Single Statistic version of RS_SummaryStats #1266

Merged
merged 6 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,35 @@ private static List<Object> getStatObjects(GridCoverage2D raster, Geometry roi,
return statObjects;
}

public static double[] getSummaryStats(GridCoverage2D rasterGeom, int band, boolean excludeNoDataValue) {
public static double getSummaryStats(GridCoverage2D rasterGeom, String statType, int band, boolean excludeNoDataValue) {
double[] stats = getSummaryStatsAll(rasterGeom,band,excludeNoDataValue);

if ("count".equalsIgnoreCase(statType)) {
return stats[0];
} else if ("sum".equalsIgnoreCase(statType)) {
return stats[1];
} else if ("mean".equalsIgnoreCase(statType)) {
return stats[2];
} else if ("stddev".equalsIgnoreCase(statType)) {
return stats[3];
} else if ("min".equalsIgnoreCase(statType)) {
return stats[4];
} else if ("max".equalsIgnoreCase(statType)) {
return stats[5];
} else {
throw new IllegalArgumentException("Invalid 'statType': '" + statType + "'. Expected one of: 'count', 'sum', 'mean', 'stddev', 'min', 'max'.");
}
}

public static double getSummaryStats(GridCoverage2D rasterGeom, String statType, int band) {
return getSummaryStats(rasterGeom, statType, band, true);
}

public static double getSummaryStats(GridCoverage2D rasterGeom, String statType) {
return getSummaryStats(rasterGeom, statType, 1, true);
}

public static double[] getSummaryStatsAll(GridCoverage2D rasterGeom, int band, boolean excludeNoDataValue) {
RasterUtils.ensureBand(rasterGeom, band);
Raster raster = RasterUtils.getRaster(rasterGeom.getRenderedImage());
int height = RasterAccessors.getHeight(rasterGeom), width = RasterAccessors.getWidth(rasterGeom);
Expand All @@ -312,7 +340,7 @@ public static double[] getSummaryStats(GridCoverage2D rasterGeom, int band, bool
if (excludeNoDataValue) {
pixelData = new ArrayList<>();
Double noDataValue = RasterBandAccessors.getBandNoDataValue(rasterGeom, band);
for (double pixel: pixels) {
for (double pixel : pixels) {
if (noDataValue == null || pixel != noDataValue) {
pixelData.add(pixel);
}
Expand Down Expand Up @@ -340,12 +368,12 @@ public static double[] getSummaryStats(GridCoverage2D rasterGeom, int band, bool
return new double[]{count, sum, mean, stddev, min, max};
}

public static double[] getSummaryStats(GridCoverage2D raster, int band) {
return getSummaryStats(raster, band, true);
public static double[] getSummaryStatsAll(GridCoverage2D raster, int band) {
return getSummaryStatsAll(raster, band, true);
}

public static double[] getSummaryStats(GridCoverage2D raster) {
return getSummaryStats(raster, 1, true);
public static double[] getSummaryStatsAll(GridCoverage2D raster) {
return getSummaryStatsAll(raster, 1, true);
}

// Adding the function signature when InferredExpression supports function with same arity but different argument types
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
package org.apache.sedona.common.raster;

import org.apache.sedona.common.Constructors;
import org.apache.sedona.common.Functions;
import org.apache.sedona.common.FunctionsGeoTools;
import org.geotools.coverage.grid.GridCoverage2D;
import org.junit.Test;
import org.locationtech.jts.geom.Geometry;
Expand Down Expand Up @@ -207,47 +205,102 @@ public void testZonalStatsAllWithEmptyRaster() throws FactoryException, ParseExc
}

@Test
public void testSummaryStatsWithAllNoData() throws FactoryException {
public void testSummaryStatsAllWithAllNoData() throws FactoryException {
GridCoverage2D emptyRaster = RasterConstructors.makeEmptyRaster(1, 5, 5, 0, 0, 1, -1, 0, 0, 0);
double[] values = new double[] {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
emptyRaster = MapAlgebra.addBandFromArray(emptyRaster, values, 1, 0d);
double[] actual = RasterBandAccessors.getSummaryStats(emptyRaster);
double[] actual = RasterBandAccessors.getSummaryStatsAll(emptyRaster);
double[] expected = {0.0, 0.0, Double.NaN, Double.NaN, Double.NaN, Double.NaN};
assertArrayEquals(expected, actual, FP_TOLERANCE);
}

@Test
public void testSummaryStatsWithEmptyRaster() throws FactoryException {
public void testSummaryStats() throws FactoryException, IOException {
GridCoverage2D emptyRaster = RasterConstructors.makeEmptyRaster(2, 5, 5, 0, 0, 1, -1, 0, 0, 0);
double[] values1 = new double[] {1,2,0,0,0,0,7,8,0,10,11,0,0,0,0,16,17,0,19,20,21,0,23,24,25};
double[] values2 = new double[] {0,0,28,29,0,0,0,33,34,35,36,37,38,0,0,0,0,43,44,45,46,47,48,49,50};
emptyRaster = MapAlgebra.addBandFromArray(emptyRaster, values1, 1, 0d);
emptyRaster = MapAlgebra.addBandFromArray(emptyRaster, values2, 2, 0d);
double[] actual = RasterBandAccessors.getSummaryStats(emptyRaster, 1, false);

GridCoverage2D raster = rasterFromGeoTiff(resourceFolder + "raster/raster_with_no_data/test5.tiff");

double actual = RasterBandAccessors.getSummaryStats(emptyRaster, "count", 2, true);
double expected = 16.0;
assertEquals(expected, actual, FP_TOLERANCE);

actual = RasterBandAccessors.getSummaryStats(emptyRaster, "sum", 2, true);
expected = 642.0;
assertEquals(expected, actual, FP_TOLERANCE);

actual = RasterBandAccessors.getSummaryStats(emptyRaster, "mean", 2, true);
expected = 40.125;
assertEquals(expected, actual, FP_TOLERANCE);

actual = RasterBandAccessors.getSummaryStats(emptyRaster, "stddev", 2, true);
expected = 6.9988838395847095;
assertEquals(expected, actual, FP_TOLERANCE);

actual = RasterBandAccessors.getSummaryStats(emptyRaster, "min", 2, true);
expected = 28.0;
assertEquals(expected, actual, FP_TOLERANCE);

actual = RasterBandAccessors.getSummaryStats(emptyRaster, "max", 2, true);
expected = 50.0;
assertEquals(expected, actual, FP_TOLERANCE);

actual = RasterBandAccessors.getSummaryStats(raster, "count", 1, false);
expected = 1036800.0;
assertEquals(expected, actual, FP_TOLERANCE);

actual = RasterBandAccessors.getSummaryStats(raster, "sum", 1, false);
expected = 2.06233487E8;
assertEquals(expected, actual, FP_TOLERANCE);

actual = RasterBandAccessors.getSummaryStats(raster, "mean", 1, false);
expected = 198.91347125792052;
assertEquals(expected, actual, FP_TOLERANCE);

actual = RasterBandAccessors.getSummaryStats(raster, "stddev", 1, false);
expected = 95.09054096111336;
assertEquals(expected, actual, FP_TOLERANCE);

actual = RasterBandAccessors.getSummaryStats(raster, "min", 1, false);
expected = 0.0;
assertEquals(expected, actual, FP_TOLERANCE);
}

@Test
public void testSummaryStatsAllWithEmptyRaster() throws FactoryException {
GridCoverage2D emptyRaster = RasterConstructors.makeEmptyRaster(2, 5, 5, 0, 0, 1, -1, 0, 0, 0);
double[] values1 = new double[] {1,2,0,0,0,0,7,8,0,10,11,0,0,0,0,16,17,0,19,20,21,0,23,24,25};
double[] values2 = new double[] {0,0,28,29,0,0,0,33,34,35,36,37,38,0,0,0,0,43,44,45,46,47,48,49,50};
emptyRaster = MapAlgebra.addBandFromArray(emptyRaster, values1, 1, 0d);
emptyRaster = MapAlgebra.addBandFromArray(emptyRaster, values2, 2, 0d);
double[] actual = RasterBandAccessors.getSummaryStatsAll(emptyRaster, 1, false);
double[] expected = {25.0, 204.0, 8.1600, 9.2765, 0.0, 25.0};
assertArrayEquals(expected, actual, FP_TOLERANCE);

actual = RasterBandAccessors.getSummaryStats(emptyRaster, 2);
actual = RasterBandAccessors.getSummaryStatsAll(emptyRaster, 2);
expected = new double[]{16.0, 642.0, 40.125, 6.9988838395847095, 28.0, 50.0};
assertArrayEquals(expected, actual, FP_TOLERANCE);

actual = RasterBandAccessors.getSummaryStats(emptyRaster);
actual = RasterBandAccessors.getSummaryStatsAll(emptyRaster);
expected = new double[] {14.0, 204.0, 14.5714, 7.7617, 1.0, 25.0};
assertArrayEquals(expected, actual, FP_TOLERANCE);
}

@Test
public void testSummaryStatsWithRaster() throws IOException {
public void testSummaryStatsAllWithRaster() throws IOException {
GridCoverage2D raster = rasterFromGeoTiff(resourceFolder + "raster/raster_with_no_data/test5.tiff");
double[] actual = RasterBandAccessors.getSummaryStats(raster, 1, false);
double[] actual = RasterBandAccessors.getSummaryStatsAll(raster, 1, false);
double[] expected = {1036800.0, 2.06233487E8, 198.9134, 95.0905, 0.0, 255.0};
assertArrayEquals(expected, actual, FP_TOLERANCE);

actual = RasterBandAccessors.getSummaryStats(raster, 1);
actual = RasterBandAccessors.getSummaryStatsAll(raster, 1);
expected = new double[]{928192.0, 2.06233487E8, 222.1883, 70.2055, 1.0, 255.0};
assertArrayEquals(expected, actual, FP_TOLERANCE);

actual = RasterBandAccessors.getSummaryStats(raster);
actual = RasterBandAccessors.getSummaryStatsAll(raster);
assertArrayEquals(expected, actual, FP_TOLERANCE);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,40 @@ public void testSetBandNoDataValueWithNull() throws IOException {
assertEquals(expected, actual);
}

@Test
public void testGetSummaryStats() throws IOException {
GridCoverage2D raster = rasterFromGeoTiff(resourceFolder + "raster/raster_with_no_data/test5.tiff");
raster = RasterBandEditors.setBandNoDataValue(raster, 1, 10.0, true);

// Test single output
double resultSummary = RasterBandAccessors.getSummaryStats(raster, "count", 1, false);
assertEquals(1036800, (int) resultSummary);

resultSummary = RasterBandAccessors.getSummaryStats(raster, "sum", 1, false);
assertEquals(207319567, (int) resultSummary);

resultSummary = RasterBandAccessors.getSummaryStats(raster, "mean", 1, false);
assertEquals(199, (int) resultSummary);

resultSummary = RasterBandAccessors.getSummaryStats(raster, "stddev", 1, false);
assertEquals(92, (int) resultSummary);

resultSummary = RasterBandAccessors.getSummaryStats(raster, "min", 1, false);
assertEquals(1, (int) resultSummary);

resultSummary = RasterBandAccessors.getSummaryStats(raster, "max", 1, false);
assertEquals(255, (int) resultSummary);
}

@Test
public void testSetBandNoDataValueWithReplaceOptionRaster() throws IOException {
GridCoverage2D raster = rasterFromGeoTiff(resourceFolder + "raster/raster_with_no_data/test5.tiff");
double[] originalSummary = RasterBandAccessors.getSummaryStats(raster, 1, false);
double[] originalSummary = RasterBandAccessors.getSummaryStatsAll(raster, 1, false);
int sumOG = (int) originalSummary[1];

assertEquals(206233487, sumOG);
GridCoverage2D resultRaster = RasterBandEditors.setBandNoDataValue(raster, 1, 10.0, true);
double[] resultSummary = RasterBandAccessors.getSummaryStats(resultRaster, 1, false);
double[] resultSummary = RasterBandAccessors.getSummaryStatsAll(resultRaster, 1, false);
int sumActual = (int) resultSummary[1];

// 108608 is the total no-data values in the raster
Expand All @@ -82,7 +107,7 @@ public void testSetBandNoDataValueWithReplaceOptionRaster() throws IOException {

// Not replacing previous no-data value
resultRaster = RasterBandEditors.setBandNoDataValue(raster, 1, 10.0);
resultSummary = RasterBandAccessors.getSummaryStats(resultRaster, 1, false);
resultSummary = RasterBandAccessors.getSummaryStatsAll(resultRaster, 1, false);
sumActual = (int) resultSummary[1];
assertEquals(sumOG, sumActual);
}
Expand Down
51 changes: 46 additions & 5 deletions docs/api/sql/Raster-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,45 @@ Output:

### RS_SummaryStats

Introduction: Returns summary statistic for a particular band based on the `statType` parameter. The function defaults to band index of `1` when `band` is not specified and excludes noDataValue if `excludeNoDataValue` is not specified.

`statType` parameter takes the following strings:

- `count`: Total count of all pixels in the specified band
- `sum`: Sum of all pixel values in the specified band
- `mean`: Mean value of all pixel values in the specified band
- `stddev`: Standard deviation of all pixels in the specified band
- `min`: Minimum pixel value in the specified band
- `max`: Maximum pixel value in the specified band

!!!Note
If excludeNoDataValue is set `true` then it will only count pixels with value not equal to the nodata value of the raster.
Set excludeNoDataValue to `false` to get count of all pixels in raster.

Formats:

`RS_SummaryStats(raster: Raster, statType: String, band: Integer = 1, excludeNoDataValue: Boolean = true)`

`RS_SummaryStats(raster: Raster, statType: String, band: Integer = 1)`

`RS_SummaryStats(raster: Raster, statType: String)`

Since: `v1.6.0`

SQL Example

```sql
SELECT RS_SummaryStats(RS_MakeEmptyRaster(2, 5, 5, 0, 0, 1, -1, 0, 0, 0), "stddev", 1, false)
```

Output:

```
9.4678403028357
```

### RS_SummaryStatsAll

Introduction: Returns summary stats consisting of count, sum, mean, stddev, min, max for a given band in raster. If band is not specified then it defaults to `1`.

!!!Note
Expand All @@ -1044,18 +1083,20 @@ Introduction: Returns summary stats consisting of count, sum, mean, stddev, min,
!!!Note
If the mentioned band index doesn't exist, this will throw an `IllegalArgumentException`.

`RS_SummaryStats(raster: Raster, band: Integer = 1, excludeNoDataValue: Boolean = true)`
Formats:

`RS_SummaryStatsAll(raster: Raster, band: Integer = 1, excludeNoDataValue: Boolean = true)`

`RS_SummaryStats(raster: Raster, band: Integer = 1)`
`RS_SummaryStatsAll(raster: Raster, band: Integer = 1)`

`RS_SummaryStats(raster: Raster)`
`RS_SummaryStatsAll(raster: Raster)`

Since: `v1.5.0`

SQL Example

```sql
SELECT RS_SummaryStats(RS_MakeEmptyRaster(2, 5, 5, 0, 0, 1, -1, 0, 0, 0), 1, false)
SELECT RS_SummaryStatsAll(RS_MakeEmptyRaster(2, 5, 5, 0, 0, 1, -1, 0, 0, 0), 1, false)
```

Output:
Expand All @@ -1067,7 +1108,7 @@ Output:
SQL Example

```sql
SELECT RS_SummaryStats(RS_MakeEmptyRaster(2, 5, 5, 0, 0, 1, -1, 0, 0, 0), 1)
SELECT RS_SummaryStatsAll(RS_MakeEmptyRaster(2, 5, 5, 0, 0, 1, -1, 0, 0, 0), 1)
```

Output:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ object Catalog {
function[RS_Clip](),
function[RS_Band](),
function[RS_AddBand](),
function[RS_SummaryStatsAll](),
function[RS_SummaryStats](),
function[RS_BandIsNoData](),
function[RS_ConvexHull](),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,16 @@ case class RS_ZonalStatsAll(inputExpressions: Seq[Expression]) extends InferredE
}

case class RS_SummaryStats(inputExpressions: Seq[Expression]) extends InferredExpression(
inferrableFunction1(RasterBandAccessors.getSummaryStats), inferrableFunction2(RasterBandAccessors.getSummaryStats),
inferrableFunction3(RasterBandAccessors.getSummaryStats)) {
inferrableFunction2(RasterBandAccessors.getSummaryStats), inferrableFunction3(RasterBandAccessors.getSummaryStats),
inferrableFunction4(RasterBandAccessors.getSummaryStats)) {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
}
}

case class RS_SummaryStatsAll(inputExpressions: Seq[Expression]) extends InferredExpression(
inferrableFunction1(RasterBandAccessors.getSummaryStatsAll), inferrableFunction2(RasterBandAccessors.getSummaryStatsAll),
inferrableFunction3(RasterBandAccessors.getSummaryStatsAll)) {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1040,23 +1040,46 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen
it("Passed RS_SummaryStats with raster") {
var df = sparkSession.read.format("binaryFile").load(resourceFolder + "raster/raster_with_no_data/test5.tiff")
df = df.selectExpr("RS_FromGeoTiff(content) as raster")
var actual = df.selectExpr("RS_SummaryStats(raster, 1, false)").first().getSeq(0)

var actual = df.selectExpr("RS_SummaryStats(raster, 'count')").first().getDouble(0)
assertEquals(928192.0, actual, 0.1d)

actual = df.selectExpr("RS_SummaryStats(raster, 'sum', 1)").first().getDouble(0)
assertEquals(2.06233487E8, actual, 0.1d)

actual = df.selectExpr("RS_SummaryStats(raster, 'mean', 1, false)").first().getDouble(0)
assertEquals(198.91347125771605, actual, 0.1d)

actual = df.selectExpr("RS_SummaryStats(raster, 'stddev', 1, false)").first().getDouble(0)
assertEquals(95.09054096106192, actual, 0.1d)

actual = df.selectExpr("RS_SummaryStats(raster, 'min', 1, false)").first().getDouble(0)
assertEquals(0.0, actual, 0.1d)

actual = df.selectExpr("RS_SummaryStats(raster, 'max', 1, false)").first().getDouble(0)
assertEquals(255.0, actual, 0.1d)
}

it("Passed RS_SummaryStatsAll with raster") {
var df = sparkSession.read.format("binaryFile").load(resourceFolder + "raster/raster_with_no_data/test5.tiff")
df = df.selectExpr("RS_FromGeoTiff(content) as raster")
var actual = df.selectExpr("RS_SummaryStatsAll(raster, 1, false)").first().getSeq(0)
assertEquals(1036800.0, actual.head, 0.1d)
assertEquals(2.06233487E8, actual(1), 0.1d)
assertEquals(198.91347125771605, actual(2), 1e-6d)
assertEquals(95.09054096106192, actual(3), 1e-6d)
assertEquals(0.0, actual(4), 0.1d)
assertEquals(255.0, actual(5), 0.1d)

actual = df.selectExpr("RS_SummaryStats(raster, 1)").first().getSeq(0)
actual = df.selectExpr("RS_SummaryStatsAll(raster, 1)").first().getSeq(0)
assertEquals(928192.0, actual.head, 0.1d)
assertEquals(2.06233487E8, actual(1), 0.1d)
assertEquals(222.18839097945252, actual(2), 1e-6d)
assertEquals(70.20559521132097, actual(3), 1e-6d)
assertEquals(1.0, actual(4), 0.1d)
assertEquals(255.0, actual(5), 0.1d)

actual = df.selectExpr("RS_SummaryStats(raster)").first().getSeq(0)
actual = df.selectExpr("RS_SummaryStatsAll(raster)").first().getSeq(0)
assertEquals(928192.0, actual.head, 0.1d)
assertEquals(2.06233487E8, actual(1), 0.1d)
assertEquals(222.18839097945252, actual(2), 1e-6d)
Expand Down
Loading