Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ use datafusion_expr::{
utils::COUNT_STAR_EXPANSION,
};
use datafusion_functions::core::coalesce;
use datafusion_functions::math::nanvl;
use datafusion_functions_aggregate::expr_fn::{
avg, count, max, median, min, stddev, sum,
};
Expand Down Expand Up @@ -2527,6 +2528,78 @@ impl DataFrame {
.collect()
}

/// Fill NaN values in specified columns with a given value
/// If no columns are specified (empty vector), applies to all columns
/// Only fills if the value can be cast to the column's type
///
/// # Arguments
/// * `value` - Value to fill NaNs with
/// * `columns` - List of column names to fill. If empty, fills all columns.
///
/// # Example
/// ```
/// # use datafusion::prelude::*;
/// # use datafusion::error::Result;
/// # use datafusion_common::ScalarValue;
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// let ctx = SessionContext::new();
/// let df = ctx
/// .read_csv("tests/data/example.csv", CsvReadOptions::new())
/// .await?;
/// // Fill NaN in only columns "a" and "c":
/// let df = df.fill_nan(ScalarValue::from(0.0), vec!["a".to_owned(), "c".to_owned()])?;
/// // Fill NaN across all columns:
/// let df = df.fill_nan(ScalarValue::from(0.0), vec![])?;
/// # Ok(())
/// # }
/// ```
#[expect(clippy::needless_pass_by_value)]
pub fn fill_nan(
&self,
value: ScalarValue,
columns: Vec<String>,
) -> Result<DataFrame> {
let cols = if columns.is_empty() {
self.logical_plan()
.schema()
.fields()
.iter()
.map(Arc::clone)
.collect()
} else {
self.find_columns(&columns)?
};

let projections = self
.logical_plan()
.schema()
.fields()
.iter()
.map(|field| {
if cols.contains(field) && field.data_type().is_floating() {
// Try to cast fill value to column type. If the cast fails, fallback to the original column.
match value.clone().cast_to(field.data_type()) {
Ok(fill_value) => Expr::Alias(Alias {
expr: Box::new(Expr::ScalarFunction(ScalarFunction {
func: nanvl(),
args: vec![col(field.name()), lit(fill_value)],
})),
relation: None,
name: field.name().to_string(),
metadata: None,
}),
Err(_) => col(field.name()),
}
Comment on lines +2582 to +2593
} else {
col(field.name())
}
})
.collect::<Vec<_>>();

self.clone().select(projections)
}

/// Find qualified columns for this dataframe from names
///
/// # Arguments
Expand Down
144 changes: 144 additions & 0 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6539,6 +6539,150 @@ async fn test_fill_null_all_columns() -> Result<()> {
Ok(())
}

async fn create_nan_table() -> Result<DataFrame> {
Comment thread
Nagato-Yuzuru marked this conversation as resolved.
// create a DataFrame with a NaN value in a float column "a" and a
// non-float column "b" that must stay untouched by fill_nan.
// "+-----+---+",
// "| a | b |",
// "+-----+---+",
// "| 1.0 | 1 |",
// "| NaN | 2 |",
// "| 3.0 | 3 |",
// "+-----+---+",
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float64, true),
Field::new("b", DataType::Int32, true),
]));
let a_values = Float64Array::from(vec![Some(1.0), Some(f64::NAN), Some(3.0)]);
let b_values = Int32Array::from(vec![Some(1), Some(2), Some(3)]);
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(a_values), Arc::new(b_values)],
)?;

let ctx = SessionContext::new();
let table = MemTable::try_new(schema.clone(), vec![vec![batch]])?;
ctx.register_table("t_nan", Arc::new(table))?;
let df = ctx.table("t_nan").await?;
Ok(df)
}

#[tokio::test]
async fn test_fill_nan() -> Result<()> {
let df = create_nan_table().await?;

// Fill NaNs in the float column "a" with 0.0.
let df_filled =
df.fill_nan(ScalarValue::Float64(Some(0.0)), vec!["a".to_string()])?;

let results = df_filled.collect().await?;
assert_snapshot!(
batches_to_sort_string(&results),
@r"
+-----+---+
| a | b |
+-----+---+
| 0.0 | 2 |
| 1.0 | 1 |
| 3.0 | 3 |
+-----+---+
"
);

Ok(())
}

#[tokio::test]
async fn test_fill_nan_all_columns() -> Result<()> {
let df = create_nan_table().await?;

// Fill NaNs across all columns. Only the float column "a" is affected;
// the non-float column "b" is left unchanged since NaN only exists for
// floating-point types.
let df_filled = df.fill_nan(ScalarValue::Float64(Some(0.0)), vec![])?;

let results = df_filled.collect().await?;
assert_snapshot!(
batches_to_sort_string(&results),
@r"
+-----+---+
| a | b |
+-----+---+
| 0.0 | 2 |
| 1.0 | 1 |
| 3.0 | 3 |
+-----+---+
"
);
Ok(())
}

#[tokio::test]
async fn test_fill_nan_non_float_column() -> Result<()> {
let df = create_nan_table().await?;

// Explicitly naming a non-float column is a no-op, not an error: NaN does
// not exist for Int32, so column "b" (and the un-targeted "a") are unchanged.
let df_filled =
df.fill_nan(ScalarValue::Float64(Some(0.0)), vec!["b".to_string()])?;

let results = df_filled.collect().await?;
assert_snapshot!(
batches_to_sort_string(&results),
@r"
+-----+---+
| a | b |
+-----+---+
| 1.0 | 1 |
| 3.0 | 3 |
| NaN | 2 |
+-----+---+
"
);

Ok(())
}

#[tokio::test]
async fn test_fill_nan_unknown_column() -> Result<()> {
let df = create_nan_table().await?;

// A column name that is not in the schema is propagated as an error.
let err = df
.fill_nan(ScalarValue::Float64(Some(0.0)), vec!["does_not_exist".to_string()])
.unwrap_err();

assert_snapshot!(err.to_string(), @"Error during planning: Column 'does_not_exist' not found");

Ok(())
}

#[tokio::test]
async fn test_fill_nan_uncastable_value() -> Result<()> {
let df = create_nan_table().await?;

// The float column "a" is targeted, but "abc" cannot be cast to Float64, so
// the fill is skipped and column "a" keeps its original NaN value.
let df_filled =
df.fill_nan(ScalarValue::Utf8(Some("abc".to_string())), vec!["a".to_string()])?;

let results = df_filled.collect().await?;
assert_snapshot!(
batches_to_sort_string(&results),
@r"
+-----+---+
| a | b |
+-----+---+
| 1.0 | 1 |
| 3.0 | 3 |
| NaN | 2 |
+-----+---+
"
);

Ok(())
}

#[tokio::test]
Comment thread
Nagato-Yuzuru marked this conversation as resolved.
async fn test_insert_into_casting_support() -> Result<()> {
// Testing case1:
Expand Down