Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SEDONA-549] Make RS_Union_aggr support combining all bands of multi-band rasters #1375

Merged
merged 15 commits into from
Apr 29, 2024
57 changes: 41 additions & 16 deletions docs/api/sql/Raster-aggregate-function.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
## RS_Union_Aggr
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you provide some examples of using this function to stack rasters, using at least the example you are discussing in slack for two RGB rasters? I'm unclear on how the index column controls stacking.


Introduction: Returns a raster containing bands by specified indexes from all rasters in the provided column. Extracts the first bands from each raster and combines them into the output raster based on the input index values.
Introduction: This function combines multiple rasters into a single multiband raster by stacking the bands of each input raster sequentially. The function arranges the bands in the output raster according to the order specified by the index column in the input. It is typically used in scenarios where rasters are grouped by certain criteria (e.g., time and/or location) and an aggregated raster output is desired.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this change the behavior of the old function? The old function only takes the first band of all input raster. Now it takes all bands of each raster? @prantogg

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it essentially stacks all bands of each raster in order of index.

Input dataframe -

idx |   Raster
--------------------------------
1    |   raster1 (R1 | G1 | B1)
2    |   raster2 (R2 | G2 | B2)
3    |   raster3 (R3 | G3 | B3)
RS_Union_aggr(Raster, idx)

resultant raster -

raster (R1 | G1 | B1 | R2 | G2 | B2 | R3 | G3 | B3 )


!!!Note
RS_Union_Aggr can take multiple banded rasters as input, but it would only extract the first band to the resulting raster. RS_Union_Aggr expects the following input, if not satisfied then will throw an IllegalArgumentException:
RS_Union_Aggr expects the following input, if not satisfied then will throw an IllegalArgumentException:

- Indexes to be in an arithmetic sequence without any gaps.
- Indexes to be unique and not repeated.
Expand All @@ -13,30 +13,55 @@ Format: `RS_Union_Aggr(A: rasterColumn, B: indexColumn)`

Since: `v1.5.1`

SQL Example
SQL Example:

Contents of `raster_table`.
First, define a window specification that partitions by geographic location and orders by time. This will prepare your data by assigning an index to each raster based on its timestamp within each location group.

```
+------------------------------+-----+
| raster|index|
+------------------------------+-----+
|GridCoverage2D["geotiff_cov...| 1|
|GridCoverage2D["geotiff_cov...| 2|
|GridCoverage2D["geotiff_cov...| 3|
|GridCoverage2D["geotiff_cov...| 4|
|GridCoverage2D["geotiff_cov...| 5|
+------------------------------+-----+
val windowSpec = Window.partitionBy("geometry").orderBy("timestamp")
val indexedRasters = df.withColumn("index", row_number().over(windowSpec))

indexedRasters.show()
```

The indexed rasters will appear as follows, showing that each raster is tagged with a sequential index (ordered by timestamp) within its group (geometry).

```
+-------------------+------------------------------+--------------+-----+
| timestamp| raster| geometry|index|
+-------------------+------------------------------+--------------+-----+
|2021-01-01T00:00:00|GridCoverage2D["geotiff_cov...|POINT (72 120)| 1|
|2021-01-02T00:00:00|GridCoverage2D["geotiff_cov...|POINT (72 120)| 2|
|2021-01-03T00:00:00|GridCoverage2D["geotiff_cov...|POINT (72 120)| 3|
|2021-01-04T00:00:00|GridCoverage2D["geotiff_cov...|POINT (72 120)| 4|
|2021-01-05T00:00:00|GridCoverage2D["geotiff_cov...|POINT (72 120)| 5|
|2021-01-02T00:00:00|GridCoverage2D["geotiff_cov...|POINT (84 132)| 1|
|2021-01-03T00:00:00|GridCoverage2D["geotiff_cov...|POINT (84 132)| 2|
|2021-01-04T00:00:00|GridCoverage2D["geotiff_cov...|POINT (84 132)| 3|
|2021-01-05T00:00:00|GridCoverage2D["geotiff_cov...|POINT (84 132)| 4|
|2021-01-06T00:00:00|GridCoverage2D["geotiff_cov...|POINT (84 132)| 5|
|2021-01-07T00:00:00|GridCoverage2D["geotiff_cov...|POINT (84 132)| 6|
+-------------------+------------------------------+--------------+-----+
```

To create a stacked raster by grouping on geometry.

```
SELECT RS_Union_Aggr(raster, index) FROM raster_table
SELECT geometry, RS_Union_Aggr(raster, index) AS raster, RS_NumBands(raster) AS Num_Bands
FROM indexedRasters
WHERE index <= 4
GROUP BY geometry
```

Output:

This output raster contains the first band of each raster in the `raster_table` at specified index.
The query yields rasters grouped by geometry, each containing the first four time steps combined into a single multiband raster, where each band represents one time step.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a great example and sufficient for merging. but can it be followed up by a PR showing how to groupby different time intervals, instead of or in addition to taking the first four timesteps?

For example, we might want to group by a single year, then take the first 4 timesteps in that specific year.

Or, we might also want to sample scenes at a monthly, or year-quarterly frequency. Would be good to show an example of group by interval that matches pandas and sample interval

pandas example:

import pandas as pd
df['date'] = pd.to_datetime(df['date'])
df.set_index('date', inplace=True)

# Group by year and sample every quarter
quarterly_sample = df.resample('A').asfreq('Q')

# Group by year and sample every month
monthly_sample = df.resample('A').asfreq('M')

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Ryan! I see your point, this is definitely a more realistic example. I've updated the example to show a groupby over a quarterly time interval in addition to geometry.


```
GridCoverage2D["geotiff_coverage", GeneralEnvel...
+--------------+--------------------+---------+
| geometry| raster|Num_Bands|
+--------------+--------------------+---------+
|POINT (72 120)|GridCoverage2D["g...| 4|
|POINT (84 132)|GridCoverage2D["g...| 4|
+--------------+--------------------+---------+
```
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,10 @@ import javax.media.jai.RasterFactory
import scala.collection.mutable.ArrayBuffer

case class BandData(
var bandInt: Array[Int],
var bandDouble: Array[Double],
var bandsData: Array[Array[Double]],
var index: Int,
var isIntegral: Boolean,
var serializedRaster: Array[Byte],
var serializedSampleDimension: Array[Byte]
var serializedSampleDimensions: Array[Array[Byte]]
)


Expand All @@ -50,49 +48,35 @@ class RS_Union_Aggr extends Aggregator[(GridCoverage2D, Int), ArrayBuffer[BandDa
def zero: ArrayBuffer[BandData] = ArrayBuffer[BandData]()

def reduce(buffer: ArrayBuffer[BandData], input: (GridCoverage2D, Int)): ArrayBuffer[BandData] = {
val raster = input._1
val rasterData = RasterUtils.getRaster(raster.getRenderedImage)
val isIntegral = RasterUtils.isDataTypeIntegral(rasterData.getDataBuffer.getDataType)

// Serializing GridSampleDimension
val serializedBytes = Serde.serializeGridSampleDimension(raster.getSampleDimension(0))

// Check and set dimensions based on the first raster in the buffer
if (buffer.isEmpty) {
val width = RasterAccessors.getWidth(raster)
val height = RasterAccessors.getHeight(raster)
val referenceSerializedRaster = Serde.serialize(raster)

buffer += BandData(
if (isIntegral) rasterData.getSamples(0, 0, width, height, 0, null.asInstanceOf[Array[Int]]) else null,
if (!isIntegral) rasterData.getSamples(0, 0, width, height, 0, null.asInstanceOf[Array[Double]]) else null,
input._2,
isIntegral,
referenceSerializedRaster,
serializedBytes
)
} else {
val referenceRaster = Serde.deserialize(buffer.head.serializedRaster)
val width = RasterAccessors.getWidth(referenceRaster)
val height = RasterAccessors.getHeight(referenceRaster)

if (width != RasterAccessors.getWidth(raster) || height != RasterAccessors.getHeight(raster)) {
throw new IllegalArgumentException("All rasters must have the same dimensions")
val raster = input._1
val renderedImage = raster.getRenderedImage
val numBands = renderedImage.getSampleModel.getNumBands
val width = renderedImage.getWidth
val height = renderedImage.getHeight

// First check if this is the first raster to set dimensions or validate against existing dimensions
if (buffer.nonEmpty) {
val referenceRaster = Serde.deserialize(buffer.head.serializedRaster)
val refWidth = RasterAccessors.getWidth(referenceRaster)
val refHeight = RasterAccessors.getHeight(referenceRaster)
if (width != refWidth || height != refHeight) {
throw new IllegalArgumentException("All rasters must have the same dimensions")
}
}

buffer += BandData(
if (isIntegral) rasterData.getSamples(0, 0, width, height, 0, null.asInstanceOf[Array[Int]]) else null,
if (!isIntegral) rasterData.getSamples(0, 0, width, height, 0, null.asInstanceOf[Array[Double]]) else null,
input._2,
isIntegral,
Serde.serialize(raster),
serializedBytes
)
}
// Extract data for each band
val rasterData = renderedImage.getData
val bandsData = Array.ofDim[Double](numBands, width * height)
val serializedSampleDimensions = new Array[Array[Byte]](numBands)

buffer
}
for (band <- 0 until numBands) {
bandsData(band) = rasterData.getSamples(0, 0, width, height, band, new Array[Double](width * height))
serializedSampleDimensions(band) = Serde.serializeGridSampleDimension(raster.getSampleDimension(band))
}

buffer += BandData(bandsData, input._2, Serde.serialize(raster), serializedSampleDimensions)
buffer
}

def merge(buffer1: ArrayBuffer[BandData], buffer2: ArrayBuffer[BandData]): ArrayBuffer[BandData] = {
val combined = ArrayBuffer.concat(buffer1, buffer2)
Expand All @@ -102,7 +86,6 @@ class RS_Union_Aggr extends Aggregator[(GridCoverage2D, Int), ArrayBuffer[BandDa
combined
}


def finish(merged: ArrayBuffer[BandData]): GridCoverage2D = {
val sortedMerged = merged.sortBy(_.index)
if (sortedMerged.zipWithIndex.exists { case (band, idx) =>
Expand All @@ -112,22 +95,20 @@ class RS_Union_Aggr extends Aggregator[(GridCoverage2D, Int), ArrayBuffer[BandDa
throw new IllegalArgumentException("Index should be in an arithmetic sequence.")
}

val numBands = sortedMerged.length
val totalBands = sortedMerged.map(_.bandsData.length).sum
val referenceRaster = Serde.deserialize(sortedMerged.head.serializedRaster)
val width = RasterAccessors.getWidth(referenceRaster)
val height = RasterAccessors.getHeight(referenceRaster)
val dataTypeCode = RasterUtils.getRaster(referenceRaster.getRenderedImage).getDataBuffer.getDataType
val resultRaster: WritableRaster = RasterFactory.createBandedRaster(dataTypeCode, width, height, numBands, null)
val gridSampleDimensions: Array[GridSampleDimension] = new Array[GridSampleDimension](numBands)

for ((bandData, idx) <- sortedMerged.zipWithIndex) {
// Deserializing GridSampleDimension
gridSampleDimensions(idx) = Serde.deserializeGridSampleDimension(bandData.serializedSampleDimension)

if(bandData.isIntegral)
resultRaster.setSamples(0, 0, width, height, idx, bandData.bandInt)
else
resultRaster.setSamples(0, 0, width, height, idx, bandData.bandDouble)
val resultRaster: WritableRaster = RasterFactory.createBandedRaster(dataTypeCode, width, height, totalBands, null)
val gridSampleDimensions = sortedMerged.flatMap(_.serializedSampleDimensions.map(Serde.deserializeGridSampleDimension)).toArray

var currentBand = 0
sortedMerged.foreach { bandData =>
bandData.bandsData.foreach { band =>
resultRaster.setSamples(0, 0, width, height, currentBand, band)
currentBand += 1
}
}

val noDataValue = RasterBandAccessors.getBandNoDataValue(referenceRaster)
Expand Down
Binary file added spark/common/src/test/resources/raster/test4.tiff
Binary file not shown.
Binary file added spark/common/src/test/resources/raster/test5.tiff
Binary file not shown.
Binary file added spark/common/src/test/resources/raster/test6.tiff
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class rasterIOTest extends TestBaseScala with BeforeAndAfter with GivenWhenThen
rasterDf.write.format("raster").mode(SaveMode.Overwrite).save(tempDir + "/raster-written")
df = sparkSession.read.format("binaryFile").load(tempDir + "/raster-written/*")
rasterDf = df.selectExpr("RS_FromGeoTiff(content)")
assert(rasterCount == 3)
assert(rasterCount == 6)
assert(rasterDf.count() == 0)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.sedona.common.raster.MapAlgebra
import org.apache.sedona.common.utils.RasterUtils
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.{Row, SaveMode}
import org.apache.spark.sql.functions.{col, collect_list, expr, lit, row_number}
import org.apache.spark.sql.functions.{col, collect_list, expr, lit, monotonically_increasing_id, row_number}
import org.geotools.coverage.grid.GridCoverage2D
import org.junit.Assert.{assertEquals, assertNotNull, assertNull, assertTrue}
import org.locationtech.jts.geom.{Coordinate, Geometry}
Expand Down Expand Up @@ -1027,6 +1027,82 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen
assertTrue(expectedMetadata.equals(actualMetadata))
}

it("Passed multi-band RS_Union_Aggr") {
var df = sparkSession.read.format("binaryFile")
.load(resourceFolder + "raster/test4.tiff")
.withColumn("raster", expr("RS_FromGeoTiff(content)"))
.withColumn("group", lit(1))
.withColumn("index", lit(1))
.select("raster", "group", "index")
.union(
sparkSession.read.format("binaryFile")
.load(resourceFolder + "raster/test5.tiff")
.withColumn("raster", expr("RS_FromGeoTiff(content)"))
.withColumn("group", lit(1))
.withColumn("index", lit(2))
.select("raster", "group", "index")
)
.union(
sparkSession.read.format("binaryFile")
.load(resourceFolder + "raster/test4.tiff")
.withColumn("raster", expr("RS_FromGeoTiff(content)"))
.withColumn("group", lit(2))
.withColumn("index", lit(1))
.select("raster", "group", "index")
)
.union(
sparkSession.read.format("binaryFile")
.load(resourceFolder + "raster/test6.tiff")
.withColumn("raster", expr("RS_FromGeoTiff(content)"))
.withColumn("group", lit(2))
.withColumn("index", lit(2))
.select("raster", "group", "index")
)
df = df.withColumn("meta", expr("RS_MetaData(raster)"))
df = df.withColumn("summary", expr("RS_SummaryStatsAll(raster)"))

// Aggregate rasters based on their indexes to create two separate 2-banded rasters
var aggregatedDF1 = df.groupBy("group")
.agg(expr("RS_Union_Aggr(raster, index) as multi_band_raster"))
.orderBy("group")
aggregatedDF1 = aggregatedDF1.withColumn("meta", expr("RS_MetaData(multi_band_raster)"))
aggregatedDF1 = aggregatedDF1.withColumn("summary1", expr("RS_SummaryStatsAll(multi_band_raster, 1)"))
.withColumn("summary2", expr("RS_SummaryStatsAll(multi_band_raster, 2)"))

// Aggregate rasters based on their group to create one 4-banded raster
var aggregatedDF2 = aggregatedDF1.selectExpr("RS_Union_Aggr(multi_band_raster, group) as raster")
aggregatedDF2 = aggregatedDF2.withColumn("meta", expr("RS_MetaData(raster)"))
aggregatedDF2 = aggregatedDF2.withColumn("summary1", expr("RS_SummaryStatsALl(raster, 1)"))
.withColumn("summary2", expr("RS_SummaryStatsALl(raster, 2)"))
.withColumn("summary3", expr("RS_SummaryStatsALl(raster, 3)"))
.withColumn("summary4", expr("RS_SummaryStatsALl(raster, 4)"))

val rowsExpected = df.selectExpr("summary").collect()
val rowsActual = df.selectExpr("summary").collect()

val expectedMetadata = df.selectExpr("meta").first().getSeq(0).slice(0, 9)
val expectedSummary1 = rowsExpected(0).getSeq(0).slice(0, 6)
val expectedSummary2 = rowsExpected(1).getSeq(0).slice(0, 6)
val expectedSummary3 = rowsExpected(2).getSeq(0).slice(0, 6)
val expectedSummary4 = rowsExpected(3).getSeq(0).slice(0, 6)

val expectedNumBands = mutable.WrappedArray.make(Array(4.0))

val actualMetadata = aggregatedDF2.selectExpr("meta").first().getSeq(0).slice(0, 9)
val actualNumBands = aggregatedDF2.selectExpr("meta").first().getSeq(0).slice(9, 10)
val actualSummary1 = rowsActual(0).getSeq(0).slice(0, 6)
val actualSummary2 = rowsActual(1).getSeq(0).slice(0, 6)
val actualSummary3 = rowsActual(2).getSeq(0).slice(0, 6)
val actualSummary4 = rowsActual(3).getSeq(0).slice(0, 6)

assertTrue(expectedMetadata.equals(actualMetadata))
assertTrue(actualNumBands == expectedNumBands)
assertTrue(expectedSummary1.equals(actualSummary1))
assertTrue(expectedSummary2.equals(actualSummary2))
assertTrue(expectedSummary3.equals(actualSummary3))
assertTrue(expectedSummary4.equals(actualSummary4))
}

it("Passed RS_ZonalStats") {
var df = sparkSession.read.format("binaryFile").load(resourceFolder + "raster_geotiff_color/FAA_UTM18N_NAD83.tif")
df = df.selectExpr("RS_FromGeoTiff(content) as raster", "ST_GeomFromWKT('POLYGON ((236722 4204770, 243900 4204770, 243900 4197590, 221170 4197590, 236722 4204770))', 26918) as geom")
Expand Down
Loading