Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SEDONA-549] Make RS_Union_aggr support combining all bands of multi-band rasters #1375

Merged
merged 15 commits into from
Apr 29, 2024
6 changes: 3 additions & 3 deletions docs/api/sql/Raster-aggregate-function.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
## RS_Union_Aggr
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you provide some examples of using this function to stack rasters, using at least the example you are discussing in slack for two RGB rasters? I'm unclear on how the index column controls stacking.


Introduction: Returns a raster containing bands by specified indexes from all rasters in the provided column. Extracts the first bands from each raster and combines them into the output raster based on the input index values.
Introduction: Returns a raster containing bands by specified indexes from all rasters in the provided column. Combines all bands from each raster into the output raster. The order of bands in the resultant raster are based on the input index order.

!!!Note
RS_Union_Aggr can take multiple banded rasters as input, but it would only extract the first band to the resulting raster. RS_Union_Aggr expects the following input, if not satisfied then will throw an IllegalArgumentException:
RS_Union_Aggr expects the following input, if not satisfied then will throw an IllegalArgumentException:

- Indexes to be in an arithmetic sequence without any gaps.
- Indexes to be unique and not repeated.
Expand Down Expand Up @@ -35,7 +35,7 @@ SELECT RS_Union_Aggr(raster, index) FROM raster_table

Output:

This output raster contains the first band of each raster in the `raster_table` at specified index.
This output raster contains all bands of each raster in the `raster_table`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we show a groupby example as well? Grouping by time and geography, similar to what we are doing with the segmentation dataset prep?


```
GridCoverage2D["geotiff_coverage", GeneralEnvel...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,10 @@ import javax.media.jai.RasterFactory
import scala.collection.mutable.ArrayBuffer

case class BandData(
var bandInt: Array[Int],
var bandDouble: Array[Double],
var bandsData: Array[Array[Double]],
var index: Int,
var isIntegral: Boolean,
var serializedRaster: Array[Byte],
var serializedSampleDimension: Array[Byte]
var serializedSampleDimensions: Array[Array[Byte]]
)


Expand All @@ -50,49 +48,35 @@ class RS_Union_Aggr extends Aggregator[(GridCoverage2D, Int), ArrayBuffer[BandDa
def zero: ArrayBuffer[BandData] = ArrayBuffer[BandData]()

def reduce(buffer: ArrayBuffer[BandData], input: (GridCoverage2D, Int)): ArrayBuffer[BandData] = {
val raster = input._1
val rasterData = RasterUtils.getRaster(raster.getRenderedImage)
val isIntegral = RasterUtils.isDataTypeIntegral(rasterData.getDataBuffer.getDataType)

// Serializing GridSampleDimension
val serializedBytes = Serde.serializeGridSampleDimension(raster.getSampleDimension(0))

// Check and set dimensions based on the first raster in the buffer
if (buffer.isEmpty) {
val width = RasterAccessors.getWidth(raster)
val height = RasterAccessors.getHeight(raster)
val referenceSerializedRaster = Serde.serialize(raster)

buffer += BandData(
if (isIntegral) rasterData.getSamples(0, 0, width, height, 0, null.asInstanceOf[Array[Int]]) else null,
if (!isIntegral) rasterData.getSamples(0, 0, width, height, 0, null.asInstanceOf[Array[Double]]) else null,
input._2,
isIntegral,
referenceSerializedRaster,
serializedBytes
)
} else {
val referenceRaster = Serde.deserialize(buffer.head.serializedRaster)
val width = RasterAccessors.getWidth(referenceRaster)
val height = RasterAccessors.getHeight(referenceRaster)

if (width != RasterAccessors.getWidth(raster) || height != RasterAccessors.getHeight(raster)) {
throw new IllegalArgumentException("All rasters must have the same dimensions")
val raster = input._1
val renderedImage = raster.getRenderedImage
val numBands = renderedImage.getSampleModel.getNumBands
val width = renderedImage.getWidth
val height = renderedImage.getHeight

// First check if this is the first raster to set dimensions or validate against existing dimensions
if (buffer.nonEmpty) {
val referenceRaster = Serde.deserialize(buffer.head.serializedRaster)
val refWidth = RasterAccessors.getWidth(referenceRaster)
val refHeight = RasterAccessors.getHeight(referenceRaster)
if (width != refWidth || height != refHeight) {
throw new IllegalArgumentException("All rasters must have the same dimensions")
}
}

buffer += BandData(
if (isIntegral) rasterData.getSamples(0, 0, width, height, 0, null.asInstanceOf[Array[Int]]) else null,
if (!isIntegral) rasterData.getSamples(0, 0, width, height, 0, null.asInstanceOf[Array[Double]]) else null,
input._2,
isIntegral,
Serde.serialize(raster),
serializedBytes
)
}
// Extract data for each band
val rasterData = renderedImage.getData
val bandsData = Array.ofDim[Double](numBands, width * height)
val serializedSampleDimensions = new Array[Array[Byte]](numBands)

buffer
}
for (band <- 0 until numBands) {
bandsData(band) = rasterData.getSamples(0, 0, width, height, band, new Array[Double](width * height))
serializedSampleDimensions(band) = Serde.serializeGridSampleDimension(raster.getSampleDimension(band))
}

buffer += BandData(bandsData, input._2, Serde.serialize(raster), serializedSampleDimensions)
buffer
}

def merge(buffer1: ArrayBuffer[BandData], buffer2: ArrayBuffer[BandData]): ArrayBuffer[BandData] = {
val combined = ArrayBuffer.concat(buffer1, buffer2)
Expand All @@ -102,7 +86,6 @@ class RS_Union_Aggr extends Aggregator[(GridCoverage2D, Int), ArrayBuffer[BandDa
combined
}


def finish(merged: ArrayBuffer[BandData]): GridCoverage2D = {
val sortedMerged = merged.sortBy(_.index)
if (sortedMerged.zipWithIndex.exists { case (band, idx) =>
Expand All @@ -112,22 +95,20 @@ class RS_Union_Aggr extends Aggregator[(GridCoverage2D, Int), ArrayBuffer[BandDa
throw new IllegalArgumentException("Index should be in an arithmetic sequence.")
}

val numBands = sortedMerged.length
val totalBands = sortedMerged.map(_.bandsData.length).sum
val referenceRaster = Serde.deserialize(sortedMerged.head.serializedRaster)
val width = RasterAccessors.getWidth(referenceRaster)
val height = RasterAccessors.getHeight(referenceRaster)
val dataTypeCode = RasterUtils.getRaster(referenceRaster.getRenderedImage).getDataBuffer.getDataType
val resultRaster: WritableRaster = RasterFactory.createBandedRaster(dataTypeCode, width, height, numBands, null)
val gridSampleDimensions: Array[GridSampleDimension] = new Array[GridSampleDimension](numBands)

for ((bandData, idx) <- sortedMerged.zipWithIndex) {
// Deserializing GridSampleDimension
gridSampleDimensions(idx) = Serde.deserializeGridSampleDimension(bandData.serializedSampleDimension)

if(bandData.isIntegral)
resultRaster.setSamples(0, 0, width, height, idx, bandData.bandInt)
else
resultRaster.setSamples(0, 0, width, height, idx, bandData.bandDouble)
val resultRaster: WritableRaster = RasterFactory.createBandedRaster(dataTypeCode, width, height, totalBands, null)
val gridSampleDimensions = sortedMerged.flatMap(_.serializedSampleDimensions.map(Serde.deserializeGridSampleDimension)).toArray

var currentBand = 0
sortedMerged.foreach { bandData =>
bandData.bandsData.foreach { band =>
resultRaster.setSamples(0, 0, width, height, currentBand, band)
currentBand += 1
}
}

val noDataValue = RasterBandAccessors.getBandNoDataValue(referenceRaster)
Expand Down
Binary file added spark/common/src/test/resources/raster/test4.tiff
Binary file not shown.
Binary file added spark/common/src/test/resources/raster/test5.tiff
Binary file not shown.
Binary file added spark/common/src/test/resources/raster/test6.tiff
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.sedona.common.raster.MapAlgebra
import org.apache.sedona.common.utils.RasterUtils
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.{Row, SaveMode}
import org.apache.spark.sql.functions.{col, collect_list, expr, lit, row_number}
import org.apache.spark.sql.functions.{col, collect_list, expr, lit, monotonically_increasing_id, row_number}
import org.geotools.coverage.grid.GridCoverage2D
import org.junit.Assert.{assertEquals, assertNotNull, assertNull, assertTrue}
import org.locationtech.jts.geom.{Coordinate, Geometry}
Expand Down Expand Up @@ -1027,6 +1027,82 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen
assertTrue(expectedMetadata.equals(actualMetadata))
}

it("Passed multi-band RS_Union_Aggr") {
var df = sparkSession.read.format("binaryFile")
.load(resourceFolder + "raster/test4.tiff")
.withColumn("raster", expr("RS_FromGeoTiff(content)"))
.withColumn("group", lit(1))
.withColumn("index", lit(1))
.select("raster", "group", "index")
.union(
sparkSession.read.format("binaryFile")
.load(resourceFolder + "raster/test5.tiff")
.withColumn("raster", expr("RS_FromGeoTiff(content)"))
.withColumn("group", lit(1))
.withColumn("index", lit(2))
.select("raster", "group", "index")
)
.union(
sparkSession.read.format("binaryFile")
.load(resourceFolder + "raster/test4.tiff")
.withColumn("raster", expr("RS_FromGeoTiff(content)"))
.withColumn("group", lit(2))
.withColumn("index", lit(1))
.select("raster", "group", "index")
)
.union(
sparkSession.read.format("binaryFile")
.load(resourceFolder + "raster/test6.tiff")
.withColumn("raster", expr("RS_FromGeoTiff(content)"))
.withColumn("group", lit(2))
.withColumn("index", lit(2))
.select("raster", "group", "index")
)
df = df.withColumn("meta", expr("RS_MetaData(raster)"))
df = df.withColumn("summary", expr("RS_SummaryStatsAll(raster)"))

// Aggregate rasters based on their indexes to create two separate 2-banded rasters
var aggregatedDF1 = df.groupBy("group")
.agg(expr("RS_Union_Aggr(raster, index) as multi_band_raster"))
.orderBy("group")
aggregatedDF1 = aggregatedDF1.withColumn("meta", expr("RS_MetaData(multi_band_raster)"))
aggregatedDF1 = aggregatedDF1.withColumn("summary1", expr("RS_SummaryStatsAll(multi_band_raster, 1)"))
.withColumn("summary2", expr("RS_SummaryStatsAll(multi_band_raster, 2)"))

// Aggregate rasters based on their group to create one 4-banded raster
var aggregatedDF2 = aggregatedDF1.selectExpr("RS_Union_Aggr(multi_band_raster, group) as raster")
aggregatedDF2 = aggregatedDF2.withColumn("meta", expr("RS_MetaData(raster)"))
aggregatedDF2 = aggregatedDF2.withColumn("summary1", expr("RS_SummaryStatsALl(raster, 1)"))
.withColumn("summary2", expr("RS_SummaryStatsALl(raster, 2)"))
.withColumn("summary3", expr("RS_SummaryStatsALl(raster, 3)"))
.withColumn("summary4", expr("RS_SummaryStatsALl(raster, 4)"))

val rowsExpected = df.selectExpr("summary").collect()
val rowsActual = df.selectExpr("summary").collect()

val expectedMetadata = df.selectExpr("meta").first().getSeq(0).slice(0, 9)
val expectedSummary1 = rowsExpected(0).getSeq(0).slice(0, 6)
val expectedSummary2 = rowsExpected(1).getSeq(0).slice(0, 6)
val expectedSummary3 = rowsExpected(2).getSeq(0).slice(0, 6)
val expectedSummary4 = rowsExpected(3).getSeq(0).slice(0, 6)

val expectedNumBands = mutable.WrappedArray.make(Array(4.0))

val actualMetadata = aggregatedDF2.selectExpr("meta").first().getSeq(0).slice(0, 9)
val actualNumBands = aggregatedDF2.selectExpr("meta").first().getSeq(0).slice(9, 10)
val actualSummary1 = rowsActual(0).getSeq(0).slice(0, 6)
val actualSummary2 = rowsActual(1).getSeq(0).slice(0, 6)
val actualSummary3 = rowsActual(2).getSeq(0).slice(0, 6)
val actualSummary4 = rowsActual(3).getSeq(0).slice(0, 6)

assertTrue(expectedMetadata.equals(actualMetadata))
assertTrue(actualNumBands == expectedNumBands)
assertTrue(expectedSummary1.equals(actualSummary1))
assertTrue(expectedSummary2.equals(actualSummary2))
assertTrue(expectedSummary3.equals(actualSummary3))
assertTrue(expectedSummary4.equals(actualSummary4))
}

it("Passed RS_ZonalStats") {
var df = sparkSession.read.format("binaryFile").load(resourceFolder + "raster_geotiff_color/FAA_UTM18N_NAD83.tif")
df = df.selectExpr("RS_FromGeoTiff(content) as raster", "ST_GeomFromWKT('POLYGON ((236722 4204770, 243900 4204770, 243900 4197590, 221170 4197590, 236722 4204770))', 26918) as geom")
Expand Down
Loading