diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs index d7ce056a339e..380f722a42ba 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe.rs @@ -24,7 +24,7 @@ use arrow::array::{Array, ArrayRef, Int64Array, StringArray}; use arrow::compute::{cast, concat}; use arrow::datatypes::{DataType, Field}; use async_trait::async_trait; -use datafusion_common::DataFusionError; +use datafusion_common::{DataFusionError, SchemaError}; use parquet::file::properties::WriterProperties; use datafusion_common::from_slice::FromSlice; @@ -1007,28 +1007,36 @@ impl DataFrame { /// ``` pub fn with_column_renamed( self, - old_name: &str, + old_name: impl Into, new_name: &str, ) -> Result { - let mut projection = vec![]; - let mut rename_applied = false; - for field in self.plan.schema().fields() { - let field_name = field.qualified_name(); - if old_name == field_name { - projection.push(col(&field_name).alias(new_name)); - rename_applied = true; - } else { - projection.push(col(&field_name)); + let old_name: Column = old_name.into(); + + let field_to_rename = match self.plan.schema().field_from_column(&old_name) { + Ok(field) => field, + // no-op if field not found + Err(DataFusionError::SchemaError(SchemaError::FieldNotFound { .. })) => { + return Ok(self) } - } - if rename_applied { - let project_plan = LogicalPlanBuilder::from(self.plan) - .project(projection)? - .build()?; - Ok(DataFrame::new(self.session_state, project_plan)) - } else { - Ok(DataFrame::new(self.session_state, self.plan)) - } + Err(err) => return Err(err), + }; + let projection = self + .plan + .schema() + .fields() + .iter() + .map(|f| { + if f == field_to_rename { + col(f.qualified_column()).alias(new_name) + } else { + col(f.qualified_column()) + } + }) + .collect::>(); + let project_plan = LogicalPlanBuilder::from(self.plan) + .project(projection)? + .build()?; + Ok(DataFrame::new(self.session_state, project_plan)) } /// Convert a prepare logical plan into its inner logical plan with all params replaced with their corresponding values @@ -1681,15 +1689,24 @@ mod tests { ])? .with_column("sum", col("c2") + col("c3"))?; - let df_sum_renamed = df.with_column_renamed("sum", "total")?.collect().await?; + let df_sum_renamed = df + .with_column_renamed("sum", "total")? + // table qualifier optional + .with_column_renamed("c1", "one")? + // accepts table qualifier + .with_column_renamed("aggregate_test_100.c2", "two")? + // no-op for missing column + .with_column_renamed("c4", "boom")? + .collect() + .await?; assert_batches_sorted_eq!( vec![ - "+----+----+----+-------+", - "| c1 | c2 | c3 | total |", - "+----+----+----+-------+", - "| a | 3 | 13 | 16 |", - "+----+----+----+-------+", + "+-----+-----+----+-------+", + "| one | two | c3 | total |", + "+-----+-----+----+-------+", + "| a | 3 | 13 | 16 |", + "+-----+-----+----+-------+", ], &df_sum_renamed ); @@ -1697,6 +1714,34 @@ mod tests { Ok(()) } + #[tokio::test] + async fn with_column_renamed_ambiguous() -> Result<()> { + let df = test_table().await?.select_columns(&["c1", "c2", "c3"])?; + let ctx = SessionContext::new(); + + let table = df.into_view(); + ctx.register_table("t1", table.clone())?; + ctx.register_table("t2", table)?; + + let actual_err = ctx + .table("t1") + .await? + .join( + ctx.table("t2").await?, + JoinType::Inner, + &["c1"], + &["c1"], + None, + )? + // can be t1.c2 or t2.c2 + .with_column_renamed("c2", "AAA") + .unwrap_err(); + let expected_err = "Schema error: Ambiguous reference to unqualified field c2"; + assert_eq!(actual_err.to_string(), expected_err); + + Ok(()) + } + #[tokio::test] async fn with_column_renamed_join() -> Result<()> { let df = test_table().await?.select_columns(&["c1", "c2", "c3"])?;