diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index b9cbb63376..a4efd768b4 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -22,6 +22,9 @@ on: env: MAVEN_OPTS: -Dmaven.wagon.httpconnectionManager.ttlSeconds=60 + JAI_CORE_VERSION: "1.1.3" + JAI_CODEC_VERSION: "1.1.3" + JAI_IMAGEIO_VERSION: "1.1" jobs: build: @@ -111,11 +114,15 @@ jobs: - env: SPARK_VERSION: ${{ matrix.spark }} HADOOP_VERSION: ${{ matrix.hadoop }} - run: wget https://archive.apache.org/dist/spark/spark-${SPARK_VERSION}/spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}.tgz - - env: - SPARK_VERSION: ${{ matrix.spark }} - HADOOP_VERSION: ${{ matrix.hadoop }} - run: tar -xzf spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}.tgz + run: | + wget https://archive.apache.org/dist/spark/spark-${SPARK_VERSION}/spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}.tgz + wget https://repo.osgeo.org/repository/release/javax/media/jai_core/${JAI_CORE_VERSION}/jai_core-${JAI_CORE_VERSION}.jar + wget https://repo.osgeo.org/repository/release/javax/media/jai_codec/${JAI_CODEC_VERSION}/jai_codec-${JAI_CODEC_VERSION}.jar + wget https://repo.osgeo.org/repository/release/javax/media/jai_imageio/${JAI_IMAGEIO_VERSION}/jai_imageio-${JAI_IMAGEIO_VERSION}.jar + tar -xzf spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}.tgz + mv -v jai_core-${JAI_CORE_VERSION}.jar spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}/jars/ + mv -v jai_codec-${JAI_CODEC_VERSION}.jar spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}/jars/ + mv -v jai_imageio-${JAI_IMAGEIO_VERSION}.jar spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}/jars/ - run: sudo apt-get -y install python3-pip python-dev-is-python3 - run: sudo pip3 install -U setuptools - run: sudo pip3 install -U wheel diff --git a/common/src/main/java/org/apache/sedona/common/raster/DeepCopiedRenderedImage.java b/common/src/main/java/org/apache/sedona/common/raster/DeepCopiedRenderedImage.java index 971b168ab4..0df154076f 100644 --- a/common/src/main/java/org/apache/sedona/common/raster/DeepCopiedRenderedImage.java +++ b/common/src/main/java/org/apache/sedona/common/raster/DeepCopiedRenderedImage.java @@ -13,8 +13,16 @@ */ package org.apache.sedona.common.raster; +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import com.esotericsoftware.kryo.serializers.JavaSerializer; +import com.sun.media.jai.rmi.ColorModelState; import com.sun.media.jai.util.ImageUtil; import it.geosolutions.jaiext.range.NoDataContainer; +import org.apache.sedona.common.raster.serde.AWTRasterSerializer; +import org.apache.sedona.common.raster.serde.KryoUtil; import org.apache.sedona.common.utils.RasterUtils; import javax.media.jai.JAI; @@ -48,7 +56,7 @@ * object is being disposed, it tries to connect to the remote server. However, there is no remote server in deep-copy * mode, so the dispose() method throws a java.net.SocketException. */ -public final class DeepCopiedRenderedImage implements RenderedImage, Serializable { +public final class DeepCopiedRenderedImage implements RenderedImage, Serializable, KryoSerializable { private transient RenderedImage source; private int minX; private int minY; @@ -69,7 +77,7 @@ public final class DeepCopiedRenderedImage implements RenderedImage, Serializabl private Rectangle imageBounds; private transient Raster imageRaster; - DeepCopiedRenderedImage() { + public DeepCopiedRenderedImage() { this.sampleModel = null; this.colorModel = null; this.sources = null; @@ -87,57 +95,54 @@ private DeepCopiedRenderedImage(RenderedImage source, boolean checkDataBuffer) { this.properties = null; if (source == null) { throw new IllegalArgumentException("source cannot be null"); - } else { - SampleModel sm = source.getSampleModel(); - if (sm != null && SerializerFactory.getSerializer(sm.getClass()) == null) { - throw new IllegalArgumentException("sample model object is not serializable"); - } else { - ColorModel cm = source.getColorModel(); - if (cm != null && SerializerFactory.getSerializer(cm.getClass()) == null) { - throw new IllegalArgumentException("color model object is not serializable"); - } else { - if (checkDataBuffer) { - Raster ras = source.getTile(source.getMinTileX(), source.getMinTileY()); - if (ras != null) { - DataBuffer db = ras.getDataBuffer(); - if (db != null && SerializerFactory.getSerializer(db.getClass()) == null) { - throw new IllegalArgumentException("data buffer object is not serializable"); - } - } - } - - this.source = source; - if (source instanceof RemoteImage) { - throw new IllegalArgumentException("RemoteImage is not supported"); - } - this.minX = source.getMinX(); - this.minY = source.getMinY(); - this.width = source.getWidth(); - this.height = source.getHeight(); - this.minTileX = source.getMinTileX(); - this.minTileY = source.getMinTileY(); - this.numXTiles = source.getNumXTiles(); - this.numYTiles = source.getNumYTiles(); - this.tileWidth = source.getTileWidth(); - this.tileHeight = source.getTileHeight(); - this.tileGridXOffset = source.getTileGridXOffset(); - this.tileGridYOffset = source.getTileGridYOffset(); - this.sampleModel = source.getSampleModel(); - this.colorModel = source.getColorModel(); - this.sources = new Vector<>(); - this.sources.add(source); - this.properties = new Hashtable<>(); - String[] propertyNames = source.getPropertyNames(); - if (propertyNames != null) { - for (String propertyName : propertyNames) { - this.properties.put(propertyName, source.getProperty(propertyName)); - } - } - - this.imageBounds = new Rectangle(this.minX, this.minY, this.width, this.height); + } + SampleModel sm = source.getSampleModel(); + if (sm != null && SerializerFactory.getSerializer(sm.getClass()) == null) { + throw new IllegalArgumentException("sample model object is not serializable"); + } + ColorModel cm = source.getColorModel(); + if (cm != null && SerializerFactory.getSerializer(cm.getClass()) == null) { + throw new IllegalArgumentException("color model object is not serializable"); + } + if (checkDataBuffer) { + Raster ras = source.getTile(source.getMinTileX(), source.getMinTileY()); + if (ras != null) { + DataBuffer db = ras.getDataBuffer(); + if (db != null && SerializerFactory.getSerializer(db.getClass()) == null) { + throw new IllegalArgumentException("data buffer object is not serializable"); } } } + + this.source = source; + if (source instanceof RemoteImage) { + throw new IllegalArgumentException("RemoteImage is not supported"); + } + this.minX = source.getMinX(); + this.minY = source.getMinY(); + this.width = source.getWidth(); + this.height = source.getHeight(); + this.minTileX = source.getMinTileX(); + this.minTileY = source.getMinTileY(); + this.numXTiles = source.getNumXTiles(); + this.numYTiles = source.getNumYTiles(); + this.tileWidth = source.getTileWidth(); + this.tileHeight = source.getTileHeight(); + this.tileGridXOffset = source.getTileGridXOffset(); + this.tileGridYOffset = source.getTileGridYOffset(); + this.sampleModel = source.getSampleModel(); + this.colorModel = source.getColorModel(); + this.sources = new Vector<>(); + this.sources.add(source); + this.properties = new Hashtable<>(); + String[] propertyNames = source.getPropertyNames(); + if (propertyNames != null) { + for (String propertyName : propertyNames) { + this.properties.put(propertyName, source.getProperty(propertyName)); + } + } + + this.imageBounds = new Rectangle(this.minX, this.minY, this.width, this.height); } @Override @@ -325,10 +330,54 @@ public int getWidth() { return this.width; } - @SuppressWarnings("unchecked") private void writeObject(ObjectOutputStream out) throws IOException { out.defaultWriteObject(); + Hashtable propertyTable = getSerializableProperties(); + out.writeObject(SerializerFactory.getState(this.colorModel, null)); + out.writeObject(propertyTable); + if (this.source != null) { + Raster serializedRaster = RasterUtils.getRaster(this.source); + out.writeObject(SerializerFactory.getState(serializedRaster, null)); + } else { + out.writeObject(SerializerFactory.getState(imageRaster, null)); + } + } + + @SuppressWarnings("unchecked") + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + this.source = null; + in.defaultReadObject(); + + SerializableState cmState = (SerializableState)in.readObject(); + this.colorModel = (ColorModel)cmState.getObject(); + this.properties = (Hashtable)in.readObject(); + for (String key : this.properties.keySet()) { + Object value = this.properties.get(key); + // Restore the value of GC_NODATA property as a NoDataContainer object. + if (value instanceof SingleValueNoDataContainer) { + SingleValueNoDataContainer noDataContainer = (SingleValueNoDataContainer) value; + this.properties.put(key, new NoDataContainer(noDataContainer.singleValue)); + } + } + SerializableState rasState = (SerializableState)in.readObject(); + this.imageRaster = (Raster)rasState.getObject(); + + // The deserialized rendered image contains only one tile (imageRaster). We need to update + // the sample model and tile properties to reflect this. + this.sampleModel = this.imageRaster.getSampleModel(); + this.tileWidth = this.width; + this.tileHeight = this.height; + this.numXTiles = 1; + this.numYTiles = 1; + this.minTileX = 0; + this.minTileY = 0; + this.tileGridXOffset = minX; + this.tileGridYOffset = minY; + } + + @SuppressWarnings("unchecked") + private Hashtable getSerializableProperties() { // Prepare serialize properties. non-serializable properties won't be serialized. Hashtable propertyTable = this.properties; boolean propertiesCloned = false; @@ -350,25 +399,54 @@ private void writeObject(ObjectOutputStream out) throws IOException { } } } + return propertyTable; + } - out.writeObject(SerializerFactory.getState(this.colorModel, null)); - out.writeObject(propertyTable); + public static void registerKryo(Kryo kryo) { + kryo.register(ColorModelState.class, new JavaSerializer()); + } + + private static final AWTRasterSerializer awtRasterSerializer = new AWTRasterSerializer(); + + @Override + public void write(Kryo kryo, Output output) { + // write basic properties + output.writeInt(minX); + output.writeInt(minY); + output.writeInt(width); + output.writeInt(height); + + // write properties + Hashtable propertyTable = getSerializableProperties(); + KryoUtil.writeObjectWithLength(kryo, output, propertyTable); + + // write color model + SerializableState colorModelState = SerializerFactory.getState(this.colorModel, null); + KryoUtil.writeObjectWithLength(kryo, output, colorModelState); + + // write raster + Raster serializedRaster; if (this.source != null) { - Raster serializedRaster = RasterUtils.getRaster(this.source); - out.writeObject(SerializerFactory.getState(serializedRaster, null)); + serializedRaster = RasterUtils.getRaster(this.source); } else { - out.writeObject(SerializerFactory.getState(imageRaster, null)); + serializedRaster = imageRaster; } + awtRasterSerializer.write(kryo, output, serializedRaster); } @SuppressWarnings("unchecked") - private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { - this.source = null; - in.defaultReadObject(); - - SerializableState cmState = (SerializableState)in.readObject(); - this.colorModel = (ColorModel)cmState.getObject(); - this.properties = (Hashtable)in.readObject(); + @Override + public void read(Kryo kryo, Input input) { + // read basic properties + minX = input.readInt(); + minY = input.readInt(); + width = input.readInt(); + height = input.readInt(); + imageBounds = new Rectangle(minX, minY, width, height); + + // read properties + input.readInt(); // skip the length of the property table + properties = kryo.readObject(input, Hashtable.class); for (String key : this.properties.keySet()) { Object value = this.properties.get(key); // Restore the value of GC_NODATA property as a NoDataContainer object. @@ -377,8 +455,14 @@ private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundE this.properties.put(key, new NoDataContainer(noDataContainer.singleValue)); } } - SerializableState rasState = (SerializableState)in.readObject(); - this.imageRaster = (Raster)rasState.getObject(); + + // read color model + input.readInt(); // skip the length of the color model state + ColorModelState cmState = kryo.readObject(input, ColorModelState.class); + this.colorModel = (ColorModel) cmState.getObject(); + + // read raster + this.imageRaster = awtRasterSerializer.read(kryo, input, Raster.class); // The deserialized rendered image contains only one tile (imageRaster). We need to update // the sample model and tile properties to reflect this. @@ -387,6 +471,10 @@ private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundE this.tileHeight = this.height; this.numXTiles = 1; this.numYTiles = 1; + this.minTileX = 0; + this.minTileY = 0; + this.tileGridXOffset = minX; + this.tileGridYOffset = minY; } /** diff --git a/common/src/main/java/org/apache/sedona/common/raster/RasterConstructorsForTesting.java b/common/src/main/java/org/apache/sedona/common/raster/RasterConstructorsForTesting.java new file mode 100644 index 0000000000..47667b9b80 --- /dev/null +++ b/common/src/main/java/org/apache/sedona/common/raster/RasterConstructorsForTesting.java @@ -0,0 +1,199 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.sedona.common.raster; + +import com.sun.media.imageioimpl.common.BogusColorSpace; +import org.apache.commons.lang3.ArrayUtils; +import org.apache.sedona.common.FunctionsGeoTools; +import org.apache.sedona.common.utils.RasterUtils; +import org.geotools.coverage.CoverageFactoryFinder; +import org.geotools.coverage.grid.GridCoverage2D; +import org.geotools.coverage.grid.GridCoverageFactory; +import org.geotools.coverage.grid.GridEnvelope2D; +import org.geotools.coverage.grid.GridGeometry2D; +import org.geotools.referencing.crs.DefaultEngineeringCRS; +import org.geotools.referencing.operation.transform.AffineTransform2D; +import org.opengis.referencing.crs.CoordinateReferenceSystem; +import org.opengis.referencing.datum.PixelInCell; +import org.opengis.referencing.operation.MathTransform; + +import javax.media.jai.RasterFactory; +import java.awt.Transparency; +import java.awt.color.ColorSpace; +import java.awt.image.BufferedImage; +import java.awt.image.ColorModel; +import java.awt.image.ComponentColorModel; +import java.awt.image.ComponentSampleModel; +import java.awt.image.DataBuffer; +import java.awt.image.DirectColorModel; +import java.awt.image.IndexColorModel; +import java.awt.image.MultiPixelPackedSampleModel; +import java.awt.image.PixelInterleavedSampleModel; +import java.awt.image.RenderedImage; +import java.awt.image.SampleModel; +import java.awt.image.SinglePixelPackedSampleModel; +import java.awt.image.WritableRaster; +import java.util.Arrays; + +/** + * Raster constructor for testing the Python implementation of raster deserializer. + */ +public class RasterConstructorsForTesting { + private RasterConstructorsForTesting() {} + + public static GridCoverage2D makeRasterForTesting( + int numBand, String bandDataType, String sampleModelType, + int widthInPixel, int heightInPixel, double upperLeftX, double upperLeftY, + double scaleX, double scaleY, double skewX, double skewY, + int srid) { + CoordinateReferenceSystem crs; + if (srid == 0) { + crs = DefaultEngineeringCRS.GENERIC_2D; + } else { + // Create the CRS from the srid + // Longitude first, Latitude second + crs = FunctionsGeoTools.sridToCRS(srid); + } + + // Create a new raster with certain pixel values + WritableRaster raster = createRasterWithSampleModel(sampleModelType, bandDataType, widthInPixel, heightInPixel, numBand); + for (int k = 0; k < numBand; k++) { + for (int y = 0; y < heightInPixel; y++) { + for (int x = 0; x < widthInPixel; x++) { + double value = k + y * widthInPixel + x; + raster.setSample(x, y, k, value); + } + } + } + + MathTransform transform = new AffineTransform2D(scaleX, skewY, skewX, scaleY, upperLeftX, upperLeftY); + GridGeometry2D gridGeometry = new GridGeometry2D( + new GridEnvelope2D(0, 0, widthInPixel, heightInPixel), + PixelInCell.CELL_CORNER, + transform, crs, null); + + int rasterDataType = raster.getDataBuffer().getDataType(); + ColorModel colorModel; + if (!sampleModelType.contains("Packed")) { + final ColorSpace cs = new BogusColorSpace(numBand); + final int[] nBits = new int[numBand]; + Arrays.fill(nBits, DataBuffer.getDataTypeSize(rasterDataType)); + colorModel = + new ComponentColorModel(cs, nBits, false, true, Transparency.OPAQUE, rasterDataType); + } else if (sampleModelType.equals("SinglePixelPackedSampleModel")) { + colorModel = new DirectColorModel(32, + 0x0F, + (0x0F) << 4, + (0x0F) << 8, + (0x0F) << 12); + } else if (sampleModelType.equals("MultiPixelPackedSampleModel")) { + byte[] arr = new byte[16]; + for (int k = 0; k < 16; k++) { + arr[k] = (byte) (k * 16); + } + colorModel = new IndexColorModel(4, arr.length, arr, arr, arr); + } else { + throw new IllegalArgumentException("Unknown sample model type: " + sampleModelType); + } + + final RenderedImage image = new BufferedImage(colorModel, raster, false, null); + GridCoverageFactory gridCoverageFactory = CoverageFactoryFinder.getGridCoverageFactory(null); + return gridCoverageFactory.create("genericCoverage", image, gridGeometry, null, null, null); + } + + private static WritableRaster createRasterWithSampleModel(String sampleModelType, String bandDataType, int widthInPixel, int heightInPixel, int numBand) { + int dataType = RasterUtils.getDataTypeCode(bandDataType); + + // Create raster according to sample model type + WritableRaster raster; + switch (sampleModelType) { + case "BandedSampleModel": + raster = RasterFactory.createBandedRaster(dataType, widthInPixel, heightInPixel, numBand, null); + break; + case "PixelInterleavedSampleModel": { + int scanlineStride = widthInPixel * numBand; + int[] bandOffsets = new int[numBand]; + for (int i = 0; i < numBand; i++) { + bandOffsets[i] = i; + } + SampleModel sm = new PixelInterleavedSampleModel(dataType, widthInPixel, heightInPixel, numBand, scanlineStride, bandOffsets); + raster = RasterFactory.createWritableRaster(sm, null); + break; + } + case "PixelInterleavedSampleModelComplex": { + int pixelStride = numBand + 2; + int scanlineStride = widthInPixel * pixelStride + 5; + int[] bandOffsets = new int[numBand]; + for (int i = 0; i < numBand; i++) { + bandOffsets[i] = i; + } + ArrayUtils.shuffle(bandOffsets); + SampleModel sm = new PixelInterleavedSampleModel(dataType, widthInPixel, heightInPixel, pixelStride, scanlineStride, bandOffsets); + raster = RasterFactory.createWritableRaster(sm, null); + break; + } + case "ComponentSampleModel": { + int pixelStride = numBand + 1; + int scanlineStride = widthInPixel * pixelStride + 5; + int[] bankIndices = new int[numBand]; + for (int i = 0; i < numBand; i++) { + bankIndices[i] = i; + } + ArrayUtils.shuffle(bankIndices); + int[] bandOffsets = new int[numBand]; + for (int i = 0; i < numBand; i++) { + bandOffsets[i] = (int)(Math.random() * widthInPixel); + } + SampleModel sm = new ComponentSampleModel(dataType, widthInPixel, heightInPixel, pixelStride, scanlineStride, bankIndices, bandOffsets); + raster = RasterFactory.createWritableRaster(sm, null); + break; + } + case "SinglePixelPackedSampleModel": { + if (dataType != DataBuffer.TYPE_INT) { + throw new IllegalArgumentException("only supports creating SinglePixelPackedSampleModel with int data type"); + } + if (numBand != 4) { + throw new IllegalArgumentException("only supports creating SinglePixelPackedSampleModel with 4 bands"); + } + int bitsPerBand = 4; + int scanlineStride = widthInPixel + 5; + int[] bitMasks = new int[numBand]; + int baseMask = (1 << bitsPerBand) - 1; + for (int i = 0; i < numBand; i++) { + bitMasks[i] = baseMask << (i * bitsPerBand); + } + SampleModel sm = new SinglePixelPackedSampleModel(dataType, widthInPixel, heightInPixel, scanlineStride, bitMasks); + raster = RasterFactory.createWritableRaster(sm, null); + break; + } + case "MultiPixelPackedSampleModel": { + if (dataType != DataBuffer.TYPE_BYTE) { + throw new IllegalArgumentException("only supports creating MultiPixelPackedSampleModel with byte data type"); + } + if (numBand != 1) { + throw new IllegalArgumentException("only supports creating MultiPixelPackedSampleModel with 1 band"); + } + int numberOfBits = 4; + int scanlineStride = widthInPixel * numberOfBits / 8 + 2; + SampleModel sm = new MultiPixelPackedSampleModel(dataType, widthInPixel, heightInPixel, numberOfBits, scanlineStride, 80); + raster = RasterFactory.createWritableRaster(sm, null); + break; + } + default: + throw new IllegalArgumentException("Unknown sample model type: " + sampleModelType); + } + + return raster; + } +} diff --git a/common/src/main/java/org/apache/sedona/common/raster/Serde.java b/common/src/main/java/org/apache/sedona/common/raster/Serde.java deleted file mode 100644 index 7f34d840d8..0000000000 --- a/common/src/main/java/org/apache/sedona/common/raster/Serde.java +++ /dev/null @@ -1,103 +0,0 @@ -/** - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - *

- * http://www.apache.org/licenses/LICENSE-2.0 - *

- * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.sedona.common.raster; - -import org.geotools.coverage.GridSampleDimension; -import org.geotools.coverage.grid.GridCoverage2D; -import org.geotools.coverage.grid.GridCoverageFactory; -import org.geotools.coverage.grid.GridEnvelope2D; -import org.geotools.coverage.grid.GridGeometry2D; -import org.opengis.referencing.operation.MathTransform; - -import javax.media.jai.RenderedImageAdapter; -import java.awt.image.RenderedImage; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; -import java.io.Serializable; -import java.lang.reflect.Field; - -public class Serde { - - static final Field field; - - static { - try { - field = GridCoverage2D.class.getDeclaredField("serializedImage"); - field.setAccessible(true); - } catch (NoSuchFieldException e) { - throw new RuntimeException(e); - } - } - - private static class SerializableState implements Serializable { - public CharSequence name; - - // The following three components are used to construct a GridGeometry2D object. - // We serialize CRS separately because the default serializer is pretty slow, we use a - // cached serializer to speed up the serialization and reuse CRS on deserialization. - public GridEnvelope2D gridEnvelope2D; - public MathTransform gridToCRS; - public byte[] serializedCRS; - - public GridSampleDimension[] bands; - public DeepCopiedRenderedImage image; - - public GridCoverage2D restore() { - GridGeometry2D gridGeometry = new GridGeometry2D(gridEnvelope2D, gridToCRS, CRSSerializer.deserialize(serializedCRS)); - return new GridCoverageFactory().create(name, image, gridGeometry, bands, null, null); - } - } - - public static byte[] serialize(GridCoverage2D raster) throws IOException { - // GridCoverage2D created by GridCoverage2DReaders contain references that are not serializable. - // Wrap the RenderedImage in DeepCopiedRenderedImage to make it serializable. - DeepCopiedRenderedImage deepCopiedRenderedImage = null; - RenderedImage renderedImage = raster.getRenderedImage(); - while (renderedImage instanceof RenderedImageAdapter) { - renderedImage = ((RenderedImageAdapter) renderedImage).getWrappedImage(); - } - if (renderedImage instanceof DeepCopiedRenderedImage) { - deepCopiedRenderedImage = (DeepCopiedRenderedImage) renderedImage; - } else { - deepCopiedRenderedImage = new DeepCopiedRenderedImage(renderedImage); - } - - SerializableState state = new SerializableState(); - GridGeometry2D gridGeometry = raster.getGridGeometry(); - state.name = raster.getName(); - state.gridEnvelope2D = gridGeometry.getGridRange2D(); - state.gridToCRS = gridGeometry.getGridToCRS2D(); - state.serializedCRS = CRSSerializer.serialize(gridGeometry.getCoordinateReferenceSystem()); - state.bands = raster.getSampleDimensions(); - state.image = deepCopiedRenderedImage; - try (ByteArrayOutputStream bos = new ByteArrayOutputStream()) { - try (ObjectOutputStream oos = new ObjectOutputStream(bos)) { - oos.writeObject(state); - return bos.toByteArray(); - } - } - } - - public static GridCoverage2D deserialize(byte[] bytes) throws IOException, ClassNotFoundException { - try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes)) { - try (ObjectInputStream ois = new ObjectInputStream(bis)) { - SerializableState state = (SerializableState) ois.readObject(); - return state.restore(); - } - } - } -} diff --git a/common/src/main/java/org/apache/sedona/common/raster/serde/AWTRasterSerializer.java b/common/src/main/java/org/apache/sedona/common/raster/serde/AWTRasterSerializer.java new file mode 100644 index 0000000000..5de5744c6e --- /dev/null +++ b/common/src/main/java/org/apache/sedona/common/raster/serde/AWTRasterSerializer.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.common.raster.serde; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + +import java.awt.Point; +import java.awt.image.DataBuffer; +import java.awt.image.Raster; +import java.awt.image.SampleModel; +import java.awt.image.WritableRaster; + +public class AWTRasterSerializer extends Serializer { + private static final SampleModelSerializer sampleModelSerializer = new SampleModelSerializer(); + private static final DataBufferSerializer dataBufferSerializer = new DataBufferSerializer(); + + @Override + public void write(Kryo kryo, Output output, Raster raster) { + Raster r; + if (raster.getParent() != null) { + r = raster.createCompatibleWritableRaster(raster.getBounds()); + ((WritableRaster) r).setRect(raster); + } else { + r = raster; + } + + output.writeInt(r.getMinX()); + output.writeInt(r.getMinY()); + sampleModelSerializer.write(kryo, output, r.getSampleModel()); + dataBufferSerializer.write(kryo, output, r.getDataBuffer()); + } + + @Override + public Raster read(Kryo kryo, Input input, Class type) { + int minX = input.readInt(); + int minY = input.readInt(); + Point location = new Point(minX, minY); + SampleModel sampleModel = sampleModelSerializer.read(kryo, input, SampleModel.class); + DataBuffer dataBuffer = dataBufferSerializer.read(kryo, input, DataBuffer.class); + return Raster.createRaster(sampleModel, dataBuffer, location); + } +} diff --git a/common/src/main/java/org/apache/sedona/common/raster/serde/AffineTransform2DSerializer.java b/common/src/main/java/org/apache/sedona/common/raster/serde/AffineTransform2DSerializer.java new file mode 100644 index 0000000000..ed0a2d5763 --- /dev/null +++ b/common/src/main/java/org/apache/sedona/common/raster/serde/AffineTransform2DSerializer.java @@ -0,0 +1,47 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.sedona.common.raster.serde; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import org.geotools.referencing.operation.transform.AffineTransform2D; + +/** + * AffineTransform2D cannot be correctly deserialized by the default serializer of Kryo, so we need to provide a + * custom serializer. + */ +public class AffineTransform2DSerializer extends Serializer { + @Override + public void write(Kryo kryo, Output output, AffineTransform2D affineTransform2D) { + output.writeDouble(affineTransform2D.getScaleX()); + output.writeDouble(affineTransform2D.getShearY()); + output.writeDouble(affineTransform2D.getShearX()); + output.writeDouble(affineTransform2D.getScaleY()); + output.writeDouble(affineTransform2D.getTranslateX()); + output.writeDouble(affineTransform2D.getTranslateY()); + } + + @Override + public AffineTransform2D read(Kryo kryo, Input input, Class aClass) { + double scaleX = input.readDouble(); + double skewY = input.readDouble(); + double skewX = input.readDouble(); + double scaleY = input.readDouble(); + double upperLeftX = input.readDouble(); + double upperLeftY = input.readDouble(); + return new AffineTransform2D(scaleX, skewY, skewX, scaleY, upperLeftX, upperLeftY); + } +} diff --git a/common/src/main/java/org/apache/sedona/common/raster/CRSSerializer.java b/common/src/main/java/org/apache/sedona/common/raster/serde/CRSSerializer.java similarity index 70% rename from common/src/main/java/org/apache/sedona/common/raster/CRSSerializer.java rename to common/src/main/java/org/apache/sedona/common/raster/serde/CRSSerializer.java index 2161310100..9c427e34fe 100644 --- a/common/src/main/java/org/apache/sedona/common/raster/CRSSerializer.java +++ b/common/src/main/java/org/apache/sedona/common/raster/serde/CRSSerializer.java @@ -16,19 +16,24 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.sedona.common.raster; +package org.apache.sedona.common.raster.serde; import com.github.benmanes.caffeine.cache.Caffeine; import com.github.benmanes.caffeine.cache.LoadingCache; +import org.apache.commons.io.IOUtils; import org.geotools.referencing.CRS; +import org.geotools.referencing.wkt.Formattable; +import org.opengis.referencing.FactoryException; import org.opengis.referencing.crs.CoordinateReferenceSystem; +import si.uom.NonSI; +import si.uom.SI; +import javax.measure.IncommensurableException; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.util.zip.DeflaterOutputStream; import java.util.zip.InflaterInputStream; @@ -42,6 +47,20 @@ public class CRSSerializer { private CRSSerializer() {} + static { + try { + // HACK: This is for warming up the piCache in tech.units.indriya.function.Calculus. + // Otherwise, concurrent calls to CoordinateReferenceSystem.toWKT() will cause a + // ConcurrentModificationException. This is a bug of tech.units.indriya, which was fixed + // in 2.1.4 by https://github.com/unitsofmeasurement/indriya/commit/fc370465 + // However, 2.1.4 is not compatible with the GeoTools version we use. That's the reason + // why we have this workaround here. + NonSI.DEGREE_ANGLE.getConverterToAny(SI.RADIAN).convert(1); + } catch (IncommensurableException e) { + throw new RuntimeException(e); + } + } + private static class CRSKey { private final CoordinateReferenceSystem crs; private final int hashCode; @@ -85,10 +104,15 @@ public static CoordinateReferenceSystem deserialize(byte[] bytes) { private static byte[] doSerializeCRS(CRSKey crsKey) throws IOException { CoordinateReferenceSystem crs = crsKey.crs; try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); - DeflaterOutputStream dos = new DeflaterOutputStream(bos); - ObjectOutputStream oos = new ObjectOutputStream(dos)) { - oos.writeObject(crs); - oos.flush(); + DeflaterOutputStream dos = new DeflaterOutputStream(bos)) { + String wktString; + if (crs instanceof Formattable) { + // Can specify "strict" as false to get rid of serialization errors in trade of correctness + wktString = ((Formattable) crs).toWKT(2, false); + } else { + wktString = crs.toWKT(); + } + dos.write(wktString.getBytes(StandardCharsets.UTF_8)); dos.finish(); byte[] res = bos.toByteArray(); crsDeserializationCache.put(ByteBuffer.wrap(res), crs); @@ -96,12 +120,13 @@ private static byte[] doSerializeCRS(CRSKey crsKey) throws IOException { } } - private static CoordinateReferenceSystem doDeserializeCRS(ByteBuffer byteBuffer) throws IOException, ClassNotFoundException { + private static CoordinateReferenceSystem doDeserializeCRS(ByteBuffer byteBuffer) throws IOException, FactoryException { byte[] bytes = byteBuffer.array(); try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes); - InflaterInputStream dis = new InflaterInputStream(bis); - ObjectInputStream ois = new ObjectInputStream(dis)) { - CoordinateReferenceSystem crs = (CoordinateReferenceSystem) ois.readObject(); + InflaterInputStream dis = new InflaterInputStream(bis)) { + byte[] wktBytes = IOUtils.toByteArray(dis); + String wktString = new String(wktBytes, StandardCharsets.UTF_8); + CoordinateReferenceSystem crs = CRS.parseWKT(wktString); crsSerializationCache.put(new CRSKey(crs), bytes); return crs; } diff --git a/common/src/main/java/org/apache/sedona/common/raster/serde/DataBufferSerializer.java b/common/src/main/java/org/apache/sedona/common/raster/serde/DataBufferSerializer.java new file mode 100644 index 0000000000..4d886ab437 --- /dev/null +++ b/common/src/main/java/org/apache/sedona/common/raster/serde/DataBufferSerializer.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.common.raster.serde; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import com.sun.media.jai.util.DataBufferUtils; + +import java.awt.image.DataBuffer; +import java.awt.image.DataBufferByte; +import java.awt.image.DataBufferInt; +import java.awt.image.DataBufferShort; +import java.awt.image.DataBufferUShort; + +public class DataBufferSerializer extends Serializer { + @Override + public void write(Kryo kryo, Output output, DataBuffer dataBuffer) { + int dataType = dataBuffer.getDataType(); + output.writeInt(dataType); + KryoUtil.writeIntArray(output, dataBuffer.getOffsets()); + output.writeInt(dataBuffer.getSize()); + switch (dataType) { + case DataBuffer.TYPE_BYTE: + byte[][] byteDataArray = ((DataBufferByte) dataBuffer).getBankData(); + KryoUtil.writeByteArrays(output, byteDataArray); + break; + case DataBuffer.TYPE_USHORT: + short[][] uShortDataArray = ((DataBufferUShort) dataBuffer).getBankData(); + KryoUtil.writeShortArrays(output, uShortDataArray); + break; + case DataBuffer.TYPE_SHORT: + short[][] shortDataArray = ((DataBufferShort) dataBuffer).getBankData(); + KryoUtil.writeShortArrays(output, shortDataArray); + break; + case DataBuffer.TYPE_INT: + int[][] intDataArray = ((DataBufferInt) dataBuffer).getBankData(); + KryoUtil.writeIntArrays(output, intDataArray); + break; + case DataBuffer.TYPE_FLOAT: + float[][] floatDataArray = DataBufferUtils.getBankDataFloat(dataBuffer); + KryoUtil.writeFloatArrays(output, floatDataArray); + break; + case DataBuffer.TYPE_DOUBLE: + double[][] doubleDataArray = DataBufferUtils.getBankDataDouble(dataBuffer); + KryoUtil.writeDoubleArrays(output, doubleDataArray); + break; + default: + throw new RuntimeException("Unknown data type: " + dataType); + } + } + + @Override + public DataBuffer read(Kryo kryo, Input input, Class type) { + int dataType = input.readInt(); + int[] offsets = KryoUtil.readIntArray(input); + int size = input.readInt(); + DataBuffer dataBuffer; + switch (dataType) { + case DataBuffer.TYPE_BYTE: + byte[][] byteDataArray = KryoUtil.readByteArrays(input); + dataBuffer = new DataBufferByte(byteDataArray, size, offsets); + break; + case DataBuffer.TYPE_USHORT: + short[][] uShortDataArray = KryoUtil.readShortArrays(input); + dataBuffer = new DataBufferUShort(uShortDataArray, size, offsets); + break; + case DataBuffer.TYPE_SHORT: + short[][] shortDataArray = KryoUtil.readShortArrays(input); + dataBuffer = new DataBufferShort(shortDataArray, size, offsets); + break; + case DataBuffer.TYPE_INT: + int[][] intDataArray = KryoUtil.readIntArrays(input); + dataBuffer = new DataBufferInt(intDataArray, size, offsets); + break; + case DataBuffer.TYPE_FLOAT: + float[][] floatDataArray = KryoUtil.readFloatArrays(input); + dataBuffer = DataBufferUtils.createDataBufferFloat(floatDataArray, size, offsets); + break; + case DataBuffer.TYPE_DOUBLE: + double[][] doubleDataArray = KryoUtil.readDoubleArrays(input); + dataBuffer = DataBufferUtils.createDataBufferDouble(doubleDataArray, size, offsets); + break; + default: + throw new RuntimeException("Unknown data type: " + dataType); + } + return dataBuffer; + } +} diff --git a/common/src/main/java/org/apache/sedona/common/raster/serde/GridEnvelopeSerializer.java b/common/src/main/java/org/apache/sedona/common/raster/serde/GridEnvelopeSerializer.java new file mode 100644 index 0000000000..024d8c8bc3 --- /dev/null +++ b/common/src/main/java/org/apache/sedona/common/raster/serde/GridEnvelopeSerializer.java @@ -0,0 +1,39 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.sedona.common.raster.serde; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import org.geotools.coverage.grid.GridEnvelope2D; + +public class GridEnvelopeSerializer extends Serializer { + @Override + public void write(Kryo kryo, Output output, GridEnvelope2D object) { + output.writeInt(object.width); + output.writeInt(object.height); + output.writeInt(object.x); + output.writeInt(object.y); + } + + @Override + public GridEnvelope2D read(Kryo kryo, Input input, Class type) { + int width = input.readInt(); + int height = input.readInt(); + int x = input.readInt(); + int y = input.readInt(); + return new GridEnvelope2D(x, y, width, height); + } +} diff --git a/common/src/main/java/org/apache/sedona/common/raster/serde/GridSampleDimensionSerializer.java b/common/src/main/java/org/apache/sedona/common/raster/serde/GridSampleDimensionSerializer.java new file mode 100644 index 0000000000..f4ca504b58 --- /dev/null +++ b/common/src/main/java/org/apache/sedona/common/raster/serde/GridSampleDimensionSerializer.java @@ -0,0 +1,54 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.sedona.common.raster.serde; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import org.apache.sedona.common.utils.RasterUtils; +import org.geotools.coverage.Category; +import org.geotools.coverage.GridSampleDimension; + +import java.util.List; + +/** + * GridSampleDimension and RenderedSampleDimension are not serializable. We need to provide a custom serializer + */ +public class GridSampleDimensionSerializer extends Serializer { + @Override + public void write(Kryo kryo, Output output, GridSampleDimension sampleDimension) { + String description = sampleDimension.getDescription().toString(); + List categories = sampleDimension.getCategories(); + double offset = sampleDimension.getOffset(); + double scale = sampleDimension.getScale(); + double noDataValue = RasterUtils.getNoDataValue(sampleDimension); + KryoUtil.writeUTF8String(output, description); + output.writeDouble(offset); + output.writeDouble(scale); + output.writeDouble(noDataValue); // for interoperability with Python RasterType. + KryoUtil.writeObjectWithLength(kryo, output, categories.toArray()); + } + + @Override + public GridSampleDimension read(Kryo kryo, Input input, Class aClass) { + String description = KryoUtil.readUTF8String(input); + double offset = input.readDouble(); + double scale = input.readDouble(); + input.readDouble(); // noDataValue is included in categories, so we just skip it + input.readInt(); // skip the length of the next object + Category[] categories = kryo.readObject(input, Category[].class); + return new GridSampleDimension(description, categories, offset, scale); + } +} diff --git a/common/src/main/java/org/apache/sedona/common/raster/serde/KryoUtil.java b/common/src/main/java/org/apache/sedona/common/raster/serde/KryoUtil.java new file mode 100644 index 0000000000..8292e7b550 --- /dev/null +++ b/common/src/main/java/org/apache/sedona/common/raster/serde/KryoUtil.java @@ -0,0 +1,297 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.sedona.common.raster.serde; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; + +/** + * Utility methods for serializing objects with Kryo. The serialization formats are well-defined and independent + * of the Kryo version. This allows us to exchange serialized data with other tech stack, such as Python. + */ +public class KryoUtil { + private KryoUtil() {} + + /** + * Write the length of the next serialized object, followed by the serialized object + * @param kryo the kryo instance + * @param output the output stream + * @param object the object to serialize + */ + public static void writeObjectWithLength(Kryo kryo, Output output, Object object) { + int lengthOffset = output.position(); + output.writeInt(0); // placeholder, will be overwritten later + + // Write the object + int start = output.position(); + kryo.writeObject(output, object); + int end = output.position(); + + // Rewrite the length + int length = end - start; + output.setPosition(lengthOffset); + output.writeInt(length); + output.setPosition(end); + } + + /** + * Write string as UTF-8 byte sequence + * @param output the output stream + * @param value the string to write + */ + public static void writeUTF8String(Output output, String value) { + byte[] utf8 = value.getBytes(StandardCharsets.UTF_8); + output.writeInt(utf8.length); + output.writeBytes(utf8); + } + + /** + * Read UTF-8 byte sequence as string + * @param input the input stream + * @return the string + */ + public static String readUTF8String(Input input) { + int length = input.readInt(); + byte[] utf8 = new byte[length]; + input.readBytes(utf8); + return new String(utf8, StandardCharsets.UTF_8); + } + + /** + * Write an array of integers + * @param output the output stream + * @param array the array to write + */ + public static void writeIntArray(Output output, int[] array) { + output.writeInt(array.length); + output.writeInts(array); + } + + /** + * Read an array of integers + * @param input the input stream + * @return the array + */ + public static int[] readIntArray(Input input) { + int length = input.readInt(); + return input.readInts(length); + } + + /** + * Write a 2-d array of ints + * @param output the output stream + * @param arrays the array to write + */ + public static void writeIntArrays(Output output, int[][] arrays) { + output.writeInt(arrays.length); + for (int[] array : arrays) { + writeIntArray(output, array); + } + } + + /** + * Read a 2-d array of ints + * @param input the input stream + * @return the array + */ + public static int[][] readIntArrays(Input input) { + int length = input.readInt(); + int[][] arrays = new int[length][]; + for (int i = 0; i < length; i++) { + arrays[i] = readIntArray(input); + } + return arrays; + } + + /** + * Write a 2-d array of bytes + * @param output the output stream + * @param arrays the array to write + */ + public static void writeByteArrays(Output output, byte[][] arrays) { + output.writeInt(arrays.length); + for (byte[] array : arrays) { + output.writeInt(array.length); + output.writeBytes(array); + } + } + + /** + * Read a 2-d array of bytes + * @param input the input stream + * @return the array + */ + public static byte[][] readByteArrays(Input input) { + int length = input.readInt(); + byte[][] arrays = new byte[length][]; + for (int i = 0; i < length; i++) { + int arrayLength = input.readInt(); + arrays[i] = input.readBytes(arrayLength); + } + return arrays; + } + + /** + * Write a 2-d array of doubles + * @param output the output stream + * @param arrays the array to write + */ + public static void writeDoubleArrays(Output output, double[][] arrays) { + output.writeInt(arrays.length); + for (double[] array : arrays) { + output.writeInt(array.length); + output.writeDoubles(array); + } + } + + /** + * Read a 2-d array of doubles + * @param input the input stream + * @return the array + */ + public static double[][] readDoubleArrays(Input input) { + int length = input.readInt(); + double[][] arrays = new double[length][]; + for (int i = 0; i < length; i++) { + int arrayLength = input.readInt(); + arrays[i] = input.readDoubles(arrayLength); + } + return arrays; + } + + /** + * Write a 2-d array of longs + * @param output the output stream + * @param arrays the array to write + */ + public static void writeLongArrays(Output output, long[][] arrays) { + output.writeInt(arrays.length); + for (long[] array : arrays) { + output.writeInt(array.length); + output.writeLongs(array); + } + } + + /** + * Read a 2-d array of longs + * @param input the input stream + * @return the array + */ + public static long[][] readLongArrays(Input input) { + int length = input.readInt(); + long[][] arrays = new long[length][]; + for (int i = 0; i < length; i++) { + int arrayLength = input.readInt(); + arrays[i] = input.readLongs(arrayLength); + } + return arrays; + } + + /** + * Write a 2-d array of floats + * @param output the output stream + * @param arrays the array to write + */ + public static void writeFloatArrays(Output output, float[][] arrays) { + output.writeInt(arrays.length); + for (float[] array : arrays) { + output.writeInt(array.length); + output.writeFloats(array); + } + } + + /** + * Read a 2-d array of floats + * @param input the input stream + * @return the array + */ + public static float[][] readFloatArrays(Input input) { + int length = input.readInt(); + float[][] arrays = new float[length][]; + for (int i = 0; i < length; i++) { + int arrayLength = input.readInt(); + arrays[i] = input.readFloats(arrayLength); + } + return arrays; + } + + /** + * Write a 2-d array of shorts + * @param output the output stream + * @param arrays the array to write + */ + public static void writeShortArrays(Output output, short[][] arrays) { + output.writeInt(arrays.length); + for (short[] array : arrays) { + output.writeInt(array.length); + output.writeShorts(array); + } + } + + /** + * Read a 2-d array of shorts + * @param input the input stream + * @return the array + */ + public static short[][] readShortArrays(Input input) { + int length = input.readInt(); + short[][] arrays = new short[length][]; + for (int i = 0; i < length; i++) { + int arrayLength = input.readInt(); + arrays[i] = input.readShorts(arrayLength); + } + return arrays; + } + + /** + * Write a {@code Map} object to the output stream + * @param output the output stream + * @param map the map to write + */ + public static void writeUTF8StringMap(Output output, Map map) { + if (map == null) { + output.writeInt(-1); + return; + } + output.writeInt(map.size()); + for (Map.Entry entry : map.entrySet()) { + writeUTF8String(output, entry.getKey()); + writeUTF8String(output, entry.getValue()); + } + } + + /** + * Read a {@code Map} object from the input stream + * @param input the input stream + * @return the map + */ + public static Map readUTF8StringMap(Input input) { + int size = input.readInt(); + if (size == -1) { + return null; + } + Map params = new HashMap<>(size); + for (int i = 0; i < size; i++) { + String key = readUTF8String(input); + String value = readUTF8String(input); + params.put(key, value); + } + return params; + } +} diff --git a/common/src/main/java/org/apache/sedona/common/raster/serde/SampleModelSerializer.java b/common/src/main/java/org/apache/sedona/common/raster/serde/SampleModelSerializer.java new file mode 100644 index 0000000000..1dce0b9beb --- /dev/null +++ b/common/src/main/java/org/apache/sedona/common/raster/serde/SampleModelSerializer.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.common.raster.serde; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + +import javax.media.jai.ComponentSampleModelJAI; +import javax.media.jai.RasterFactory; +import java.awt.image.BandedSampleModel; +import java.awt.image.ComponentSampleModel; +import java.awt.image.MultiPixelPackedSampleModel; +import java.awt.image.PixelInterleavedSampleModel; +import java.awt.image.SampleModel; +import java.awt.image.SinglePixelPackedSampleModel; + +/** + * Serializer for SampleModelState using Kryo. This is translated from the original JAI implementation + * of writeObject and readObject. + */ +public class SampleModelSerializer extends Serializer { + + // These constants are taken from SampleModelState + private static final int TYPE_BANDED = 1; + private static final int TYPE_PIXEL_INTERLEAVED = 2; + private static final int TYPE_SINGLE_PIXEL_PACKED = 3; + private static final int TYPE_MULTI_PIXEL_PACKED = 4; + private static final int TYPE_COMPONENT_JAI = 5; + private static final int TYPE_COMPONENT = 6; + + private static int sampleModelTypeOf(SampleModel sampleModel) { + if (sampleModel instanceof ComponentSampleModel) { + if (sampleModel instanceof PixelInterleavedSampleModel) { + return TYPE_PIXEL_INTERLEAVED; + } else if (sampleModel instanceof BandedSampleModel) { + return TYPE_BANDED; + } else if (sampleModel instanceof ComponentSampleModelJAI) { + return TYPE_COMPONENT_JAI; + } else { + return TYPE_COMPONENT; + } + } else if (sampleModel instanceof SinglePixelPackedSampleModel) { + return TYPE_SINGLE_PIXEL_PACKED; + } else if (sampleModel instanceof MultiPixelPackedSampleModel) { + return TYPE_MULTI_PIXEL_PACKED; + } else { + throw new UnsupportedOperationException("Unsupported SampleModel type: " + sampleModel.getClass().getName()); + } + } + + @Override + public void write(Kryo kryo, Output output, SampleModel sampleModel) { + int sampleModelType = sampleModelTypeOf(sampleModel); + output.writeInt(sampleModelType); + output.writeInt(sampleModel.getTransferType()); + output.writeInt(sampleModel.getWidth()); + output.writeInt(sampleModel.getHeight()); + + switch (sampleModelType) { + case TYPE_BANDED: { + BandedSampleModel sm = (BandedSampleModel)sampleModel; + KryoUtil.writeIntArray(output, sm.getBankIndices()); + KryoUtil.writeIntArray(output, sm.getBandOffsets()); + break; + } + + case TYPE_PIXEL_INTERLEAVED: { + PixelInterleavedSampleModel sm = (PixelInterleavedSampleModel)sampleModel; + output.writeInt(sm.getPixelStride()); + output.writeInt(sm.getScanlineStride()); + KryoUtil.writeIntArray(output, sm.getBandOffsets()); + break; + } + + case TYPE_COMPONENT: + case TYPE_COMPONENT_JAI: { + ComponentSampleModel sm = (ComponentSampleModel)sampleModel; + output.writeInt(sm.getPixelStride()); + output.writeInt(sm.getScanlineStride()); + KryoUtil.writeIntArray(output, sm.getBankIndices()); + KryoUtil.writeIntArray(output, sm.getBandOffsets()); + break; + } + + case TYPE_SINGLE_PIXEL_PACKED: { + SinglePixelPackedSampleModel sm = (SinglePixelPackedSampleModel)sampleModel; + output.writeInt(sm.getScanlineStride()); + KryoUtil.writeIntArray(output, sm.getBitMasks()); + break; + } + + case TYPE_MULTI_PIXEL_PACKED: { + MultiPixelPackedSampleModel sm = (MultiPixelPackedSampleModel)sampleModel; + output.writeInt(sm.getPixelBitStride()); + output.writeInt(sm.getScanlineStride()); + output.writeInt(sm.getDataBitOffset()); + break; + } + + default: + throw new UnsupportedOperationException("Unknown SampleModel type: " + sampleModel.getClass().getName()); + } + } + + @Override + public SampleModel read(Kryo kryo, Input input, Class type) { + int sampleModelType = input.readInt(); + int transferType = input.readInt(); + int width = input.readInt(); + int height = input.readInt(); + + switch (sampleModelType) { + case TYPE_BANDED: { + int[] bankIndices = KryoUtil.readIntArray(input); + int[] bandOffsets = KryoUtil.readIntArray(input); + return RasterFactory.createBandedSampleModel(transferType, width, height, bankIndices.length, bankIndices, bandOffsets); + } + + case TYPE_PIXEL_INTERLEAVED: { + int pixelStride = input.readInt(); + int scanLineStride = input.readInt(); + int[] bandOffsets = KryoUtil.readIntArray(input); + return RasterFactory.createPixelInterleavedSampleModel(transferType, width, height, pixelStride, scanLineStride, bandOffsets); + } + + case TYPE_COMPONENT_JAI: + case TYPE_COMPONENT: { + int pixelStride = input.readInt(); + int scanLineStride = input.readInt(); + int[] bankIndices = KryoUtil.readIntArray(input); + int[] bandOffsets = KryoUtil.readIntArray(input); + if (sampleModelType == TYPE_COMPONENT_JAI) { + return new ComponentSampleModelJAI(transferType, width, height, pixelStride, scanLineStride, bankIndices, bandOffsets); + } else { + return new ComponentSampleModel(transferType, width, height, pixelStride, scanLineStride, bankIndices, bandOffsets); + } + } + + case TYPE_SINGLE_PIXEL_PACKED: { + int scanLineStride = input.readInt(); + int[] bitMasks = KryoUtil.readIntArray(input); + return new SinglePixelPackedSampleModel(transferType, width, height, scanLineStride, bitMasks); + } + + case TYPE_MULTI_PIXEL_PACKED: { + int pixelStride = input.readInt(); + int scanLineStride = input.readInt(); + int dataBitOffset = input.readInt(); + return new MultiPixelPackedSampleModel(transferType, width, height, pixelStride, scanLineStride, dataBitOffset); + } + + default: + throw new UnsupportedOperationException("Unsupported SampleModel type: " + sampleModelType); + } + } +} diff --git a/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java b/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java new file mode 100644 index 0000000000..616ded015e --- /dev/null +++ b/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java @@ -0,0 +1,179 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.sedona.common.raster.serde; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import com.esotericsoftware.kryo.io.UnsafeInput; +import com.esotericsoftware.kryo.io.UnsafeOutput; +import org.apache.sedona.common.raster.DeepCopiedRenderedImage; +import org.geotools.coverage.GridSampleDimension; +import org.geotools.coverage.grid.GridCoverage2D; +import org.geotools.coverage.grid.GridCoverageFactory; +import org.geotools.coverage.grid.GridEnvelope2D; +import org.geotools.coverage.grid.GridGeometry2D; +import org.geotools.referencing.operation.transform.AffineTransform2D; +import org.objenesis.strategy.StdInstantiatorStrategy; +import org.opengis.referencing.operation.MathTransform; + +import javax.media.jai.RenderedImageAdapter; +import java.awt.image.RenderedImage; +import java.io.IOException; +import java.io.Serializable; +import java.net.URI; + +public class Serde { + private Serde() {} + + /** + * URIs are not serializable. We need to provide a custom serializer + */ + private static class URISerializer extends Serializer { + public URISerializer() { + setImmutable(true); + } + + @Override + public void write(final Kryo kryo, final Output output, final URI uri) { + KryoUtil.writeUTF8String(output, uri.toString()); + } + + @Override + public URI read(final Kryo kryo, final Input input, final Class uriClass) { + return URI.create(KryoUtil.readUTF8String(input)); + } + } + + private static final ThreadLocal kryos = ThreadLocal.withInitial(() -> { + Kryo kryo = new Kryo(); + kryo.setInstantiatorStrategy(new Kryo.DefaultInstantiatorStrategy(new StdInstantiatorStrategy())); + kryo.register(AffineTransform2D.class, new AffineTransform2DSerializer()); + kryo.register(GridSampleDimension.class, new GridSampleDimensionSerializer()); + kryo.register(URI.class, new URISerializer()); + DeepCopiedRenderedImage.registerKryo(kryo); + try { + kryo.register(Class.forName("org.geotools.coverage.grid.RenderedSampleDimension"), + new GridSampleDimensionSerializer()); + } catch (ClassNotFoundException e) { + throw new RuntimeException("Cannot register kryo serializer for class RenderedSampleDimension", e); + } + kryo.setClassLoader(Thread.currentThread().getContextClassLoader()); + return kryo; + }); + + private static class SerializableState implements Serializable, KryoSerializable { + public CharSequence name; + + // The following three components are used to construct a GridGeometry2D object. + // We serialize CRS separately because the default serializer is pretty slow, we use a + // cached serializer to speed up the serialization and reuse CRS on deserialization. + public GridEnvelope2D gridEnvelope2D; + public MathTransform gridToCRS; + public byte[] serializedCRS; + + public GridSampleDimension[] bands; + public DeepCopiedRenderedImage image; + + public GridCoverage2D restore() { + GridGeometry2D gridGeometry = new GridGeometry2D(gridEnvelope2D, gridToCRS, CRSSerializer.deserialize(serializedCRS)); + return new GridCoverageFactory().create(name, image, gridGeometry, bands, null, null); + } + + private static final GridEnvelopeSerializer gridEnvelopeSerializer = new GridEnvelopeSerializer(); + private static final AffineTransform2DSerializer affineTransform2DSerializer = new AffineTransform2DSerializer(); + private static final GridSampleDimensionSerializer gridSampleDimensionSerializer = new GridSampleDimensionSerializer(); + + @Override + public void write(Kryo kryo, Output output) { + KryoUtil.writeUTF8String(output, name.toString()); + gridEnvelopeSerializer.write(kryo, output, gridEnvelope2D); + if (!(gridToCRS instanceof AffineTransform2D)) { + throw new UnsupportedOperationException("Only AffineTransform2D is supported"); + } + affineTransform2DSerializer.write(kryo, output, (AffineTransform2D) gridToCRS); + output.writeInt(serializedCRS.length); + output.writeBytes(serializedCRS); + output.writeInt(bands.length); + for (GridSampleDimension band : bands) { + gridSampleDimensionSerializer.write(kryo, output, band); + } + image.write(kryo, output); + } + + @Override + public void read(Kryo kryo, Input input) { + name = KryoUtil.readUTF8String(input); + gridEnvelope2D = gridEnvelopeSerializer.read(kryo, input, GridEnvelope2D.class); + gridToCRS = affineTransform2DSerializer.read(kryo, input, AffineTransform2D.class); + int serializedCRSLength = input.readInt(); + serializedCRS = input.readBytes(serializedCRSLength); + int bandCount = input.readInt(); + bands = new GridSampleDimension[bandCount]; + for (int i = 0; i < bandCount; i++) { + bands[i] = gridSampleDimensionSerializer.read(kryo, input, GridSampleDimension.class); + } + image = new DeepCopiedRenderedImage(); + image.read(kryo, input); + } + } + + // A byte reserved for supporting rasters with other storage schemes + private static final int IN_DB = 0; + + public static byte[] serialize(GridCoverage2D raster) throws IOException { + Kryo kryo = kryos.get(); + // GridCoverage2D created by GridCoverage2DReaders contain references that are not serializable. + // Wrap the RenderedImage in DeepCopiedRenderedImage to make it serializable. + DeepCopiedRenderedImage deepCopiedRenderedImage = null; + RenderedImage renderedImage = raster.getRenderedImage(); + while (renderedImage instanceof RenderedImageAdapter) { + renderedImage = ((RenderedImageAdapter) renderedImage).getWrappedImage(); + } + if (renderedImage instanceof DeepCopiedRenderedImage) { + deepCopiedRenderedImage = (DeepCopiedRenderedImage) renderedImage; + } else { + deepCopiedRenderedImage = new DeepCopiedRenderedImage(renderedImage); + } + + SerializableState state = new SerializableState(); + GridGeometry2D gridGeometry = raster.getGridGeometry(); + state.name = raster.getName(); + state.gridEnvelope2D = gridGeometry.getGridRange2D(); + state.gridToCRS = gridGeometry.getGridToCRS2D(); + state.serializedCRS = CRSSerializer.serialize(gridGeometry.getCoordinateReferenceSystem()); + state.bands = raster.getSampleDimensions(); + state.image = deepCopiedRenderedImage; + try (UnsafeOutput out = new UnsafeOutput(4096, -1)) { + out.writeByte(IN_DB); + state.write(kryo, out); + return out.toBytes(); + } + } + + public static GridCoverage2D deserialize(byte[] bytes) throws IOException, ClassNotFoundException { + Kryo kryo = kryos.get(); + try (UnsafeInput in = new UnsafeInput(bytes)) { + int rasterType = in.readByte(); + if (rasterType != IN_DB) { + throw new IllegalArgumentException("Unsupported raster type: " + rasterType); + } + SerializableState state = new SerializableState(); + state.read(kryo, in); + return state.restore(); + } + } +} diff --git a/common/src/test/java/org/apache/sedona/common/raster/RasterBandEditorsTest.java b/common/src/test/java/org/apache/sedona/common/raster/RasterBandEditorsTest.java index 78af992c03..19b932ad77 100644 --- a/common/src/test/java/org/apache/sedona/common/raster/RasterBandEditorsTest.java +++ b/common/src/test/java/org/apache/sedona/common/raster/RasterBandEditorsTest.java @@ -19,6 +19,7 @@ package org.apache.sedona.common.raster; import org.apache.sedona.common.Constructors; +import org.apache.sedona.common.raster.serde.Serde; import org.geotools.coverage.grid.GridCoverage2D; import org.junit.Test; import org.locationtech.jts.geom.Geometry; @@ -189,7 +190,7 @@ public void testClipWithGeometryTransform() throws FactoryException, IOException } @Test - public void testClip() throws IOException, FactoryException, TransformException, ParseException { + public void testClip() throws IOException, FactoryException, TransformException, ParseException, ClassNotFoundException { GridCoverage2D raster = rasterFromGeoTiff(resourceFolder + "raster_geotiff_color/FAA_UTM18N_NAD83.tif"); String polygon = "POLYGON ((236722 4204770, 243900 4204770, 243900 4197590, 221170 4197590, 236722 4204770))"; Geometry geom = Constructors.geomFromWKT(polygon, RasterAccessors.srid(raster)); @@ -216,6 +217,8 @@ public void testClip() throws IOException, FactoryException, TransformException, GridCoverage2D croppedRaster = RasterBandEditors.clip(raster, 1, geom, 200, true); assertEquals(0, croppedRaster.getRenderedImage().getMinX()); assertEquals(0, croppedRaster.getRenderedImage().getMinY()); + GridCoverage2D croppedRaster2 = Serde.deserialize(Serde.serialize(croppedRaster)); + assertSameCoverage(croppedRaster, croppedRaster2); points = new ArrayList<>(); points.add(Constructors.geomFromWKT("POINT(236842 4.20465e+06)", 26918)); points.add(Constructors.geomFromWKT("POINT(236961 4.20453e+06)", 26918)); diff --git a/common/src/test/java/org/apache/sedona/common/raster/RasterConstructorsForTestingTest.java b/common/src/test/java/org/apache/sedona/common/raster/RasterConstructorsForTestingTest.java new file mode 100644 index 0000000000..0dde7811d1 --- /dev/null +++ b/common/src/test/java/org/apache/sedona/common/raster/RasterConstructorsForTestingTest.java @@ -0,0 +1,111 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.sedona.common.raster; + +import org.apache.sedona.common.raster.serde.Serde; +import org.apache.sedona.common.utils.RasterUtils; +import org.geotools.coverage.grid.GridCoverage2D; +import org.junit.Test; + +import java.awt.image.ComponentSampleModel; +import java.awt.image.MultiPixelPackedSampleModel; +import java.awt.image.PixelInterleavedSampleModel; +import java.awt.image.Raster; +import java.awt.image.SinglePixelPackedSampleModel; +import java.io.IOException; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +public class RasterConstructorsForTestingTest extends RasterTestBase { + @Test + public void testBandedRaster() { + GridCoverage2D raster = makeRasterWithFallbackParams(4, "I", "BandedSampleModel", 4, 3); + assertTrue(raster.getRenderedImage().getSampleModel() instanceof ComponentSampleModel); + testSerde(raster); + } + + @Test + public void testPixelInterleavedRaster() { + GridCoverage2D raster = makeRasterWithFallbackParams(4, "I", "PixelInterleavedSampleModel", 4, 3); + assertTrue(raster.getRenderedImage().getSampleModel() instanceof PixelInterleavedSampleModel); + testSerde(raster); + raster = makeRasterWithFallbackParams(4, "I", "PixelInterleavedSampleModelComplex", 4, 3); + assertTrue(raster.getRenderedImage().getSampleModel() instanceof PixelInterleavedSampleModel); + testSerde(raster); + } + + @Test + public void testComponentSampleModel() { + GridCoverage2D raster = makeRasterWithFallbackParams(4, "I", "ComponentSampleModel", 4, 3); + assertTrue(raster.getRenderedImage().getSampleModel() instanceof ComponentSampleModel); + testSerde(raster); + } + + @Test + public void testSinglePixelPackedSampleModel() { + GridCoverage2D raster = makeRasterWithFallbackParams(4, "I", "SinglePixelPackedSampleModel", 4, 3); + assertTrue(raster.getRenderedImage().getSampleModel() instanceof SinglePixelPackedSampleModel); + testSerde(raster); + } + + @Test + public void testMultiPixelPackedSampleModel() { + GridCoverage2D raster = makeRasterWithFallbackParams(1, "B", "MultiPixelPackedSampleModel", 4, 3); + assertTrue(raster.getRenderedImage().getSampleModel() instanceof MultiPixelPackedSampleModel); + testSerde(raster); + + raster = makeRasterWithFallbackParams(1, "B", "MultiPixelPackedSampleModel", 21, 8); + Raster r = RasterUtils.getRaster(raster.getRenderedImage()); + assertTrue(r.getSampleModel() instanceof MultiPixelPackedSampleModel); + assertEquals(21, r.getWidth()); + assertEquals(8, r.getHeight()); + for (int y = 0; y < r.getHeight(); y++) { + for (int x = 0; x < r.getWidth(); x++) { + assertEquals((x + y * 21) % 16, r.getSample(x, y, 0)); + } + } + } + + private static GridCoverage2D makeRasterWithFallbackParams(int numBand, String bandDataType, String sampleModelType, int width, int height) { + return RasterConstructorsForTesting.makeRasterForTesting(numBand, bandDataType, sampleModelType, width, height, + 0.5, -0.5, 1, -1, 0, 0, 3857); + } + + private static void testSerde(GridCoverage2D raster) { + try { + byte[] bytes = Serde.serialize(raster); + GridCoverage2D roundTripRaster = Serde.deserialize(bytes); + assertNotNull(roundTripRaster); + assertEquals(raster.getNumSampleDimensions(), roundTripRaster.getNumSampleDimensions()); + + assertEquals(raster.getGridGeometry(), roundTripRaster.getGridGeometry()); + int width = raster.getRenderedImage().getWidth(); + int height = raster.getRenderedImage().getHeight(); + Raster r = RasterUtils.getRaster(raster.getRenderedImage()); + for (int b = 0; b < raster.getNumSampleDimensions(); b++) { + for (int y = 0; y < height; y++) { + for (int x = 0; x < width; x++) { + double value = b + y * width + x; + assertEquals(value, r.getSampleDouble(x, y, b), 0.0001); + } + } + } + + } catch (IOException | ClassNotFoundException e) { + throw new RuntimeException(e); + } + } +} diff --git a/common/src/test/java/org/apache/sedona/common/raster/RasterTestBase.java b/common/src/test/java/org/apache/sedona/common/raster/RasterTestBase.java index e1f1a9ee95..d531c77f6c 100644 --- a/common/src/test/java/org/apache/sedona/common/raster/RasterTestBase.java +++ b/common/src/test/java/org/apache/sedona/common/raster/RasterTestBase.java @@ -49,8 +49,8 @@ public class RasterTestBase { protected static final double FP_TOLERANCE = 1E-4; - GridCoverage2D oneBandRaster; - GridCoverage2D multiBandRaster; + protected GridCoverage2D oneBandRaster; + protected GridCoverage2D multiBandRaster; byte[] geoTiff; byte[] testNc; String ncFile = resourceFolder + "raster/netcdf/test.nc"; @@ -121,6 +121,10 @@ GridCoverage2D createMultibandRaster() throws IOException { return factory.create("test", image, new Envelope2D(DefaultGeographicCRS.WGS84, 0, 0, 10, 10)); } + protected void assertSameCoverage(GridCoverage2D expected, GridCoverage2D actual) { + assertSameCoverage(expected, actual, 10); + } + protected void assertSameCoverage(GridCoverage2D expected, GridCoverage2D actual, int density) { Assert.assertEquals(expected.getNumSampleDimensions(), actual.getNumSampleDimensions()); Envelope expectedEnvelope = expected.getEnvelope(); diff --git a/common/src/test/java/org/apache/sedona/common/raster/CRSSerializerTest.java b/common/src/test/java/org/apache/sedona/common/raster/serde/CRSSerializerTest.java similarity index 97% rename from common/src/test/java/org/apache/sedona/common/raster/CRSSerializerTest.java rename to common/src/test/java/org/apache/sedona/common/raster/serde/CRSSerializerTest.java index fd749bc067..bb5399fd0c 100644 --- a/common/src/test/java/org/apache/sedona/common/raster/CRSSerializerTest.java +++ b/common/src/test/java/org/apache/sedona/common/raster/serde/CRSSerializerTest.java @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.sedona.common.raster; +package org.apache.sedona.common.raster.serde; import org.geotools.referencing.CRS; import org.junit.Assert; diff --git a/common/src/test/java/org/apache/sedona/common/raster/serde/DataBufferSerializerTest.java b/common/src/test/java/org/apache/sedona/common/raster/serde/DataBufferSerializerTest.java new file mode 100644 index 0000000000..b3c6de1797 --- /dev/null +++ b/common/src/test/java/org/apache/sedona/common/raster/serde/DataBufferSerializerTest.java @@ -0,0 +1,153 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.sedona.common.raster.serde; + +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import org.junit.Assert; +import org.junit.Test; + +import java.awt.image.DataBuffer; +import java.awt.image.DataBufferByte; +import java.awt.image.DataBufferDouble; +import java.awt.image.DataBufferFloat; +import java.awt.image.DataBufferInt; +import java.awt.image.DataBufferShort; +import java.awt.image.DataBufferUShort; + +public class DataBufferSerializerTest extends KryoSerializerTestBase { + private static final DataBufferSerializer serializer = new DataBufferSerializer(); + + private static void assertEquals(DataBuffer expected, DataBuffer actual) { + Assert.assertEquals(expected.getDataType(), actual.getDataType()); + Assert.assertEquals(expected.getNumBanks(), actual.getNumBanks()); + Assert.assertEquals(expected.getSize(), actual.getSize()); + Assert.assertArrayEquals(expected.getOffsets(), actual.getOffsets()); + for (int bank = 0; bank < expected.getNumBanks(); bank++) { + for (int k = 0; k < expected.getSize(); k++) { + Assert.assertEquals(expected.getElemDouble(bank, k), actual.getElemDouble(bank, k), 1e-6); + } + } + } + + @Test + public void serializeByteBuffer() { + byte[][] dataArray = { + {1, 2, 3, 4, 5}, + {6, 7, 8, 9, 0} + }; + int size = 5; + int[] offsets = {0, 0}; + DataBufferByte dataBufferByte = new DataBufferByte(dataArray, size, offsets); + try (Output out = createOutput()) { + serializer.write(kryo, out, dataBufferByte); + try (Input in = createInput(out)) { + DataBuffer dataBufferByte1 = serializer.read(kryo, in, DataBuffer.class); + assertEquals(dataBufferByte, dataBufferByte1); + } + } + } + + @Test + public void serializeShortBuffer() { + short[][] dataArray = { + {1, 2, 3, 4, 5}, + {6, 7, 8, 9, 0} + }; + int size = 5; + int[] offsets = {0, 0}; + DataBuffer dataBufferShort = new DataBufferShort(dataArray, size, offsets); + try (Output out = createOutput()) { + serializer.write(kryo, out, dataBufferShort); + try (Input in = createInput(out)) { + DataBuffer dataBufferShort1 = serializer.read(kryo, in, DataBuffer.class); + Assert.assertTrue(dataBufferShort1 instanceof DataBufferShort); + assertEquals(dataBufferShort, dataBufferShort1); + } + } + } + + @Test + public void serializeUShortBuffer() { + short[][] dataArray = { + {1, 2, 3, 4, 5}, + {6, 7, 8, 9, 0} + }; + int size = 5; + int[] offsets = {0, 0}; + DataBuffer dataBufferShort = new DataBufferUShort(dataArray, size, offsets); + try (Output out = createOutput()) { + serializer.write(kryo, out, dataBufferShort); + try (Input in = createInput(out)) { + DataBuffer dataBufferShort1 = serializer.read(kryo, in, DataBuffer.class); + Assert.assertTrue(dataBufferShort1 instanceof DataBufferUShort); + assertEquals(dataBufferShort, dataBufferShort1); + } + } + } + + @Test + public void serializeIntBuffer() { + int[][] dataArray = { + {1, 2, 3, 4, 5}, + {6, 7, 8, 9, 0} + }; + int size = 5; + int[] offsets = {0, 0}; + DataBuffer dataBufferInt = new DataBufferInt(dataArray, size, offsets); + try (Output out = createOutput()) { + serializer.write(kryo, out, dataBufferInt); + try (Input in = createInput(out)) { + DataBuffer dataBufferInt1 = serializer.read(kryo, in, DataBuffer.class); + assertEquals(dataBufferInt, dataBufferInt1); + } + } + } + + @Test + public void serializeFloatBuffer() { + float[][] dataArray = { + {1, 2, 3, 4, 5}, + {6, 7, 8, 9, 0} + }; + int size = 5; + int[] offsets = {0, 0}; + DataBuffer dataBufferFloat = new DataBufferFloat(dataArray, size, offsets); + try (Output out = createOutput()) { + serializer.write(kryo, out, dataBufferFloat); + try (Input in = createInput(out)) { + DataBuffer dataBufferFloat1 = serializer.read(kryo, in, DataBuffer.class); + assertEquals(dataBufferFloat, dataBufferFloat1); + } + } + } + + @Test + public void serializeDoubleBuffer() { + double[][] dataArray = { + {1, 2, 3, 4, 5}, + {6, 7, 8, 9, 0} + }; + int size = 5; + int[] offsets = {0, 0}; + DataBuffer dataBufferDouble = new DataBufferDouble(dataArray, size, offsets); + try (Output out = createOutput()) { + serializer.write(kryo, out, dataBufferDouble); + try (Input in = createInput(out)) { + DataBuffer dataBufferDouble1 = serializer.read(kryo, in, DataBuffer.class); + assertEquals(dataBufferDouble, dataBufferDouble1); + } + } + } +} diff --git a/common/src/test/java/org/apache/sedona/common/raster/serde/KryoSerializerTestBase.java b/common/src/test/java/org/apache/sedona/common/raster/serde/KryoSerializerTestBase.java new file mode 100644 index 0000000000..af088a3cc5 --- /dev/null +++ b/common/src/test/java/org/apache/sedona/common/raster/serde/KryoSerializerTestBase.java @@ -0,0 +1,34 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.sedona.common.raster.serde; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import com.esotericsoftware.kryo.io.UnsafeInput; +import com.esotericsoftware.kryo.io.UnsafeOutput; + +public class KryoSerializerTestBase { + protected static final Kryo kryo = new Kryo(); + + protected static Output createOutput() { + return new UnsafeOutput(4096, -1); + } + + protected static Input createInput(Output out) { + out.flush(); + byte[] bytes = out.toBytes(); + return new UnsafeInput(bytes); + } +} diff --git a/common/src/test/java/org/apache/sedona/common/raster/serde/KryoUtilTest.java b/common/src/test/java/org/apache/sedona/common/raster/serde/KryoUtilTest.java new file mode 100644 index 0000000000..359f54f877 --- /dev/null +++ b/common/src/test/java/org/apache/sedona/common/raster/serde/KryoUtilTest.java @@ -0,0 +1,234 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.sedona.common.raster.serde; + +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import org.junit.Assert; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +public class KryoUtilTest extends KryoSerializerTestBase { + + private static class TestClass { + private int a; + private String b; + private double[] c; + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TestClass testClass = (TestClass) o; + return a == testClass.a && Objects.equals(b, testClass.b) && Arrays.equals(c, testClass.c); + } + + @Override + public int hashCode() { + return Objects.hash(a, b, Arrays.hashCode(c)); + } + } + + @Test + public void writeObjectWithLength() { + TestClass obj = new TestClass(); + obj.a = 1; + obj.b = "test"; + obj.c = new double[]{1.0, 2.0, 3.0}; + try (Output out = createOutput()) { + KryoUtil.writeObjectWithLength(kryo, out, obj); + try (Input in = createInput(out)) { + in.readInt(); // skip serialized object length + TestClass obj2 = kryo.readObject(in, TestClass.class); + assertEquals(obj, obj2); + } + } + } + + @Test + public void serializeUTF8String() { + String str = + "Hello - English\n" + + "Hola - Spanish\n" + + "Bonjour - French\n" + + "Hallo - German\n" + + "Ciao - Italian\n" + + "你好 - Chinese\n" + + "こんにちは - Japanese\n" + + "안녕하세요 - Korean\n" + + "Здравствуйте - Russian\n" + + "नमस्ते - Hindi\n" + + "مرحبا - Arabic\n" + + "שלום - Hebrew\n" + + "สวัสดี - Thai\n" + + "Merhaba - Turkish\n" + + "Γεια σας - Greek"; + try (Output out = createOutput()) { + KryoUtil.writeUTF8String(out, str); + try (Input in = createInput(out)) { + String str2 = KryoUtil.readUTF8String(in); + assertEquals(str, str2); + } + } + } + + @Test + public void serializeIntArray() { + int[] arr = new int[]{1, 2, 3, 4, 5}; + try (Output out = createOutput()) { + KryoUtil.writeIntArray(out, arr); + try (Input in = createInput(out)) { + int[] arr2 = KryoUtil.readIntArray(in); + assertArrayEquals(arr, arr2); + } + } + } + + @Test + public void serializeIntArrays() { + int[][] arrs = new int[][]{ + new int[]{1, 2, 3, 4, 5}, + new int[]{6, 7, 8, 9, 10} + }; + try (Output out = createOutput()) { + KryoUtil.writeIntArrays(out, arrs); + try (Input in = createInput(out)) { + int[][] arrs2 = KryoUtil.readIntArrays(in); + assertArrayEquals(arrs, arrs2); + } + } + } + + @Test + public void serializeByteArrays() { + byte[][] arrs = new byte[][]{ + new byte[]{1, 2, 3, 4, 5}, + new byte[]{6, 7, 8, 9, 10} + }; + try (Output out = createOutput()) { + KryoUtil.writeByteArrays(out, arrs); + try (Input in = createInput(out)) { + byte[][] arrs2 = KryoUtil.readByteArrays(in); + assertArrayEquals(arrs, arrs2); + } + } + } + + @Test + public void serializeDoubleArrays() { + double[][] arrs = new double[][]{ + new double[]{1.0, 2.0, 3.0, 4.0, 5.0}, + new double[]{6.0, 7.0, 8.0, 9.0, 10.0} + }; + try (Output out = createOutput()) { + KryoUtil.writeDoubleArrays(out, arrs); + try (Input in = createInput(out)) { + double[][] arrs2 = KryoUtil.readDoubleArrays(in); + assertArrayEquals(arrs, arrs2); + } + } + } + + @Test + public void serializeLongArrays() { + long[][] arrs = new long[][]{ + new long[]{1L, 2L, 3L, 4L, 5L}, + new long[]{6L, 7L, 8L, 9L, 10L} + }; + try (Output out = createOutput()) { + KryoUtil.writeLongArrays(out, arrs); + try (Input in = createInput(out)) { + long[][] arrs2 = KryoUtil.readLongArrays(in); + assertArrayEquals(arrs, arrs2); + } + } + } + + @Test + public void serializeFloatArrays() { + float[][] arrs = new float[][]{ + new float[]{1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, + new float[]{6.0f, 7.0f, 8.0f, 9.0f, 10.0f} + }; + try (Output out = createOutput()) { + KryoUtil.writeFloatArrays(out, arrs); + try (Input in = createInput(out)) { + float[][] arrs2 = KryoUtil.readFloatArrays(in); + assertArrayEquals(arrs, arrs2); + } + } + } + + @Test + public void serializeShortArrays() { + short[][] arrs = new short[][]{ + new short[]{1, 2, 3, 4, 5}, + new short[]{6, 7, 8, 9, 10} + }; + try (Output out = createOutput()) { + KryoUtil.writeShortArrays(out, arrs); + try (Input in = createInput(out)) { + short[][] arrs2 = KryoUtil.readShortArrays(in); + assertArrayEquals(arrs, arrs2); + } + } + } + + @Test + public void serializeNullUTF8StringMap() { + try (Output out = createOutput()) { + KryoUtil.writeUTF8StringMap(out, null); + try (Input in = createInput(out)) { + Map map = KryoUtil.readUTF8StringMap(in); + Assert.assertNull(map); + } + } + } + + @Test + public void serializeEmptyUTF8StringMap() { + try (Output out = createOutput()) { + KryoUtil.writeUTF8StringMap(out, Collections.emptyMap()); + try (Input in = createInput(out)) { + Map map = KryoUtil.readUTF8StringMap(in); + Assert.assertNotNull(map); + Assert.assertTrue(map.isEmpty()); + } + } + } + + @Test + public void serializeUTF8StringMap() { + Map map = new HashMap<>(); + map.put("key1", "value1"); + map.put("key2", "value2"); + + try (Output out = createOutput()) { + KryoUtil.writeUTF8StringMap(out, map); + try (Input in = createInput(out)) { + Map map2 = KryoUtil.readUTF8StringMap(in); + Assert.assertNotNull(map2); + Assert.assertEquals(map, map2); + } + } + } +} diff --git a/common/src/test/java/org/apache/sedona/common/raster/serde/SampleModelSerializerTest.java b/common/src/test/java/org/apache/sedona/common/raster/serde/SampleModelSerializerTest.java new file mode 100644 index 0000000000..5bac27f938 --- /dev/null +++ b/common/src/test/java/org/apache/sedona/common/raster/serde/SampleModelSerializerTest.java @@ -0,0 +1,112 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.sedona.common.raster.serde; + +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import org.junit.Assert; +import org.junit.Test; + +import javax.media.jai.ComponentSampleModelJAI; +import java.awt.image.BandedSampleModel; +import java.awt.image.ComponentSampleModel; +import java.awt.image.DataBuffer; +import java.awt.image.MultiPixelPackedSampleModel; +import java.awt.image.PixelInterleavedSampleModel; +import java.awt.image.SampleModel; +import java.awt.image.SinglePixelPackedSampleModel; + +public class SampleModelSerializerTest extends KryoSerializerTestBase { + private static final SampleModelSerializer serializer = new SampleModelSerializer(); + + @Test + public void serializeBandedSampleModel() { + int[] bankIndices = {2, 0, 1}; + int[] bandOffsets = {4, 8, 12}; + SampleModel sm = new BandedSampleModel(DataBuffer.TYPE_INT, 100, 80, 100, bankIndices, bandOffsets); + try (Output out = createOutput()) { + serializer.write(kryo, out, sm); + try (Input in = createInput(out)) { + SampleModel sm1 = serializer.read(kryo, in, SampleModel.class); + Assert.assertEquals(sm, sm1); + } + } + } + + @Test + public void serializePixelInterleavedSampleModel() { + int[] bandOffsets = {0, 1, 2}; + SampleModel sm = new PixelInterleavedSampleModel(DataBuffer.TYPE_INT, 100, 80, 3, 300, bandOffsets); + try (Output out = createOutput()) { + serializer.write(kryo, out, sm); + try (Input in = createInput(out)) { + SampleModel sm1 = serializer.read(kryo, in, SampleModel.class); + Assert.assertEquals(sm, sm1); + } + } + } + + @Test + public void serializeComponentSampleModel() { + int[] bankIndices = {1, 0}; + int[] bandOffsets = {0, 10000}; + SampleModel sm = new ComponentSampleModel(DataBuffer.TYPE_INT, 100, 80, 1, 100, bankIndices, bandOffsets); + try (Output out = createOutput()) { + serializer.write(kryo, out, sm); + try (Input in = createInput(out)) { + SampleModel sm1 = serializer.read(kryo, in, SampleModel.class); + Assert.assertEquals(sm, sm1); + } + } + } + + @Test + public void serializeComponentSampleModelJAI() { + int[] bankIndices = {1, 0}; + int[] bandOffsets = {0, 10000}; + SampleModel sm = new ComponentSampleModelJAI(DataBuffer.TYPE_INT, 100, 80, 1, 100, bankIndices, bandOffsets); + try (Output out = createOutput()) { + serializer.write(kryo, out, sm); + try (Input in = createInput(out)) { + SampleModel sm1 = serializer.read(kryo, in, SampleModel.class); + Assert.assertEquals(sm, sm1); + } + } + } + + @Test + public void serializeSinglePixelPackedSampleModel() { + int[] bitMasks = {0x000000ff, 0x0000ff00, 0x00ff0000}; + SampleModel sm = new SinglePixelPackedSampleModel(DataBuffer.TYPE_INT, 100, 80, 100, bitMasks); + try (Output out = createOutput()) { + serializer.write(kryo, out, sm); + try (Input in = createInput(out)) { + SampleModel sm1 = serializer.read(kryo, in, SampleModel.class); + Assert.assertEquals(sm, sm1); + } + } + } + + @Test + public void serializedMultiPixelPackedSampleModel() { + SampleModel sm = new MultiPixelPackedSampleModel(DataBuffer.TYPE_BYTE, 100, 80, 4); + try (Output out = createOutput()) { + serializer.write(kryo, out, sm); + try (Input in = createInput(out)) { + SampleModel sm1 = serializer.read(kryo, in, SampleModel.class); + Assert.assertEquals(sm, sm1); + } + } + } +} diff --git a/common/src/test/java/org/apache/sedona/common/raster/SerdeTest.java b/common/src/test/java/org/apache/sedona/common/raster/serde/SerdeTest.java similarity index 73% rename from common/src/test/java/org/apache/sedona/common/raster/SerdeTest.java rename to common/src/test/java/org/apache/sedona/common/raster/serde/SerdeTest.java index 7a29aaa45b..844c3d34fa 100644 --- a/common/src/test/java/org/apache/sedona/common/raster/SerdeTest.java +++ b/common/src/test/java/org/apache/sedona/common/raster/serde/SerdeTest.java @@ -16,11 +16,14 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.sedona.common.raster; +package org.apache.sedona.common.raster.serde; +import org.apache.sedona.common.raster.RasterConstructors; +import org.apache.sedona.common.raster.RasterTestBase; import org.geotools.coverage.grid.GridCoverage2D; import org.geotools.gce.geotiff.GeoTiffReader; import org.junit.Test; +import org.opengis.referencing.FactoryException; import java.io.File; import java.io.IOException; @@ -37,12 +40,12 @@ public class SerdeTest extends RasterTestBase { }; @Test - public void testRoundtripSerdeSingelbandRaster() throws IOException, ClassNotFoundException { + public void testRoundTripSerdeSingleBandRaster() throws IOException, ClassNotFoundException { testRoundTrip(oneBandRaster); } @Test - public void testRoundtripSerdeMultibandRaster() throws IOException, ClassNotFoundException { + public void testRoundTripSerdeMultiBandRaster() throws IOException, ClassNotFoundException { testRoundTrip(multiBandRaster); } @@ -55,6 +58,20 @@ public void testInDbRaster() throws IOException, ClassNotFoundException { } } + @Test + public void testNorthPoleRaster() throws IOException, ClassNotFoundException, FactoryException { + // If we are not using non-strict mode to serializing CRS, this will raise an exception: + // org.geotools.referencing.wkt.UnformattableObjectException: This "AxisDirection" object is too complex for + // WKT syntax. + GridCoverage2D raster = RasterConstructors.makeEmptyRaster( + 1, "B", 256, 256, + -345000.000, 345000.000, + 2000, -2000, + 0, 0, + 3996); + testRoundTrip(raster); + } + private GridCoverage2D testRoundTrip(GridCoverage2D raster) throws IOException, ClassNotFoundException { return testRoundTrip(raster, 10); } diff --git a/docs/setup/compile.md b/docs/setup/compile.md index 417038c279..49e627775d 100644 --- a/docs/setup/compile.md +++ b/docs/setup/compile.md @@ -73,11 +73,20 @@ For example, export SPARK_HOME=$PWD/spark-3.0.1-bin-hadoop2.7 export PYTHONPATH=$SPARK_HOME/python ``` -2. Compile the Sedona Scala and Java code with `-Dgeotools` and then copy the ==sedona-spark-shaded-{{ sedona.current_version }}.jar== to ==SPARK_HOME/jars/== folder. +2. Put JAI jars to ==SPARK_HOME/jars/== folder. +``` +export JAI_CORE_VERSION="1.1.3" +export JAI_CODEC_VERSION="1.1.3" +export JAI_IMAGEIO_VERSION="1.1" +wget -P $SPARK_HOME/jars/ https://repo.osgeo.org/repository/release/javax/media/jai_core/${JAI_CORE_VERSION}/jai_core-${JAI_CORE_VERSION}.jar +wget -P $SPARK_HOME/jars/ https://repo.osgeo.org/repository/release/javax/media/jai_codec/${JAI_CODEC_VERSION}/jai_codec-${JAI_CODEC_VERSION}.jar +wget -P $SPARK_HOME/jars/ https://repo.osgeo.org/repository/release/javax/media/jai_imageio/${JAI_IMAGEIO_VERSION}/jai_imageio-${JAI_IMAGEIO_VERSION}.jar +``` +3. Compile the Sedona Scala and Java code with `-Dgeotools` and then copy the ==sedona-spark-shaded-{{ sedona.current_version }}.jar== to ==SPARK_HOME/jars/== folder. ``` cp spark-shaded/target/sedona-spark-shaded-xxx.jar $SPARK_HOME/jars/ ``` -3. Install the following libraries +4. Install the following libraries ``` sudo apt-get -y install python3-pip python-dev libgeos-dev sudo pip3 install -U setuptools @@ -86,12 +95,12 @@ sudo pip3 install -U virtualenvwrapper sudo pip3 install -U pipenv ``` Homebrew can be used to install libgeos-dev in macOS: `brew install geos` -4. Set up pipenv to the desired Python version: 3.7, 3.8, or 3.9 +5. Set up pipenv to the desired Python version: 3.7, 3.8, or 3.9 ``` cd python pipenv --python 3.7 ``` -5. Install the PySpark version and the other dependency +6. Install the PySpark version and the other dependency ``` cd python pipenv install pyspark @@ -99,7 +108,7 @@ pipenv install --dev ``` `pipenv install pyspark` installs the latest version of pyspark. In order to remain consistent with the installed spark version, use `pipenv install pyspark==` -6. Run the Python tests +7. Run the Python tests ``` cd python pipenv run python setup.py build_ext --inplace diff --git a/docs/tutorial/raster.md b/docs/tutorial/raster.md index ee4a922e2f..cfc64bc9c6 100644 --- a/docs/tutorial/raster.md +++ b/docs/tutorial/raster.md @@ -583,6 +583,92 @@ SELECT RS_AsPNG(raster) Please refer to [Raster writer docs](../../api/sql/Raster-writer) for more details. +## Collecting raster Dataframes and working with them locally in Python + +Sedona allows collecting Dataframes with raster columns and working with them locally in Python since `v1.6.0`. +The raster objects are represented as `SedonaRaster` objects in Python, which can be used to perform raster operations. + +```python +df_raster = sedona.read.format("binaryFile").load("/path/to/raster.tif").selectExpr("RS_FromGeoTiff(content) as rast") +rows = df_raster.collect() +raster = rows[0].rast +raster # +``` + +You can retrieve the metadata of the raster by accessing the properties of the `SedonaRaster` object. + +```python +raster.width # width of the raster +raster.height # height of the raster +raster.affine_trans # affine transformation matrix +raster.crs_wkt # coordinate reference system as WKT +``` + +You can get a numpy array containing the band data of the raster using the `as_numpy` or `as_numpy_masked` method. The +band data is organized in CHW order. + +```python +raster.as_numpy() # numpy array of the raster +raster.as_numpy_masked() # numpy array with nodata values masked as nan +``` + +If you want to work with the raster data using `rasterio`, you can retrieve a `rasterio.DatasetReader` object using the +`as_rasterio` method. + +```python +ds = raster.as_rasterio() # rasterio.DatasetReader object +# Work with the raster using rasterio +band1 = ds.read(1) # read the first band +``` + +## Writing Python UDF to work with raster data + +You can write Python UDFs to work with raster data in Python. The UDFs can take `SedonaRaster` objects as input and +return any Spark data type as output. This is an example of a Python UDF that calculates the mean of the raster data. + +```python +from pyspark.sql.types import DoubleType + +def mean_udf(raster): + return float(raster.as_numpy().mean()) + +sedona.udf.register("mean_udf", mean_udf, DoubleType()) +df_raster.withColumn("mean", expr("mean_udf(rast)")).show() +``` + +``` ++--------------------+------------------+ +| rast| mean| ++--------------------+------------------+ +|GridCoverage2D["g...|1542.8092886117788| ++--------------------+------------------+ +``` + +It is much trickier to write an UDF that returns a raster object, since Sedona does not support serializing Python raster +objects yet. However, you can write a UDF that returns the band data as an array and then construct the raster object using +`RS_MakeRaster`. This is an example of a Python UDF that creates a mask raster based on the first band of the input raster. + +```python +from pyspark.sql.types import ArrayType, DoubleType +import numpy as np + +def mask_udf(raster): + band1 = raster.as_numpy()[0,:,:] + mask = (band1 < 1400).astype(np.float64) + return mask.flatten().tolist() + +sedona.udf.register("mask_udf", band_udf, ArrayType(DoubleType())) +df_raster.withColumn("mask", expr("mask_udf(rast)")).withColumn("mask_rast", expr("RS_MakeRaster(rast, 'I', mask)")).show() +``` + +``` ++--------------------+--------------------+--------------------+ +| rast| mask| mask_rast| ++--------------------+--------------------+--------------------+ +|GridCoverage2D["g...|[0.0, 0.0, 0.0, 0...|GridCoverage2D["g...| ++--------------------+--------------------+--------------------+ +``` + ## Performance optimization When working with large raster datasets, refer to the [documentation on storing raster geometries in Parquet format](../storing-blobs-in-parquet) for recommendations to optimize performance. diff --git a/python/Pipfile b/python/Pipfile index 110203363b..cd7fdee21c 100644 --- a/python/Pipfile +++ b/python/Pipfile @@ -19,6 +19,7 @@ attrs="*" pyarrow="*" keplergl = "==0.3.2" pydeck = "===0.8.0" +rasterio = ">=1.2.10" [requires] python_version = "3.7" diff --git a/python/sedona/raster/__init__.py b/python/sedona/raster/__init__.py new file mode 100644 index 0000000000..a67d5ea255 --- /dev/null +++ b/python/sedona/raster/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/python/sedona/raster/awt_raster.py b/python/sedona/raster/awt_raster.py new file mode 100644 index 0000000000..d951359436 --- /dev/null +++ b/python/sedona/raster/awt_raster.py @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from .data_buffer import DataBuffer +from .sample_model import SampleModel + + +class AWTRaster: + """Raster data structure of Java AWT Raster used by GeoTools GridCoverage2D. + + """ + min_x: int + min_y: int + width: int + height: int + sample_model: SampleModel + data_buffer: DataBuffer + + def __init__(self, min_x, min_y, width, height, sample_model: SampleModel, data_buffer: DataBuffer): + if sample_model.width != width or sample_model.height != height: + raise RuntimeError("Size of the image does not match with the sample model") + self.min_x = min_x + self.min_y = min_y + self.width = width + self.height = height + self.sample_model = sample_model + self.data_buffer = data_buffer diff --git a/python/sedona/raster/data_buffer.py b/python/sedona/raster/data_buffer.py new file mode 100644 index 0000000000..8826e26bdc --- /dev/null +++ b/python/sedona/raster/data_buffer.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import List, Any +import numpy as np + + +class DataBuffer: + TYPE_BYTE = 0 + TYPE_USHORT = 1 + TYPE_SHORT = 2 + TYPE_INT = 3 + TYPE_FLOAT = 4 + TYPE_DOUBLE = 5 + + data_type: int + bank_data: List[np.ndarray] + size: int + offsets: List[int] + + def __init__(self, data_type: int, bank_data: List[np.ndarray], size: int, offsets: List[int]): + self.data_type = data_type + self.bank_data = bank_data + self.size = size + self.offsets = offsets diff --git a/python/sedona/raster/meta.py b/python/sedona/raster/meta.py new file mode 100644 index 0000000000..b0013359dd --- /dev/null +++ b/python/sedona/raster/meta.py @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from enum import Enum +from typing import List, Dict, Optional + + +class PixelAnchor(Enum): + """Anchor of the pixel cell. GeoTools anchors the coordinates at the center + of pixels, while GDAL anchors the coordinates at the upper-left corner of + the pixels. This difference requires us to convert the affine + transformation between these conventions. + + """ + CENTER = 1 + UPPER_LEFT = 2 + + +class AffineTransform: + scale_x: float + skew_y: float + skew_x: float + scale_y: float + ip_x: float + ip_y: float + pixel_anchor: PixelAnchor + + def __init__(self, scale_x, skew_y, skew_x, scale_y, ip_x, ip_y, pixel_anchor: PixelAnchor): + self.scale_x = scale_x + self.skew_y = skew_y + self.skew_x = skew_x + self.scale_y = scale_y + self.ip_x = ip_x + self.ip_y = ip_y + self.pixel_anchor = pixel_anchor + + def with_anchor(self, pixel_anchor: PixelAnchor): + if pixel_anchor == self.pixel_anchor: + return self + return self._do_change_pixel_anchor(self.pixel_anchor, pixel_anchor) + + def translate(self, offset_x: float, offset_y: float): + new_ipx = self.ip_x + offset_x * self.scale_x + offset_y * self.skew_x + new_ipy = self.ip_y + offset_x * self.skew_y + offset_y * self.scale_y + return AffineTransform(self.scale_x, self.skew_y, self.skew_x, self.scale_y, + new_ipx, new_ipy, self.pixel_anchor) + + def _do_change_pixel_anchor(self, from_anchor: PixelAnchor, to_anchor: PixelAnchor): + assert from_anchor != to_anchor + if from_anchor == PixelAnchor.CENTER: + m00 = 1.0 + m10 = 0.0 + m01 = 0.0 + m11 = 1.0 + m02 = -0.5 + m12 = -0.5 + else: + m00 = 1.0 + m10 = 0.0 + m01 = 0.0 + m11 = 1.0 + m02 = 0.5 + m12 = 0.5 + + old_m00 = self.scale_x + old_m10 = self.skew_y + old_m01 = self.skew_x + old_m11 = self.scale_y + old_m02 = self.ip_x + old_m12 = self.ip_y + new_m00 = old_m00 * m00 + old_m01 * m10 + new_m01 = old_m00 * m01 + old_m01 * m11 + new_m02 = old_m00 * m02 + old_m01 * m12 + old_m02 + new_m10 = old_m10 * m00 + old_m11 * m10 + new_m11 = old_m10 * m01 + old_m11 * m11 + new_m12 = old_m10 * m02 + old_m11 * m12 + old_m12 + return AffineTransform(new_m00, new_m10, new_m01, new_m11, new_m02, new_m12, to_anchor) + + def __repr__(self): + return ("[ {} {} {}\n".format(self.scale_x, self.skew_x, self.ip_x) + + " {} {} {}\n".format(self.skew_y, self.scale_y, self.ip_y) + + " 0 0 1 ]") + + +class SampleDimension: + """Raster band metadata. + + """ + description: str + offset: float + scale: float + nodata: float + + def __init__(self, description, offset, scale, nodata): + self.description = description + self.offset = offset + self.scale = scale + self.nodata = nodata diff --git a/python/sedona/raster/raster_serde.py b/python/sedona/raster/raster_serde.py new file mode 100644 index 0000000000..63b740c5a3 --- /dev/null +++ b/python/sedona/raster/raster_serde.py @@ -0,0 +1,180 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union, Tuple, List, Dict +from io import BytesIO +import struct +import zlib +import numpy as np + +from .sample_model import SampleModel, ComponentSampleModel, PixelInterleavedSampleModel, MultiPixelPackedSampleModel, SinglePixelPackedSampleModel +from .data_buffer import DataBuffer +from .awt_raster import AWTRaster +from .meta import AffineTransform, PixelAnchor, SampleDimension +from .sedona_raster import SedonaRaster, InDbSedonaRaster + + +class RasterTypes: + IN_DB = 0 + + +def deserialize(buf: Union[bytearray, bytes]) -> Optional[SedonaRaster]: + if buf is None: + return None + + bio = BytesIO(buf) + raster_type = int(bio.read(1)[0]) + return _deserialize(bio, raster_type) + + +def _deserialize(bio: BytesIO, raster_type: int) -> SedonaRaster: + name = _read_utf8_string(bio) + width, height, x, y = _read_grid_envelope(bio) + affine_trans = _read_affine_transformation(bio) + affine_trans = affine_trans.translate(x, y) + affine_trans = affine_trans.with_anchor(PixelAnchor.UPPER_LEFT) + crs_wkt = _read_crs_wkt(bio) + bands_meta = _read_sample_dimensions(bio) + if raster_type == RasterTypes.IN_DB: + # In-DB raster + awt_raster = _read_awt_raster(bio) + return InDbSedonaRaster(width, height, bands_meta, affine_trans, crs_wkt, awt_raster) + else: + raise ValueError("unsupported raster_type: {}".format(raster_type)) + + +def _read_grid_envelope(bio: BytesIO) -> Tuple[int, int, int, int]: + width, height, x, y = struct.unpack("=iiii", bio.read(4 * 4)) + return (width, height, x, y) + + +def _read_affine_transformation(bio: BytesIO) -> AffineTransform: + scale_x, skew_y, skew_x, scale_y, ip_x, ip_y = struct.unpack("=dddddd", bio.read(8 * 6)) + return AffineTransform(scale_x, skew_y, skew_x, scale_y, ip_x, ip_y, PixelAnchor.CENTER) + + +def _read_crs_wkt(bio: BytesIO) -> str: + size, = struct.unpack("=i", bio.read(4)) + compressed_wkt = bio.read(size) + crs_wkt = zlib.decompress(compressed_wkt) + return crs_wkt.decode('utf-8') + + +def _read_sample_dimensions(bio: BytesIO) -> List[SampleDimension]: + num_bands, = struct.unpack("=i", bio.read(4)) + bands_meta = [] + for i in range(num_bands): + description = _read_utf8_string(bio) + offset, scale, nodata = struct.unpack("=ddd", bio.read(8 * 3)) + _ignore_java_object(bio) + bands_meta.append(SampleDimension(description, offset, scale, nodata)) + return bands_meta + + +def _read_awt_raster(bio: BytesIO) -> AWTRaster: + min_x, min_y, width, height = struct.unpack("=iiii", bio.read(4 * 4)) + _ignore_java_object(bio) # image properties + _ignore_java_object(bio) # color model + min_x_1, min_y_1 = struct.unpack("=ii", bio.read(4 * 2)) + if min_x_1 != min_x or min_y_1 != min_y: + raise RuntimeError("malformed serialized raster: minx/miny of the image cannot match with minx/miny of the AWT raster") + sample_model = _read_sample_model(bio) + data_buffer = _read_data_buffer(bio) + return AWTRaster(min_x, min_y, width, height, sample_model, data_buffer) + + +def _read_sample_model(bio: BytesIO) -> SampleModel: + sample_model_type, data_type, width, height = struct.unpack("=iiii", bio.read(4 * 4)) + if sample_model_type == SampleModel.TYPE_BANDED: + bank_indices = _read_int_array(bio) + band_offsets = _read_int_array(bio) + return ComponentSampleModel(data_type, width, height, 1, width, bank_indices, band_offsets) + elif sample_model_type == SampleModel.TYPE_PIXEL_INTERLEAVED: + pixel_stride, scanline_stride = struct.unpack("=ii", bio.read(4 * 2)) + band_offsets = _read_int_array(bio) + return PixelInterleavedSampleModel(data_type, width, height, pixel_stride, scanline_stride, band_offsets) + elif sample_model_type in [SampleModel.TYPE_COMPONENT, SampleModel.TYPE_COMPONENT_JAI]: + pixel_stride, scanline_stride = struct.unpack("=ii", bio.read(4 * 2)) + bank_indices = _read_int_array(bio) + band_offsets = _read_int_array(bio) + return ComponentSampleModel(data_type, width, height, pixel_stride, scanline_stride, bank_indices, band_offsets) + elif sample_model_type == SampleModel.TYPE_SINGLE_PIXEL_PACKED: + scanline_stride, = struct.unpack("=i", bio.read(4)) + bit_masks = _read_int_array(bio) + return SinglePixelPackedSampleModel(data_type, width, height, scanline_stride, bit_masks) + elif sample_model_type == SampleModel.TYPE_MULTI_PIXEL_PACKED: + num_bits, scanline_stride, data_bit_offset = struct.unpack("=iii", bio.read(4 * 3)) + return MultiPixelPackedSampleModel(data_type, width, height, num_bits, scanline_stride, data_bit_offset) + else: + raise RuntimeError(f"Unsupported SampleModel type: {sample_model_type}") + + +def _read_data_buffer(bio: BytesIO) -> DataBuffer: + data_type, = struct.unpack("=i", bio.read(4)) + offsets = _read_int_array(bio) + size, = struct.unpack("=i", bio.read(4)) + + num_banks, = struct.unpack("=i", bio.read(4)) + banks = [] + for i in range(num_banks): + bank_size, = struct.unpack("=i", bio.read(4)) + if data_type == DataBuffer.TYPE_BYTE: + np_array = np.frombuffer(bio.read(bank_size), dtype=np.uint8) + elif data_type == DataBuffer.TYPE_SHORT: + np_array = np.frombuffer(bio.read(2 * bank_size), dtype=np.int16) + elif data_type == DataBuffer.TYPE_USHORT: + np_array = np.frombuffer(bio.read(2 * bank_size), dtype=np.uint16) + elif data_type == DataBuffer.TYPE_INT: + np_array = np.frombuffer(bio.read(4 * bank_size), dtype=np.int32) + elif data_type == DataBuffer.TYPE_FLOAT: + np_array = np.frombuffer(bio.read(4 * bank_size), dtype=np.float32) + elif data_type == DataBuffer.TYPE_DOUBLE: + np_array = np.frombuffer(bio.read(8 * bank_size), dtype=np.float64) + else: + raise ValueError("unknown data_type {}".format(data_type)) + + banks.append(np_array) + + return DataBuffer(data_type, banks, size, offsets) + + +def _read_utf8_string(bio: BytesIO) -> str: + size, = struct.unpack("=i", bio.read(4)) + utf8_bytes = bio.read(size) + return utf8_bytes.decode('utf-8') + + +def _ignore_java_object(bio: BytesIO): + size, = struct.unpack("=i", bio.read(4)) + bio.read(size) + + +def _read_int_array(bio: BytesIO) -> List[int]: + length, = struct.unpack("=i", bio.read(4)) + return [struct.unpack("=i", bio.read(4))[0] for _ in range(length)] + + +def _read_utf8_string_map(bio: BytesIO) -> Optional[Dict[str, str]]: + size, = struct.unpack("=i", bio.read(4)) + if size == -1: + return None + params = {} + for _ in range(size): + key = _read_utf8_string(bio) + value = _read_utf8_string(bio) + params[key] = value + return params diff --git a/python/sedona/raster/sample_model.py b/python/sedona/raster/sample_model.py new file mode 100644 index 0000000000..4c5ac193e7 --- /dev/null +++ b/python/sedona/raster/sample_model.py @@ -0,0 +1,193 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import List +from abc import ABC, abstractmethod +import numpy as np + +from .data_buffer import DataBuffer + + +class SampleModel(ABC): + """The SampleModel class and its subclasses are defined according to the data structure of + SampleModel class in Java AWT. + + """ + TYPE_BANDED = 1 + TYPE_PIXEL_INTERLEAVED = 2 + TYPE_SINGLE_PIXEL_PACKED = 3 + TYPE_MULTI_PIXEL_PACKED = 4 + TYPE_COMPONENT_JAI = 5 + TYPE_COMPONENT = 6 + + sample_model_type: int + data_type: int + width: int + height: int + + def __init__(self, sample_model_type, data_type, width, height): + self.sample_model_type = sample_model_type + self.data_type = data_type + self.width = width + self.height = height + + @abstractmethod + def as_numpy(self, data_buffer: DataBuffer) -> np.ndarray: + raise NotImplementedError("Abstract method as_numpy was not implemented by subclass") + + +class ComponentSampleModel(SampleModel): + pixel_stride: int + scanline_stride: int + bank_indices: List[int] + band_offsets: List[int] + + def __init__(self, data_type, width, height, pixel_stride, scanline_stride, bank_indices, band_offsets): + super().__init__(SampleModel.TYPE_COMPONENT, data_type, width, height) + self.pixel_stride = pixel_stride + self.scanline_stride = scanline_stride + self.bank_indices = bank_indices + self.band_offsets = band_offsets + + def as_numpy(self, data_buffer: DataBuffer) -> np.ndarray: + if self.scanline_stride == self.width and self.pixel_stride == 1: + # Fast path: no gaps between pixels + band_arrs = [] + for bank_index in self.bank_indices: + bank_data = data_buffer.bank_data[bank_index] + offset = self.band_offsets[bank_index] + if offset != 0: + bank_data = bank_data[offset:(offset + self.width * self.height)] + band_arr = bank_data.reshape(self.height, self.width) + band_arrs.append(band_arr) + return np.array(band_arrs) + else: + # Slow path + band_arrs = [] + for k in range(len(self.bank_indices)): + bank_index = self.bank_indices[k] + bank_data = data_buffer.bank_data[bank_index] + offset = self.band_offsets[k] + band_pixel_data = [] + for y in range(self.height): + for x in range(self.width): + pos = offset + y * self.scanline_stride + x * self.pixel_stride + band_pixel_data.append(bank_data[pos]) + arr = np.array(band_pixel_data).reshape(self.height, self.width) + band_arrs.append(arr) + + return np.array(band_arrs) + + +class PixelInterleavedSampleModel(SampleModel): + pixel_stride: int + scanline_stride: int + band_offsets: List[int] + + def __init__(self, data_type, width, height, pixel_stride, scanline_stride, band_offsets): + super().__init__(SampleModel.TYPE_PIXEL_INTERLEAVED, data_type, width, height) + self.pixel_stride = pixel_stride + self.scanline_stride = scanline_stride + self.band_offsets = band_offsets + + def as_numpy(self, data_buffer: DataBuffer) -> np.ndarray: + num_bands = len(self.band_offsets) + bank_data = data_buffer.bank_data[0] + if self.pixel_stride == num_bands and \ + self.scanline_stride == self.width * num_bands and \ + self.band_offsets == list(range(0, num_bands)): + # Fast path: no gapping in between band data, no band reordering + arr = bank_data.reshape(self.height, self.width, num_bands) + return np.transpose(arr, [2, 0, 1]) + else: + # Slow path + pixel_data = [] + for y in range(self.height): + for x in range(self.width): + begin = y * self.scanline_stride + x * self.pixel_stride + end = begin + num_bands + pixel = bank_data[begin:end][self.band_offsets] + pixel_data.append(pixel) + arr = np.array(pixel_data).reshape(self.height, self.width, num_bands) + return np.transpose(arr, [2, 0, 1]) + + +class SinglePixelPackedSampleModel(SampleModel): + scanline_stride: int + bit_masks: List[int] + bit_offsets: List[int] + + def __init__(self, data_type, width, height, scanline_stride, bit_masks): + super().__init__(SampleModel.TYPE_SINGLE_PIXEL_PACKED, data_type, width, height) + self.scanline_stride = scanline_stride + self.bit_masks = bit_masks + self.bit_offsets = [] + for v in self.bit_masks: + self.bit_offsets.append((v & -v).bit_length() - 1) + + def as_numpy(self, data_buffer: DataBuffer) -> np.ndarray: + num_bands = len(self.bit_masks) + bank_data = data_buffer.bank_data[0] + pixel_data = [] + for y in range(self.height): + for x in range(self.width): + pos = y * self.scanline_stride + x + value = bank_data[pos] + pixel = [] + for mask, bit_offset in zip(self.bit_masks, self.bit_offsets): + pixel.append((value & mask) >> bit_offset) + pixel_data.append(pixel) + arr = np.array(pixel_data, dtype=bank_data.dtype).reshape(self.height, self.width, num_bands) + return np.transpose(arr, [2, 0, 1]) + + +class MultiPixelPackedSampleModel(SampleModel): + num_bits: int + scanline_stride: int + data_bit_offset: int + + def __init__(self, data_type, width, height, num_bits, scanline_stride, data_bit_offset): + super().__init__(SampleModel.TYPE_MULTI_PIXEL_PACKED, data_type, width, height) + self.num_bits = num_bits + self.scanline_stride = scanline_stride + self.data_bit_offset = data_bit_offset + + def as_numpy(self, data_buffer: DataBuffer) -> np.ndarray: + bank_data = data_buffer.bank_data[0] + bits_per_value = bank_data.dtype.itemsize * 8 + pixel_per_value = bits_per_value / self.num_bits + shift_right = bits_per_value - self.num_bits + mask = ((1 << self.num_bits) - 1) << shift_right + + band_data = [] + for y in range(self.height): + pos = y * self.scanline_stride + self.data_bit_offset // bits_per_value + value = bank_data[pos] + shift = self.data_bit_offset % bits_per_value + value = (value << shift) + pixels: List[int] = [] + while len(pixels) < self.width: + while shift < bits_per_value and len(pixels) < self.width: + pixels.append((value & mask) >> shift_right) + value = (value << self.num_bits) + shift += self.num_bits + pos += 1 + value = bank_data[pos] + shift = 0 + band_data.append(np.array(pixels, dtype=bank_data.dtype)) + + return np.array(band_data).reshape(1, self.height, self.width) diff --git a/python/sedona/raster/sedona_raster.py b/python/sedona/raster/sedona_raster.py new file mode 100644 index 0000000000..e5ecb3723b --- /dev/null +++ b/python/sedona/raster/sedona_raster.py @@ -0,0 +1,261 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import List, Dict, Optional +from abc import ABC, abstractmethod +from xml.etree.ElementTree import Element, SubElement, tostring + +import numpy as np +import rasterio # type: ignore +import rasterio.env # type: ignore +from rasterio.transform import Affine # type: ignore +from rasterio.io import MemoryFile # type: ignore +from rasterio.io import DatasetReader # type: ignore + +try: + # for rasterio >= 1.3.0 + from rasterio._path import _parse_path as parse_path # type: ignore +except: + # for rasterio >= 1.2.0 + from rasterio.path import parse_path # type: ignore + +from .awt_raster import AWTRaster +from .data_buffer import DataBuffer +from .meta import AffineTransform, PixelAnchor +from .meta import SampleDimension + + +def _rasterio_open(fp, driver=None): + """A variant of rasterio.open. This function skip setting up a new GDAL env + when there is already an environment. This saves us lots of overhead + introduced by GDAL env initialization. + + """ + if rasterio.env.hasenv(): + # There is already an env, so we can get rid of the overhead of + # GDAL env initialization in rasterio.open(). + return DatasetReader(parse_path(fp), driver=driver) + else: + return rasterio.open(fp, mode="r", driver=driver) + + +def _generate_vrt_xml(src_path, data_type, width, height, geo_transform, crs_wkt, off_x, off_y, band_indices) -> bytes: + # Create root element + root = Element('VRTDataset') + root.set('rasterXSize', str(width)) + root.set('rasterYSize', str(height)) + + # Add CRS + if crs_wkt is not None and crs_wkt != '': + srs = SubElement(root, 'SRS') + srs.text = crs_wkt + + # Add GeoTransform + gt = SubElement(root, 'GeoTransform') + gt.text = geo_transform + + # Add bands + for i, band_index in enumerate(band_indices, start=1): + band = SubElement(root, 'VRTRasterBand') + band.set('dataType', data_type) + band.set('band', str(i)) + + # Add source + source = SubElement(band, 'SimpleSource') + src_prop = SubElement(source, 'SourceFilename') + src_prop.text = src_path + + # Set source properties + SubElement(source, 'SourceBand').text = str(band_index + 1) + SubElement(source, 'SrcRect', {'xOff': str(off_x), 'yOff': str(off_y), 'xSize': str(width), 'ySize': str(height)}) + SubElement(source, 'DstRect', {'xOff': '0', 'yOff': '0', 'xSize': str(width), 'ySize': str(height)}) + + # Generate pretty XML + xml_bytes = tostring(root, encoding='utf-8') + return xml_bytes + + +class SedonaRaster(ABC): + _width: int + _height: int + _bands_meta: List[SampleDimension] + _affine_trans: AffineTransform + _crs_wkt: str + + def __init__(self, width: int, height: int, bands_meta: List[SampleDimension], + affine_trans: AffineTransform, crs_wkt: str): + self._width = width + self._height = height + self._bands_meta = bands_meta + self._affine_trans = affine_trans + self._crs_wkt = crs_wkt + + @property + def width(self) -> int: + """Width of the raster in pixel""" + return self._width + + @property + def height(self) -> int: + """Height of the raster in pixel""" + return self._height + + @property + def crs_wkt(self) -> str: + """CRS of the raster as a WKT string""" + return self._crs_wkt + + @property + def bands_meta(self) -> List[SampleDimension]: + """Metadata of bands, including nodata value for each band""" + return self._bands_meta + + @property + def affine_trans(self) -> AffineTransform: + """Geo transform of the raster""" + return self._affine_trans + + @abstractmethod + def as_numpy(self) -> np.ndarray: + """Get the bands data as an numpy array in CHW layout + + """ + raise NotImplementedError() + + def as_numpy_masked(self) -> np.ndarray: + """Get the bands data as an numpy array in CHW layout, with nodata + values masked as nan. + + """ + arr = self.as_numpy() + nodata_values = np.array([bm.nodata for bm in self._bands_meta]) + nodata_values_reshaped = nodata_values[:, None, None] + mask = arr == nodata_values_reshaped + masked_arr = np.where(mask, np.nan, arr) + return masked_arr + + @abstractmethod + def as_rasterio(self) -> DatasetReader: + """Retrieve the raster as an rasterio DatasetReader + + """ + raise NotImplementedError() + + @abstractmethod + def close(self): + """Release all resources allocated for this sedona raster. The rasterio + DatasetReader returned by as_rasterio() will also be closed. + + """ + raise NotImplementedError() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def __del__(self): + self.close() + + +class InDbSedonaRaster(SedonaRaster): + awt_raster: AWTRaster + rasterio_memfile: Optional[MemoryFile] + rasterio_dataset_reader: Optional[DatasetReader] + + def __init__(self, width: int, height: int, bands_meta: List[SampleDimension], + affine_trans: AffineTransform, crs_wkt: str, + awt_raster: AWTRaster): + super().__init__(width, height, bands_meta, affine_trans, crs_wkt) + self.awt_raster = awt_raster + self.rasterio_memfile = None + self.rasterio_dataset_reader = None + + def as_numpy(self) -> np.ndarray: + sm = self.awt_raster.sample_model + return sm.as_numpy(self.awt_raster.data_buffer) + + def as_rasterio(self) -> DatasetReader: + if self.rasterio_dataset_reader is not None: + return self.rasterio_dataset_reader + + affine = Affine.from_gdal( + self._affine_trans.ip_x, self._affine_trans.scale_x, self._affine_trans.skew_x, + self._affine_trans.ip_y, self._affine_trans.skew_y, self._affine_trans.scale_y) + num_bands = len(self._bands_meta) + + data_array = np.ascontiguousarray(self.as_numpy()) + + dtype = data_array.dtype + if dtype == np.uint8: + data_type = 'Byte' + elif dtype == np.int8: + data_type = 'Int8' + elif dtype == np.uint16: + data_type = 'Uint16' + elif dtype == np.int16: + data_type = 'Int16' + elif dtype == np.uint32: + data_type = 'UInt32' + elif dtype == np.int32: + data_type = 'Int32' + elif dtype == np.float32: + data_type = 'Float32' + elif dtype == np.float64: + data_type = 'Float64' + elif dtype == np.int64: + data_type = 'Int64' + elif dtype == np.uint64: + data_type = 'Uint64' + else: + raise RuntimeError("unknown dtype: " + str(dtype)) + + arr_if = data_array.__array_interface__ + data_pointer = arr_if['data'][0] + geotransform = (f"{self._affine_trans.ip_x}/{self._affine_trans.scale_x}/{self._affine_trans.skew_x}/" + + f"{self._affine_trans.ip_y}/{self._affine_trans.skew_y}/{self._affine_trans.scale_y}") + # FIXME: GDAL 3.6 shipped with rasterio does not support + # SPATIALREFERENCE parameter, so we have to workaround this issue in a + # hacky way. If newer versions of rasterio bundle GDAL 3.7 then this + # won't be a problem. See https://gdal.org/drivers/raster/mem.html + desc = (f"MEM:::DATAPOINTER={data_pointer},PIXELS={self._width},LINES={self._height},BANDS={num_bands}," + + f"DATATYPE={data_type},GEOTRANSFORM={geotransform}") + + # construct a VRT to wrap this MEM dataset, with SRS set up properly + vrt_xml = _generate_vrt_xml( + desc, data_type, self._width, self._height, geotransform.replace('/', ','), self._crs_wkt, + 0, 0, list(range(num_bands))) + + # dataset = _rasterio_open(desc, driver="MEM") + self.rasterio_memfile = MemoryFile(vrt_xml, ext='.vrt') + dataset = self.rasterio_memfile.open(driver='VRT') + + # XXX: dataset does not copy the data held by data_array, so we set + # data_array as a property of dataset to make sure that the lifetime of + # data_array is as long as dataset, otherwise we may see band data + # corruption. + dataset.mem_data_array = data_array + return dataset + + def close(self): + if self.rasterio_dataset_reader is not None: + self.rasterio_dataset_reader.close() + self.rasterio_dataset_reader = None + if self.rasterio_memfile is not None: + self.rasterio_memfile.close() + self.rasterio_memfile = None diff --git a/python/sedona/sql/types.py b/python/sedona/sql/types.py index 36e22e17f4..239f19df8e 100644 --- a/python/sedona/sql/types.py +++ b/python/sedona/sql/types.py @@ -18,6 +18,8 @@ from pyspark.sql.types import UserDefinedType, BinaryType from ..utils import geometry_serde +from ..raster import raster_serde +from ..raster.sedona_raster import SedonaRaster class GeometryType(UserDefinedType): @@ -55,7 +57,7 @@ def serialize(self, obj): raise NotImplementedError("RasterType.serialize is not implemented yet") def deserialize(self, datum): - raise NotImplementedError("RasterType.deserialize is not implemented yet") + return raster_serde.deserialize(datum) @classmethod def module(cls): @@ -67,3 +69,6 @@ def needConversion(self): @classmethod def scalaUDT(cls): return "org.apache.spark.sql.sedona_sql.UDT.RasterUDT" + + +SedonaRaster.__UDT__ = RasterType() diff --git a/python/setup.py b/python/setup.py index 7576957d0d..6429499cbe 100644 --- a/python/setup.py +++ b/python/setup.py @@ -52,7 +52,7 @@ long_description=long_description, long_description_content_type="text/markdown", python_requires='>=3.6', - install_requires=['attrs', "shapely>=1.7.0"], + install_requires=['attrs', "shapely>=1.7.0", "rasterio>=1.2.10"], extras_require={ "spark": ["pyspark>=2.3.0"], "pydeck-map": ["pandas<=1.3.5", "geopandas<=0.10.2", "pydeck==0.8.0"], diff --git a/python/tests/raster/__init__.py b/python/tests/raster/__init__.py new file mode 100644 index 0000000000..a67d5ea255 --- /dev/null +++ b/python/tests/raster/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/python/tests/raster/test_meta.py b/python/tests/raster/test_meta.py new file mode 100644 index 0000000000..68135ba25d --- /dev/null +++ b/python/tests/raster/test_meta.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +from pytest import approx + +from sedona.raster.meta import AffineTransform +from sedona.raster.meta import PixelAnchor + + +class TestAffineTransform: + + def test_change_anchor_to_upper_left(self): + scale_x = 10.0 + skew_y = 1.0 + skew_x = 2.0 + scale_y = -8.0 + ip_x = 100 + ip_y = 200 + + trans = AffineTransform(scale_x, skew_y, skew_x, scale_y, ip_x, ip_y, PixelAnchor.CENTER) + trans_gdal = trans.with_anchor(PixelAnchor.UPPER_LEFT) + assert trans_gdal.scale_x == approx(scale_x) + assert trans_gdal.scale_y == approx(scale_y) + assert trans_gdal.skew_x == approx(skew_x) + assert trans_gdal.skew_y == approx(skew_y) + assert trans_gdal.ip_x == approx(94.0) + assert trans_gdal.ip_y == approx(203.5) + + def test_change_anchor_to_center(self): + scale_x = 10.0 + skew_y = 1.0 + skew_x = 2.0 + scale_y = -8.0 + ip_x = 100 + ip_y = 200 + + trans_gdal = AffineTransform(scale_x, skew_y, skew_x, scale_y, ip_x, ip_y, PixelAnchor.UPPER_LEFT) + trans = trans_gdal.with_anchor(PixelAnchor.CENTER) + assert trans.scale_x == approx(scale_x) + assert trans.scale_y == approx(scale_y) + assert trans.skew_x == approx(skew_x) + assert trans.skew_y == approx(skew_y) + assert trans.ip_x == approx(106.0) + assert trans.ip_y == approx(196.5) diff --git a/python/tests/raster/test_pandas_udf.py b/python/tests/raster/test_pandas_udf.py new file mode 100644 index 0000000000..8e7304941f --- /dev/null +++ b/python/tests/raster/test_pandas_udf.py @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +from tests.test_base import TestBase +from pyspark.sql.functions import expr, pandas_udf +from pyspark.sql.types import IntegerType +import pyspark +import pandas as pd +import numpy as np +import rasterio + +from tests import world_map_raster_input_location + +class TestRasterPandasUDF(TestBase): + @pytest.mark.skipif(pyspark.__version__ < '3.4', reason="requires Spark 3.4 or higher") + def test_raster_as_param(self): + spark = TestRasterPandasUDF.spark + df = spark.range(10).withColumn("rast", expr("RS_MakeRasterForTesting(1, 'I', 'PixelInterleavedSampleModel', 4, 3, 100, 100, 10, -10, 0, 0, 3857)")) + + # A Python Pandas UDF that takes a raster as input + @pandas_udf(IntegerType()) + def pandas_udf_raster_as_param(s: pd.Series) -> pd.Series: + from sedona.raster import raster_serde + + def func(x): + with raster_serde.deserialize(x) as raster: + arr = raster.as_numpy() + return int(np.sum(arr)) + + return s.apply(func) + + # A Python Pandas UDF that takes a raster as input + @pandas_udf(IntegerType()) + def pandas_udf_raster_as_param_2(s: pd.Series) -> pd.Series: + from sedona.raster import raster_serde + + def func(x): + with raster_serde.deserialize(x) as raster: + ds = raster.as_rasterio() + return int(np.sum(ds.read(1))) + + # wrap s.apply() with a rasterio env to get rid of the overhead of repeated + # env initialization in as_rasterio() + with rasterio.Env(): + return s.apply(func) + + spark.udf.register("pandas_udf_raster_as_param", pandas_udf_raster_as_param) + spark.udf.register("pandas_udf_raster_as_param_2", pandas_udf_raster_as_param_2) + + df_result = df.selectExpr("pandas_udf_raster_as_param(rast) as res") + rows = df_result.collect() + assert len(rows) == 10 + for row in rows: + assert row['res'] == 66 + + df_result = df.selectExpr("pandas_udf_raster_as_param_2(rast) as res") + rows = df_result.collect() + assert len(rows) == 10 + for row in rows: + assert row['res'] == 66 diff --git a/python/tests/raster/test_serde.py b/python/tests/raster/test_serde.py new file mode 100644 index 0000000000..dc94b01099 --- /dev/null +++ b/python/tests/raster/test_serde.py @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import rasterio +import numpy as np + +from tests.test_base import TestBase +from pyspark.sql.functions import expr +from sedona.sql.types import RasterType + +from tests import world_map_raster_input_location + +class TestRasterSerde(TestBase): + def test_empty_raster(self): + df = TestRasterSerde.spark.sql("SELECT RS_MakeEmptyRaster(2, 100, 200, 1000, 2000, 1) as raster") + raster = df.first()[0] + assert raster.width == 100 and raster.height == 200 and len(raster.bands_meta) == 2 + assert raster.affine_trans.ip_x == 1000 and raster.affine_trans.ip_y == 2000 + assert raster.affine_trans.scale_x == 1 and raster.affine_trans.scale_y == -1 + + def test_banded_sample_model(self): + df = TestRasterSerde.spark.sql("SELECT RS_MakeRasterForTesting(3, 'I', 'BandedSampleModel', 10, 8, 100, 100, 10, -10, 0, 0, 3857) as raster") + raster = df.first()[0] + assert raster.width == 10 and raster.height == 8 and len(raster.bands_meta) == 3 + self.validate_test_raster(raster) + + def test_pixel_interleaved_sample_model(self): + df = TestRasterSerde.spark.sql("SELECT RS_MakeRasterForTesting(3, 'I', 'PixelInterleavedSampleModel', 10, 10, 100, 100, 10, -10, 0, 0, 3857) as raster") + raster = df.first()[0] + assert raster.width == 10 and raster.height == 10 and len(raster.bands_meta) == 3 + self.validate_test_raster(raster) + df = TestRasterSerde.spark.sql("SELECT RS_MakeRasterForTesting(4, 'I', 'PixelInterleavedSampleModelComplex', 8, 10, 100, 100, 10, -10, 0, 0, 3857) as raster") + raster = df.first()[0] + assert raster.width == 8 and raster.height == 10 and len(raster.bands_meta) == 4 + self.validate_test_raster(raster) + + def test_component_sample_model(self): + for pixel_type in ['B', 'S', 'US', 'I', 'F', 'D']: + df = TestRasterSerde.spark.sql("SELECT RS_MakeRasterForTesting(4, '{}', 'ComponentSampleModel', 10, 10, 100, 100, 10, -10, 0, 0, 3857) as raster".format(pixel_type)) + raster = df.first()[0] + assert raster.width == 10 and raster.height == 10 and len(raster.bands_meta) == 4 + self.validate_test_raster(raster) + + def test_multi_pixel_packed_sample_model(self): + df = TestRasterSerde.spark.sql("SELECT RS_MakeRasterForTesting(1, 'B', 'MultiPixelPackedSampleModel', 10, 10, 100, 100, 10, -10, 0, 0, 3857) as raster") + raster = df.first()[0] + assert raster.width == 10 and raster.height == 10 and len(raster.bands_meta) == 1 + self.validate_test_raster(raster, packed=True) + + def test_single_pixel_packed_sample_model(self): + df = TestRasterSerde.spark.sql("SELECT RS_MakeRasterForTesting(4, 'I', 'SinglePixelPackedSampleModel', 10, 10, 100, 100, 10, -10, 0, 0, 3857) as raster") + raster = df.first()[0] + assert raster.width == 10 and raster.height == 10 and len(raster.bands_meta) == 4 + self.validate_test_raster(raster, packed=True) + + def test_raster_read_from_geotiff(self): + raster_path = world_map_raster_input_location + r_orig = rasterio.open(raster_path) + band = r_orig.read(1) + band_masked = np.where(band == 0, np.nan, band) + df = TestRasterSerde.spark.read.format("binaryFile").load(raster_path).selectExpr("RS_FromGeoTiff(content) as raster") + raster = df.first()[0] + assert raster.width == r_orig.width + assert raster.height == r_orig.height + assert raster.bands_meta[0].nodata == 0 + + # test as_rasterio + assert (band == raster.as_numpy()[0, :, :]).all() + ds = raster.as_rasterio() + assert ds.crs is not None + band_actual = ds.read(1) + assert (band == band_actual).all() + + # test as_numpy + arr = raster.as_numpy() + assert (arr[0, :, :] == band).all() + + # test as_numpy_masked + arr = raster.as_numpy_masked()[0, :, :] + assert np.array_equal(arr, band_masked) or np.array_equal(np.isnan(arr), np.isnan(band_masked)) + + raster.close() + r_orig.close() + + def test_to_pandas(self): + spark = TestRasterSerde.spark + df = spark.sql("SELECT RS_MakeRasterForTesting(3, 'I', 'BandedSampleModel', 10, 8, 100, 100, 10, -10, 0, 0, 3857) as raster") + pandas_df = df.toPandas() + raster = pandas_df.iloc[0]['raster'] + self.validate_test_raster(raster) + + def validate_test_raster(self, raster, packed = False): + arr = raster.as_numpy() + ds = raster.as_rasterio() + bands, height, width = arr.shape + assert bands > 0 and width > 0 and height > 0 + assert ds.crs is not None + for b in range(bands): + band = ds.read(b + 1) + for y in range(height): + for x in range(width): + expected = b + y * width + x + if packed: + expected = expected % 16 + assert arr[b, y, x] == expected + assert band[y, x] == expected diff --git a/spark-shaded/pom.xml b/spark-shaded/pom.xml index 064e543a2f..b9855e8766 100644 --- a/spark-shaded/pom.xml +++ b/spark-shaded/pom.xml @@ -63,30 +63,128 @@ org.geotools gt-main + + + javax.media + jai_core + + + javax.media + jai_codec + + + javax.media + jai_imageio + + org.geotools gt-referencing + + + javax.media + jai_core + + + javax.media + jai_codec + + + javax.media + jai_imageio + + org.geotools gt-epsg-hsql + + + javax.media + jai_core + + + javax.media + jai_codec + + + javax.media + jai_imageio + + org.geotools gt-geotiff + + + javax.media + jai_core + + + javax.media + jai_codec + + + javax.media + jai_imageio + + org.geotools gt-process-feature + + + javax.media + jai_core + + + javax.media + jai_codec + + + javax.media + jai_imageio + + org.geotools gt-arcgrid + + + javax.media + jai_core + + + javax.media + jai_codec + + + javax.media + jai_imageio + + org.geotools gt-coverage + + + javax.media + jai_core + + + javax.media + jai_codec + + + javax.media + jai_imageio + + diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala index 3d06f84e34..f12976606e 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala @@ -212,6 +212,7 @@ object Catalog { function[RS_FromGeoTiff](), function[RS_MakeEmptyRaster](), function[RS_MakeRaster](), + function[RS_MakeRasterForTesting](), function[RS_Tile](), function[RS_TileExplode](), function[RS_Envelope](), diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/utils/RasterSerializer.scala b/spark/common/src/main/scala/org/apache/sedona/sql/utils/RasterSerializer.scala index 3ffc41c6c2..0164753077 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/utils/RasterSerializer.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/utils/RasterSerializer.scala @@ -19,7 +19,7 @@ package org.apache.sedona.sql.utils -import org.apache.sedona.common.raster.Serde +import org.apache.sedona.common.raster.serde.Serde import org.geotools.coverage.grid.GridCoverage2D object RasterSerializer { diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/RasterUDT.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/RasterUDT.scala index 7db42c3530..f88d61ccd8 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/RasterUDT.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/RasterUDT.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.sedona_sql.UDT -import org.apache.sedona.common.raster.Serde +import org.apache.sedona.common.raster.serde.Serde import org.apache.spark.sql.types.{BinaryType, DataType, UserDefinedType} import org.geotools.coverage.grid.GridCoverage2D diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala index c977f09131..1b77516894 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala @@ -18,7 +18,8 @@ */ package org.apache.spark.sql.sedona_sql.expressions.raster -import org.apache.sedona.common.raster.RasterConstructors +import org.apache.sedona.common.raster.{RasterConstructors, RasterConstructorsForTesting} +import org.apache.sedona.sql.utils.RasterSerializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{CreateArray, Expression, Generator, Literal} @@ -76,6 +77,13 @@ case class RS_MakeRaster(inputExpressions: Seq[Expression]) } } +case class RS_MakeRasterForTesting(inputExpressions: Seq[Expression]) + extends InferredExpression(RasterConstructorsForTesting.makeRasterForTesting _) { + protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { + copy(inputExpressions = newChildren) + } +} + case class RS_Tile(inputExpressions: Seq[Expression]) extends InferredExpression( nullTolerantInferrableFunction3(RasterConstructors.rsTile), diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/implicits.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/implicits.scala index f1e1c6bf61..a26f8cd9f6 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/implicits.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/implicits.scala @@ -18,7 +18,7 @@ */ package org.apache.spark.sql.sedona_sql.expressions.raster -import org.apache.sedona.common.raster.Serde +import org.apache.sedona.common.raster.serde.Serde import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.sedona_sql.expressions.SerdeAware diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala b/spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala index 0acc2ab5df..cc49cdacd0 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala @@ -28,7 +28,7 @@ import org.junit.Assert.{assertEquals, assertNotNull, assertNull, assertTrue} import org.locationtech.jts.geom.{Coordinate, Geometry} import org.scalatest.{BeforeAndAfter, GivenWhenThen} -import java.awt.image.DataBuffer +import java.awt.image.{DataBuffer, SinglePixelPackedSampleModel} import java.io.File import java.net.URLConnection import scala.collection.mutable @@ -846,6 +846,13 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen } } + it("Passed RS_MakeRasterForTesting") { + val result = sparkSession.sql("SELECT RS_MakeRasterForTesting(4, 'I', 'SinglePixelPackedSampleModel', 10, 10, 100, 100, 10, -10, 0, 0, 3857) as raster").first().get(0) + assert(result.isInstanceOf[GridCoverage2D]) + val gridCoverage2D = result.asInstanceOf[GridCoverage2D] + assert(gridCoverage2D.getRenderedImage.getSampleModel.isInstanceOf[SinglePixelPackedSampleModel]) + } + it("Passed RS_BandAsArray") { val df = sparkSession.read.format("binaryFile").load(resourceFolder + "raster/test1.tiff") val metadata = df.selectExpr("RS_Metadata(RS_FromGeoTiff(content))").first().getSeq(0) diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/serdeAwareTest.scala b/spark/common/src/test/scala/org/apache/sedona/sql/serdeAwareTest.scala index 786af27469..53a753c057 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/serdeAwareTest.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/serdeAwareTest.scala @@ -21,7 +21,7 @@ package org.apache.sedona.sql import org.apache.sedona.common.geometrySerde.GeometrySerializer import org.apache.sedona.common.raster.RasterConstructors.fromArcInfoAsciiGrid -import org.apache.sedona.common.raster.Serde +import org.apache.sedona.common.raster.serde.Serde import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.sedona_sql.expressions.{ST_Buffer, ST_GeomFromText, ST_Point, ST_Union} import org.apache.spark.sql.sedona_sql.expressions.raster.{RS_FromArcInfoAsciiGrid, RS_NumBands}