diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 76acb5c1da34..02c2c81ad54b 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -43,7 +43,7 @@ use crate::physical_plan::{ use crate::prelude::SessionContext; use std::any::Any; use std::borrow::Cow; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; use arrow::array::{Array, ArrayRef, Int64Array, StringArray}; @@ -2023,31 +2023,38 @@ impl DataFrame { pub fn with_column(self, name: &str, expr: Expr) -> Result { let window_func_exprs = find_window_exprs([&expr]); - let (window_fn_str, plan) = if window_func_exprs.is_empty() { - (None, self.plan) + let original_names: HashSet = self + .plan + .schema() + .iter() + .map(|(_, f)| f.name().clone()) + .collect(); + + // Maybe build window plan + let plan = if window_func_exprs.is_empty() { + self.plan } else { - ( - Some(window_func_exprs[0].to_string()), - LogicalPlanBuilder::window_plan(self.plan, window_func_exprs)?, - ) + LogicalPlanBuilder::window_plan(self.plan, window_func_exprs)? }; - let mut col_exists = false; let new_column = expr.alias(name); + let mut col_exists = false; + let mut fields: Vec<(Expr, bool)> = plan .schema() .iter() .filter_map(|(qualifier, field)| { + // Skip new fields introduced by window_plan + if !original_names.contains(field.name()) { + return None; + } + if field.name() == name { col_exists = true; Some((new_column.clone(), true)) } else { let e = col(Column::from((qualifier, field))); - window_fn_str - .as_ref() - .filter(|s| *s == &e.to_string()) - .is_none() - .then_some((e, self.projection_requires_validation)) + Some((e, self.projection_requires_validation)) } }) .collect(); diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index c37ad52bbbbe..2a15b2967dca 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -39,7 +39,7 @@ use datafusion_functions_aggregate::expr_fn::{ sum_distinct, }; use datafusion_functions_nested::make_array::make_array_udf; -use datafusion_functions_window::expr_fn::{first_value, row_number}; +use datafusion_functions_window::expr_fn::{first_value, lead, row_number}; use insta::assert_snapshot; use object_store::local::LocalFileSystem; use std::collections::HashMap; @@ -92,6 +92,9 @@ use datafusion_physical_plan::aggregates::{ use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::{displayable, ExecutionPlan, ExecutionPlanProperties}; +use datafusion::error::Result as DataFusionResult; +use datafusion_functions_window::expr_fn::lag; + // Get string representation of the plan async fn physical_plan_to_string(df: &DataFrame) -> String { let physical_plan = df @@ -158,6 +161,46 @@ async fn test_array_agg_ord_schema() -> Result<()> { Ok(()) } +type WindowFnCase = (fn() -> Expr, &'static str); + +#[tokio::test] +async fn with_column_window_functions() -> DataFusionResult<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))], + )?; + + let ctx = SessionContext::new(); + + let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]])?; + ctx.register_table("t", Arc::new(provider))?; + + // Define test cases: (expr builder, alias name) + let test_cases: Vec = vec![ + (|| lag(col("a"), Some(1), None), "lag_val"), + (|| lead(col("a"), Some(1), None), "lead_val"), + (row_number, "row_num"), + ]; + + for (make_expr, alias) in test_cases { + let df = ctx.table("t").await?; + let expr = make_expr(); + let df_with = df.with_column(alias, expr)?; + let df_schema = df_with.schema().clone(); + + assert!( + df_schema.has_column_with_unqualified_name(alias), + "Schema does not contain expected column {alias}", + ); + + assert_eq!(2, df_schema.columns().len()); + } + + Ok(()) +} + #[tokio::test] async fn test_coalesce_schema() -> Result<()> { let ctx = SessionContext::new();