Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SEDONA-549] Fix memory bloat issue of RS_Union_Aggr when working with non-double band data #1402

Merged
merged 1 commit into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,11 @@ public static GridCoverage2D addBandFromArray(GridCoverage2D rasterGeom, double[
throw new IllegalArgumentException("Band index is out of bounds. Must be between 1 and " + (numBands + 1) + ")");
}

Double[] bandValuesClass = Arrays.stream(bandValues).boxed().toArray(Double[]::new);
if (bandIndex == numBands + 1) {
return RasterUtils.copyRasterAndAppendBand(rasterGeom, bandValuesClass, noDataValue);
return RasterUtils.copyRasterAndAppendBand(rasterGeom, bandValues, noDataValue);
}
else {
return RasterUtils.copyRasterAndReplaceBand(rasterGeom, bandIndex, bandValuesClass, noDataValue, true);
return RasterUtils.copyRasterAndReplaceBand(rasterGeom, bandIndex, bandValues, noDataValue, true);
}
}

Expand All @@ -94,12 +93,11 @@ public static GridCoverage2D addBandFromArray(GridCoverage2D rasterGeom, double[
throw new IllegalArgumentException("Band index is out of bounds. Must be between 1 and " + (numBands + 1) + ")");
}

Double[] bandValuesClass = Arrays.stream(bandValues).boxed().toArray(Double[]::new);
if (bandIndex == numBands + 1) {
return RasterUtils.copyRasterAndAppendBand(rasterGeom, bandValuesClass);
return RasterUtils.copyRasterAndAppendBand(rasterGeom, bandValues);
}
else {
return RasterUtils.copyRasterAndReplaceBand(rasterGeom, bandIndex, bandValuesClass);
return RasterUtils.copyRasterAndReplaceBand(rasterGeom, bandIndex, bandValues);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,16 +135,16 @@ public static GridCoverage2D addBand(GridCoverage2D toRaster, GridCoverage2D fro
if (RasterUtils.isDataTypeIntegral(dataTypeCode)) {
int[] bandValues = rasterData.getSamples(0, 0, width, height, fromBand - 1, (int[]) null);
if (numBands + 1 == toRasterIndex) {
return RasterUtils.copyRasterAndAppendBand(toRaster, Arrays.stream(bandValues).boxed().toArray(Integer[]::new), noDataValue);
return RasterUtils.copyRasterAndAppendBand(toRaster, bandValues, noDataValue);
} else {
return RasterUtils.copyRasterAndReplaceBand(toRaster, fromBand, Arrays.stream(bandValues).boxed().toArray(Integer[]::new), noDataValue, false);
return RasterUtils.copyRasterAndReplaceBand(toRaster, fromBand, bandValues, noDataValue, false);
}
} else {
double[] bandValues = rasterData.getSamples(0, 0, width, height, fromBand - 1, (double[]) null);
if (numBands + 1 == toRasterIndex) {
return RasterUtils.copyRasterAndAppendBand(toRaster, Arrays.stream(bandValues).boxed().toArray(Double[]::new), noDataValue);
return RasterUtils.copyRasterAndAppendBand(toRaster, bandValues, noDataValue);
} else {
return RasterUtils.copyRasterAndReplaceBand(toRaster, fromBand, Arrays.stream(bandValues).boxed().toArray(Double[]::new), noDataValue, false);
return RasterUtils.copyRasterAndReplaceBand(toRaster, fromBand, bandValues, noDataValue, false);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@

import javax.media.jai.RenderedImageAdapter;
import java.awt.image.RenderedImage;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.Serializable;
import java.net.URI;
Expand Down Expand Up @@ -178,22 +176,4 @@ public static GridCoverage2D deserialize(byte[] bytes) throws IOException, Class
return state.restore();
}
}

public static byte[] serializeGridSampleDimension(GridSampleDimension sampleDimension) {
Kryo kryo = kryos.get();
ByteArrayOutputStream baos = new ByteArrayOutputStream();
Output output = new Output(baos);
GridSampleDimensionSerializer serializer = new GridSampleDimensionSerializer();
serializer.write(kryo, output, sampleDimension);
output.close();
return baos.toByteArray();
}

public static GridSampleDimension deserializeGridSampleDimension(byte[] data) {
Kryo kryo = kryos.get();
Input input = new Input(new ByteArrayInputStream(data));
GridSampleDimensionSerializer serializer = new GridSampleDimensionSerializer();
return serializer.read(kryo, input, GridSampleDimension.class);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ public static boolean isDataTypeIntegral(int dataTypeCode) {
* @param bandValues
* @return
*/
public static GridCoverage2D copyRasterAndAppendBand(GridCoverage2D gridCoverage2D, Number[] bandValues, Double noDataValue) {
public static GridCoverage2D copyRasterAndAppendBand(GridCoverage2D gridCoverage2D, Object bandValues, Double noDataValue) {
// Get the original image and its properties
RenderedImage originalImage = gridCoverage2D.getRenderedImage();
Raster raster = getRaster(originalImage);
Expand All @@ -565,17 +565,19 @@ public static GridCoverage2D copyRasterAndAppendBand(GridCoverage2D gridCoverage
// Copy the raster data and append the new band values
for (int i = 0; i < raster.getWidth(); i++) {
for (int j = 0; j < raster.getHeight(); j++) {
if (bandValues instanceof Double[]) {
if (bandValues instanceof double[]) {
double[] values = (double[]) bandValues;
double[] pixels = raster.getPixel(i, j, (double[]) null);
double[] copiedPixels = new double[pixels.length + 1];
System.arraycopy(pixels, 0, copiedPixels, 0, pixels.length);
copiedPixels[pixels.length] = (double) bandValues[j * raster.getWidth() + i];
copiedPixels[pixels.length] = values[j * raster.getWidth() + i];
wr.setPixel(i, j, copiedPixels);
} else if (bandValues instanceof Integer[]) {
} else if (bandValues instanceof int[]) {
int[] values = (int[]) bandValues;
int[] pixels = raster.getPixel(i, j, (int[]) null);
int[] copiedPixels = new int[pixels.length + 1];
System.arraycopy(pixels, 0, copiedPixels, 0, pixels.length);
copiedPixels[pixels.length] = (int) bandValues[j * raster.getWidth() + i];
copiedPixels[pixels.length] = values[j * raster.getWidth() + i];
wr.setPixel(i, j, copiedPixels);
}
}
Expand All @@ -594,11 +596,11 @@ public static GridCoverage2D copyRasterAndAppendBand(GridCoverage2D gridCoverage
return clone(wr, gridCoverage2D.getGridGeometry(), sampleDimensions, gridCoverage2D, null, true);
}

public static GridCoverage2D copyRasterAndAppendBand(GridCoverage2D gridCoverage2D, Number[] bandValues) {
public static GridCoverage2D copyRasterAndAppendBand(GridCoverage2D gridCoverage2D, Object bandValues) {
return copyRasterAndAppendBand(gridCoverage2D, bandValues, null);
}

public static GridCoverage2D copyRasterAndReplaceBand(GridCoverage2D gridCoverage2D, int bandIndex, Number[] bandValues, Double noDataValue, boolean removeNoDataIfNull) {
public static GridCoverage2D copyRasterAndReplaceBand(GridCoverage2D gridCoverage2D, int bandIndex, Object bandValues, Double noDataValue, boolean removeNoDataIfNull) {
// Do not allow the band index to be out of bounds
ensureBand(gridCoverage2D, bandIndex);
// Get the original image and its properties
Expand All @@ -608,13 +610,15 @@ public static GridCoverage2D copyRasterAndReplaceBand(GridCoverage2D gridCoverag
// Copy the raster data and replace the band values
for (int i = 0; i < raster.getWidth(); i++) {
for (int j = 0; j < raster.getHeight(); j++) {
if (bandValues instanceof Double[]) {
if (bandValues instanceof double[]) {
double[] values = (double[]) bandValues;
double[] bands = raster.getPixel(i, j, (double[]) null);
bands[bandIndex - 1] = (double) bandValues[j * raster.getWidth() + i];
bands[bandIndex - 1] = values[j * raster.getWidth() + i];
wr.setPixel(i, j, bands);
} else if (bandValues instanceof Integer[]) {
} else if (bandValues instanceof int[]) {
int[] values = (int[]) bandValues;
int[] bands = raster.getPixel(i, j, (int[]) null);
bands[bandIndex - 1] = (int) bandValues[j * raster.getWidth() + i];
bands[bandIndex - 1] = values[j * raster.getWidth() + i];
wr.setPixel(i, j, bands);
}
}
Expand All @@ -629,7 +633,7 @@ public static GridCoverage2D copyRasterAndReplaceBand(GridCoverage2D gridCoverag
return clone(wr, gridCoverage2D.getGridGeometry(), sampleDimensions, gridCoverage2D, null, true);
}

public static GridCoverage2D copyRasterAndReplaceBand(GridCoverage2D gridCoverage2D, int bandIndex, Number[] bandValues) {
public static GridCoverage2D copyRasterAndReplaceBand(GridCoverage2D gridCoverage2D, int bandIndex, Object bandValues) {
return copyRasterAndReplaceBand(gridCoverage2D, bandIndex, bandValues, null, false);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,23 @@

package org.apache.spark.sql.sedona_sql.expressions.raster

import org.apache.sedona.common.raster.serde.Serde
import org.apache.sedona.common.raster.{RasterAccessors, RasterBandAccessors}
import org.apache.sedona.common.utils.RasterUtils
import org.apache.sedona.sql.utils.RasterSerializer
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
import org.geotools.coverage.GridSampleDimension
import org.geotools.coverage.grid.GridCoverage2D

import java.awt.image.WritableRaster
import javax.media.jai.RasterFactory
import scala.collection.mutable.ArrayBuffer

case class BandData(
var bandsData: Array[Array[Double]],
var index: Int,
var serializedRaster: Array[Byte],
var serializedSampleDimensions: Array[Array[Byte]]
)

index: Int,
width: Int,
height: Int,
serializedRaster: Array[Byte])

/**
* Return a raster containing bands at given indexes from all rasters in a given column
Expand All @@ -48,37 +45,32 @@ class RS_Union_Aggr extends Aggregator[(GridCoverage2D, Int), ArrayBuffer[BandDa
def zero: ArrayBuffer[BandData] = ArrayBuffer[BandData]()

def reduce(buffer: ArrayBuffer[BandData], input: (GridCoverage2D, Int)): ArrayBuffer[BandData] = {
val raster = input._1
val renderedImage = raster.getRenderedImage
val numBands = renderedImage.getSampleModel.getNumBands
val width = renderedImage.getWidth
val height = renderedImage.getHeight

// First check if this is the first raster to set dimensions or validate against existing dimensions
if (buffer.nonEmpty) {
val referenceRaster = Serde.deserialize(buffer.head.serializedRaster)
val refWidth = RasterAccessors.getWidth(referenceRaster)
val refHeight = RasterAccessors.getHeight(referenceRaster)
if (width != refWidth || height != refHeight) {
throw new IllegalArgumentException("All rasters must have the same dimensions")
}
val (raster, index) = input
val renderedImage = raster.getRenderedImage
val width = renderedImage.getWidth
val height = renderedImage.getHeight
val serializedRaster = RasterSerializer.serialize(raster)
raster.dispose(true)

// First check if this is the first raster to set dimensions or validate against existing dimensions
if (buffer.nonEmpty) {
val refWidth = buffer.head.width
val refHeight = buffer.head.height
if (width != refWidth || height != refHeight) {
throw new IllegalArgumentException("All rasters must have the same dimensions")
}
}

// Extract data for each band
val rasterData = renderedImage.getData
val bandsData = Array.ofDim[Double](numBands, width * height)
val serializedSampleDimensions = new Array[Array[Byte]](numBands)
buffer += BandData(index, width, height, serializedRaster)
buffer
}

for (band <- 0 until numBands) {
bandsData(band) = rasterData.getSamples(0, 0, width, height, band, new Array[Double](width * height))
serializedSampleDimensions(band) = Serde.serializeGridSampleDimension(raster.getSampleDimension(band))
def merge(buffer1: ArrayBuffer[BandData], buffer2: ArrayBuffer[BandData]): ArrayBuffer[BandData] = {
if (buffer1.nonEmpty && buffer2.nonEmpty) {
if (buffer1.head.width != buffer2.head.width || buffer1.head.height != buffer2.head.height) {
throw new IllegalArgumentException("All rasters must have the same dimensions")
}

buffer += BandData(bandsData, input._2, Serde.serialize(raster), serializedSampleDimensions)
buffer
}

def merge(buffer1: ArrayBuffer[BandData], buffer2: ArrayBuffer[BandData]): ArrayBuffer[BandData] = {
val combined = ArrayBuffer.concat(buffer1, buffer2)
if (combined.map(_.index).distinct.length != combined.length) {
throw new IllegalArgumentException("Indexes shouldn't be repeated.")
Expand All @@ -95,24 +87,37 @@ class RS_Union_Aggr extends Aggregator[(GridCoverage2D, Int), ArrayBuffer[BandDa
throw new IllegalArgumentException("Index should be in an arithmetic sequence.")
}

val totalBands = sortedMerged.map(_.bandsData.length).sum
val referenceRaster = Serde.deserialize(sortedMerged.head.serializedRaster)
val width = RasterAccessors.getWidth(referenceRaster)
val height = RasterAccessors.getHeight(referenceRaster)
val dataTypeCode = RasterUtils.getRaster(referenceRaster.getRenderedImage).getDataBuffer.getDataType
val resultRaster: WritableRaster = RasterFactory.createBandedRaster(dataTypeCode, width, height, totalBands, null)
val gridSampleDimensions = sortedMerged.flatMap(_.serializedSampleDimensions.map(Serde.deserializeGridSampleDimension)).toArray

var currentBand = 0
sortedMerged.foreach { bandData =>
bandData.bandsData.foreach { band =>
resultRaster.setSamples(0, 0, width, height, currentBand, band)
currentBand += 1
val rasters = sortedMerged.map(d => RasterSerializer.deserialize(d.serializedRaster))
try {
val gridSampleDimensions = rasters.flatMap(_.getSampleDimensions).toArray
val totalBands = rasters.map(_.getNumSampleDimensions).sum
val referenceRaster = rasters.head
val width = RasterAccessors.getWidth(referenceRaster)
val height = RasterAccessors.getHeight(referenceRaster)
val dataTypeCode = RasterUtils.getRaster(referenceRaster.getRenderedImage).getDataBuffer.getDataType
val resultRaster: WritableRaster = RasterFactory.createBandedRaster(dataTypeCode, width, height, totalBands, null)

var currentBand = 0
rasters.foreach { raster =>
var bandIndex = 0
while (bandIndex < raster.getNumSampleDimensions) {
if (RasterUtils.isDataTypeIntegral(dataTypeCode)) {
val band = RasterUtils.getRaster(raster.getRenderedImage).getSamples(0, 0, width, height, bandIndex, new Array[Int](width * height))
resultRaster.setSamples(0, 0, width, height, currentBand, band)
} else {
val band = RasterUtils.getRaster(raster.getRenderedImage).getSamples(0, 0, width, height, bandIndex, new Array[Double](width * height))
resultRaster.setSamples(0, 0, width, height, currentBand, band)
}
currentBand += 1
bandIndex += 1
}
}
}

val noDataValue = RasterBandAccessors.getBandNoDataValue(referenceRaster)
RasterUtils.clone(resultRaster, referenceRaster.getGridGeometry, gridSampleDimensions, referenceRaster, noDataValue, true)
val noDataValue = RasterBandAccessors.getBandNoDataValue(referenceRaster)
RasterUtils.clone(resultRaster, referenceRaster.getGridGeometry, gridSampleDimensions, referenceRaster, noDataValue, false)
} finally {
rasters.foreach(_.dispose(true))
}
}

val serde = ExpressionEncoder[GridCoverage2D]
Expand Down
Loading