diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs index 82f8deb3c20d..1fbf19c330d6 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe.rs @@ -20,7 +20,7 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, Int64Array, StringArray}; +use arrow::array::{Array, ArrayRef, Int64Array, StringArray}; use arrow::compute::{cast, concat}; use arrow::datatypes::{DataType, Field}; use async_trait::async_trait; @@ -329,10 +329,10 @@ impl DataFrame { let supported_describe_functions = vec!["count", "null_count", "mean", "std", "min", "max", "median"]; - let fields_iter = self.schema().fields().iter(); + let original_schema_fields = self.schema().fields().iter(); //define describe column - let mut describe_schemas = fields_iter + let mut describe_schemas = original_schema_fields .clone() .map(|field| { if field.data_type().is_numeric() { @@ -344,24 +344,38 @@ impl DataFrame { .collect::>(); describe_schemas.insert(0, Field::new("describe", DataType::Utf8, false)); + //count aggregation + let cnt = self.clone().aggregate( + vec![], + original_schema_fields + .clone() + .map(|f| count(col(f.name()))) + .collect::>(), + )?; + // The optimization of AggregateStatistics will rewrite the physical plan + // for the count function and ignore alias functions, + // as shown in https://github.com/apache/arrow-datafusion/issues/5444. + // This logic should be removed when #5444 is fixed. + let cnt = cnt.clone().select( + cnt.schema() + .fields() + .iter() + .zip(original_schema_fields.clone()) + .map(|(count_field, orgin_field)| { + col(count_field.name()).alias(orgin_field.name()) + }) + .collect::>(), + )?; + //should be removed when #5444 is fixed //collect recordBatch let describe_record_batch = vec![ // count aggregation - self.clone() - .aggregate( - vec![], - fields_iter - .clone() - .map(|f| count(col(f.name())).alias(f.name())) - .collect::>(), - )? - .collect() - .await?, + cnt.collect().await?, // null_count aggregation self.clone() .aggregate( vec![], - fields_iter + original_schema_fields .clone() .map(|f| count(is_null(col(f.name()))).alias(f.name())) .collect::>(), @@ -372,7 +386,7 @@ impl DataFrame { self.clone() .aggregate( vec![], - fields_iter + original_schema_fields .clone() .filter(|f| f.data_type().is_numeric()) .map(|f| avg(col(f.name())).alias(f.name())) @@ -384,7 +398,7 @@ impl DataFrame { self.clone() .aggregate( vec![], - fields_iter + original_schema_fields .clone() .filter(|f| f.data_type().is_numeric()) .map(|f| stddev(col(f.name())).alias(f.name())) @@ -396,7 +410,7 @@ impl DataFrame { self.clone() .aggregate( vec![], - fields_iter + original_schema_fields .clone() .filter(|f| { !matches!(f.data_type(), DataType::Binary | DataType::Boolean) @@ -410,7 +424,7 @@ impl DataFrame { self.clone() .aggregate( vec![], - fields_iter + original_schema_fields .clone() .filter(|f| { !matches!(f.data_type(), DataType::Binary | DataType::Boolean) @@ -424,7 +438,7 @@ impl DataFrame { self.clone() .aggregate( vec![], - fields_iter + original_schema_fields .clone() .filter(|f| f.data_type().is_numeric()) .map(|f| median(col(f.name())).alias(f.name())) @@ -435,7 +449,7 @@ impl DataFrame { ]; let mut array_ref_vec: Vec = vec![]; - for field in fields_iter { + for field in original_schema_fields { let mut array_datas = vec![]; for record_batch in describe_record_batch.iter() { let column = record_batch.get(0).unwrap().column_by_name(field.name()); diff --git a/datafusion/core/tests/dataframe.rs b/datafusion/core/tests/dataframe.rs index ede74b2272ce..453b8f5cb76b 100644 --- a/datafusion/core/tests/dataframe.rs +++ b/datafusion/core/tests/dataframe.rs @@ -40,26 +40,26 @@ async fn describe() -> Result<()> { let ctx = SessionContext::new(); let testdata = datafusion::test_util::parquet_test_data(); - let filename = &format!("{testdata}/alltypes_plain.parquet"); - let df = ctx - .read_parquet(filename, ParquetReadOptions::default()) + .read_parquet( + &format!("{testdata}/alltypes_tiny_pages.parquet"), + ParquetReadOptions::default(), + ) .await?; - let describe_record_batch = df.describe().await.unwrap().collect().await.unwrap(); #[rustfmt::skip] let expected = vec![ - "+------------+--------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+-------------------+-----------------+------------+---------------------+", - "| describe | id | bool_col | tinyint_col | smallint_col | int_col | bigint_col | float_col | double_col | date_string_col | string_col | timestamp_col |", - "+------------+--------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+-------------------+-----------------+------------+---------------------+", - "| count | 8.0 | 8 | 8.0 | 8.0 | 8.0 | 8.0 | 8.0 | 8.0 | 8 | 8 | 8 |", - "| null_count | 8.0 | 8 | 8.0 | 8.0 | 8.0 | 8.0 | 8.0 | 8.0 | 8 | 8 | 8 |", - "| mean | 3.5 | null | 0.5 | 0.5 | 0.5 | 5.0 | 0.550000011920929 | 5.05 | null | null | null |", - "| std | 2.4494897427831783 | null | 0.5345224838248488 | 0.5345224838248488 | 0.5345224838248488 | 5.3452248382484875 | 0.5879747449513427 | 5.398677086630973 | null | null | null |", - "| min | 0.0 | null | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | null | null | 2009-01-01T00:00:00 |", - "| max | 7.0 | null | 1.0 | 1.0 | 1.0 | 10.0 | 1.100000023841858 | 10.1 | null | null | 2009-04-01T00:01:00 |", - "| median | 3.0 | null | 0.0 | 0.0 | 0.0 | 5.0 | 0.550000011920929 | 5.05 | null | null | null |", - "+------------+--------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+-------------------+-----------------+------------+---------------------+", + "+------------+-------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+------------+-------------------------+--------------------+-------------------+", + "| describe | id | bool_col | tinyint_col | smallint_col | int_col | bigint_col | float_col | double_col | date_string_col | string_col | timestamp_col | year | month |", + "+------------+-------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+------------+-------------------------+--------------------+-------------------+", + "| count | 7300.0 | 7300 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300 | 7300 | 7300 | 7300.0 | 7300.0 |", + "| null_count | 7300.0 | 7300 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300 | 7300 | 7300 | 7300.0 | 7300.0 |", + "| mean | 3649.5 | null | 4.5 | 4.5 | 4.5 | 45.0 | 4.949999964237213 | 45.45000000000001 | null | null | null | 2009.5 | 6.526027397260274 |", + "| std | 2107.472815166704 | null | 2.8724780750809518 | 2.8724780750809518 | 2.8724780750809518 | 28.724780750809533 | 3.1597258182544645 | 29.012028558317645 | null | null | null | 0.5000342500942125 | 3.44808750051728 |", + "| min | 0.0 | null | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 01/01/09 | 0 | 2008-12-31T23:00:00 | 2009.0 | 1.0 |", + "| max | 7299.0 | null | 9.0 | 9.0 | 9.0 | 90.0 | 9.899999618530273 | 90.89999999999999 | 12/31/10 | 9 | 2010-12-31T04:09:13.860 | 2010.0 | 12.0 |", + "| median | 3649.0 | null | 4.0 | 4.0 | 4.0 | 45.0 | 4.949999809265137 | 45.45 | null | null | null | 2009.0 | 7.0 |", + "+------------+-------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+------------+-------------------------+--------------------+-------------------+", ]; assert_batches_eq!(expected, &describe_record_batch);