Skip to content

Commit

Permalink
[SEDONA-549] Make RS_Union_aggr support combining all bands of multi-…
Browse files Browse the repository at this point in the history
…band rasters (#1375)

* Init: move class level members to data buffer

* Init: add multi-band support

* init: add tests

* Add index checks

* add tests

* fix test

* update docs

* update docs

* fix typo

* fix test

* update doc example

* add comments to doc example

* fix lint issue

* Update example to show group by geometry, year, quarter
  • Loading branch information
prantogg committed Apr 29, 2024
1 parent 9980e51 commit 170871a
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 74 deletions.
68 changes: 52 additions & 16 deletions docs/api/sql/Raster-aggregate-function.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
## RS_Union_Aggr

Introduction: Returns a raster containing bands by specified indexes from all rasters in the provided column. Extracts the first bands from each raster and combines them into the output raster based on the input index values.
Introduction: This function combines multiple rasters into a single multiband raster by stacking the bands of each input raster sequentially. The function arranges the bands in the output raster according to the order specified by the index column in the input. It is typically used in scenarios where rasters are grouped by certain criteria (e.g., time and/or location) and an aggregated raster output is desired.

!!!Note
RS_Union_Aggr can take multiple banded rasters as input, but it would only extract the first band to the resulting raster. RS_Union_Aggr expects the following input, if not satisfied then will throw an IllegalArgumentException:
RS_Union_Aggr expects the following input, if not satisfied then will throw an IllegalArgumentException:

- Indexes to be in an arithmetic sequence without any gaps.
- Indexes to be unique and not repeated.
Expand All @@ -13,30 +13,66 @@ Format: `RS_Union_Aggr(A: rasterColumn, B: indexColumn)`

Since: `v1.5.1`

SQL Example
SQL Example:

Contents of `raster_table`.
First, we enrich the dataset with time-based grouping columns and index the rasters based on time intervals:

```
+------------------------------+-----+
| raster|index|
+------------------------------+-----+
|GridCoverage2D["geotiff_cov...| 1|
|GridCoverage2D["geotiff_cov...| 2|
|GridCoverage2D["geotiff_cov...| 3|
|GridCoverage2D["geotiff_cov...| 4|
|GridCoverage2D["geotiff_cov...| 5|
+------------------------------+-----+
// Add yearly and quarterly time interval columns for grouping
df = df
.withColumn("year", year($"timestamp"))
.withColumn("quarter", quarter($"timestamp"))
// Define window specs for quarterly indexing within each geometry-year group
windowSpecQuarter = Window.partitionBy("geometry", "year", "quarter").orderBy("timestamp")
indexedDf = df.withColumn("index", row_number().over(windowSpecQuarter))
indexedDf.show()
```

The indexed rasters will appear as follows, showing that each raster is tagged with a sequential index (ordered by timestamp) within its group (grouped by geometry, year and quarter).

```
+-------------------+-----------------------------+--------------+----+-------+-----+
|timestamp |raster |geometry |year|quarter|index|
+-------------------+-----------------------------+--------------+----+-------+-----+
|2021-01-10 00:00:00|GridCoverage2D["geotiff_co...|POINT (72 120)|2021|1 |1 |
|2021-01-25 00:00:00|GridCoverage2D["geotiff_co...|POINT (72 120)|2021|1 |2 |
|2021-02-15 00:00:00|GridCoverage2D["geotiff_co...|POINT (72 120)|2021|1 |3 |
|2021-03-15 00:00:00|GridCoverage2D["geotiff_co...|POINT (72 120)|2021|1 |4 |
|2021-03-25 00:00:00|GridCoverage2D["geotiff_co...|POINT (72 120)|2021|1 |5 |
|2021-04-10 00:00:00|GridCoverage2D["geotiff_co...|POINT (84 132)|2021|2 |1 |
|2021-04-22 00:00:00|GridCoverage2D["geotiff_co...|POINT (84 132)|2021|2 |2 |
|2021-05-15 00:00:00|GridCoverage2D["geotiff_co...|POINT (84 132)|2021|2 |3 |
|2021-05-20 00:00:00|GridCoverage2D["geotiff_co...|POINT (84 132)|2021|2 |4 |
|2021-05-29 00:00:00|GridCoverage2D["geotiff_co...|POINT (84 132)|2021|2 |5 |
|2021-06-10 00:00:00|GridCoverage2D["geotiff_co...|POINT (84 132)|2021|2 |6 |
+-------------------+-----------------------------+------------- +----+-------+-----+
```
SELECT RS_Union_Aggr(raster, index) FROM raster_table

To create a stacked raster by grouping on geometry.

```
indexedDf.createOrReplaceTempView("indexedDf")
sedona.sql('''
SELECT geometry, year, quarter, RS_Union_Aggr(raster, index) AS aggregated_raster
FROM indexedDf
WHERE index <= 4
GROUP BY geometry, year, quarter
''').show()
```

Output:

This output raster contains the first band of each raster in the `raster_table` at specified index.
The query yields rasters grouped by geometry, year and quarter, each containing the first four time steps combined into a single multiband raster, where each band represents one time step.

```
GridCoverage2D["geotiff_coverage", GeneralEnvel...
+--------------+----+-------+--------------------+---------+
| geometry|year|quarter| raster|Num_Bands|
+--------------+----+-------+--------------------+---------+
|POINT (72 120)|2021|1 |GridCoverage2D["g...| 4|
|POINT (84 132)|2021|2 |GridCoverage2D["g...| 4|
+--------------+----+-------+--------------------+---------+
```
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,10 @@ import javax.media.jai.RasterFactory
import scala.collection.mutable.ArrayBuffer

case class BandData(
var bandInt: Array[Int],
var bandDouble: Array[Double],
var bandsData: Array[Array[Double]],
var index: Int,
var isIntegral: Boolean,
var serializedRaster: Array[Byte],
var serializedSampleDimension: Array[Byte]
var serializedSampleDimensions: Array[Array[Byte]]
)


Expand All @@ -50,49 +48,35 @@ class RS_Union_Aggr extends Aggregator[(GridCoverage2D, Int), ArrayBuffer[BandDa
def zero: ArrayBuffer[BandData] = ArrayBuffer[BandData]()

def reduce(buffer: ArrayBuffer[BandData], input: (GridCoverage2D, Int)): ArrayBuffer[BandData] = {
val raster = input._1
val rasterData = RasterUtils.getRaster(raster.getRenderedImage)
val isIntegral = RasterUtils.isDataTypeIntegral(rasterData.getDataBuffer.getDataType)

// Serializing GridSampleDimension
val serializedBytes = Serde.serializeGridSampleDimension(raster.getSampleDimension(0))

// Check and set dimensions based on the first raster in the buffer
if (buffer.isEmpty) {
val width = RasterAccessors.getWidth(raster)
val height = RasterAccessors.getHeight(raster)
val referenceSerializedRaster = Serde.serialize(raster)

buffer += BandData(
if (isIntegral) rasterData.getSamples(0, 0, width, height, 0, null.asInstanceOf[Array[Int]]) else null,
if (!isIntegral) rasterData.getSamples(0, 0, width, height, 0, null.asInstanceOf[Array[Double]]) else null,
input._2,
isIntegral,
referenceSerializedRaster,
serializedBytes
)
} else {
val referenceRaster = Serde.deserialize(buffer.head.serializedRaster)
val width = RasterAccessors.getWidth(referenceRaster)
val height = RasterAccessors.getHeight(referenceRaster)

if (width != RasterAccessors.getWidth(raster) || height != RasterAccessors.getHeight(raster)) {
throw new IllegalArgumentException("All rasters must have the same dimensions")
val raster = input._1
val renderedImage = raster.getRenderedImage
val numBands = renderedImage.getSampleModel.getNumBands
val width = renderedImage.getWidth
val height = renderedImage.getHeight

// First check if this is the first raster to set dimensions or validate against existing dimensions
if (buffer.nonEmpty) {
val referenceRaster = Serde.deserialize(buffer.head.serializedRaster)
val refWidth = RasterAccessors.getWidth(referenceRaster)
val refHeight = RasterAccessors.getHeight(referenceRaster)
if (width != refWidth || height != refHeight) {
throw new IllegalArgumentException("All rasters must have the same dimensions")
}
}

buffer += BandData(
if (isIntegral) rasterData.getSamples(0, 0, width, height, 0, null.asInstanceOf[Array[Int]]) else null,
if (!isIntegral) rasterData.getSamples(0, 0, width, height, 0, null.asInstanceOf[Array[Double]]) else null,
input._2,
isIntegral,
Serde.serialize(raster),
serializedBytes
)
}
// Extract data for each band
val rasterData = renderedImage.getData
val bandsData = Array.ofDim[Double](numBands, width * height)
val serializedSampleDimensions = new Array[Array[Byte]](numBands)

buffer
}
for (band <- 0 until numBands) {
bandsData(band) = rasterData.getSamples(0, 0, width, height, band, new Array[Double](width * height))
serializedSampleDimensions(band) = Serde.serializeGridSampleDimension(raster.getSampleDimension(band))
}

buffer += BandData(bandsData, input._2, Serde.serialize(raster), serializedSampleDimensions)
buffer
}

def merge(buffer1: ArrayBuffer[BandData], buffer2: ArrayBuffer[BandData]): ArrayBuffer[BandData] = {
val combined = ArrayBuffer.concat(buffer1, buffer2)
Expand All @@ -102,7 +86,6 @@ class RS_Union_Aggr extends Aggregator[(GridCoverage2D, Int), ArrayBuffer[BandDa
combined
}


def finish(merged: ArrayBuffer[BandData]): GridCoverage2D = {
val sortedMerged = merged.sortBy(_.index)
if (sortedMerged.zipWithIndex.exists { case (band, idx) =>
Expand All @@ -112,22 +95,20 @@ class RS_Union_Aggr extends Aggregator[(GridCoverage2D, Int), ArrayBuffer[BandDa
throw new IllegalArgumentException("Index should be in an arithmetic sequence.")
}

val numBands = sortedMerged.length
val totalBands = sortedMerged.map(_.bandsData.length).sum
val referenceRaster = Serde.deserialize(sortedMerged.head.serializedRaster)
val width = RasterAccessors.getWidth(referenceRaster)
val height = RasterAccessors.getHeight(referenceRaster)
val dataTypeCode = RasterUtils.getRaster(referenceRaster.getRenderedImage).getDataBuffer.getDataType
val resultRaster: WritableRaster = RasterFactory.createBandedRaster(dataTypeCode, width, height, numBands, null)
val gridSampleDimensions: Array[GridSampleDimension] = new Array[GridSampleDimension](numBands)

for ((bandData, idx) <- sortedMerged.zipWithIndex) {
// Deserializing GridSampleDimension
gridSampleDimensions(idx) = Serde.deserializeGridSampleDimension(bandData.serializedSampleDimension)

if(bandData.isIntegral)
resultRaster.setSamples(0, 0, width, height, idx, bandData.bandInt)
else
resultRaster.setSamples(0, 0, width, height, idx, bandData.bandDouble)
val resultRaster: WritableRaster = RasterFactory.createBandedRaster(dataTypeCode, width, height, totalBands, null)
val gridSampleDimensions = sortedMerged.flatMap(_.serializedSampleDimensions.map(Serde.deserializeGridSampleDimension)).toArray

var currentBand = 0
sortedMerged.foreach { bandData =>
bandData.bandsData.foreach { band =>
resultRaster.setSamples(0, 0, width, height, currentBand, band)
currentBand += 1
}
}

val noDataValue = RasterBandAccessors.getBandNoDataValue(referenceRaster)
Expand Down
Binary file added spark/common/src/test/resources/raster/test4.tiff
Binary file not shown.
Binary file added spark/common/src/test/resources/raster/test5.tiff
Binary file not shown.
Binary file added spark/common/src/test/resources/raster/test6.tiff
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class rasterIOTest extends TestBaseScala with BeforeAndAfter with GivenWhenThen
rasterDf.write.format("raster").mode(SaveMode.Overwrite).save(tempDir + "/raster-written")
df = sparkSession.read.format("binaryFile").load(tempDir + "/raster-written/*")
rasterDf = df.selectExpr("RS_FromGeoTiff(content)")
assert(rasterCount == 3)
assert(rasterCount == 6)
assert(rasterDf.count() == 0)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.sedona.common.raster.MapAlgebra
import org.apache.sedona.common.utils.RasterUtils
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.{Row, SaveMode}
import org.apache.spark.sql.functions.{col, collect_list, expr, lit, row_number}
import org.apache.spark.sql.functions.{col, collect_list, expr, lit, monotonically_increasing_id, row_number}
import org.geotools.coverage.grid.GridCoverage2D
import org.junit.Assert.{assertEquals, assertNotNull, assertNull, assertTrue}
import org.locationtech.jts.geom.{Coordinate, Geometry}
Expand Down Expand Up @@ -1027,6 +1027,82 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen
assertTrue(expectedMetadata.equals(actualMetadata))
}

it("Passed multi-band RS_Union_Aggr") {
var df = sparkSession.read.format("binaryFile")
.load(resourceFolder + "raster/test4.tiff")
.withColumn("raster", expr("RS_FromGeoTiff(content)"))
.withColumn("group", lit(1))
.withColumn("index", lit(1))
.select("raster", "group", "index")
.union(
sparkSession.read.format("binaryFile")
.load(resourceFolder + "raster/test5.tiff")
.withColumn("raster", expr("RS_FromGeoTiff(content)"))
.withColumn("group", lit(1))
.withColumn("index", lit(2))
.select("raster", "group", "index")
)
.union(
sparkSession.read.format("binaryFile")
.load(resourceFolder + "raster/test4.tiff")
.withColumn("raster", expr("RS_FromGeoTiff(content)"))
.withColumn("group", lit(2))
.withColumn("index", lit(1))
.select("raster", "group", "index")
)
.union(
sparkSession.read.format("binaryFile")
.load(resourceFolder + "raster/test6.tiff")
.withColumn("raster", expr("RS_FromGeoTiff(content)"))
.withColumn("group", lit(2))
.withColumn("index", lit(2))
.select("raster", "group", "index")
)
df = df.withColumn("meta", expr("RS_MetaData(raster)"))
df = df.withColumn("summary", expr("RS_SummaryStatsAll(raster)"))

// Aggregate rasters based on their indexes to create two separate 2-banded rasters
var aggregatedDF1 = df.groupBy("group")
.agg(expr("RS_Union_Aggr(raster, index) as multi_band_raster"))
.orderBy("group")
aggregatedDF1 = aggregatedDF1.withColumn("meta", expr("RS_MetaData(multi_band_raster)"))
aggregatedDF1 = aggregatedDF1.withColumn("summary1", expr("RS_SummaryStatsAll(multi_band_raster, 1)"))
.withColumn("summary2", expr("RS_SummaryStatsAll(multi_band_raster, 2)"))

// Aggregate rasters based on their group to create one 4-banded raster
var aggregatedDF2 = aggregatedDF1.selectExpr("RS_Union_Aggr(multi_band_raster, group) as raster")
aggregatedDF2 = aggregatedDF2.withColumn("meta", expr("RS_MetaData(raster)"))
aggregatedDF2 = aggregatedDF2.withColumn("summary1", expr("RS_SummaryStatsALl(raster, 1)"))
.withColumn("summary2", expr("RS_SummaryStatsALl(raster, 2)"))
.withColumn("summary3", expr("RS_SummaryStatsALl(raster, 3)"))
.withColumn("summary4", expr("RS_SummaryStatsALl(raster, 4)"))

val rowsExpected = df.selectExpr("summary").collect()
val rowsActual = df.selectExpr("summary").collect()

val expectedMetadata = df.selectExpr("meta").first().getSeq(0).slice(0, 9)
val expectedSummary1 = rowsExpected(0).getSeq(0).slice(0, 6)
val expectedSummary2 = rowsExpected(1).getSeq(0).slice(0, 6)
val expectedSummary3 = rowsExpected(2).getSeq(0).slice(0, 6)
val expectedSummary4 = rowsExpected(3).getSeq(0).slice(0, 6)

val expectedNumBands = mutable.WrappedArray.make(Array(4.0))

val actualMetadata = aggregatedDF2.selectExpr("meta").first().getSeq(0).slice(0, 9)
val actualNumBands = aggregatedDF2.selectExpr("meta").first().getSeq(0).slice(9, 10)
val actualSummary1 = rowsActual(0).getSeq(0).slice(0, 6)
val actualSummary2 = rowsActual(1).getSeq(0).slice(0, 6)
val actualSummary3 = rowsActual(2).getSeq(0).slice(0, 6)
val actualSummary4 = rowsActual(3).getSeq(0).slice(0, 6)

assertTrue(expectedMetadata.equals(actualMetadata))
assertTrue(actualNumBands == expectedNumBands)
assertTrue(expectedSummary1.equals(actualSummary1))
assertTrue(expectedSummary2.equals(actualSummary2))
assertTrue(expectedSummary3.equals(actualSummary3))
assertTrue(expectedSummary4.equals(actualSummary4))
}

it("Passed RS_ZonalStats") {
var df = sparkSession.read.format("binaryFile").load(resourceFolder + "raster_geotiff_color/FAA_UTM18N_NAD83.tif")
df = df.selectExpr("RS_FromGeoTiff(content) as raster", "ST_GeomFromWKT('POLYGON ((236722 4204770, 243900 4204770, 243900 4197590, 221170 4197590, 236722 4204770))', 26918) as geom")
Expand Down

0 comments on commit 170871a

Please sign in to comment.