Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions rust/sedona-raster-functions/src/rs_band_accessors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,64 @@ mod tests {
assert!(float_array.is_null(0));
}

#[test]
fn udf_bandpixeltype_multi_band() {
let udf: ScalarUDF = rs_bandpixeltype_udf().into();
let tester = ScalarUdfTester::new(udf, vec![RASTER, SedonaType::Arrow(DataType::Int32)]);

let rasters = sedona_testing::rasters::generate_multi_band_raster();

// Band 1: UInt8
let result = tester
.invoke_array_scalar(Arc::new(rasters.clone()), 1_i32)
.unwrap();
let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(arr.value(0), "UNSIGNED_8BITS");

// Band 2: UInt16
let result = tester
.invoke_array_scalar(Arc::new(rasters.clone()), 2_i32)
.unwrap();
let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(arr.value(0), "UNSIGNED_16BITS");

// Band 3: Float32
let result = tester
.invoke_array_scalar(Arc::new(rasters), 3_i32)
.unwrap();
let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(arr.value(0), "REAL_32BITS");
}

#[test]
fn udf_bandnodatavalue_multi_band() {
let udf: ScalarUDF = rs_bandnodatavalue_udf().into();
let tester = ScalarUdfTester::new(udf, vec![RASTER, SedonaType::Arrow(DataType::Int32)]);

let rasters = sedona_testing::rasters::generate_multi_band_raster();

// Band 1: nodata=255 (UInt8)
let result = tester
.invoke_array_scalar(Arc::new(rasters.clone()), 1_i32)
.unwrap();
let arr = result.as_any().downcast_ref::<Float64Array>().unwrap();
assert_eq!(arr.value(0), 255.0);

// Band 2: nodata=0 (UInt16)
let result = tester
.invoke_array_scalar(Arc::new(rasters.clone()), 2_i32)
.unwrap();
let arr = result.as_any().downcast_ref::<Float64Array>().unwrap();
assert_eq!(arr.value(0), 0.0);

// Band 3: no nodata (Float32)
let result = tester
.invoke_array_scalar(Arc::new(rasters), 3_i32)
.unwrap();
let arr = result.as_any().downcast_ref::<Float64Array>().unwrap();
assert!(arr.is_null(0));
}

#[test]
fn udf_bandnodatavalue_non_existing_band() {
let udf: ScalarUDF = rs_bandnodatavalue_udf().into();
Expand Down
11 changes: 11 additions & 0 deletions rust/sedona-raster-functions/src/rs_numbands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,15 @@ mod tests {
let result = tester.invoke_scalar(ScalarValue::Null).unwrap();
tester.assert_scalar_result_equals(result, ScalarValue::UInt32(None));
}

#[test]
fn udf_numbands_multi_band() {
let udf: ScalarUDF = rs_numbands_udf().into();
let tester = ScalarUdfTester::new(udf, vec![RASTER]);

let rasters = sedona_testing::rasters::generate_multi_band_raster();
let expected: Arc<dyn arrow_array::Array> = Arc::new(UInt32Array::from(vec![Some(3)]));
let result = tester.invoke_array(Arc::new(rasters)).unwrap();
assert_array_equal(&result, &expected);
}
}
104 changes: 104 additions & 0 deletions rust/sedona-testing/src/rasters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,77 @@ pub fn raster_from_single_band(
builder.finish().expect("finish")
}

/// Builds a single raster with 3 bands of different types for testing multi-band operations.
/// Band 1: UInt8 (nodata=255), Band 2: UInt16 (nodata=0), Band 3: Float32 (no nodata).
/// Each band is 2x2 pixels.
pub fn generate_multi_band_raster() -> StructArray {
let mut builder = RasterBuilder::new(1);
let crs = lnglat().unwrap().to_crs_string();
let metadata = RasterMetadata {
width: 2,
height: 2,
upperleft_x: 10.0,
upperleft_y: 20.0,
scale_x: 0.5,
scale_y: -0.5,
skew_x: 0.0,
skew_y: 0.0,
};
builder.start_raster(&metadata, Some(&crs)).unwrap();

// Band 1: UInt8, nodata=255
builder
.start_band(BandMetadata {
datatype: BandDataType::UInt8,
nodata_value: Some(vec![255u8]),
storage_type: StorageType::InDb,
outdb_url: None,
outdb_band_id: None,
})
.unwrap();
builder
.band_data_writer()
.append_value([1u8, 2u8, 3u8, 4u8]);
builder.finish_band().unwrap();

// Band 2: UInt16, nodata=0
builder
.start_band(BandMetadata {
datatype: BandDataType::UInt16,
nodata_value: Some(vec![0u8, 0u8]),
storage_type: StorageType::InDb,
outdb_url: None,
outdb_band_id: None,
})
.unwrap();
let band2_data: Vec<u8> = [100u16, 200u16, 300u16, 400u16]
.iter()
.flat_map(|v| v.to_le_bytes())
.collect();
builder.band_data_writer().append_value(&band2_data);
builder.finish_band().unwrap();

// Band 3: Float32, no nodata
builder
.start_band(BandMetadata {
datatype: BandDataType::Float32,
nodata_value: None,
storage_type: StorageType::InDb,
outdb_url: None,
outdb_band_id: None,
})
.unwrap();
let band3_data: Vec<u8> = [1.5f32, 2.5f32, 3.5f32, 4.5f32]
.iter()
.flat_map(|v| v.to_le_bytes())
.collect();
builder.band_data_writer().append_value(&band3_data);
builder.finish_band().unwrap();

builder.finish_raster().unwrap();
builder.finish().unwrap()
}

/// Determine if this tile contains a corner of the overall grid and return its position
/// Returns Some(position) if this tile contains a corner, None otherwise
fn get_corner_position(
Expand Down Expand Up @@ -526,6 +597,39 @@ mod tests {
assert_raster_equal(&raster1, &raster2);
}

#[test]
fn test_generate_multi_band_raster() {
let struct_array = generate_multi_band_raster();
let raster_array = RasterStructArray::new(&struct_array);
assert_eq!(raster_array.len(), 1);

let raster = raster_array.get(0).unwrap();
let metadata = raster.metadata();
assert_eq!(metadata.width(), 2);
assert_eq!(metadata.height(), 2);
assert_eq!(metadata.upper_left_x(), 10.0);
assert_eq!(metadata.upper_left_y(), 20.0);

let bands = raster.bands();
assert_eq!(bands.len(), 3);

// Band 1: UInt8, nodata=255
let b1 = bands.band(1).unwrap();
assert_eq!(b1.metadata().data_type().unwrap(), BandDataType::UInt8);
assert_eq!(b1.metadata().nodata_value(), Some(&[255u8][..]));
assert_eq!(b1.data(), &[1u8, 2, 3, 4]);

// Band 2: UInt16, nodata=0
let b2 = bands.band(2).unwrap();
assert_eq!(b2.metadata().data_type().unwrap(), BandDataType::UInt16);
assert_eq!(b2.metadata().nodata_value(), Some(&[0u8, 0][..]));

// Band 3: Float32, no nodata
let b3 = bands.band(3).unwrap();
assert_eq!(b3.metadata().data_type().unwrap(), BandDataType::Float32);
assert_eq!(b3.metadata().nodata_value(), None);
}

#[test]
#[should_panic = "Raster upper left x does not match"]
fn test_raster_different_metadata() {
Expand Down
Loading