diff --git a/rust/sedona-raster-functions/src/rs_band_accessors.rs b/rust/sedona-raster-functions/src/rs_band_accessors.rs index a22883636..ee1a308e1 100644 --- a/rust/sedona-raster-functions/src/rs_band_accessors.rs +++ b/rust/sedona-raster-functions/src/rs_band_accessors.rs @@ -431,6 +431,64 @@ mod tests { assert!(float_array.is_null(0)); } + #[test] + fn udf_bandpixeltype_multi_band() { + let udf: ScalarUDF = rs_bandpixeltype_udf().into(); + let tester = ScalarUdfTester::new(udf, vec![RASTER, SedonaType::Arrow(DataType::Int32)]); + + let rasters = sedona_testing::rasters::generate_multi_band_raster(); + + // Band 1: UInt8 + let result = tester + .invoke_array_scalar(Arc::new(rasters.clone()), 1_i32) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), "UNSIGNED_8BITS"); + + // Band 2: UInt16 + let result = tester + .invoke_array_scalar(Arc::new(rasters.clone()), 2_i32) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), "UNSIGNED_16BITS"); + + // Band 3: Float32 + let result = tester + .invoke_array_scalar(Arc::new(rasters), 3_i32) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), "REAL_32BITS"); + } + + #[test] + fn udf_bandnodatavalue_multi_band() { + let udf: ScalarUDF = rs_bandnodatavalue_udf().into(); + let tester = ScalarUdfTester::new(udf, vec![RASTER, SedonaType::Arrow(DataType::Int32)]); + + let rasters = sedona_testing::rasters::generate_multi_band_raster(); + + // Band 1: nodata=255 (UInt8) + let result = tester + .invoke_array_scalar(Arc::new(rasters.clone()), 1_i32) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 255.0); + + // Band 2: nodata=0 (UInt16) + let result = tester + .invoke_array_scalar(Arc::new(rasters.clone()), 2_i32) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 0.0); + + // Band 3: no nodata (Float32) + let result = tester + .invoke_array_scalar(Arc::new(rasters), 3_i32) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert!(arr.is_null(0)); + } + #[test] fn udf_bandnodatavalue_non_existing_band() { let udf: ScalarUDF = rs_bandnodatavalue_udf().into(); diff --git a/rust/sedona-raster-functions/src/rs_numbands.rs b/rust/sedona-raster-functions/src/rs_numbands.rs index d3389f1cd..f25c4df47 100644 --- a/rust/sedona-raster-functions/src/rs_numbands.rs +++ b/rust/sedona-raster-functions/src/rs_numbands.rs @@ -108,4 +108,15 @@ mod tests { let result = tester.invoke_scalar(ScalarValue::Null).unwrap(); tester.assert_scalar_result_equals(result, ScalarValue::UInt32(None)); } + + #[test] + fn udf_numbands_multi_band() { + let udf: ScalarUDF = rs_numbands_udf().into(); + let tester = ScalarUdfTester::new(udf, vec![RASTER]); + + let rasters = sedona_testing::rasters::generate_multi_band_raster(); + let expected: Arc = Arc::new(UInt32Array::from(vec![Some(3)])); + let result = tester.invoke_array(Arc::new(rasters)).unwrap(); + assert_array_equal(&result, &expected); + } } diff --git a/rust/sedona-testing/src/rasters.rs b/rust/sedona-testing/src/rasters.rs index 2aab413cc..d30940473 100644 --- a/rust/sedona-testing/src/rasters.rs +++ b/rust/sedona-testing/src/rasters.rs @@ -221,6 +221,77 @@ pub fn raster_from_single_band( builder.finish().expect("finish") } +/// Builds a single raster with 3 bands of different types for testing multi-band operations. +/// Band 1: UInt8 (nodata=255), Band 2: UInt16 (nodata=0), Band 3: Float32 (no nodata). +/// Each band is 2x2 pixels. +pub fn generate_multi_band_raster() -> StructArray { + let mut builder = RasterBuilder::new(1); + let crs = lnglat().unwrap().to_crs_string(); + let metadata = RasterMetadata { + width: 2, + height: 2, + upperleft_x: 10.0, + upperleft_y: 20.0, + scale_x: 0.5, + scale_y: -0.5, + skew_x: 0.0, + skew_y: 0.0, + }; + builder.start_raster(&metadata, Some(&crs)).unwrap(); + + // Band 1: UInt8, nodata=255 + builder + .start_band(BandMetadata { + datatype: BandDataType::UInt8, + nodata_value: Some(vec![255u8]), + storage_type: StorageType::InDb, + outdb_url: None, + outdb_band_id: None, + }) + .unwrap(); + builder + .band_data_writer() + .append_value([1u8, 2u8, 3u8, 4u8]); + builder.finish_band().unwrap(); + + // Band 2: UInt16, nodata=0 + builder + .start_band(BandMetadata { + datatype: BandDataType::UInt16, + nodata_value: Some(vec![0u8, 0u8]), + storage_type: StorageType::InDb, + outdb_url: None, + outdb_band_id: None, + }) + .unwrap(); + let band2_data: Vec = [100u16, 200u16, 300u16, 400u16] + .iter() + .flat_map(|v| v.to_le_bytes()) + .collect(); + builder.band_data_writer().append_value(&band2_data); + builder.finish_band().unwrap(); + + // Band 3: Float32, no nodata + builder + .start_band(BandMetadata { + datatype: BandDataType::Float32, + nodata_value: None, + storage_type: StorageType::InDb, + outdb_url: None, + outdb_band_id: None, + }) + .unwrap(); + let band3_data: Vec = [1.5f32, 2.5f32, 3.5f32, 4.5f32] + .iter() + .flat_map(|v| v.to_le_bytes()) + .collect(); + builder.band_data_writer().append_value(&band3_data); + builder.finish_band().unwrap(); + + builder.finish_raster().unwrap(); + builder.finish().unwrap() +} + /// Determine if this tile contains a corner of the overall grid and return its position /// Returns Some(position) if this tile contains a corner, None otherwise fn get_corner_position( @@ -526,6 +597,39 @@ mod tests { assert_raster_equal(&raster1, &raster2); } + #[test] + fn test_generate_multi_band_raster() { + let struct_array = generate_multi_band_raster(); + let raster_array = RasterStructArray::new(&struct_array); + assert_eq!(raster_array.len(), 1); + + let raster = raster_array.get(0).unwrap(); + let metadata = raster.metadata(); + assert_eq!(metadata.width(), 2); + assert_eq!(metadata.height(), 2); + assert_eq!(metadata.upper_left_x(), 10.0); + assert_eq!(metadata.upper_left_y(), 20.0); + + let bands = raster.bands(); + assert_eq!(bands.len(), 3); + + // Band 1: UInt8, nodata=255 + let b1 = bands.band(1).unwrap(); + assert_eq!(b1.metadata().data_type().unwrap(), BandDataType::UInt8); + assert_eq!(b1.metadata().nodata_value(), Some(&[255u8][..])); + assert_eq!(b1.data(), &[1u8, 2, 3, 4]); + + // Band 2: UInt16, nodata=0 + let b2 = bands.band(2).unwrap(); + assert_eq!(b2.metadata().data_type().unwrap(), BandDataType::UInt16); + assert_eq!(b2.metadata().nodata_value(), Some(&[0u8, 0][..])); + + // Band 3: Float32, no nodata + let b3 = bands.band(3).unwrap(); + assert_eq!(b3.metadata().data_type().unwrap(), BandDataType::Float32); + assert_eq!(b3.metadata().nodata_value(), None); + } + #[test] #[should_panic = "Raster upper left x does not match"] fn test_raster_different_metadata() {