diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index d23602c4bfd7..abff811c2ef7 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -525,8 +525,8 @@ mod tests { use crate::physical_plan::metrics::MetricValue; use crate::prelude::{SessionConfig, SessionContext}; use arrow::array::{ - ArrayRef, BinaryArray, BooleanArray, Float32Array, Float64Array, Int32Array, - StringArray, TimestampNanosecondArray, + Array, ArrayRef, BinaryArray, BooleanArray, Float32Array, Float64Array, + Int32Array, StringArray, TimestampNanosecondArray, }; use arrow::record_batch::RecordBatch; use async_trait::async_trait; @@ -1023,6 +1023,50 @@ mod tests { Ok(()) } + #[tokio::test] + async fn read_decimal_parquet() -> Result<()> { + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + + // parquet use the int32 as the physical type to store decimal + let exec = get_exec("int32_decimal.parquet", None, None).await?; + let batches = collect(exec, task_ctx.clone()).await?; + assert_eq!(1, batches.len()); + assert_eq!(1, batches[0].num_columns()); + let column = batches[0].column(0); + assert_eq!(&DataType::Decimal(4, 2), column.data_type()); + + // parquet use the int64 as the physical type to store decimal + let exec = get_exec("int64_decimal.parquet", None, None).await?; + let batches = collect(exec, task_ctx.clone()).await?; + assert_eq!(1, batches.len()); + assert_eq!(1, batches[0].num_columns()); + let column = batches[0].column(0); + assert_eq!(&DataType::Decimal(10, 2), column.data_type()); + + // parquet use the fixed length binary as the physical type to store decimal + let exec = get_exec("fixed_length_decimal.parquet", None, None).await?; + let batches = collect(exec, task_ctx.clone()).await?; + assert_eq!(1, batches.len()); + assert_eq!(1, batches[0].num_columns()); + let column = batches[0].column(0); + assert_eq!(&DataType::Decimal(25, 2), column.data_type()); + + let exec = get_exec("fixed_length_decimal_legacy.parquet", None, None).await?; + let batches = collect(exec, task_ctx.clone()).await?; + assert_eq!(1, batches.len()); + assert_eq!(1, batches[0].num_columns()); + let column = batches[0].column(0); + assert_eq!(&DataType::Decimal(13, 2), column.data_type()); + + // parquet use the fixed length binary as the physical type to store decimal + // TODO: arrow-rs don't support convert the physical type of binary to decimal + // https://github.com/apache/arrow-rs/pull/2160 + // let exec = get_exec("byte_array_decimal.parquet", None, None).await?; + + Ok(()) + } + fn assert_bytes_scanned(exec: Arc, expected: usize) { let actual = exec .metrics() diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 4b0d04b54ee2..2265675b300b 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -800,10 +800,12 @@ mod tests { use crate::from_slice::FromSlice; use crate::logical_plan::{col, lit}; use crate::{assert_batches_eq, physical_optimizer::pruning::StatisticsType}; + use arrow::array::DecimalArray; use arrow::{ array::{BinaryArray, Int32Array, Int64Array, StringArray}, datatypes::{DataType, TimeUnit}, }; + use datafusion_common::ScalarValue; use std::collections::HashMap; #[derive(Debug)] @@ -814,6 +816,38 @@ mod tests { } impl ContainerStats { + fn new_decimal128( + min: impl IntoIterator>, + max: impl IntoIterator>, + precision: usize, + scale: usize, + ) -> Self { + Self { + min: Arc::new( + min.into_iter() + .collect::() + .with_precision_and_scale(precision, scale) + .unwrap(), + ), + max: Arc::new( + max.into_iter() + .collect::() + .with_precision_and_scale(precision, scale) + .unwrap(), + ), + } + } + + fn new_i64( + min: impl IntoIterator>, + max: impl IntoIterator>, + ) -> Self { + Self { + min: Arc::new(min.into_iter().collect::()), + max: Arc::new(max.into_iter().collect::()), + } + } + fn new_i32( min: impl IntoIterator>, max: impl IntoIterator>, @@ -1418,6 +1452,74 @@ mod tests { Ok(()) } + #[test] + fn prune_decimal_data() { + // decimal(9,2) + let schema = Arc::new(Schema::new(vec![Field::new( + "s1", + DataType::Decimal(9, 2), + true, + )])); + // s1 > 5 + let expr = col("s1").gt(lit(ScalarValue::Decimal128(Some(500), 9, 2))); + // If the data is written by spark, the physical data type is INT32 in the parquet + // So we use the INT32 type of statistic. + let statistics = TestStatistics::new().with( + "s1", + ContainerStats::new_i32( + vec![Some(0), Some(4), None, Some(3)], // min + vec![Some(5), Some(6), Some(4), None], // max + ), + ); + let p = PruningPredicate::try_new(expr, schema).unwrap(); + let result = p.prune(&statistics).unwrap(); + let expected = vec![false, true, false, true]; + assert_eq!(result, expected); + + // decimal(18,2) + let schema = Arc::new(Schema::new(vec![Field::new( + "s1", + DataType::Decimal(18, 2), + true, + )])); + // s1 > 5 + let expr = col("s1").gt(lit(ScalarValue::Decimal128(Some(500), 18, 2))); + // If the data is written by spark, the physical data type is INT64 in the parquet + // So we use the INT32 type of statistic. + let statistics = TestStatistics::new().with( + "s1", + ContainerStats::new_i64( + vec![Some(0), Some(4), None, Some(3)], // min + vec![Some(5), Some(6), Some(4), None], // max + ), + ); + let p = PruningPredicate::try_new(expr, schema).unwrap(); + let result = p.prune(&statistics).unwrap(); + let expected = vec![false, true, false, true]; + assert_eq!(result, expected); + + // decimal(23,2) + let schema = Arc::new(Schema::new(vec![Field::new( + "s1", + DataType::Decimal(23, 2), + true, + )])); + // s1 > 5 + let expr = col("s1").gt(lit(ScalarValue::Decimal128(Some(500), 23, 2))); + let statistics = TestStatistics::new().with( + "s1", + ContainerStats::new_decimal128( + vec![Some(0), Some(400), None, Some(300)], // min + vec![Some(500), Some(600), Some(400), None], // max + 23, + 2, + ), + ); + let p = PruningPredicate::try_new(expr, schema).unwrap(); + let result = p.prune(&statistics).unwrap(); + let expected = vec![false, true, false, true]; + assert_eq!(result, expected); + } #[test] fn prune_api() { let schema = Arc::new(Schema::new(vec![