diff --git a/common/src/main/java/org/apache/sedona/common/raster/RasterEditors.java b/common/src/main/java/org/apache/sedona/common/raster/RasterEditors.java index 27b77f529b..eb0db1eab4 100644 --- a/common/src/main/java/org/apache/sedona/common/raster/RasterEditors.java +++ b/common/src/main/java/org/apache/sedona/common/raster/RasterEditors.java @@ -21,6 +21,7 @@ import org.apache.sedona.common.FunctionsGeoTools; import org.apache.sedona.common.utils.RasterUtils; import org.geotools.coverage.CoverageFactoryFinder; +import org.geotools.coverage.GridSampleDimension; import org.geotools.coverage.grid.GridCoverage2D; import org.geotools.coverage.grid.GridCoverageFactory; import org.geotools.coverage.grid.GridEnvelope2D; @@ -29,7 +30,9 @@ import org.geotools.geometry.Envelope2D; import org.geotools.referencing.crs.DefaultEngineeringCRS; import org.geotools.referencing.operation.transform.AffineTransform2D; +import org.opengis.coverage.SampleDimensionType; import org.opengis.coverage.grid.GridCoverage; +import org.opengis.geometry.Envelope; import org.opengis.metadata.spatial.PixelOrientation; import org.opengis.referencing.FactoryException; import org.opengis.referencing.crs.CoordinateReferenceSystem; @@ -39,18 +42,57 @@ import org.opengis.referencing.operation.TransformException; import javax.media.jai.Interpolation; +import javax.media.jai.RasterFactory; +import java.awt.*; import java.awt.geom.Point2D; -import java.awt.image.DataBuffer; -import java.awt.image.RenderedImage; +import java.awt.image.*; import java.util.Arrays; import java.util.Map; import java.util.Objects; +import static java.lang.Double.NaN; import static org.apache.sedona.common.raster.MapAlgebra.addBandFromArray; import static org.apache.sedona.common.raster.MapAlgebra.bandAsArray; public class RasterEditors { + + /** + * Changes the band pixel type of a specific band of a raster. + * + * @param raster The input raster. + * @param dataType The desired data type of the pixel. + * @return The modified raster with updated pixel type. + */ + + public static GridCoverage2D setPixelType(GridCoverage2D raster, String dataType) { + int newDataType = RasterUtils.getDataTypeCode(dataType); + + // Extracting the original data + RenderedImage originalImage = raster.getRenderedImage(); + Raster originalData = RasterUtils.getRaster(originalImage); + + int width = originalImage.getWidth(); + int height = originalImage.getHeight(); + int numBands = originalImage.getSampleModel().getNumBands(); + + // Create a new writable raster with the specified data type + WritableRaster modifiedRaster = RasterFactory.createBandedRaster(newDataType, width, height, numBands, null); + + // Populate modified raster and recreate sample dimensions + GridSampleDimension[] sampleDimensions = raster.getSampleDimensions(); + for (int band = 0; band < numBands; band++) { + double[] samples = originalData.getSamples(0, 0, width, height, band, (double[]) null); + modifiedRaster.setSamples(0, 0, width, height, band, samples); + if (!Double.isNaN(RasterUtils.getNoDataValue(sampleDimensions[band]))) { + sampleDimensions[band] = RasterUtils.createSampleDimensionWithNoDataValue(sampleDimensions[band], castRasterDataType(RasterUtils.getNoDataValue(sampleDimensions[band]), newDataType)); + } + } + + // Clone the original GridCoverage2D with the modified raster + return RasterUtils.clone(modifiedRaster, raster.getGridGeometry(), sampleDimensions, raster, null, true); + } + public static GridCoverage2D setSrid(GridCoverage2D raster, int srid) { CoordinateReferenceSystem crs; @@ -345,7 +387,10 @@ public static GridCoverage2D normalizeAll(GridCoverage2D rasterGeom, double minL private static double castRasterDataType(double value, int dataType) { switch (dataType) { case DataBuffer.TYPE_BYTE: - return (byte) value; + // Cast to unsigned byte (0-255) + double remainder = value%256; + double v = (remainder < 0) ? remainder+256 : remainder; + return (int) v; case DataBuffer.TYPE_SHORT: return (short) value; case DataBuffer.TYPE_INT: diff --git a/common/src/test/java/org/apache/sedona/common/raster/RasterEditorsTest.java b/common/src/test/java/org/apache/sedona/common/raster/RasterEditorsTest.java index 2c3b99a4f4..895127525f 100644 --- a/common/src/test/java/org/apache/sedona/common/raster/RasterEditorsTest.java +++ b/common/src/test/java/org/apache/sedona/common/raster/RasterEditorsTest.java @@ -32,6 +32,76 @@ import static org.junit.Assert.assertThrows; public class RasterEditorsTest extends RasterTestBase { + + @Test + public void testSetBandPixelType() throws FactoryException { + GridCoverage2D testRaster = RasterConstructors.makeEmptyRaster(4, "F", 4, 4, 0, 0, 1); + double[] bandValues1 = {1.1,2.1,3.1,4.1,5.1,6.1,7.1,8.1,9.1,10.1,11.1,12.1,13.1,14.1,15.1,99.2}; + double[] bandValues2 = {17.9, 18.9, 19.9, 20.9, 21.9, 22.9, 23.9, 24.9, 25.9, 26.9, 27.9, 28.9, 29.9, 30.9, 31.9, 32.9}; + double[] bandValues3 = {-10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5}; + double[] bandValues4 = {65535, 65536, 65537, 65538, 65539, 65540, 65541, 65542, 65543, 65544, 65545, 65546, 65547, 65548, 65549, -9999}; + + testRaster = MapAlgebra.addBandFromArray(testRaster, bandValues1, 1); + testRaster = MapAlgebra.addBandFromArray(testRaster, bandValues2, 2); + testRaster = MapAlgebra.addBandFromArray(testRaster, bandValues3, 3); + testRaster = MapAlgebra.addBandFromArray(testRaster, bandValues4, 4); + testRaster = RasterBandEditors.setBandNoDataValue(testRaster, 1, 99.2); + testRaster = RasterBandEditors.setBandNoDataValue(testRaster, 4, -9999.0); + + GridCoverage2D modifiedRaster = RasterEditors.setPixelType(testRaster, "D"); + + assertEquals(DataBuffer.TYPE_DOUBLE, modifiedRaster.getRenderedImage().getSampleModel().getDataType()); + assertEquals(99.19999694824219, RasterUtils.getNoDataValue(modifiedRaster.getSampleDimension(0)), 0.01); + assertEquals(-9999, RasterUtils.getNoDataValue(modifiedRaster.getSampleDimension(3)), 0.01); + assertEquals(testRaster.getRenderedImage().getWidth(), modifiedRaster.getRenderedImage().getWidth()); + assertEquals(testRaster.getRenderedImage().getHeight(), modifiedRaster.getRenderedImage().getHeight()); + + modifiedRaster = RasterEditors.setPixelType(testRaster, "F"); + + assertEquals(DataBuffer.TYPE_FLOAT, modifiedRaster.getRenderedImage().getSampleModel().getDataType()); + assertEquals(99.19999694824219, RasterUtils.getNoDataValue(modifiedRaster.getSampleDimension(0)), 0.01); + assertEquals(-9999, RasterUtils.getNoDataValue(modifiedRaster.getSampleDimension(3)), 0.01); + assertEquals(testRaster.getRenderedImage().getWidth(), modifiedRaster.getRenderedImage().getWidth()); + assertEquals(testRaster.getRenderedImage().getHeight(), modifiedRaster.getRenderedImage().getHeight()); + + + modifiedRaster = RasterEditors.setPixelType(testRaster, "I"); + + assertEquals(DataBuffer.TYPE_INT, modifiedRaster.getRenderedImage().getSampleModel().getDataType()); + assertEquals(99, RasterUtils.getNoDataValue(modifiedRaster.getSampleDimension(0)), 0.01); + assertEquals(-9999, RasterUtils.getNoDataValue(modifiedRaster.getSampleDimension(3)), 0.01); + assertEquals(testRaster.getRenderedImage().getWidth(), modifiedRaster.getRenderedImage().getWidth()); + assertEquals(testRaster.getRenderedImage().getHeight(), modifiedRaster.getRenderedImage().getHeight()); + + modifiedRaster = RasterEditors.setPixelType(testRaster, "S"); + double[] expected = {-1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, -9999.0}; + + assertEquals(DataBuffer.TYPE_SHORT, modifiedRaster.getRenderedImage().getSampleModel().getDataType()); + assertEquals(99, RasterUtils.getNoDataValue(modifiedRaster.getSampleDimension(0)), 0.01); + assertEquals(-9999, RasterUtils.getNoDataValue(modifiedRaster.getSampleDimension(3)), 0.01); + assertEquals(Arrays.toString(expected), Arrays.toString(MapAlgebra.bandAsArray(modifiedRaster, 4))); + assertEquals(testRaster.getRenderedImage().getWidth(), modifiedRaster.getRenderedImage().getWidth()); + assertEquals(testRaster.getRenderedImage().getHeight(), modifiedRaster.getRenderedImage().getHeight()); + + modifiedRaster = RasterEditors.setPixelType(testRaster, "US"); + expected = new double[]{65526.0, 65527.0, 65528.0, 65529.0, 65530.0, 65531.0, 65532.0, 65533.0, 65534.0, 65535.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0}; + + assertEquals(DataBuffer.TYPE_USHORT, modifiedRaster.getRenderedImage().getSampleModel().getDataType()); + assertEquals(99, RasterUtils.getNoDataValue(modifiedRaster.getSampleDimension(0)), 0.01); + assertEquals(55537, RasterUtils.getNoDataValue(modifiedRaster.getSampleDimension(3)), 0.01); + assertEquals(Arrays.toString(expected), Arrays.toString(MapAlgebra.bandAsArray(modifiedRaster, 3))); + assertEquals(testRaster.getRenderedImage().getWidth(), modifiedRaster.getRenderedImage().getWidth()); + assertEquals(testRaster.getRenderedImage().getHeight(), modifiedRaster.getRenderedImage().getHeight()); + + modifiedRaster = RasterEditors.setPixelType(testRaster, "B"); + + assertEquals(DataBuffer.TYPE_BYTE, modifiedRaster.getRenderedImage().getSampleModel().getDataType()); + assertEquals(99, RasterUtils.getNoDataValue(modifiedRaster.getSampleDimension(0)), 0.01); + assertEquals(241, RasterUtils.getNoDataValue(modifiedRaster.getSampleDimension(3)), 0.01); + assertEquals(testRaster.getRenderedImage().getWidth(), modifiedRaster.getRenderedImage().getWidth()); + assertEquals(testRaster.getRenderedImage().getHeight(), modifiedRaster.getRenderedImage().getHeight()); + } + @Test public void testSetGeoReferenceWithRaster() throws IOException { GridCoverage2D raster = rasterFromGeoTiff(resourceFolder + "raster/test1.tiff"); diff --git a/docs/api/sql/Raster-loader.md b/docs/api/sql/Raster-loader.md index ef648838e1..4f5a5137f8 100644 --- a/docs/api/sql/Raster-loader.md +++ b/docs/api/sql/Raster-loader.md @@ -65,7 +65,7 @@ Accepts one of: 3. "I" - 32 bits signed Integer 4. "S" - 16 bits signed Short 5. "US" - 16 bits unsigned Short - 6. "B" - 8 bits Byte + 6. "B" - 8 bits unsigned Byte * Width: The width of the raster in pixels. * Height: The height of the raster in pixels. * UpperleftX: The X coordinate of the upper left corner of the raster, in terms of the CRS units. diff --git a/docs/api/sql/Raster-operators.md b/docs/api/sql/Raster-operators.md index a3b75dbd86..27899f2a9f 100644 --- a/docs/api/sql/Raster-operators.md +++ b/docs/api/sql/Raster-operators.md @@ -1767,6 +1767,35 @@ Output: -3.000000 ``` +### RS_SetPixelType + +Introduction: Returns a modified raster with the desired pixel data type. + +The `dataType` parameter accepts one of the following strings. + +- "D" - 64 bits Double +- "F" - 32 bits Float +- "I" - 32 bits signed Integer +- "S" - 16 bits signed Short +- "US" - 16 bits unsigned Short +- "B" - 8 bits unsigned Byte + +!!!note + If the specified `dataType` is narrower than the original data type, the function will truncate the pixel values to fit the new data type range. + +Format: +``` +RS_SetPixelType(raster: Raster, dataType: String) +``` + +Since: `v1.6.0` + +SQL Example: + +```sql +RS_SetPixelType(raster, "I") +``` + ### RS_SetValue Introduction: Returns a raster by replacing the value of pixel specified by `colX` and `rowY`. diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala index db9fe9d099..e0adb53938 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala @@ -217,6 +217,7 @@ object Catalog { function[RS_SetSRID](), function[RS_SetGeoReference](), function[RS_SetBandNoDataValue](), + function[RS_SetPixelType](), function[RS_SetValues](), function[RS_SetValue](), function[RS_SRID](), diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterEditors.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterEditors.scala index 5ccf34ccb5..673e59911a 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterEditors.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterEditors.scala @@ -37,6 +37,12 @@ case class RS_SetGeoReference(inputExpressions: Seq[Expression]) extends Inferre } } +case class RS_SetPixelType(inputExpressions: Seq[Expression]) extends InferredExpression(RasterEditors.setPixelType _) { + protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { + copy(inputExpressions = newChildren) + } +} + case class RS_Resample(inputExpressions: Seq[Expression]) extends InferredExpression( nullTolerantInferrableFunction4(RasterEditors.resample), nullTolerantInferrableFunction5(RasterEditors.resample), nullTolerantInferrableFunction7(RasterEditors.resample)) { diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala b/spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala index 6d7a9cf512..4fc41a67b1 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala @@ -229,6 +229,35 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen assert(result4.isInstanceOf[GridCoverage2D]) } + it("should pass RS_SetPixelType") { + var df = sparkSession.read.format("binaryFile").load(resourceFolder + "raster/test1.tiff") + df = df.selectExpr("RS_FromGeoTiff(content) as raster") + + val df1 = df.selectExpr("RS_SetPixelType(raster, 'D') as modifiedRaster") + val result1 = df1.selectExpr("RS_BandPixelType(modifiedRaster)").first().get(0).toString + assert(result1 == "REAL_64BITS") + + val df2 = df.selectExpr("RS_SetPixelType(raster, 'F') as modifiedRaster") + val result2 = df2.selectExpr("RS_BandPixelType(modifiedRaster)").first().get(0).toString + assert(result2 == "REAL_32BITS") + + val df3 = df.selectExpr("RS_SetPixelType(raster, 'I') as modifiedRaster") + val result3 = df3.selectExpr("RS_BandPixelType(modifiedRaster)").first().get(0).toString + assert(result3 == "SIGNED_32BITS") + + val df4 = df.selectExpr("RS_SetPixelType(raster, 'S') as modifiedRaster") + val result4 = df4.selectExpr("RS_BandPixelType(modifiedRaster)").first().get(0).toString + assert(result4 == "SIGNED_16BITS") + + val df5 = df.selectExpr("RS_SetPixelType(raster, 'US') as modifiedRaster") + val result5 = df5.selectExpr("RS_BandPixelType(modifiedRaster)").first().get(0).toString + assert(result5 == "UNSIGNED_16BITS") + + val df6 = df.selectExpr("RS_SetPixelType(raster, 'B') as modifiedRaster") + val result6 = df6.selectExpr("RS_BandPixelType(modifiedRaster)").first().get(0).toString + assert(result6 == "UNSIGNED_8BITS") + } + it("should pass RS_Array") { val df = sparkSession.sql("SELECT RS_Array(6, 1e-6) as band") val result = df.first().getAs[mutable.WrappedArray[Double]](0)