diff --git a/common/src/main/java/org/apache/sedona/common/raster/RasterConstructors.java b/common/src/main/java/org/apache/sedona/common/raster/RasterConstructors.java index 688c6387db..c7635c70bd 100644 --- a/common/src/main/java/org/apache/sedona/common/raster/RasterConstructors.java +++ b/common/src/main/java/org/apache/sedona/common/raster/RasterConstructors.java @@ -52,9 +52,10 @@ import java.awt.image.RenderedImage; import java.awt.image.WritableRaster; import java.io.IOException; +import java.util.Arrays; +import java.util.Map; import java.util.ArrayList; import java.util.List; -import java.util.Map; public class RasterConstructors { @@ -399,6 +400,39 @@ public static GridCoverage2D makeNonEmptyRaster(int numBands, String bandDataTyp return RasterUtils.create(raster, gridGeometry, null); } + /** + * Make a non-empty raster from a reference raster and a set of values. The constructed raster will have the same CRS, + * geo-reference metadata, width and height as the reference raster. The number of bands of the reference raster is + * determined by the size of values. The size of values should be multiple of width * height of the reference raster. + * @param ref the reference raster + * @param bandDataType the data type of the band + * @param values the values to set + * @return the constructed raster + */ + public static GridCoverage2D makeNonEmptyRaster(GridCoverage2D ref, String bandDataType, double[] values) { + CoordinateReferenceSystem crs = ref.getCoordinateReferenceSystem(); + int widthInPixel = ref.getRenderedImage().getWidth(); + int heightInPixel = ref.getRenderedImage().getHeight(); + int valuesPerBand = widthInPixel * heightInPixel; + if (values.length == 0) { + throw new IllegalArgumentException("The size of values should be greater than 0"); + } + if (values.length % valuesPerBand != 0) { + throw new IllegalArgumentException("The size of values should be multiple of width * height of the reference raster"); + } + int numBands = values.length / valuesPerBand; + WritableRaster raster = RasterFactory.createBandedRaster(RasterUtils.getDataTypeCode(bandDataType), widthInPixel, heightInPixel, numBands, null); + for (int i = 0; i < numBands; i++) { + double[] bandValues = Arrays.copyOfRange(values, i * valuesPerBand, (i + 1) * valuesPerBand); + raster.setSamples(0, 0, widthInPixel, heightInPixel, i, bandValues); + } + MathTransform transform = ref.getGridGeometry().getGridToCRS(PixelInCell.CELL_CENTER); + GridGeometry2D gridGeometry = new GridGeometry2D( + new GridEnvelope2D(0, 0, widthInPixel, heightInPixel), + PixelInCell.CELL_CENTER, + transform, crs, null); + return RasterUtils.create(raster, gridGeometry, null); + } public static class Tile { private final int tileX; diff --git a/common/src/test/java/org/apache/sedona/common/raster/RasterConstructorsTest.java b/common/src/test/java/org/apache/sedona/common/raster/RasterConstructorsTest.java index e6713b625d..56427f567a 100644 --- a/common/src/test/java/org/apache/sedona/common/raster/RasterConstructorsTest.java +++ b/common/src/test/java/org/apache/sedona/common/raster/RasterConstructorsTest.java @@ -21,12 +21,14 @@ import org.junit.Assert; import org.junit.Test; import org.locationtech.jts.geom.Geometry; -import org.locationtech.jts.io.ParseException; +import org.opengis.coverage.SampleDimensionType; import org.opengis.geometry.DirectPosition; +import org.locationtech.jts.io.ParseException; import org.opengis.referencing.FactoryException; import org.opengis.referencing.operation.TransformException; import java.awt.image.DataBuffer; +import java.awt.image.Raster; import java.awt.image.RenderedImage; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -246,6 +248,65 @@ public void makeEmptyRaster() throws FactoryException { assertEquals("SIGNED_32BITS", gridCoverage2D.getSampleDimension(0).getSampleDimensionType().name()); } + @Test + public void testMakeNonEmptyRaster(){ + double[] bandData = new double[10000]; + for (int i = 0; i < bandData.length; i++) { + bandData[i] = i; + } + GridCoverage2D ref = RasterConstructors.makeNonEmptyRaster(1, "d", 100, 100, + 100, 80, 10, -10, 0, 0, 4326, + new double[][] {bandData}); + + // Test with empty band data + Assert.assertThrows(IllegalArgumentException.class, () -> RasterConstructors.makeNonEmptyRaster(ref, "D", new double[0])); + + // Test with single band + double[] values = new double[10000]; + for (int i = 0; i < values.length; i++) { + values[i] = i * i; + } + GridCoverage2D raster = RasterConstructors.makeNonEmptyRaster(ref, "D", values); + assertEquals(100, raster.getRenderedImage().getWidth()); + assertEquals(100, raster.getRenderedImage().getHeight()); + assertEquals(1, raster.getNumSampleDimensions()); + assertEquals(ref.getGridGeometry(), raster.getGridGeometry()); + assertEquals(ref.getCoordinateReferenceSystem(), raster.getCoordinateReferenceSystem()); + Raster r = RasterUtils.getRaster(raster.getRenderedImage()); + for (int i = 0; i < values.length; i++) { + assertEquals(values[i], r.getSampleDouble(i % 100, i / 100, 0), 0.001); + } + + // Test with multi band + values = new double[20000]; + for (int i = 0; i < values.length; i++) { + values[i] = i * i; + } + raster = RasterConstructors.makeNonEmptyRaster(ref, "D", values); + assertEquals(100, raster.getRenderedImage().getWidth()); + assertEquals(100, raster.getRenderedImage().getHeight()); + assertEquals(2, raster.getNumSampleDimensions()); + assertEquals(SampleDimensionType.REAL_64BITS, raster.getSampleDimension(0).getSampleDimensionType()); + assertEquals(ref.getGridGeometry(), raster.getGridGeometry()); + assertEquals(ref.getCoordinateReferenceSystem(), raster.getCoordinateReferenceSystem()); + r = RasterUtils.getRaster(raster.getRenderedImage()); + for (int i = 0; i < values.length; i++) { + assertEquals(values[i], r.getSampleDouble(i % 100, (i / 100) % 100, i / 10000), 0.001); + } + + // Test with integer data type + values = new double[10000]; + for (int i = 0; i < values.length; i++) { + values[i] = 10.0 + i; + } + raster = RasterConstructors.makeNonEmptyRaster(ref, "US", values); + assertEquals(SampleDimensionType.UNSIGNED_16BITS, raster.getSampleDimension(0).getSampleDimensionType()); + r = RasterUtils.getRaster(raster.getRenderedImage()); + for (int i = 0; i < values.length; i++) { + assertEquals(values[i], r.getSampleDouble(i % 100, i / 100, 0), 0.001); + } + } + @Test public void testInDbTileWithoutPadding() { GridCoverage2D raster = createRandomRaster(DataBuffer.TYPE_BYTE, 100, 100, 1000, 1010, 10, 1, "EPSG:3857"); diff --git a/docs/api/sql/Raster-loader.md b/docs/api/sql/Raster-loader.md index 4f5a5137f8..c0cdafb036 100644 --- a/docs/api/sql/Raster-loader.md +++ b/docs/api/sql/Raster-loader.md @@ -165,6 +165,37 @@ Output: +------------------------------------------------------------------+ ``` +### RS_MakeRaster + +Introduction: Creates a raster from the given array of pixel values. The width, height, geo-reference information, and +the CRS will be taken from the given reference raster. The data type of the resulting raster will be DOUBLE and the +number of bands of the resulting raster will be `data.length / (refRaster.width * refRaster.height)`. + +Since: `v1.6.0` + +Format: `RS_MakeRaster(refRaster: Raster, bandDataType: String, data: ARRAY[Double])` + +* refRaster: The reference raster from which the width, height, geo-reference information, and the CRS will be taken. +* bandDataType: The data type of the bands in the resulting raster. Please refer to the `RS_MakeEmptyRaster` function for the accepted values. +* data: The array of pixel values. The size of the array cannot be 0, and should be multiple of width * height of the reference raster. + +SQL example: + +```sql +WITH r AS (SELECT RS_MakeEmptyRaster(2, 3, 2, 0.0, 0.0, 1.0, -1.0, 0.0, 0.0, 4326) AS rast) +SELECT RS_AsMatrix(RS_MakeRaster(rast, 'D', ARRAY(1, 2, 3, 4, 5, 6))) FROM r +``` + +Output: + +``` ++------------------------------------------------------------+ +|rs_asmatrix(rs_makeraster(rast, D, array(1, 2, 3, 4, 5, 6)))| ++------------------------------------------------------------+ +||1.0 2.0 3.0|\n|4.0 5.0 6.0|\n | ++------------------------------------------------------------+ +``` + ### RS_FromNetCDF Introduction: Returns a raster geometry representing the given record variable short name from a NetCDF file. diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala index 3f688b1e2c..7b7f4862cd 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala @@ -211,6 +211,7 @@ object Catalog { function[RS_FromArcInfoAsciiGrid](), function[RS_FromGeoTiff](), function[RS_MakeEmptyRaster](), + function[RS_MakeRaster](), function[RS_Tile](), function[RS_TileExplode](), function[RS_Envelope](), diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala index ae6c9e103d..c977f09131 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala @@ -69,6 +69,13 @@ case class RS_MakeEmptyRaster(inputExpressions: Seq[Expression]) } } +case class RS_MakeRaster(inputExpressions: Seq[Expression]) + extends InferredExpression(inferrableFunction3(RasterConstructors.makeNonEmptyRaster)) { + protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { + copy(inputExpressions = newChildren) + } +} + case class RS_Tile(inputExpressions: Seq[Expression]) extends InferredExpression( nullTolerantInferrableFunction3(RasterConstructors.rsTile), diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala b/spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala index 2ace2741cb..fa72438d4c 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala @@ -22,7 +22,7 @@ import org.apache.sedona.common.raster.MapAlgebra import org.apache.sedona.common.utils.RasterUtils import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.{Row, SaveMode} -import org.apache.spark.sql.functions.{col, collect_list, expr, row_number} +import org.apache.spark.sql.functions.{col, collect_list, expr, lit, row_number} import org.geotools.coverage.grid.GridCoverage2D import org.junit.Assert.{assertEquals, assertNotNull, assertNull, assertTrue} import org.locationtech.jts.geom.{Coordinate, Geometry} @@ -826,6 +826,26 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen assertEquals(numBands, result(9), 0.001) } + it("Passed RS_MakeRaster") { + val df = sparkSession.read.format("binaryFile").load(resourceFolder + "raster/test1.tiff") + .withColumn("rast", expr("RS_FromGeoTiff(content)")) + val metadata = df.selectExpr("RS_Metadata(rast)").first().getSeq(0) + val width = metadata(2).asInstanceOf[Double].toInt + val height = metadata(3).asInstanceOf[Double].toInt + val values = Array.tabulate(width * height) { i => i * i } + + // Attach values as a new column to the dataframe + val dfWithValues = df.withColumn("values", lit(values)) + + // Call RS_MakeRaster to create a new raster with the values + val result = dfWithValues.selectExpr("RS_MakeRaster(rast, 'D', values) AS rast").first().get(0) + val rast = result.asInstanceOf[GridCoverage2D].getRenderedImage + val r = RasterUtils.getRaster(rast) + for (i <- values.indices) { + assertEquals(values(i), r.getSampleDouble(i % width, i / width, 0), 0.001) + } + } + it("Passed RS_BandAsArray") { val df = sparkSession.read.format("binaryFile").load(resourceFolder + "raster/test1.tiff") val metadata = df.selectExpr("RS_Metadata(RS_FromGeoTiff(content))").first().getSeq(0)