Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions datafusion/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,52 @@ impl DataFrame {
&project_plan,
)))
}

/// Rename one column by applying a new projection. This is a no-op if the column to be
/// renamed does not exist.
///
/// ```
/// # use datafusion::prelude::*;
/// # use datafusion::error::Result;
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// let ctx = SessionContext::new();
/// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?;
/// let df = df.with_column_renamed("ab_sum", "total")?;
/// # Ok(())
/// # }
/// ```
pub fn with_column_renamed(
&self,
old_name: &str,
new_name: &str,
) -> Result<Arc<DataFrame>> {
let mut projection = vec![];
let mut rename_applied = false;
for field in self.plan.schema().fields() {
let field_name = field.qualified_name();
if old_name == field_name {
projection.push(col(&field_name).alias(new_name));
rename_applied = true;
} else {
projection.push(col(&field_name));
}
}
if rename_applied {
let project_plan = LogicalPlanBuilder::from(self.plan.clone())
.project(projection)?
.build()?;
Ok(Arc::new(DataFrame::new(
self.session_state.clone(),
&project_plan,
)))
} else {
Ok(Arc::new(DataFrame::new(
self.session_state.clone(),
&self.plan,
)))
}
}
}

// TODO: This will introduce a ref cycle (#2659)
Expand Down Expand Up @@ -1130,4 +1176,94 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn with_column_renamed() -> Result<()> {
let df = test_table()
.await?
.select_columns(&["c1", "c2", "c3"])?
.filter(col("c2").eq(lit(3)).and(col("c1").eq(lit("a"))))?
.limit(None, Some(1))?
.sort(vec![
// make the test deterministic
col("c1").sort(true, true),
col("c2").sort(true, true),
col("c3").sort(true, true),
])?
.with_column("sum", col("c2") + col("c3"))?;

let df_sum_renamed = df.with_column_renamed("sum", "total")?.collect().await?;

assert_batches_sorted_eq!(
vec![
"+----+----+----+-------+",
"| c1 | c2 | c3 | total |",
"+----+----+----+-------+",
"| a | 3 | 13 | 16 |",
"+----+----+----+-------+",
],
&df_sum_renamed
);

Ok(())
}

#[tokio::test]
async fn with_column_renamed_join() -> Result<()> {
let df = test_table().await?.select_columns(&["c1", "c2", "c3"])?;
let ctx = SessionContext::new();
ctx.register_table("t1", df.clone())?;
ctx.register_table("t2", df)?;
let df = ctx
.table("t1")?
.join(ctx.table("t2")?, JoinType::Inner, &["c1"], &["c1"], None)?
.sort(vec![
// make the test deterministic
col("t1.c1").sort(true, true),
col("t1.c2").sort(true, true),
col("t1.c3").sort(true, true),
col("t2.c1").sort(true, true),
col("t2.c2").sort(true, true),
col("t2.c3").sort(true, true),
])?
.limit(None, Some(1))?;

let df_results = df.collect().await?;
assert_batches_sorted_eq!(
vec![
"+----+----+-----+----+----+-----+",
"| c1 | c2 | c3 | c1 | c2 | c3 |",
"+----+----+-----+----+----+-----+",
"| a | 1 | -85 | a | 1 | -85 |",
"+----+----+-----+----+----+-----+",
],
&df_results
);

let df_renamed = df.with_column_renamed("t1.c1", "AAA")?;
assert_eq!("\
Projection: #t1.c1 AS AAA, #t1.c2, #t1.c3, #t2.c1, #t2.c2, #t2.c3\
\n Limit: skip=None, fetch=1\
\n Sort: #t1.c1 ASC NULLS FIRST, #t1.c2 ASC NULLS FIRST, #t1.c3 ASC NULLS FIRST, #t2.c1 ASC NULLS FIRST, #t2.c2 ASC NULLS FIRST, #t2.c3 ASC NULLS FIRST\
\n Inner Join: #t1.c1 = #t2.c1\
\n TableScan: t1 projection=[c1, c2, c3]\
\n TableScan: t2 projection=[c1, c2, c3]",
format!("{:?}", df_renamed.to_logical_plan()?)
);

let df_results = df_renamed.collect().await?;

assert_batches_sorted_eq!(
vec![
"+-----+----+-----+----+----+-----+",
"| AAA | c2 | c3 | c1 | c2 | c3 |",
"+-----+----+-----+----+----+-----+",
"| a | 1 | -85 | a | 1 | -85 |",
"+-----+----+-----+----+----+-----+",
],
&df_results
);

Ok(())
}
}