Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SEDONA-514] Add RS_SetPixelType #1276

Merged
merged 8 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.sedona.common.FunctionsGeoTools;
import org.apache.sedona.common.utils.RasterUtils;
import org.geotools.coverage.CoverageFactoryFinder;
import org.geotools.coverage.GridSampleDimension;
import org.geotools.coverage.grid.GridCoverage2D;
import org.geotools.coverage.grid.GridCoverageFactory;
import org.geotools.coverage.grid.GridEnvelope2D;
Expand All @@ -29,7 +30,9 @@
import org.geotools.geometry.Envelope2D;
import org.geotools.referencing.crs.DefaultEngineeringCRS;
import org.geotools.referencing.operation.transform.AffineTransform2D;
import org.opengis.coverage.SampleDimensionType;
import org.opengis.coverage.grid.GridCoverage;
import org.opengis.geometry.Envelope;
import org.opengis.metadata.spatial.PixelOrientation;
import org.opengis.referencing.FactoryException;
import org.opengis.referencing.crs.CoordinateReferenceSystem;
Expand All @@ -39,18 +42,57 @@
import org.opengis.referencing.operation.TransformException;

import javax.media.jai.Interpolation;
import javax.media.jai.RasterFactory;
import java.awt.*;
import java.awt.geom.Point2D;
import java.awt.image.DataBuffer;
import java.awt.image.RenderedImage;
import java.awt.image.*;
import java.util.Arrays;
import java.util.Map;
import java.util.Objects;

import static java.lang.Double.NaN;
import static org.apache.sedona.common.raster.MapAlgebra.addBandFromArray;
import static org.apache.sedona.common.raster.MapAlgebra.bandAsArray;

public class RasterEditors
{

/**
* Changes the band pixel type of a specific band of a raster.
*
* @param raster The input raster.
* @param dataType The desired data type of the pixel.
* @return The modified raster with updated pixel type.
*/

public static GridCoverage2D setPixelType(GridCoverage2D raster, String dataType) {
int newDataType = RasterUtils.getDataTypeCode(dataType);

// Extracting the original data
RenderedImage originalImage = raster.getRenderedImage();
Raster originalData = RasterUtils.getRaster(originalImage);

int width = originalImage.getWidth();
int height = originalImage.getHeight();
int numBands = originalImage.getSampleModel().getNumBands();

// Create a new writable raster with the specified data type
WritableRaster modifiedRaster = RasterFactory.createBandedRaster(newDataType, width, height, numBands, null);

// Populate modified raster and recreate sample dimensions
GridSampleDimension[] sampleDimensions = raster.getSampleDimensions();
for (int band = 0; band < numBands; band++) {
double[] samples = originalData.getSamples(0, 0, width, height, band, (double[]) null);
modifiedRaster.setSamples(0, 0, width, height, band, samples);
if (!Double.isNaN(RasterUtils.getNoDataValue(sampleDimensions[band]))) {
sampleDimensions[band] = RasterUtils.createSampleDimensionWithNoDataValue(sampleDimensions[band], castRasterDataType(RasterUtils.getNoDataValue(sampleDimensions[band]), newDataType));
}
}

// Clone the original GridCoverage2D with the modified raster
return RasterUtils.clone(modifiedRaster, raster.getGridGeometry(), sampleDimensions, raster, null, true);
}

public static GridCoverage2D setSrid(GridCoverage2D raster, int srid)
{
CoordinateReferenceSystem crs;
Expand Down Expand Up @@ -345,7 +387,10 @@ public static GridCoverage2D normalizeAll(GridCoverage2D rasterGeom, double minL
private static double castRasterDataType(double value, int dataType) {
switch (dataType) {
case DataBuffer.TYPE_BYTE:
return (byte) value;
// Cast to unsigned byte (0-255)
double remainder = value%256;
double v = (remainder < 0) ? remainder+256 : remainder;
return (int) v;
case DataBuffer.TYPE_SHORT:
return (short) value;
case DataBuffer.TYPE_INT:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,76 @@
import static org.junit.Assert.assertThrows;

public class RasterEditorsTest extends RasterTestBase {

@Test
public void testSetBandPixelType() throws FactoryException {
GridCoverage2D testRaster = RasterConstructors.makeEmptyRaster(4, "F", 4, 4, 0, 0, 1);
double[] bandValues1 = {1.1,2.1,3.1,4.1,5.1,6.1,7.1,8.1,9.1,10.1,11.1,12.1,13.1,14.1,15.1,99.2};
double[] bandValues2 = {17.9, 18.9, 19.9, 20.9, 21.9, 22.9, 23.9, 24.9, 25.9, 26.9, 27.9, 28.9, 29.9, 30.9, 31.9, 32.9};
double[] bandValues3 = {-10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5};
double[] bandValues4 = {65535, 65536, 65537, 65538, 65539, 65540, 65541, 65542, 65543, 65544, 65545, 65546, 65547, 65548, 65549, -9999};

testRaster = MapAlgebra.addBandFromArray(testRaster, bandValues1, 1);
testRaster = MapAlgebra.addBandFromArray(testRaster, bandValues2, 2);
testRaster = MapAlgebra.addBandFromArray(testRaster, bandValues3, 3);
testRaster = MapAlgebra.addBandFromArray(testRaster, bandValues4, 4);
testRaster = RasterBandEditors.setBandNoDataValue(testRaster, 1, 99.2);
testRaster = RasterBandEditors.setBandNoDataValue(testRaster, 4, -9999.0);

GridCoverage2D modifiedRaster = RasterEditors.setPixelType(testRaster, "D");

assertEquals(DataBuffer.TYPE_DOUBLE, modifiedRaster.getRenderedImage().getSampleModel().getDataType());
assertEquals(99.19999694824219, RasterUtils.getNoDataValue(modifiedRaster.getSampleDimension(0)), 0.01);
assertEquals(-9999, RasterUtils.getNoDataValue(modifiedRaster.getSampleDimension(3)), 0.01);
assertEquals(testRaster.getRenderedImage().getWidth(), modifiedRaster.getRenderedImage().getWidth());
assertEquals(testRaster.getRenderedImage().getHeight(), modifiedRaster.getRenderedImage().getHeight());

modifiedRaster = RasterEditors.setPixelType(testRaster, "F");

assertEquals(DataBuffer.TYPE_FLOAT, modifiedRaster.getRenderedImage().getSampleModel().getDataType());
assertEquals(99.19999694824219, RasterUtils.getNoDataValue(modifiedRaster.getSampleDimension(0)), 0.01);
assertEquals(-9999, RasterUtils.getNoDataValue(modifiedRaster.getSampleDimension(3)), 0.01);
assertEquals(testRaster.getRenderedImage().getWidth(), modifiedRaster.getRenderedImage().getWidth());
assertEquals(testRaster.getRenderedImage().getHeight(), modifiedRaster.getRenderedImage().getHeight());


modifiedRaster = RasterEditors.setPixelType(testRaster, "I");

assertEquals(DataBuffer.TYPE_INT, modifiedRaster.getRenderedImage().getSampleModel().getDataType());
assertEquals(99, RasterUtils.getNoDataValue(modifiedRaster.getSampleDimension(0)), 0.01);
assertEquals(-9999, RasterUtils.getNoDataValue(modifiedRaster.getSampleDimension(3)), 0.01);
assertEquals(testRaster.getRenderedImage().getWidth(), modifiedRaster.getRenderedImage().getWidth());
assertEquals(testRaster.getRenderedImage().getHeight(), modifiedRaster.getRenderedImage().getHeight());

modifiedRaster = RasterEditors.setPixelType(testRaster, "S");
double[] expected = {-1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, -9999.0};

assertEquals(DataBuffer.TYPE_SHORT, modifiedRaster.getRenderedImage().getSampleModel().getDataType());
assertEquals(99, RasterUtils.getNoDataValue(modifiedRaster.getSampleDimension(0)), 0.01);
assertEquals(-9999, RasterUtils.getNoDataValue(modifiedRaster.getSampleDimension(3)), 0.01);
assertEquals(Arrays.toString(expected), Arrays.toString(MapAlgebra.bandAsArray(modifiedRaster, 4)));
assertEquals(testRaster.getRenderedImage().getWidth(), modifiedRaster.getRenderedImage().getWidth());
assertEquals(testRaster.getRenderedImage().getHeight(), modifiedRaster.getRenderedImage().getHeight());

modifiedRaster = RasterEditors.setPixelType(testRaster, "US");
expected = new double[]{65526.0, 65527.0, 65528.0, 65529.0, 65530.0, 65531.0, 65532.0, 65533.0, 65534.0, 65535.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0};

assertEquals(DataBuffer.TYPE_USHORT, modifiedRaster.getRenderedImage().getSampleModel().getDataType());
assertEquals(99, RasterUtils.getNoDataValue(modifiedRaster.getSampleDimension(0)), 0.01);
assertEquals(55537, RasterUtils.getNoDataValue(modifiedRaster.getSampleDimension(3)), 0.01);
assertEquals(Arrays.toString(expected), Arrays.toString(MapAlgebra.bandAsArray(modifiedRaster, 3)));
assertEquals(testRaster.getRenderedImage().getWidth(), modifiedRaster.getRenderedImage().getWidth());
assertEquals(testRaster.getRenderedImage().getHeight(), modifiedRaster.getRenderedImage().getHeight());

modifiedRaster = RasterEditors.setPixelType(testRaster, "B");

assertEquals(DataBuffer.TYPE_BYTE, modifiedRaster.getRenderedImage().getSampleModel().getDataType());
assertEquals(99, RasterUtils.getNoDataValue(modifiedRaster.getSampleDimension(0)), 0.01);
assertEquals(241, RasterUtils.getNoDataValue(modifiedRaster.getSampleDimension(3)), 0.01);
assertEquals(testRaster.getRenderedImage().getWidth(), modifiedRaster.getRenderedImage().getWidth());
assertEquals(testRaster.getRenderedImage().getHeight(), modifiedRaster.getRenderedImage().getHeight());
}

@Test
public void testSetGeoReferenceWithRaster() throws IOException {
GridCoverage2D raster = rasterFromGeoTiff(resourceFolder + "raster/test1.tiff");
Expand Down
2 changes: 1 addition & 1 deletion docs/api/sql/Raster-loader.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ Accepts one of:
3. "I" - 32 bits signed Integer
4. "S" - 16 bits signed Short
5. "US" - 16 bits unsigned Short
6. "B" - 8 bits Byte
6. "B" - 8 bits unsigned Byte
* Width: The width of the raster in pixels.
* Height: The height of the raster in pixels.
* UpperleftX: The X coordinate of the upper left corner of the raster, in terms of the CRS units.
Expand Down
29 changes: 29 additions & 0 deletions docs/api/sql/Raster-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -1767,6 +1767,35 @@ Output:
-3.000000
```

### RS_SetPixelType

Introduction: Returns a modified raster with the desired pixel data type.

The `dataType` parameter accepts one of the following strings.

- "D" - 64 bits Double
- "F" - 32 bits Float
- "I" - 32 bits signed Integer
- "S" - 16 bits signed Short
- "US" - 16 bits unsigned Short
- "B" - 8 bits unsigned Byte

!!!note
If the specified `dataType` is narrower than the original data type, the function will truncate the pixel values to fit the new data type range.

Format:
```
RS_SetPixelType(raster: Raster, dataType: String)
```

Since: `v1.6.0`

SQL Example:

```sql
RS_SetPixelType(raster, "I")
```

### RS_SetValue

Introduction: Returns a raster by replacing the value of pixel specified by `colX` and `rowY`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ object Catalog {
function[RS_SetSRID](),
function[RS_SetGeoReference](),
function[RS_SetBandNoDataValue](),
function[RS_SetPixelType](),
function[RS_SetValues](),
function[RS_SetValue](),
function[RS_SRID](),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ case class RS_SetGeoReference(inputExpressions: Seq[Expression]) extends Inferre
}
}

case class RS_SetPixelType(inputExpressions: Seq[Expression]) extends InferredExpression(RasterEditors.setPixelType _) {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
}
}

case class RS_Resample(inputExpressions: Seq[Expression]) extends InferredExpression(
nullTolerantInferrableFunction4(RasterEditors.resample), nullTolerantInferrableFunction5(RasterEditors.resample),
nullTolerantInferrableFunction7(RasterEditors.resample)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,35 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen
assert(result4.isInstanceOf[GridCoverage2D])
}

it("should pass RS_SetPixelType") {
var df = sparkSession.read.format("binaryFile").load(resourceFolder + "raster/test1.tiff")
df = df.selectExpr("RS_FromGeoTiff(content) as raster")

val df1 = df.selectExpr("RS_SetPixelType(raster, 'D') as modifiedRaster")
val result1 = df1.selectExpr("RS_BandPixelType(modifiedRaster)").first().get(0).toString
assert(result1 == "REAL_64BITS")

val df2 = df.selectExpr("RS_SetPixelType(raster, 'F') as modifiedRaster")
val result2 = df2.selectExpr("RS_BandPixelType(modifiedRaster)").first().get(0).toString
assert(result2 == "REAL_32BITS")

val df3 = df.selectExpr("RS_SetPixelType(raster, 'I') as modifiedRaster")
val result3 = df3.selectExpr("RS_BandPixelType(modifiedRaster)").first().get(0).toString
assert(result3 == "SIGNED_32BITS")

val df4 = df.selectExpr("RS_SetPixelType(raster, 'S') as modifiedRaster")
val result4 = df4.selectExpr("RS_BandPixelType(modifiedRaster)").first().get(0).toString
assert(result4 == "SIGNED_16BITS")

val df5 = df.selectExpr("RS_SetPixelType(raster, 'US') as modifiedRaster")
val result5 = df5.selectExpr("RS_BandPixelType(modifiedRaster)").first().get(0).toString
assert(result5 == "UNSIGNED_16BITS")

val df6 = df.selectExpr("RS_SetPixelType(raster, 'B') as modifiedRaster")
val result6 = df6.selectExpr("RS_BandPixelType(modifiedRaster)").first().get(0).toString
assert(result6 == "UNSIGNED_8BITS")
}

it("should pass RS_Array") {
val df = sparkSession.sql("SELECT RS_Array(6, 1e-6) as band")
val result = df.first().getAs[mutable.WrappedArray[Double]](0)
Expand Down