diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs index 32c9299e9677..2c99f5be14b6 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe.rs @@ -669,6 +669,52 @@ impl DataFrame { &project_plan, ))) } + + /// Rename one column by applying a new projection. This is a no-op if the column to be + /// renamed does not exist. + /// + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let ctx = SessionContext::new(); + /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; + /// let df = df.with_column_renamed("ab_sum", "total")?; + /// # Ok(()) + /// # } + /// ``` + pub fn with_column_renamed( + &self, + old_name: &str, + new_name: &str, + ) -> Result> { + let mut projection = vec![]; + let mut rename_applied = false; + for field in self.plan.schema().fields() { + let field_name = field.qualified_name(); + if old_name == field_name { + projection.push(col(&field_name).alias(new_name)); + rename_applied = true; + } else { + projection.push(col(&field_name)); + } + } + if rename_applied { + let project_plan = LogicalPlanBuilder::from(self.plan.clone()) + .project(projection)? + .build()?; + Ok(Arc::new(DataFrame::new( + self.session_state.clone(), + &project_plan, + ))) + } else { + Ok(Arc::new(DataFrame::new( + self.session_state.clone(), + &self.plan, + ))) + } + } } // TODO: This will introduce a ref cycle (#2659) @@ -1130,4 +1176,94 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn with_column_renamed() -> Result<()> { + let df = test_table() + .await? + .select_columns(&["c1", "c2", "c3"])? + .filter(col("c2").eq(lit(3)).and(col("c1").eq(lit("a"))))? + .limit(None, Some(1))? + .sort(vec![ + // make the test deterministic + col("c1").sort(true, true), + col("c2").sort(true, true), + col("c3").sort(true, true), + ])? + .with_column("sum", col("c2") + col("c3"))?; + + let df_sum_renamed = df.with_column_renamed("sum", "total")?.collect().await?; + + assert_batches_sorted_eq!( + vec![ + "+----+----+----+-------+", + "| c1 | c2 | c3 | total |", + "+----+----+----+-------+", + "| a | 3 | 13 | 16 |", + "+----+----+----+-------+", + ], + &df_sum_renamed + ); + + Ok(()) + } + + #[tokio::test] + async fn with_column_renamed_join() -> Result<()> { + let df = test_table().await?.select_columns(&["c1", "c2", "c3"])?; + let ctx = SessionContext::new(); + ctx.register_table("t1", df.clone())?; + ctx.register_table("t2", df)?; + let df = ctx + .table("t1")? + .join(ctx.table("t2")?, JoinType::Inner, &["c1"], &["c1"], None)? + .sort(vec![ + // make the test deterministic + col("t1.c1").sort(true, true), + col("t1.c2").sort(true, true), + col("t1.c3").sort(true, true), + col("t2.c1").sort(true, true), + col("t2.c2").sort(true, true), + col("t2.c3").sort(true, true), + ])? + .limit(None, Some(1))?; + + let df_results = df.collect().await?; + assert_batches_sorted_eq!( + vec![ + "+----+----+-----+----+----+-----+", + "| c1 | c2 | c3 | c1 | c2 | c3 |", + "+----+----+-----+----+----+-----+", + "| a | 1 | -85 | a | 1 | -85 |", + "+----+----+-----+----+----+-----+", + ], + &df_results + ); + + let df_renamed = df.with_column_renamed("t1.c1", "AAA")?; + assert_eq!("\ + Projection: #t1.c1 AS AAA, #t1.c2, #t1.c3, #t2.c1, #t2.c2, #t2.c3\ + \n Limit: skip=None, fetch=1\ + \n Sort: #t1.c1 ASC NULLS FIRST, #t1.c2 ASC NULLS FIRST, #t1.c3 ASC NULLS FIRST, #t2.c1 ASC NULLS FIRST, #t2.c2 ASC NULLS FIRST, #t2.c3 ASC NULLS FIRST\ + \n Inner Join: #t1.c1 = #t2.c1\ + \n TableScan: t1 projection=[c1, c2, c3]\ + \n TableScan: t2 projection=[c1, c2, c3]", + format!("{:?}", df_renamed.to_logical_plan()?) + ); + + let df_results = df_renamed.collect().await?; + + assert_batches_sorted_eq!( + vec![ + "+-----+----+-----+----+----+-----+", + "| AAA | c2 | c3 | c1 | c2 | c3 |", + "+-----+----+-----+----+----+-----+", + "| a | 1 | -85 | a | 1 | -85 |", + "+-----+----+-----+----+----+-----+", + ], + &df_results + ); + + Ok(()) + } }