From 3b9c1848fe87aa74b2dfebf178b93356c0ba1f2c Mon Sep 17 00:00:00 2001 From: Nagato Yuzuru Date: Tue, 2 Jun 2026 01:20:01 +0800 Subject: [PATCH 1/2] feat: add DataFrame fill_nan --- datafusion/core/src/dataframe/mod.rs | 73 ++++++++++++++++++++++++ datafusion/core/tests/dataframe/mod.rs | 77 ++++++++++++++++++++++++++ 2 files changed, 150 insertions(+) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 3d6b832aa6b27..2972501466f09 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -65,6 +65,7 @@ use datafusion_expr::{ utils::COUNT_STAR_EXPANSION, }; use datafusion_functions::core::coalesce; +use datafusion_functions::math::nanvl; use datafusion_functions_aggregate::expr_fn::{ avg, count, max, median, min, stddev, sum, }; @@ -2527,6 +2528,78 @@ impl DataFrame { .collect() } + /// Fill NaN values in specified columns with a given value + /// If no columns are specified (empty vector), applies to all columns + /// Only fills if the value can be cast to the column's type + /// + /// # Arguments + /// * `value` - Value to fill NaNs with + /// * `columns` - List of column names to fill. If empty, fills all columns. + /// + /// # Example + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # use datafusion_common::ScalarValue; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let ctx = SessionContext::new(); + /// let df = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; + /// // Fill NaN in only columns "a" and "c": + /// let df = df.fill_nan(ScalarValue::from(0.0), vec!["a".to_owned(), "c".to_owned()])?; + /// // Fill NaN across all columns: + /// let df = df.fill_nan(ScalarValue::from(0.0), vec![])?; + /// # Ok(()) + /// # } + /// ``` + #[expect(clippy::needless_pass_by_value)] + pub fn fill_nan( + &self, + value: ScalarValue, + columns: Vec, + ) -> Result { + let cols = if columns.is_empty() { + self.logical_plan() + .schema() + .fields() + .iter() + .map(Arc::clone) + .collect() + } else { + self.find_columns(&columns)? + }; + + let projections = self + .logical_plan() + .schema() + .fields() + .iter() + .map(|field| { + if cols.contains(field) && field.data_type().is_floating() { + // Try to cast fill value to column type. If the cast fails, fallback to the original column. + match value.clone().cast_to(field.data_type()) { + Ok(fill_value) => Expr::Alias(Alias { + expr: Box::new(Expr::ScalarFunction(ScalarFunction { + func: nanvl(), + args: vec![col(field.name()), lit(fill_value)], + })), + relation: None, + name: field.name().to_string(), + metadata: None, + }), + Err(_) => col(field.name()), + } + } else { + col(field.name()) + } + }) + .collect::>(); + + self.clone().select(projections) + } + /// Find qualified columns for this dataframe from names /// /// # Arguments diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index bc1ad4c4c6bb1..3198ae37956f7 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -6539,6 +6539,83 @@ async fn test_fill_null_all_columns() -> Result<()> { Ok(()) } +async fn create_nan_table() -> Result { + // create a DataFrame with a NaN value in a float column "a" and a + // non-float column "b" that must stay untouched by fill_nan. + // "+-----+---+", + // "| a | b |", + // "+-----+---+", + // "| 1.0 | 1 |", + // "| NaN | 2 |", + // "| 3.0 | 3 |", + // "+-----+---+", + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float64, true), + Field::new("b", DataType::Int32, true), + ])); + let a_values = Float64Array::from(vec![Some(1.0), Some(f64::NAN), Some(3.0)]); + let b_values = Int32Array::from(vec![Some(1), Some(2), Some(3)]); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(a_values), Arc::new(b_values)], + )?; + + let ctx = SessionContext::new(); + let table = MemTable::try_new(schema.clone(), vec![vec![batch]])?; + ctx.register_table("t_nan", Arc::new(table))?; + let df = ctx.table("t_nan").await?; + Ok(df) +} + +#[tokio::test] +async fn test_fill_nan() -> Result<()> { + let df = create_nan_table().await?; + + // Fill NaNs in the float column "a" with 0.0. + let df_filled = + df.fill_nan(ScalarValue::Float64(Some(0.0)), vec!["a".to_string()])?; + + let results = df_filled.collect().await?; + assert_snapshot!( + batches_to_sort_string(&results), + @r" + +-----+---+ + | a | b | + +-----+---+ + | 0.0 | 2 | + | 1.0 | 1 | + | 3.0 | 3 | + +-----+---+ + " + ); + + Ok(()) +} + +#[tokio::test] +async fn test_fill_nan_all_columns() -> Result<()> { + let df = create_nan_table().await?; + + // Fill NaNs across all columns. Only the float column "a" is affected; + // the non-float column "b" is left unchanged since NaN only exists for + // floating-point types. + let df_filled = df.fill_nan(ScalarValue::Float64(Some(0.0)), vec![])?; + + let results = df_filled.collect().await?; + assert_snapshot!( + batches_to_sort_string(&results), + @r" + +-----+---+ + | a | b | + +-----+---+ + | 0.0 | 2 | + | 1.0 | 1 | + | 3.0 | 3 | + +-----+---+ + " + ); + Ok(()) +} #[tokio::test] async fn test_insert_into_casting_support() -> Result<()> { // Testing case1: From be9e1d4eee8f4bff9d4b4f450ce3bee31c6bd6d8 Mon Sep 17 00:00:00 2001 From: Nagato Yuzuru Date: Tue, 2 Jun 2026 01:35:30 +0800 Subject: [PATCH 2/2] test: add tests for fill_nan with non-float and unknown columns --- datafusion/core/tests/dataframe/mod.rs | 67 ++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 3198ae37956f7..1d443c0b65910 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -6616,6 +6616,73 @@ async fn test_fill_nan_all_columns() -> Result<()> { ); Ok(()) } + +#[tokio::test] +async fn test_fill_nan_non_float_column() -> Result<()> { + let df = create_nan_table().await?; + + // Explicitly naming a non-float column is a no-op, not an error: NaN does + // not exist for Int32, so column "b" (and the un-targeted "a") are unchanged. + let df_filled = + df.fill_nan(ScalarValue::Float64(Some(0.0)), vec!["b".to_string()])?; + + let results = df_filled.collect().await?; + assert_snapshot!( + batches_to_sort_string(&results), + @r" + +-----+---+ + | a | b | + +-----+---+ + | 1.0 | 1 | + | 3.0 | 3 | + | NaN | 2 | + +-----+---+ + " + ); + + Ok(()) +} + +#[tokio::test] +async fn test_fill_nan_unknown_column() -> Result<()> { + let df = create_nan_table().await?; + + // A column name that is not in the schema is propagated as an error. + let err = df + .fill_nan(ScalarValue::Float64(Some(0.0)), vec!["does_not_exist".to_string()]) + .unwrap_err(); + + assert_snapshot!(err.to_string(), @"Error during planning: Column 'does_not_exist' not found"); + + Ok(()) +} + +#[tokio::test] +async fn test_fill_nan_uncastable_value() -> Result<()> { + let df = create_nan_table().await?; + + // The float column "a" is targeted, but "abc" cannot be cast to Float64, so + // the fill is skipped and column "a" keeps its original NaN value. + let df_filled = + df.fill_nan(ScalarValue::Utf8(Some("abc".to_string())), vec!["a".to_string()])?; + + let results = df_filled.collect().await?; + assert_snapshot!( + batches_to_sort_string(&results), + @r" + +-----+---+ + | a | b | + +-----+---+ + | 1.0 | 1 | + | 3.0 | 3 | + | NaN | 2 | + +-----+---+ + " + ); + + Ok(()) +} + #[tokio::test] async fn test_insert_into_casting_support() -> Result<()> { // Testing case1: