Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SEDONA-543] Fixes RS_Union_aggr throwing referenceRaster is null error when run on cluster #1364

Merged
merged 6 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@

import javax.media.jai.RenderedImageAdapter;
import java.awt.image.RenderedImage;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.Serializable;
import java.net.URI;
Expand Down Expand Up @@ -176,4 +178,19 @@ public static GridCoverage2D deserialize(byte[] bytes) throws IOException, Class
return state.restore();
}
}

public static byte[] serializeGridSampleDimension(GridSampleDimension sampleDimension) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was already implemented in Sedona. Why implement it again?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is as per Kristin's recommendation to implement new serialization/deserialization methods for GridSampleDimension that use kryos.get() to retrieve Kryo objects. kryos.get() avoids repeatedly initializing new Kryo instances.
I should have used the existing GridSampleDimensionSerializer for read and write into output buffer. This is added now.

Kryo kryo = kryos.get();
ByteArrayOutputStream baos = new ByteArrayOutputStream();
Output output = new Output(baos);
kryo.writeClassAndObject(output, sampleDimension);
output.close();
return baos.toByteArray();
}

public static GridSampleDimension deserializeGridSampleDimension(byte[] data) {
Kryo kryo = kryos.get();
Input input = new Input(new ByteArrayInputStream(data));
return (GridSampleDimension) kryo.readClassAndObject(input);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.apache.spark.sql.sedona_sql.expressions.raster

import org.apache.sedona.common.raster.serde.Serde
import org.apache.sedona.common.raster.{RasterAccessors, RasterBandAccessors}
import org.apache.sedona.common.utils.RasterUtils
import org.apache.spark.sql.Encoder
Expand All @@ -29,93 +30,106 @@ import org.geotools.coverage.grid.GridCoverage2D

import java.awt.image.WritableRaster
import javax.media.jai.RasterFactory
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

case class BandData(var bandInt: Array[Int], var bandDouble: Array[Double], var index: Int, var isIntegral: Boolean)
case class BandData(
var bandInt: Array[Int],
var bandDouble: Array[Double],
var index: Int,
var isIntegral: Boolean,
var serializedRaster: Array[Byte],
var serializedSampleDimension: Array[Byte]
)


/**
* Return a raster containing bands at given indexes from all rasters in a given column
*/
class RS_Union_Aggr extends Aggregator[(GridCoverage2D, Int), ArrayBuffer[BandData], GridCoverage2D] {

var width: Int = -1

var height: Int = -1

var referenceRaster: GridCoverage2D = _

var gridSampleDimension: mutable.Map[Int, GridSampleDimension] = new mutable.HashMap()

def zero: ArrayBuffer[BandData] = ArrayBuffer[BandData]()

/**
* Valid raster shape to be the same in the given column
*/
def checkRasterShape(raster: GridCoverage2D): Boolean = {
// first iteration
if (width == -1 && height == -1) {
width = RasterAccessors.getWidth(raster)
height = RasterAccessors.getHeight(raster)
referenceRaster = raster
true
} else {
val widthNewRaster = RasterAccessors.getWidth(raster)
val heightNewRaster = RasterAccessors.getHeight(raster)

width == widthNewRaster && height == heightNewRaster
}
}

def reduce(buffer: ArrayBuffer[BandData], input: (GridCoverage2D, Int)): ArrayBuffer[BandData] = {
val raster = input._1
if (!checkRasterShape(raster)) {
throw new IllegalArgumentException("Rasters provides should be of the same shape.")
}
if (gridSampleDimension.contains(input._2)) {
throw new IllegalArgumentException("Indexes shouldn't be repeated. Index should be in an arithmetic sequence.")
}

val rasterData = RasterUtils.getRaster(raster.getRenderedImage)
val isIntegral = RasterUtils.isDataTypeIntegral(rasterData.getDataBuffer.getDataType)

val bandData = if (isIntegral) {
val band = rasterData.getSamples(0, 0, width, height, 0, null.asInstanceOf[Array[Int]])
BandData(band, null, input._2, isIntegral)
// Serializing GridSampleDimension
val serializedBytes = Serde.serializeGridSampleDimension(raster.getSampleDimension(0))

// Check and set dimensions based on the first raster in the buffer
if (buffer.isEmpty) {
val width = RasterAccessors.getWidth(raster)
val height = RasterAccessors.getHeight(raster)
val referenceSerializedRaster = Serde.serialize(raster)

buffer += BandData(
if (isIntegral) rasterData.getSamples(0, 0, width, height, 0, null.asInstanceOf[Array[Int]]) else null,
if (!isIntegral) rasterData.getSamples(0, 0, width, height, 0, null.asInstanceOf[Array[Double]]) else null,
input._2,
isIntegral,
referenceSerializedRaster,
serializedBytes
)
} else {
val band = rasterData.getSamples(0, 0, width, height, 0, null.asInstanceOf[Array[Double]])
BandData(null, band, input._2, isIntegral)
val referenceRaster = Serde.deserialize(buffer.head.serializedRaster)
val width = RasterAccessors.getWidth(referenceRaster)
val height = RasterAccessors.getHeight(referenceRaster)

if (width != RasterAccessors.getWidth(raster) || height != RasterAccessors.getHeight(raster)) {
throw new IllegalArgumentException("All rasters must have the same dimensions")
}

buffer += BandData(
if (isIntegral) rasterData.getSamples(0, 0, width, height, 0, null.asInstanceOf[Array[Int]]) else null,
if (!isIntegral) rasterData.getSamples(0, 0, width, height, 0, null.asInstanceOf[Array[Double]]) else null,
input._2,
isIntegral,
Serde.serialize(raster),
serializedBytes
)
}
gridSampleDimension = gridSampleDimension + (input._2 -> raster.getSampleDimension(0))

buffer += bandData
buffer
}


def merge(buffer1: ArrayBuffer[BandData], buffer2: ArrayBuffer[BandData]): ArrayBuffer[BandData] = {
ArrayBuffer.concat(buffer1, buffer2)
val combined = ArrayBuffer.concat(buffer1, buffer2)
if (combined.map(_.index).distinct.length != combined.length) {
throw new IllegalArgumentException("Indexes shouldn't be repeated.")
}
combined
}


def finish(merged: ArrayBuffer[BandData]): GridCoverage2D = {
val sortedMerged = merged.sortBy(_.index)
if (sortedMerged.zipWithIndex.exists { case (band, idx) =>
if (idx > 0) (band.index - sortedMerged(idx - 1).index) != (sortedMerged(1).index - sortedMerged(0).index)
else false
}) {
throw new IllegalArgumentException("Index should be in an arithmetic sequence.")
}

val numBands = sortedMerged.length
val rasterData = RasterUtils.getRaster(referenceRaster.getRenderedImage)
val dataTypeCode = rasterData.getDataBuffer.getDataType
val referenceRaster = Serde.deserialize(sortedMerged.head.serializedRaster)
val width = RasterAccessors.getWidth(referenceRaster)
val height = RasterAccessors.getHeight(referenceRaster)
val dataTypeCode = RasterUtils.getRaster(referenceRaster.getRenderedImage).getDataBuffer.getDataType
val resultRaster: WritableRaster = RasterFactory.createBandedRaster(dataTypeCode, width, height, numBands, null)
val gridSampleDimensions: Array[GridSampleDimension] = new Array[GridSampleDimension](numBands)
var indexCheck = 1

for (bandData: BandData <- sortedMerged) {
if (bandData.index != indexCheck) {
throw new IllegalArgumentException("Indexes should be in a valid arithmetic sequence.")
}
indexCheck += 1
gridSampleDimensions(bandData.index - 1) = gridSampleDimension(bandData.index)
if(RasterUtils.isDataTypeIntegral(dataTypeCode))
resultRaster.setSamples(0, 0, width, height, (bandData.index - 1), bandData.bandInt)
else
resultRaster.setSamples(0, 0, width, height, bandData.index - 1, bandData.bandDouble)
for ((bandData, idx) <- sortedMerged.zipWithIndex) {
// Deserializing GridSampleDimension
gridSampleDimensions(idx) = Serde.deserializeGridSampleDimension(bandData.serializedSampleDimension)

if(bandData.isIntegral)
resultRaster.setSamples(0, 0, width, height, idx, bandData.bandInt)
else
resultRaster.setSamples(0, 0, width, height, idx, bandData.bandDouble)
}

val noDataValue = RasterBandAccessors.getBandNoDataValue(referenceRaster)
RasterUtils.clone(resultRaster, referenceRaster.getGridGeometry, gridSampleDimensions, referenceRaster, noDataValue, true)
}
Expand Down