Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 5 additions & 6 deletions datafusion-examples/examples/custom_file_casts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef};

use datafusion::assert_batches_eq;
use datafusion::common::not_impl_err;
use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion::common::tree_node::{Transformed, TransformedResult};
use datafusion::common::{Result, ScalarValue};
use datafusion::datasource::listing::{
ListingTable, ListingTableConfig, ListingTableConfigExt, ListingTableUrl,
Expand All @@ -31,7 +31,7 @@ use datafusion::execution::context::SessionContext;
use datafusion::execution::object_store::ObjectStoreUrl;
use datafusion::parquet::arrow::ArrowWriter;
use datafusion::physical_expr::expressions::CastExpr;
use datafusion::physical_expr::PhysicalExpr;
use datafusion::physical_expr::{PhysicalExpr, PhysicalExprExt};
use datafusion::prelude::SessionConfig;
use datafusion_physical_expr_adapter::{
DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory,
Expand Down Expand Up @@ -181,11 +181,10 @@ impl PhysicalExprAdapter for CustomCastsPhysicalExprAdapter {
expr = self.inner.rewrite(expr)?;
// Now we can apply custom casting rules or even swap out all CastExprs for a custom cast kernel / expression
// For example, [DataFusion Comet](https://github.com/apache/datafusion-comet) has a [custom cast kernel](https://github.com/apache/datafusion-comet/blob/b4ac876ab420ed403ac7fc8e1b29f42f1f442566/native/spark-expr/src/conversion_funcs/cast.rs#L133-L138).
expr.transform(|expr| {
expr.transform_with_schema(&self.physical_file_schema, |expr, schema| {
if let Some(cast) = expr.as_any().downcast_ref::<CastExpr>() {
let input_data_type =
cast.expr().data_type(&self.physical_file_schema)?;
let output_data_type = cast.data_type(&self.physical_file_schema)?;
let input_data_type = cast.expr().data_type(schema)?;
let output_data_type = cast.data_type(schema)?;
if !cast.is_bigger_cast(&input_data_type) {
return not_impl_err!(
"Unsupported CAST from {input_data_type} to {output_data_type}"
Expand Down
14 changes: 9 additions & 5 deletions datafusion-examples/examples/default_column_values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ use async_trait::async_trait;
use datafusion::assert_batches_eq;
use datafusion::catalog::memory::DataSourceExec;
use datafusion::catalog::{Session, TableProvider};
use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion::common::DFSchema;
use datafusion::common::tree_node::{Transformed, TransformedResult};
use datafusion::common::{DFSchema, HashSet};
use datafusion::common::{Result, ScalarValue};
use datafusion::datasource::listing::PartitionedFile;
use datafusion::datasource::physical_plan::{FileScanConfigBuilder, ParquetSource};
Expand All @@ -38,7 +38,7 @@ use datafusion::logical_expr::{Expr, TableProviderFilterPushDown, TableType};
use datafusion::parquet::arrow::ArrowWriter;
use datafusion::parquet::file::properties::WriterProperties;
use datafusion::physical_expr::expressions::{CastExpr, Column, Literal};
use datafusion::physical_expr::PhysicalExpr;
use datafusion::physical_expr::{PhysicalExpr, PhysicalExprExt};
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::{lit, SessionConfig};
use datafusion_physical_expr_adapter::{
Expand Down Expand Up @@ -308,11 +308,12 @@ impl PhysicalExprAdapter for DefaultValuePhysicalExprAdapter {
fn rewrite(&self, expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
// First try our custom default value injection for missing columns
let rewritten = expr
.transform(|expr| {
.transform_with_lambdas_params(|expr, lambdas_params| {
self.inject_default_values(
expr,
&self.logical_file_schema,
&self.physical_file_schema,
lambdas_params,
)
})
.data()?;
Expand Down Expand Up @@ -348,12 +349,15 @@ impl DefaultValuePhysicalExprAdapter {
expr: Arc<dyn PhysicalExpr>,
logical_file_schema: &Schema,
physical_file_schema: &Schema,
lambdas_params: &HashSet<String>,
) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
if let Some(column) = expr.as_any().downcast_ref::<Column>() {
let column_name = column.name();

// Check if this column exists in the physical schema
if physical_file_schema.index_of(column_name).is_err() {
if !lambdas_params.contains(column_name)
&& physical_file_schema.index_of(column_name).is_err()
{
// Column is missing from physical schema, check if logical schema has a default
if let Ok(logical_field) =
logical_file_schema.field_with_name(column_name)
Expand Down
8 changes: 4 additions & 4 deletions datafusion-examples/examples/expr_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use arrow::record_batch::RecordBatch;

use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use datafusion::common::stats::Precision;
use datafusion::common::tree_node::{Transformed, TreeNode};
use datafusion::common::tree_node::Transformed;
use datafusion::common::{ColumnStatistics, DFSchema};
use datafusion::common::{ScalarValue, ToDFSchema};
use datafusion::error::Result;
Expand Down Expand Up @@ -556,7 +556,7 @@ fn type_coercion_demo() -> Result<()> {
// 3. Type coercion with `TypeCoercionRewriter`.
let coerced_expr = expr
.clone()
.rewrite(&mut TypeCoercionRewriter::new(&df_schema))?
.rewrite_with_schema(&df_schema, &mut TypeCoercionRewriter::new(&df_schema))?
.data;
let physical_expr = datafusion::physical_expr::create_physical_expr(
&coerced_expr,
Expand All @@ -567,15 +567,15 @@ fn type_coercion_demo() -> Result<()> {

// 4. Apply explicit type coercion by manually rewriting the expression
let coerced_expr = expr
.transform(|e| {
.transform_with_schema(&df_schema, |e, df_schema| {
// Only type coerces binary expressions.
let Expr::BinaryExpr(e) = e else {
return Ok(Transformed::no(e));
};
if let Expr::Column(ref col_expr) = *e.left {
let field = df_schema.field_with_name(None, col_expr.name())?;
let cast_to_type = field.data_type();
let coerced_right = e.right.cast_to(cast_to_type, &df_schema)?;
let coerced_right = e.right.cast_to(cast_to_type, df_schema)?;
Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new(
e.left,
e.op,
Expand Down
14 changes: 8 additions & 6 deletions datafusion-examples/examples/json_shredding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@ use arrow::array::{RecordBatch, StringArray};
use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef};

use datafusion::assert_batches_eq;
use datafusion::common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
};
use datafusion::common::{assert_contains, exec_datafusion_err, Result};
use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNodeRecursion};
use datafusion::common::{assert_contains, exec_datafusion_err, HashSet, Result};
use datafusion::datasource::listing::{
ListingTable, ListingTableConfig, ListingTableConfigExt, ListingTableUrl,
};
Expand All @@ -36,8 +34,8 @@ use datafusion::logical_expr::{
};
use datafusion::parquet::arrow::ArrowWriter;
use datafusion::parquet::file::properties::WriterProperties;
use datafusion::physical_expr::PhysicalExpr;
use datafusion::physical_expr::{expressions, ScalarFunctionExpr};
use datafusion::physical_expr::{PhysicalExpr, PhysicalExprExt};
use datafusion::prelude::SessionConfig;
use datafusion::scalar::ScalarValue;
use datafusion_physical_expr_adapter::{
Expand Down Expand Up @@ -302,7 +300,9 @@ impl PhysicalExprAdapter for ShreddedJsonRewriter {
fn rewrite(&self, expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
// First try our custom JSON shredding rewrite
let rewritten = expr
.transform(|expr| self.rewrite_impl(expr, &self.physical_file_schema))
.transform_with_lambdas_params(|expr, lambdas_params| {
self.rewrite_impl(expr, &self.physical_file_schema, lambdas_params)
})
.data()?;

// Then apply the default adapter as a fallback to handle standard schema differences
Expand Down Expand Up @@ -335,6 +335,7 @@ impl ShreddedJsonRewriter {
&self,
expr: Arc<dyn PhysicalExpr>,
physical_file_schema: &Schema,
lambdas_params: &HashSet<String>,
) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
if let Some(func) = expr.as_any().downcast_ref::<ScalarFunctionExpr>() {
if func.name() == "json_get_str" && func.args().len() == 2 {
Expand All @@ -348,6 +349,7 @@ impl ShreddedJsonRewriter {
if let Some(column) = func.args()[1]
.as_any()
.downcast_ref::<expressions::Column>()
.filter(|col| !lambdas_params.contains(col.name()))
{
let column_name = column.name();
// Check if there's a flat column with underscore prefix
Expand Down
9 changes: 5 additions & 4 deletions datafusion/catalog-listing/src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ use object_store::{ObjectMeta, ObjectStore};
/// was performed
pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool {
let mut is_applicable = true;
expr.apply(|expr| match expr {
Expr::Column(Column { ref name, .. }) => {
is_applicable &= col_names.contains(&name.as_str());
expr.apply_with_lambdas_params(|expr, lambdas_params| match expr {
Expr::Column(col) => {
is_applicable &= col_names.contains(&col.name()) || col.is_lambda_parameter(lambdas_params);
if is_applicable {
Ok(TreeNodeRecursion::Jump)
} else {
Expand Down Expand Up @@ -86,7 +86,8 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool {
| Expr::InSubquery(_)
| Expr::ScalarSubquery(_)
| Expr::GroupingSet(_)
| Expr::Case(_) => Ok(TreeNodeRecursion::Continue),
| Expr::Case(_)
| Expr::Lambda(_) => Ok(TreeNodeRecursion::Continue),

Expr::ScalarFunction(scalar_function) => {
match scalar_function.func.signature().volatility {
Expand Down
6 changes: 6 additions & 0 deletions datafusion/common/src/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use crate::utils::parse_identifiers_normalized;
use crate::utils::quote_identifier;
use crate::{DFSchema, Diagnostic, Result, SchemaError, Spans, TableReference};
use arrow::datatypes::{Field, FieldRef};
use std::borrow::Borrow;
use std::collections::HashSet;
use std::fmt;

Expand Down Expand Up @@ -325,6 +326,11 @@ impl Column {
..self.clone()
}
}

pub fn is_lambda_parameter(&self, lambdas_params: &crate::HashSet<impl Borrow<str> + Eq + std::hash::Hash>) -> bool {
// currently, references to lambda parameters are always unqualified
self.relation.is_none() && lambdas_params.contains(self.name())
}
}

impl From<&str> for Column {
Expand Down
22 changes: 16 additions & 6 deletions datafusion/common/src/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,14 @@ pub trait CSEController {
/// if all are always evaluated.
fn conditional_children(node: &Self::Node) -> Option<ChildrenList<&Self::Node>>;

// A helper method called on each node before is_ignored, during top-down traversal during the first,
// visiting traversal of CSE.
fn visit_f_down(&mut self, _node: &Self::Node) {}

// A helper method called on each node after is_ignored, during bottom-up traversal during the first,
// visiting traversal of CSE.
fn visit_f_up(&mut self, _node: &Self::Node) {}

// Returns true if a node is valid. If a node is invalid then it can't be eliminated.
// Validity is propagated up which means no subtree can be eliminated that contains
// an invalid node.
Expand Down Expand Up @@ -274,7 +282,7 @@ where
/// thus can not be extracted as a common [`TreeNode`].
conditional: bool,

controller: &'a C,
controller: &'a mut C,
}

/// Record item that used when traversing a [`TreeNode`] tree.
Expand Down Expand Up @@ -352,6 +360,7 @@ where
self.visit_stack
.push(VisitRecord::EnterMark(self.down_index));
self.down_index += 1;
self.controller.visit_f_down(node);

// If a node can short-circuit then some of its children might not be executed so
// count the occurrence either normal or conditional.
Expand Down Expand Up @@ -414,6 +423,7 @@ where
self.visit_stack
.push(VisitRecord::NodeItem(node_id, is_valid));
self.up_index += 1;
self.controller.visit_f_up(node);

Ok(TreeNodeRecursion::Continue)
}
Expand Down Expand Up @@ -532,7 +542,7 @@ where

/// Add an identifier to `id_array` for every [`TreeNode`] in this tree.
fn node_to_id_array<'n>(
&self,
&mut self,
node: &'n N,
node_stats: &mut NodeStats<'n, N>,
id_array: &mut IdArray<'n, N>,
Expand All @@ -546,7 +556,7 @@ where
random_state: &self.random_state,
found_common: false,
conditional: false,
controller: &self.controller,
controller: &mut self.controller,
};
node.visit(&mut visitor)?;

Expand All @@ -561,7 +571,7 @@ where
/// Each element is itself the result of [`CSE::node_to_id_array`] for that node
/// (e.g. the identifiers for each node in the tree)
fn to_arrays<'n>(
&self,
&mut self,
nodes: &'n [N],
node_stats: &mut NodeStats<'n, N>,
) -> Result<(bool, Vec<IdArray<'n, N>>)> {
Expand Down Expand Up @@ -761,7 +771,7 @@ mod test {
#[test]
fn id_array_visitor() -> Result<()> {
let alias_generator = AliasGenerator::new();
let eliminator = CSE::new(TestTreeNodeCSEController::new(
let mut eliminator = CSE::new(TestTreeNodeCSEController::new(
&alias_generator,
TestTreeNodeMask::Normal,
));
Expand Down Expand Up @@ -853,7 +863,7 @@ mod test {
assert_eq!(expected, id_array);

// include aggregates
let eliminator = CSE::new(TestTreeNodeCSEController::new(
let mut eliminator = CSE::new(TestTreeNodeCSEController::new(
&alias_generator,
TestTreeNodeMask::NormalAndAggregates,
));
Expand Down
16 changes: 13 additions & 3 deletions datafusion/common/src/dfschema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,10 @@ impl DFSchema {
return;
}

let self_fields: HashSet<(Option<&TableReference>, &FieldRef)> =
self.iter().collect();
let self_fields: HashSet<(Option<&TableReference>, &str)> = self
.iter()
.map(|(qualifier, field)| (qualifier, field.name().as_str()))
.collect();
let self_unqualified_names: HashSet<&str> = self
.inner
.fields
Expand All @@ -328,7 +330,10 @@ impl DFSchema {
for (qualifier, field) in other_schema.iter() {
// skip duplicate columns
let duplicated_field = match qualifier {
Some(q) => self_fields.contains(&(Some(q), field)),
Some(q) => {
self_fields.contains(&(Some(q), field.name().as_str()))
|| self_fields.contains(&(None, field.name().as_str()))
}
// for unqualified columns, check as unqualified name
None => self_unqualified_names.contains(field.name().as_str()),
};
Expand Down Expand Up @@ -867,6 +872,11 @@ impl DFSchema {
&self.functional_dependencies
}

/// Get functional dependencies
pub fn field_qualifiers(&self) -> &[Option<TableReference>] {
&self.field_qualifiers
}

/// Iterate over the qualifiers and fields in the DFSchema
pub fn iter(&self) -> impl Iterator<Item = (Option<&TableReference>, &FieldRef)> {
self.field_qualifiers
Expand Down
2 changes: 2 additions & 0 deletions datafusion/common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ pub mod hash_set {
pub use hashbrown::hash_set::Entry;
}

pub use hashbrown;

/// Downcast an Arrow Array to a concrete type, return an `DataFusionError::Internal` if the cast is
/// not possible. In normal usage of DataFusion the downcast should always succeed.
///
Expand Down
Loading