diff --git a/common/src/main/java/org/apache/sedona/common/raster/MapAlgebra.java b/common/src/main/java/org/apache/sedona/common/raster/MapAlgebra.java index 564077a36e..b5fb469cd0 100644 --- a/common/src/main/java/org/apache/sedona/common/raster/MapAlgebra.java +++ b/common/src/main/java/org/apache/sedona/common/raster/MapAlgebra.java @@ -78,12 +78,11 @@ public static GridCoverage2D addBandFromArray(GridCoverage2D rasterGeom, double[ throw new IllegalArgumentException("Band index is out of bounds. Must be between 1 and " + (numBands + 1) + ")"); } - Double[] bandValuesClass = Arrays.stream(bandValues).boxed().toArray(Double[]::new); if (bandIndex == numBands + 1) { - return RasterUtils.copyRasterAndAppendBand(rasterGeom, bandValuesClass, noDataValue); + return RasterUtils.copyRasterAndAppendBand(rasterGeom, bandValues, noDataValue); } else { - return RasterUtils.copyRasterAndReplaceBand(rasterGeom, bandIndex, bandValuesClass, noDataValue, true); + return RasterUtils.copyRasterAndReplaceBand(rasterGeom, bandIndex, bandValues, noDataValue, true); } } @@ -94,12 +93,11 @@ public static GridCoverage2D addBandFromArray(GridCoverage2D rasterGeom, double[ throw new IllegalArgumentException("Band index is out of bounds. Must be between 1 and " + (numBands + 1) + ")"); } - Double[] bandValuesClass = Arrays.stream(bandValues).boxed().toArray(Double[]::new); if (bandIndex == numBands + 1) { - return RasterUtils.copyRasterAndAppendBand(rasterGeom, bandValuesClass); + return RasterUtils.copyRasterAndAppendBand(rasterGeom, bandValues); } else { - return RasterUtils.copyRasterAndReplaceBand(rasterGeom, bandIndex, bandValuesClass); + return RasterUtils.copyRasterAndReplaceBand(rasterGeom, bandIndex, bandValues); } } diff --git a/common/src/main/java/org/apache/sedona/common/raster/RasterBandEditors.java b/common/src/main/java/org/apache/sedona/common/raster/RasterBandEditors.java index 4ffac4d308..41b159a5d2 100644 --- a/common/src/main/java/org/apache/sedona/common/raster/RasterBandEditors.java +++ b/common/src/main/java/org/apache/sedona/common/raster/RasterBandEditors.java @@ -135,16 +135,16 @@ public static GridCoverage2D addBand(GridCoverage2D toRaster, GridCoverage2D fro if (RasterUtils.isDataTypeIntegral(dataTypeCode)) { int[] bandValues = rasterData.getSamples(0, 0, width, height, fromBand - 1, (int[]) null); if (numBands + 1 == toRasterIndex) { - return RasterUtils.copyRasterAndAppendBand(toRaster, Arrays.stream(bandValues).boxed().toArray(Integer[]::new), noDataValue); + return RasterUtils.copyRasterAndAppendBand(toRaster, bandValues, noDataValue); } else { - return RasterUtils.copyRasterAndReplaceBand(toRaster, fromBand, Arrays.stream(bandValues).boxed().toArray(Integer[]::new), noDataValue, false); + return RasterUtils.copyRasterAndReplaceBand(toRaster, fromBand, bandValues, noDataValue, false); } } else { double[] bandValues = rasterData.getSamples(0, 0, width, height, fromBand - 1, (double[]) null); if (numBands + 1 == toRasterIndex) { - return RasterUtils.copyRasterAndAppendBand(toRaster, Arrays.stream(bandValues).boxed().toArray(Double[]::new), noDataValue); + return RasterUtils.copyRasterAndAppendBand(toRaster, bandValues, noDataValue); } else { - return RasterUtils.copyRasterAndReplaceBand(toRaster, fromBand, Arrays.stream(bandValues).boxed().toArray(Double[]::new), noDataValue, false); + return RasterUtils.copyRasterAndReplaceBand(toRaster, fromBand, bandValues, noDataValue, false); } } } diff --git a/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java b/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java index e775ee46a6..7f67708b63 100644 --- a/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java +++ b/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java @@ -32,8 +32,6 @@ import javax.media.jai.RenderedImageAdapter; import java.awt.image.RenderedImage; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.Serializable; import java.net.URI; @@ -178,22 +176,4 @@ public static GridCoverage2D deserialize(byte[] bytes) throws IOException, Class return state.restore(); } } - - public static byte[] serializeGridSampleDimension(GridSampleDimension sampleDimension) { - Kryo kryo = kryos.get(); - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - Output output = new Output(baos); - GridSampleDimensionSerializer serializer = new GridSampleDimensionSerializer(); - serializer.write(kryo, output, sampleDimension); - output.close(); - return baos.toByteArray(); - } - - public static GridSampleDimension deserializeGridSampleDimension(byte[] data) { - Kryo kryo = kryos.get(); - Input input = new Input(new ByteArrayInputStream(data)); - GridSampleDimensionSerializer serializer = new GridSampleDimensionSerializer(); - return serializer.read(kryo, input, GridSampleDimension.class); - } - } diff --git a/common/src/main/java/org/apache/sedona/common/utils/RasterUtils.java b/common/src/main/java/org/apache/sedona/common/utils/RasterUtils.java index 7cc6c50800..ce6c3b6046 100644 --- a/common/src/main/java/org/apache/sedona/common/utils/RasterUtils.java +++ b/common/src/main/java/org/apache/sedona/common/utils/RasterUtils.java @@ -556,7 +556,7 @@ public static boolean isDataTypeIntegral(int dataTypeCode) { * @param bandValues * @return */ - public static GridCoverage2D copyRasterAndAppendBand(GridCoverage2D gridCoverage2D, Number[] bandValues, Double noDataValue) { + public static GridCoverage2D copyRasterAndAppendBand(GridCoverage2D gridCoverage2D, Object bandValues, Double noDataValue) { // Get the original image and its properties RenderedImage originalImage = gridCoverage2D.getRenderedImage(); Raster raster = getRaster(originalImage); @@ -565,17 +565,19 @@ public static GridCoverage2D copyRasterAndAppendBand(GridCoverage2D gridCoverage // Copy the raster data and append the new band values for (int i = 0; i < raster.getWidth(); i++) { for (int j = 0; j < raster.getHeight(); j++) { - if (bandValues instanceof Double[]) { + if (bandValues instanceof double[]) { + double[] values = (double[]) bandValues; double[] pixels = raster.getPixel(i, j, (double[]) null); double[] copiedPixels = new double[pixels.length + 1]; System.arraycopy(pixels, 0, copiedPixels, 0, pixels.length); - copiedPixels[pixels.length] = (double) bandValues[j * raster.getWidth() + i]; + copiedPixels[pixels.length] = values[j * raster.getWidth() + i]; wr.setPixel(i, j, copiedPixels); - } else if (bandValues instanceof Integer[]) { + } else if (bandValues instanceof int[]) { + int[] values = (int[]) bandValues; int[] pixels = raster.getPixel(i, j, (int[]) null); int[] copiedPixels = new int[pixels.length + 1]; System.arraycopy(pixels, 0, copiedPixels, 0, pixels.length); - copiedPixels[pixels.length] = (int) bandValues[j * raster.getWidth() + i]; + copiedPixels[pixels.length] = values[j * raster.getWidth() + i]; wr.setPixel(i, j, copiedPixels); } } @@ -594,11 +596,11 @@ public static GridCoverage2D copyRasterAndAppendBand(GridCoverage2D gridCoverage return clone(wr, gridCoverage2D.getGridGeometry(), sampleDimensions, gridCoverage2D, null, true); } - public static GridCoverage2D copyRasterAndAppendBand(GridCoverage2D gridCoverage2D, Number[] bandValues) { + public static GridCoverage2D copyRasterAndAppendBand(GridCoverage2D gridCoverage2D, Object bandValues) { return copyRasterAndAppendBand(gridCoverage2D, bandValues, null); } - public static GridCoverage2D copyRasterAndReplaceBand(GridCoverage2D gridCoverage2D, int bandIndex, Number[] bandValues, Double noDataValue, boolean removeNoDataIfNull) { + public static GridCoverage2D copyRasterAndReplaceBand(GridCoverage2D gridCoverage2D, int bandIndex, Object bandValues, Double noDataValue, boolean removeNoDataIfNull) { // Do not allow the band index to be out of bounds ensureBand(gridCoverage2D, bandIndex); // Get the original image and its properties @@ -608,13 +610,15 @@ public static GridCoverage2D copyRasterAndReplaceBand(GridCoverage2D gridCoverag // Copy the raster data and replace the band values for (int i = 0; i < raster.getWidth(); i++) { for (int j = 0; j < raster.getHeight(); j++) { - if (bandValues instanceof Double[]) { + if (bandValues instanceof double[]) { + double[] values = (double[]) bandValues; double[] bands = raster.getPixel(i, j, (double[]) null); - bands[bandIndex - 1] = (double) bandValues[j * raster.getWidth() + i]; + bands[bandIndex - 1] = values[j * raster.getWidth() + i]; wr.setPixel(i, j, bands); - } else if (bandValues instanceof Integer[]) { + } else if (bandValues instanceof int[]) { + int[] values = (int[]) bandValues; int[] bands = raster.getPixel(i, j, (int[]) null); - bands[bandIndex - 1] = (int) bandValues[j * raster.getWidth() + i]; + bands[bandIndex - 1] = values[j * raster.getWidth() + i]; wr.setPixel(i, j, bands); } } @@ -629,7 +633,7 @@ public static GridCoverage2D copyRasterAndReplaceBand(GridCoverage2D gridCoverag return clone(wr, gridCoverage2D.getGridGeometry(), sampleDimensions, gridCoverage2D, null, true); } - public static GridCoverage2D copyRasterAndReplaceBand(GridCoverage2D gridCoverage2D, int bandIndex, Number[] bandValues) { + public static GridCoverage2D copyRasterAndReplaceBand(GridCoverage2D gridCoverage2D, int bandIndex, Object bandValues) { return copyRasterAndReplaceBand(gridCoverage2D, bandIndex, bandValues, null, false); } diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/AggregateFunctions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/AggregateFunctions.scala index 3bf1326290..d8a6f4be46 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/AggregateFunctions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/AggregateFunctions.scala @@ -19,13 +19,12 @@ package org.apache.spark.sql.sedona_sql.expressions.raster -import org.apache.sedona.common.raster.serde.Serde import org.apache.sedona.common.raster.{RasterAccessors, RasterBandAccessors} import org.apache.sedona.common.utils.RasterUtils +import org.apache.sedona.sql.utils.RasterSerializer import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.expressions.Aggregator -import org.geotools.coverage.GridSampleDimension import org.geotools.coverage.grid.GridCoverage2D import java.awt.image.WritableRaster @@ -33,12 +32,10 @@ import javax.media.jai.RasterFactory import scala.collection.mutable.ArrayBuffer case class BandData( - var bandsData: Array[Array[Double]], - var index: Int, - var serializedRaster: Array[Byte], - var serializedSampleDimensions: Array[Array[Byte]] - ) - + index: Int, + width: Int, + height: Int, + serializedRaster: Array[Byte]) /** * Return a raster containing bands at given indexes from all rasters in a given column @@ -48,37 +45,32 @@ class RS_Union_Aggr extends Aggregator[(GridCoverage2D, Int), ArrayBuffer[BandDa def zero: ArrayBuffer[BandData] = ArrayBuffer[BandData]() def reduce(buffer: ArrayBuffer[BandData], input: (GridCoverage2D, Int)): ArrayBuffer[BandData] = { - val raster = input._1 - val renderedImage = raster.getRenderedImage - val numBands = renderedImage.getSampleModel.getNumBands - val width = renderedImage.getWidth - val height = renderedImage.getHeight - - // First check if this is the first raster to set dimensions or validate against existing dimensions - if (buffer.nonEmpty) { - val referenceRaster = Serde.deserialize(buffer.head.serializedRaster) - val refWidth = RasterAccessors.getWidth(referenceRaster) - val refHeight = RasterAccessors.getHeight(referenceRaster) - if (width != refWidth || height != refHeight) { - throw new IllegalArgumentException("All rasters must have the same dimensions") - } + val (raster, index) = input + val renderedImage = raster.getRenderedImage + val width = renderedImage.getWidth + val height = renderedImage.getHeight + val serializedRaster = RasterSerializer.serialize(raster) + raster.dispose(true) + + // First check if this is the first raster to set dimensions or validate against existing dimensions + if (buffer.nonEmpty) { + val refWidth = buffer.head.width + val refHeight = buffer.head.height + if (width != refWidth || height != refHeight) { + throw new IllegalArgumentException("All rasters must have the same dimensions") } + } - // Extract data for each band - val rasterData = renderedImage.getData - val bandsData = Array.ofDim[Double](numBands, width * height) - val serializedSampleDimensions = new Array[Array[Byte]](numBands) + buffer += BandData(index, width, height, serializedRaster) + buffer + } - for (band <- 0 until numBands) { - bandsData(band) = rasterData.getSamples(0, 0, width, height, band, new Array[Double](width * height)) - serializedSampleDimensions(band) = Serde.serializeGridSampleDimension(raster.getSampleDimension(band)) + def merge(buffer1: ArrayBuffer[BandData], buffer2: ArrayBuffer[BandData]): ArrayBuffer[BandData] = { + if (buffer1.nonEmpty && buffer2.nonEmpty) { + if (buffer1.head.width != buffer2.head.width || buffer1.head.height != buffer2.head.height) { + throw new IllegalArgumentException("All rasters must have the same dimensions") } - - buffer += BandData(bandsData, input._2, Serde.serialize(raster), serializedSampleDimensions) - buffer } - - def merge(buffer1: ArrayBuffer[BandData], buffer2: ArrayBuffer[BandData]): ArrayBuffer[BandData] = { val combined = ArrayBuffer.concat(buffer1, buffer2) if (combined.map(_.index).distinct.length != combined.length) { throw new IllegalArgumentException("Indexes shouldn't be repeated.") @@ -95,24 +87,37 @@ class RS_Union_Aggr extends Aggregator[(GridCoverage2D, Int), ArrayBuffer[BandDa throw new IllegalArgumentException("Index should be in an arithmetic sequence.") } - val totalBands = sortedMerged.map(_.bandsData.length).sum - val referenceRaster = Serde.deserialize(sortedMerged.head.serializedRaster) - val width = RasterAccessors.getWidth(referenceRaster) - val height = RasterAccessors.getHeight(referenceRaster) - val dataTypeCode = RasterUtils.getRaster(referenceRaster.getRenderedImage).getDataBuffer.getDataType - val resultRaster: WritableRaster = RasterFactory.createBandedRaster(dataTypeCode, width, height, totalBands, null) - val gridSampleDimensions = sortedMerged.flatMap(_.serializedSampleDimensions.map(Serde.deserializeGridSampleDimension)).toArray - - var currentBand = 0 - sortedMerged.foreach { bandData => - bandData.bandsData.foreach { band => - resultRaster.setSamples(0, 0, width, height, currentBand, band) - currentBand += 1 + val rasters = sortedMerged.map(d => RasterSerializer.deserialize(d.serializedRaster)) + try { + val gridSampleDimensions = rasters.flatMap(_.getSampleDimensions).toArray + val totalBands = rasters.map(_.getNumSampleDimensions).sum + val referenceRaster = rasters.head + val width = RasterAccessors.getWidth(referenceRaster) + val height = RasterAccessors.getHeight(referenceRaster) + val dataTypeCode = RasterUtils.getRaster(referenceRaster.getRenderedImage).getDataBuffer.getDataType + val resultRaster: WritableRaster = RasterFactory.createBandedRaster(dataTypeCode, width, height, totalBands, null) + + var currentBand = 0 + rasters.foreach { raster => + var bandIndex = 0 + while (bandIndex < raster.getNumSampleDimensions) { + if (RasterUtils.isDataTypeIntegral(dataTypeCode)) { + val band = RasterUtils.getRaster(raster.getRenderedImage).getSamples(0, 0, width, height, bandIndex, new Array[Int](width * height)) + resultRaster.setSamples(0, 0, width, height, currentBand, band) + } else { + val band = RasterUtils.getRaster(raster.getRenderedImage).getSamples(0, 0, width, height, bandIndex, new Array[Double](width * height)) + resultRaster.setSamples(0, 0, width, height, currentBand, band) + } + currentBand += 1 + bandIndex += 1 + } } - } - val noDataValue = RasterBandAccessors.getBandNoDataValue(referenceRaster) - RasterUtils.clone(resultRaster, referenceRaster.getGridGeometry, gridSampleDimensions, referenceRaster, noDataValue, true) + val noDataValue = RasterBandAccessors.getBandNoDataValue(referenceRaster) + RasterUtils.clone(resultRaster, referenceRaster.getGridGeometry, gridSampleDimensions, referenceRaster, noDataValue, false) + } finally { + rasters.foreach(_.dispose(true)) + } } val serde = ExpressionEncoder[GridCoverage2D]