From 170871a40fa540fda44ecba468b7d27989a43a50 Mon Sep 17 00:00:00 2001 From: Pranav Toggi Date: Mon, 29 Apr 2024 18:59:54 -0400 Subject: [PATCH] [SEDONA-549] Make RS_Union_aggr support combining all bands of multi-band rasters (#1375) * Init: move class level members to data buffer * Init: add multi-band support * init: add tests * Add index checks * add tests * fix test * update docs * update docs * fix typo * fix test * update doc example * add comments to doc example * fix lint issue * Update example to show group by geometry, year, quarter --- docs/api/sql/Raster-aggregate-function.md | 68 ++++++++++--- .../raster/AggregateFunctions.scala | 93 +++++++----------- .../src/test/resources/raster/test4.tiff | Bin 0 -> 467 bytes .../src/test/resources/raster/test5.tiff | Bin 0 -> 467 bytes .../src/test/resources/raster/test6.tiff | Bin 0 -> 467 bytes .../org/apache/sedona/sql/rasterIOTest.scala | 2 +- .../apache/sedona/sql/rasteralgebraTest.scala | 78 ++++++++++++++- 7 files changed, 167 insertions(+), 74 deletions(-) create mode 100644 spark/common/src/test/resources/raster/test4.tiff create mode 100644 spark/common/src/test/resources/raster/test5.tiff create mode 100644 spark/common/src/test/resources/raster/test6.tiff diff --git a/docs/api/sql/Raster-aggregate-function.md b/docs/api/sql/Raster-aggregate-function.md index dd72aaa13b..1c51421d1b 100644 --- a/docs/api/sql/Raster-aggregate-function.md +++ b/docs/api/sql/Raster-aggregate-function.md @@ -1,9 +1,9 @@ ## RS_Union_Aggr -Introduction: Returns a raster containing bands by specified indexes from all rasters in the provided column. Extracts the first bands from each raster and combines them into the output raster based on the input index values. +Introduction: This function combines multiple rasters into a single multiband raster by stacking the bands of each input raster sequentially. The function arranges the bands in the output raster according to the order specified by the index column in the input. It is typically used in scenarios where rasters are grouped by certain criteria (e.g., time and/or location) and an aggregated raster output is desired. !!!Note - RS_Union_Aggr can take multiple banded rasters as input, but it would only extract the first band to the resulting raster. RS_Union_Aggr expects the following input, if not satisfied then will throw an IllegalArgumentException: + RS_Union_Aggr expects the following input, if not satisfied then will throw an IllegalArgumentException: - Indexes to be in an arithmetic sequence without any gaps. - Indexes to be unique and not repeated. @@ -13,30 +13,66 @@ Format: `RS_Union_Aggr(A: rasterColumn, B: indexColumn)` Since: `v1.5.1` -SQL Example +SQL Example: -Contents of `raster_table`. +First, we enrich the dataset with time-based grouping columns and index the rasters based on time intervals: ``` -+------------------------------+-----+ -| raster|index| -+------------------------------+-----+ -|GridCoverage2D["geotiff_cov...| 1| -|GridCoverage2D["geotiff_cov...| 2| -|GridCoverage2D["geotiff_cov...| 3| -|GridCoverage2D["geotiff_cov...| 4| -|GridCoverage2D["geotiff_cov...| 5| -+------------------------------+-----+ +// Add yearly and quarterly time interval columns for grouping +df = df + .withColumn("year", year($"timestamp")) + .withColumn("quarter", quarter($"timestamp")) + +// Define window specs for quarterly indexing within each geometry-year group +windowSpecQuarter = Window.partitionBy("geometry", "year", "quarter").orderBy("timestamp") + +indexedDf = df.withColumn("index", row_number().over(windowSpecQuarter)) + +indexedDf.show() ``` +The indexed rasters will appear as follows, showing that each raster is tagged with a sequential index (ordered by timestamp) within its group (grouped by geometry, year and quarter). + +``` ++-------------------+-----------------------------+--------------+----+-------+-----+ +|timestamp |raster |geometry |year|quarter|index| ++-------------------+-----------------------------+--------------+----+-------+-----+ +|2021-01-10 00:00:00|GridCoverage2D["geotiff_co...|POINT (72 120)|2021|1 |1 | +|2021-01-25 00:00:00|GridCoverage2D["geotiff_co...|POINT (72 120)|2021|1 |2 | +|2021-02-15 00:00:00|GridCoverage2D["geotiff_co...|POINT (72 120)|2021|1 |3 | +|2021-03-15 00:00:00|GridCoverage2D["geotiff_co...|POINT (72 120)|2021|1 |4 | +|2021-03-25 00:00:00|GridCoverage2D["geotiff_co...|POINT (72 120)|2021|1 |5 | +|2021-04-10 00:00:00|GridCoverage2D["geotiff_co...|POINT (84 132)|2021|2 |1 | +|2021-04-22 00:00:00|GridCoverage2D["geotiff_co...|POINT (84 132)|2021|2 |2 | +|2021-05-15 00:00:00|GridCoverage2D["geotiff_co...|POINT (84 132)|2021|2 |3 | +|2021-05-20 00:00:00|GridCoverage2D["geotiff_co...|POINT (84 132)|2021|2 |4 | +|2021-05-29 00:00:00|GridCoverage2D["geotiff_co...|POINT (84 132)|2021|2 |5 | +|2021-06-10 00:00:00|GridCoverage2D["geotiff_co...|POINT (84 132)|2021|2 |6 | ++-------------------+-----------------------------+------------- +----+-------+-----+ ``` -SELECT RS_Union_Aggr(raster, index) FROM raster_table + +To create a stacked raster by grouping on geometry. + +``` +indexedDf.createOrReplaceTempView("indexedDf") + +sedona.sql(''' + SELECT geometry, year, quarter, RS_Union_Aggr(raster, index) AS aggregated_raster + FROM indexedDf + WHERE index <= 4 + GROUP BY geometry, year, quarter +''').show() ``` Output: -This output raster contains the first band of each raster in the `raster_table` at specified index. +The query yields rasters grouped by geometry, year and quarter, each containing the first four time steps combined into a single multiband raster, where each band represents one time step. ``` -GridCoverage2D["geotiff_coverage", GeneralEnvel... ++--------------+----+-------+--------------------+---------+ +| geometry|year|quarter| raster|Num_Bands| ++--------------+----+-------+--------------------+---------+ +|POINT (72 120)|2021|1 |GridCoverage2D["g...| 4| +|POINT (84 132)|2021|2 |GridCoverage2D["g...| 4| ++--------------+----+-------+--------------------+---------+ ``` diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/AggregateFunctions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/AggregateFunctions.scala index b768416383..3bf1326290 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/AggregateFunctions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/AggregateFunctions.scala @@ -33,12 +33,10 @@ import javax.media.jai.RasterFactory import scala.collection.mutable.ArrayBuffer case class BandData( - var bandInt: Array[Int], - var bandDouble: Array[Double], + var bandsData: Array[Array[Double]], var index: Int, - var isIntegral: Boolean, var serializedRaster: Array[Byte], - var serializedSampleDimension: Array[Byte] + var serializedSampleDimensions: Array[Array[Byte]] ) @@ -50,49 +48,35 @@ class RS_Union_Aggr extends Aggregator[(GridCoverage2D, Int), ArrayBuffer[BandDa def zero: ArrayBuffer[BandData] = ArrayBuffer[BandData]() def reduce(buffer: ArrayBuffer[BandData], input: (GridCoverage2D, Int)): ArrayBuffer[BandData] = { - val raster = input._1 - val rasterData = RasterUtils.getRaster(raster.getRenderedImage) - val isIntegral = RasterUtils.isDataTypeIntegral(rasterData.getDataBuffer.getDataType) - - // Serializing GridSampleDimension - val serializedBytes = Serde.serializeGridSampleDimension(raster.getSampleDimension(0)) - - // Check and set dimensions based on the first raster in the buffer - if (buffer.isEmpty) { - val width = RasterAccessors.getWidth(raster) - val height = RasterAccessors.getHeight(raster) - val referenceSerializedRaster = Serde.serialize(raster) - - buffer += BandData( - if (isIntegral) rasterData.getSamples(0, 0, width, height, 0, null.asInstanceOf[Array[Int]]) else null, - if (!isIntegral) rasterData.getSamples(0, 0, width, height, 0, null.asInstanceOf[Array[Double]]) else null, - input._2, - isIntegral, - referenceSerializedRaster, - serializedBytes - ) - } else { - val referenceRaster = Serde.deserialize(buffer.head.serializedRaster) - val width = RasterAccessors.getWidth(referenceRaster) - val height = RasterAccessors.getHeight(referenceRaster) - - if (width != RasterAccessors.getWidth(raster) || height != RasterAccessors.getHeight(raster)) { - throw new IllegalArgumentException("All rasters must have the same dimensions") + val raster = input._1 + val renderedImage = raster.getRenderedImage + val numBands = renderedImage.getSampleModel.getNumBands + val width = renderedImage.getWidth + val height = renderedImage.getHeight + + // First check if this is the first raster to set dimensions or validate against existing dimensions + if (buffer.nonEmpty) { + val referenceRaster = Serde.deserialize(buffer.head.serializedRaster) + val refWidth = RasterAccessors.getWidth(referenceRaster) + val refHeight = RasterAccessors.getHeight(referenceRaster) + if (width != refWidth || height != refHeight) { + throw new IllegalArgumentException("All rasters must have the same dimensions") + } } - buffer += BandData( - if (isIntegral) rasterData.getSamples(0, 0, width, height, 0, null.asInstanceOf[Array[Int]]) else null, - if (!isIntegral) rasterData.getSamples(0, 0, width, height, 0, null.asInstanceOf[Array[Double]]) else null, - input._2, - isIntegral, - Serde.serialize(raster), - serializedBytes - ) - } + // Extract data for each band + val rasterData = renderedImage.getData + val bandsData = Array.ofDim[Double](numBands, width * height) + val serializedSampleDimensions = new Array[Array[Byte]](numBands) - buffer - } + for (band <- 0 until numBands) { + bandsData(band) = rasterData.getSamples(0, 0, width, height, band, new Array[Double](width * height)) + serializedSampleDimensions(band) = Serde.serializeGridSampleDimension(raster.getSampleDimension(band)) + } + buffer += BandData(bandsData, input._2, Serde.serialize(raster), serializedSampleDimensions) + buffer + } def merge(buffer1: ArrayBuffer[BandData], buffer2: ArrayBuffer[BandData]): ArrayBuffer[BandData] = { val combined = ArrayBuffer.concat(buffer1, buffer2) @@ -102,7 +86,6 @@ class RS_Union_Aggr extends Aggregator[(GridCoverage2D, Int), ArrayBuffer[BandDa combined } - def finish(merged: ArrayBuffer[BandData]): GridCoverage2D = { val sortedMerged = merged.sortBy(_.index) if (sortedMerged.zipWithIndex.exists { case (band, idx) => @@ -112,22 +95,20 @@ class RS_Union_Aggr extends Aggregator[(GridCoverage2D, Int), ArrayBuffer[BandDa throw new IllegalArgumentException("Index should be in an arithmetic sequence.") } - val numBands = sortedMerged.length + val totalBands = sortedMerged.map(_.bandsData.length).sum val referenceRaster = Serde.deserialize(sortedMerged.head.serializedRaster) val width = RasterAccessors.getWidth(referenceRaster) val height = RasterAccessors.getHeight(referenceRaster) val dataTypeCode = RasterUtils.getRaster(referenceRaster.getRenderedImage).getDataBuffer.getDataType - val resultRaster: WritableRaster = RasterFactory.createBandedRaster(dataTypeCode, width, height, numBands, null) - val gridSampleDimensions: Array[GridSampleDimension] = new Array[GridSampleDimension](numBands) - - for ((bandData, idx) <- sortedMerged.zipWithIndex) { - // Deserializing GridSampleDimension - gridSampleDimensions(idx) = Serde.deserializeGridSampleDimension(bandData.serializedSampleDimension) - - if(bandData.isIntegral) - resultRaster.setSamples(0, 0, width, height, idx, bandData.bandInt) - else - resultRaster.setSamples(0, 0, width, height, idx, bandData.bandDouble) + val resultRaster: WritableRaster = RasterFactory.createBandedRaster(dataTypeCode, width, height, totalBands, null) + val gridSampleDimensions = sortedMerged.flatMap(_.serializedSampleDimensions.map(Serde.deserializeGridSampleDimension)).toArray + + var currentBand = 0 + sortedMerged.foreach { bandData => + bandData.bandsData.foreach { band => + resultRaster.setSamples(0, 0, width, height, currentBand, band) + currentBand += 1 + } } val noDataValue = RasterBandAccessors.getBandNoDataValue(referenceRaster) diff --git a/spark/common/src/test/resources/raster/test4.tiff b/spark/common/src/test/resources/raster/test4.tiff new file mode 100644 index 0000000000000000000000000000000000000000..2130fc1aeecf96f62b31890496c99363ac271202 GIT binary patch literal 467 zcmebD)MDUZU|-qGR53%@Aa!g=Y(YjAu-<&2gea1@7?ce% zQyi)$1;~~`QWK2C=4<9*U1yUv7sF#&kd#-7&>xGrpYcT0-5KqEF{9QG%q_ZzdWyof&cnxkz35_$JsNk zrC$43KZRlC58>p`r#D-O9E@M~(XEf6aZ&3cbHQf?^AG&{cW-k{)tyDxf^Q@U>O`MZ viII}}5&T^#-fLp`f|uqR2hQ*vn{QAd+kNWglm;^^4w<)KcKmy(X;A_I#%)hA literal 0 HcmV?d00001 diff --git a/spark/common/src/test/resources/raster/test5.tiff b/spark/common/src/test/resources/raster/test5.tiff new file mode 100644 index 0000000000000000000000000000000000000000..7fab41acac66bc8c84c321866c13d1ae162e2c25 GIT binary patch literal 467 zcmebD)MDUZU|-qGR53%@Aa!g=Y(YjAu-<&2gea1@7?ce% zQyi)$1;~~`QWK2C=4<9*U1yUv7sF#&kd#-7&>xGrpYcT0-5KqEF{9QG%q_ZzdWyop}(hQ$x5TStOdu` zI!Q3-PI#0&!=mCn*LjVDdKJ3r%a;``>nQL2pjY|r@eKERZnoP+`fpbWG>T>Qo&2XR uS7%nx){xdUOZLdBi+2|-qGR53%@Aa!g=Y(YjAu-<&2gea1@7?ce% zQyi)$1;~~`QWK2C=4<9*U1yUv7sF#&kd#-7&>xGrpYcT0-5KqEF{9QG%q_ZzdWyop~)+QN8)bUB#Y~A zZj;YSKHqbD?)~=FcaonxU$~?xz9auz+Vf2^0-qGU9Qyc{-~Vv^$u*g#8(FcNe}~O# vz032%?C_73j;kD(#mfc=hpd^?s&+5(VDO9>qa-8NzdxtiPmF&1IwS%BP7hH~ literal 0 HcmV?d00001 diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala b/spark/common/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala index 3875439d25..2dba6679a2 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala @@ -68,7 +68,7 @@ class rasterIOTest extends TestBaseScala with BeforeAndAfter with GivenWhenThen rasterDf.write.format("raster").mode(SaveMode.Overwrite).save(tempDir + "/raster-written") df = sparkSession.read.format("binaryFile").load(tempDir + "/raster-written/*") rasterDf = df.selectExpr("RS_FromGeoTiff(content)") - assert(rasterCount == 3) + assert(rasterCount == 6) assert(rasterDf.count() == 0) } diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala b/spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala index f738a1ba29..6fdec52a8a 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala @@ -22,7 +22,7 @@ import org.apache.sedona.common.raster.MapAlgebra import org.apache.sedona.common.utils.RasterUtils import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.{Row, SaveMode} -import org.apache.spark.sql.functions.{col, collect_list, expr, lit, row_number} +import org.apache.spark.sql.functions.{col, collect_list, expr, lit, monotonically_increasing_id, row_number} import org.geotools.coverage.grid.GridCoverage2D import org.junit.Assert.{assertEquals, assertNotNull, assertNull, assertTrue} import org.locationtech.jts.geom.{Coordinate, Geometry} @@ -1027,6 +1027,82 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen assertTrue(expectedMetadata.equals(actualMetadata)) } + it("Passed multi-band RS_Union_Aggr") { + var df = sparkSession.read.format("binaryFile") + .load(resourceFolder + "raster/test4.tiff") + .withColumn("raster", expr("RS_FromGeoTiff(content)")) + .withColumn("group", lit(1)) + .withColumn("index", lit(1)) + .select("raster", "group", "index") + .union( + sparkSession.read.format("binaryFile") + .load(resourceFolder + "raster/test5.tiff") + .withColumn("raster", expr("RS_FromGeoTiff(content)")) + .withColumn("group", lit(1)) + .withColumn("index", lit(2)) + .select("raster", "group", "index") + ) + .union( + sparkSession.read.format("binaryFile") + .load(resourceFolder + "raster/test4.tiff") + .withColumn("raster", expr("RS_FromGeoTiff(content)")) + .withColumn("group", lit(2)) + .withColumn("index", lit(1)) + .select("raster", "group", "index") + ) + .union( + sparkSession.read.format("binaryFile") + .load(resourceFolder + "raster/test6.tiff") + .withColumn("raster", expr("RS_FromGeoTiff(content)")) + .withColumn("group", lit(2)) + .withColumn("index", lit(2)) + .select("raster", "group", "index") + ) + df = df.withColumn("meta", expr("RS_MetaData(raster)")) + df = df.withColumn("summary", expr("RS_SummaryStatsAll(raster)")) + + // Aggregate rasters based on their indexes to create two separate 2-banded rasters + var aggregatedDF1 = df.groupBy("group") + .agg(expr("RS_Union_Aggr(raster, index) as multi_band_raster")) + .orderBy("group") + aggregatedDF1 = aggregatedDF1.withColumn("meta", expr("RS_MetaData(multi_band_raster)")) + aggregatedDF1 = aggregatedDF1.withColumn("summary1", expr("RS_SummaryStatsAll(multi_band_raster, 1)")) + .withColumn("summary2", expr("RS_SummaryStatsAll(multi_band_raster, 2)")) + + // Aggregate rasters based on their group to create one 4-banded raster + var aggregatedDF2 = aggregatedDF1.selectExpr("RS_Union_Aggr(multi_band_raster, group) as raster") + aggregatedDF2 = aggregatedDF2.withColumn("meta", expr("RS_MetaData(raster)")) + aggregatedDF2 = aggregatedDF2.withColumn("summary1", expr("RS_SummaryStatsALl(raster, 1)")) + .withColumn("summary2", expr("RS_SummaryStatsALl(raster, 2)")) + .withColumn("summary3", expr("RS_SummaryStatsALl(raster, 3)")) + .withColumn("summary4", expr("RS_SummaryStatsALl(raster, 4)")) + + val rowsExpected = df.selectExpr("summary").collect() + val rowsActual = df.selectExpr("summary").collect() + + val expectedMetadata = df.selectExpr("meta").first().getSeq(0).slice(0, 9) + val expectedSummary1 = rowsExpected(0).getSeq(0).slice(0, 6) + val expectedSummary2 = rowsExpected(1).getSeq(0).slice(0, 6) + val expectedSummary3 = rowsExpected(2).getSeq(0).slice(0, 6) + val expectedSummary4 = rowsExpected(3).getSeq(0).slice(0, 6) + + val expectedNumBands = mutable.WrappedArray.make(Array(4.0)) + + val actualMetadata = aggregatedDF2.selectExpr("meta").first().getSeq(0).slice(0, 9) + val actualNumBands = aggregatedDF2.selectExpr("meta").first().getSeq(0).slice(9, 10) + val actualSummary1 = rowsActual(0).getSeq(0).slice(0, 6) + val actualSummary2 = rowsActual(1).getSeq(0).slice(0, 6) + val actualSummary3 = rowsActual(2).getSeq(0).slice(0, 6) + val actualSummary4 = rowsActual(3).getSeq(0).slice(0, 6) + + assertTrue(expectedMetadata.equals(actualMetadata)) + assertTrue(actualNumBands == expectedNumBands) + assertTrue(expectedSummary1.equals(actualSummary1)) + assertTrue(expectedSummary2.equals(actualSummary2)) + assertTrue(expectedSummary3.equals(actualSummary3)) + assertTrue(expectedSummary4.equals(actualSummary4)) + } + it("Passed RS_ZonalStats") { var df = sparkSession.read.format("binaryFile").load(resourceFolder + "raster_geotiff_color/FAA_UTM18N_NAD83.tif") df = df.selectExpr("RS_FromGeoTiff(content) as raster", "ST_GeomFromWKT('POLYGON ((236722 4204770, 243900 4204770, 243900 4197590, 221170 4197590, 236722 4204770))', 26918) as geom")