From 40bc21aafb91a11e2ac3b75c83bb6480e47ab26d Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 15 Jul 2022 11:10:28 -0600 Subject: [PATCH 1/4] Add DataFrame::with_column_renamed --- datafusion/core/src/dataframe.rs | 113 +++++++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs index 32c9299e9677..334841a3700d 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe.rs @@ -669,6 +669,53 @@ impl DataFrame { &project_plan, ))) } + + /// Rename one column by applying a new projection. This is a no-op if the column to be + /// renamed does not exist. + /// + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let ctx = SessionContext::new(); + /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; + /// let df = df.with_column_renamed("ab_sum", "total")?; + /// # Ok(()) + /// # } + /// ``` + pub fn with_column_renamed( + &self, + old_name: &str, + new_name: &str, + ) -> Result> { + let mut projection = vec![]; + let mut rename_applied = false; + for field in self.plan.schema().fields() { + let field_name = field.qualified_name(); + println!("{:?}", field_name); + if old_name == field_name { + projection.push(col(&field_name).alias(new_name)); + rename_applied = true; + } else { + projection.push(col(&field_name)); + } + } + if rename_applied { + let project_plan = LogicalPlanBuilder::from(self.plan.clone()) + .project(projection)? + .build()?; + Ok(Arc::new(DataFrame::new( + self.session_state.clone(), + &project_plan, + ))) + } else { + Ok(Arc::new(DataFrame::new( + self.session_state.clone(), + &self.plan, + ))) + } + } } // TODO: This will introduce a ref cycle (#2659) @@ -1130,4 +1177,70 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn with_column_renamed() -> Result<()> { + let df = test_table() + .await? + .select_columns(&["c1", "c2", "c3"])? + .filter(col("c2").eq(lit(3)).and(col("c1").eq(lit("a"))))? + .limit(None, Some(1))? + .sort(vec![col("c1").sort(true, true)])? + .with_column("sum", col("c2") + col("c3"))?; + + let df_sum_renamed = df.with_column_renamed("sum", "total")?.collect().await?; + + assert_batches_sorted_eq!( + vec![ + "+----+----+----+-------+", + "| c1 | c2 | c3 | total |", + "+----+----+----+-------+", + "| a | 3 | 13 | 16 |", + "+----+----+----+-------+", + ], + &df_sum_renamed + ); + + Ok(()) + } + + #[tokio::test] + async fn with_column_renamed_join() -> Result<()> { + let df = test_table().await?.select_columns(&["c1", "c2", "c3"])?; + let ctx = SessionContext::new(); + ctx.register_table("t1", df.clone())?; + ctx.register_table("t2", df)?; + let df = ctx + .table("t1")? + .join(ctx.table("t2")?, JoinType::Inner, &["c1"], &["c1"], None)? + .sort(vec![col("c1").sort(true, true)])? + .limit(None, Some(1))?; + + let df_results = &df.collect().await?; + assert_batches_sorted_eq!( + vec![ + "+----+----+-----+----+----+-----+", + "| c1 | c2 | c3 | c1 | c2 | c3 |", + "+----+----+-----+----+----+-----+", + "| a | 1 | -85 | a | 1 | -85 |", + "+----+----+-----+----+----+-----+", + ], + &df_results + ); + + let df_results = &df.with_column_renamed("t1.c1", "AAA")?.collect().await?; + + assert_batches_sorted_eq!( + vec![ + "+-----+----+-----+----+----+-----+", + "| AAA | c2 | c3 | c1 | c2 | c3 |", + "+-----+----+-----+----+----+-----+", + "| a | 1 | -85 | a | 1 | -85 |", + "+-----+----+-----+----+----+-----+", + ], + &df_results + ); + + Ok(()) + } } From 9926e7d9a10e5a39757e9e294df30022fea5fdae Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 15 Jul 2022 11:15:04 -0600 Subject: [PATCH 2/4] show plan in test --- datafusion/core/src/dataframe.rs | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs index 334841a3700d..50e40c009b43 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe.rs @@ -693,7 +693,6 @@ impl DataFrame { let mut rename_applied = false; for field in self.plan.schema().fields() { let field_name = field.qualified_name(); - println!("{:?}", field_name); if old_name == field_name { projection.push(col(&field_name).alias(new_name)); rename_applied = true; @@ -1228,7 +1227,20 @@ mod tests { &df_results ); - let df_results = &df.with_column_renamed("t1.c1", "AAA")?.collect().await?; + let df_renamed = df.with_column_renamed("t1.c1", "AAA")?; + assert_eq!( + "\ + Projection: #t1.c1 AS AAA, #t1.c2, #t1.c3, #t2.c1, #t2.c2, #t2.c3\ + \n Limit: skip=None, fetch=1\ + \n Projection: #t1.c1, #t1.c2, #t1.c3, #t2.c1, #t2.c2, #t2.c3\ + \n Sort: #t1.c1 ASC NULLS FIRST\ + \n Inner Join: #t1.c1 = #t2.c1\ + \n TableScan: t1 projection=[c1, c2, c3]\ + \n TableScan: t2 projection=[c1, c2, c3]", + format!("{:?}", df_renamed.to_logical_plan()?) + ); + + let df_results = &df_renamed.collect().await?; assert_batches_sorted_eq!( vec![ From a49210840257ff10b7cd90bf966ae71c042fb535 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 15 Jul 2022 13:22:14 -0600 Subject: [PATCH 3/4] make the tests deterministic --- datafusion/core/src/dataframe.rs | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs index 50e40c009b43..1c0e679b186f 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe.rs @@ -1184,7 +1184,12 @@ mod tests { .select_columns(&["c1", "c2", "c3"])? .filter(col("c2").eq(lit(3)).and(col("c1").eq(lit("a"))))? .limit(None, Some(1))? - .sort(vec![col("c1").sort(true, true)])? + .sort(vec![ + // make the test deterministic + col("c1").sort(true, true), + col("c2").sort(true, true), + col("c3").sort(true, true), + ])? .with_column("sum", col("c2") + col("c3"))?; let df_sum_renamed = df.with_column_renamed("sum", "total")?.collect().await?; @@ -1212,7 +1217,15 @@ mod tests { let df = ctx .table("t1")? .join(ctx.table("t2")?, JoinType::Inner, &["c1"], &["c1"], None)? - .sort(vec![col("c1").sort(true, true)])? + .sort(vec![ + // make the test deterministic + col("t1.c1").sort(true, true), + col("t1.c2").sort(true, true), + col("t1.c3").sort(true, true), + col("t2.c1").sort(true, true), + col("t2.c2").sort(true, true), + col("t2.c3").sort(true, true), + ])? .limit(None, Some(1))?; let df_results = &df.collect().await?; @@ -1228,15 +1241,13 @@ mod tests { ); let df_renamed = df.with_column_renamed("t1.c1", "AAA")?; - assert_eq!( - "\ + assert_eq!("\ Projection: #t1.c1 AS AAA, #t1.c2, #t1.c3, #t2.c1, #t2.c2, #t2.c3\ \n Limit: skip=None, fetch=1\ - \n Projection: #t1.c1, #t1.c2, #t1.c3, #t2.c1, #t2.c2, #t2.c3\ - \n Sort: #t1.c1 ASC NULLS FIRST\ - \n Inner Join: #t1.c1 = #t2.c1\ - \n TableScan: t1 projection=[c1, c2, c3]\ - \n TableScan: t2 projection=[c1, c2, c3]", + \n Sort: #t1.c1 ASC NULLS FIRST, #t1.c2 ASC NULLS FIRST, #t1.c3 ASC NULLS FIRST, #t2.c1 ASC NULLS FIRST, #t2.c2 ASC NULLS FIRST, #t2.c3 ASC NULLS FIRST\ + \n Inner Join: #t1.c1 = #t2.c1\ + \n TableScan: t1 projection=[c1, c2, c3]\ + \n TableScan: t2 projection=[c1, c2, c3]", format!("{:?}", df_renamed.to_logical_plan()?) ); From 5577ab8f725b497442e26bed73d078e1aab0b1a9 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 15 Jul 2022 14:10:13 -0600 Subject: [PATCH 4/4] clippy --- datafusion/core/src/dataframe.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs index 1c0e679b186f..2c99f5be14b6 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe.rs @@ -1228,7 +1228,7 @@ mod tests { ])? .limit(None, Some(1))?; - let df_results = &df.collect().await?; + let df_results = df.collect().await?; assert_batches_sorted_eq!( vec![ "+----+----+-----+----+----+-----+", @@ -1251,7 +1251,7 @@ mod tests { format!("{:?}", df_renamed.to_logical_plan()?) ); - let df_results = &df_renamed.collect().await?; + let df_results = df_renamed.collect().await?; assert_batches_sorted_eq!( vec![