Skip to content

Bug: unnecessary columns projected and redundant filters pushed down #18816

@niebayes

Description

@niebayes

I manually composed a sort of simple plan: TableScan -> Projection -> Filter -> Sort -> Extension.
However, DataFusion optimizer made it very complicated and there're two bugs:

  • Unnecessary columns projected
  • Redundant filters pushed down

A quick look:

// Original Plan.
DummyPlan
  Sort: t.a ASC NULLS FIRST, t.ts ASC NULLS FIRST
    Filter: t.ts > TimestampMillisecond(1000, Some("UTC")) AND t.ts < TimestampMillisecond(2000, Some("UTC"))
      Projection: t.a, CAST(t.ts AS Timestamp(ms, "UTC")) AS ts
        TableScan: t

// Optimized Plan
DummyPlan
  Sort: t.a ASC NULLS FIRST, t.ts ASC NULLS FIRST
    Projection: t.a, CAST(t.ts AS Timestamp(ms, "UTC")) AS ts
      Projection: t.a, t.b, t.ts
        Projection: CAST(t.ts AS Timestamp(ms, "UTC")) AS __common_expr_1, t.a, t.b, t.ts
          Projection: t.a, t.b, t.ts
            Projection: CAST(t.ts AS Timestamp(ms, "UTC")) AS __common_expr_2, t.a, t.b, t.ts
              Projection: t.a, t.b, t.ts
                Filter: __common_expr_3 > TimestampMillisecond(1000, Some("UTC")) AND __common_expr_3 < TimestampMillisecond(2000, Some("UTC"))
                  Projection: CAST(t.ts AS Timestamp(ms, "UTC")) AS __common_expr_3, t.a, t.b, t.ts
                    TableScan: t, partial_filters=[t.ts > TimestampNanosecond(1000000000, None), t.ts < TimestampNanosecond(2000000000, None), CAST(t.ts AS Timestamp(ms, "UTC")) > TimestampMillisecond(1000, Some("UTC")), CAST(t.ts AS Timestamp(ms, "UTC")) < TimestampMillisecond(2000, Some("UTC"))]

DataFusion revision: 2dd17b9

To reproduce:

#[cfg(test)]
mod tests {
    use std::any::Any;
    use std::sync::Arc;

    use arrow_schema::{DataType, Field, Schema, SchemaRef, TimeUnit};
    use async_trait::async_trait;
    use datafusion::catalog::{Session, TableProvider};
    use datafusion::datasource::provider_as_source;
    use datafusion::physical_plan::ExecutionPlan;
    use datafusion::prelude::SessionContext;
    use datafusion_common::{DFSchemaRef, Result, ScalarValue, ToDFSchema};
    use datafusion_expr::{
        col, Expr, ExprSchemable, Extension, LogicalPlan, LogicalPlanBuilder, SortExpr,
        TableProviderFilterPushDown, TableType, UserDefinedLogicalNodeCore,
    };

    fn timestamp_ms(value: i64) -> Expr {
        Expr::Literal(
            ScalarValue::TimestampMillisecond(Some(value), Some("UTC".into())),
            None,
        )
    }

    #[derive(Debug)]
    struct DummyTable {
        schema: SchemaRef,
    }

    #[async_trait]
    impl TableProvider for DummyTable {
        fn as_any(&self) -> &dyn Any {
            self
        }

        fn schema(&self) -> SchemaRef {
            self.schema.clone()
        }

        fn table_type(&self) -> TableType {
            TableType::Base
        }

        async fn scan(
            &self,
            _state: &dyn Session,
            _projection: Option<&Vec<usize>>,
            _filters: &[Expr],
            _limit: Option<usize>,
        ) -> Result<Arc<dyn ExecutionPlan>> {
            unimplemented!()
        }

        fn supports_filters_pushdown(
            &self,
            filters: &[&Expr],
        ) -> Result<Vec<TableProviderFilterPushDown>> {
            Ok(vec![TableProviderFilterPushDown::Inexact; filters.len()])
        }
    }

    #[derive(Debug, Hash, PartialEq, Eq)]
    pub struct DummyPlan {
        input: Arc<LogicalPlan>,
        schema: DFSchemaRef,
    }

    impl PartialOrd for DummyPlan {
        fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
            self.input.partial_cmp(&other.input)
        }
    }

    impl UserDefinedLogicalNodeCore for DummyPlan {
        fn name(&self) -> &str {
            "DummyPlan"
        }

        fn inputs(&self) -> Vec<&LogicalPlan> {
            vec![&self.input]
        }

        fn schema(&self) -> &DFSchemaRef {
            &self.schema
        }

        fn expressions(&self) -> Vec<Expr> {
            vec![]
        }

        fn with_exprs_and_inputs(
            &self,
            _exprs: Vec<Expr>,
            mut inputs: Vec<LogicalPlan>,
        ) -> Result<Self> {
            Ok(Self {
                input: inputs.pop().unwrap().into(),
                schema: self.schema.clone(),
            })
        }

        fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
            write!(f, "{}", self.name())
        }
    }

    #[tokio::test]
    async fn test_optimize() -> Result<()> {
        let schema = Arc::new(Schema::new(vec![
            Field::new("a", DataType::Int32, true),
            Field::new("b", DataType::Int32, true),
            Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true),
        ]));
        let table = Arc::new(DummyTable {
            schema: schema.clone(),
        });

        let context = SessionContext::new();
        context.register_table("t", table.clone())?;

        let plan = LogicalPlanBuilder::scan("t", provider_as_source(table), None)?
            .project(vec![
                col("a"),
                col("ts")
                    .cast_to(
                        &DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())),
                        &schema.clone().to_dfschema()?,
                    )?
                    .alias_qualified(Some("t"), "ts"),
            ])?
            .filter(
                col("ts")
                    .gt(timestamp_ms(1000))
                    .and(col("ts").lt(timestamp_ms(2000))),
            )?
            .sort(vec![
                SortExpr::new(col("a"), true, true),
                SortExpr::new(col("ts"), true, true),
            ])?
            .build()?;

        let plan = LogicalPlan::Extension(Extension {
            node: Arc::new(DummyPlan {
                input: plan.into(),
                schema: schema.to_dfschema_ref()?,
            }),
        });

        assert_eq!(
            plan.display_indent().to_string(),
            r#"DummyPlan
  Sort: t.a ASC NULLS FIRST, t.ts ASC NULLS FIRST
    Filter: t.ts > TimestampMillisecond(1000, Some("UTC")) AND t.ts < TimestampMillisecond(2000, Some("UTC"))
      Projection: t.a, CAST(t.ts AS Timestamp(ms, "UTC")) AS ts
        TableScan: t"#
        );

        let optimized_plan = context.state().optimize(&plan)?;
        assert_eq!(
            optimized_plan.display_indent().to_string(),
            r#"DummyPlan
  Sort: t.a ASC NULLS FIRST, t.ts ASC NULLS FIRST
    Projection: t.a, CAST(t.ts AS Timestamp(ms, "UTC")) AS ts
      Projection: t.a, t.b, t.ts
        Projection: CAST(t.ts AS Timestamp(ms, "UTC")) AS __common_expr_1, t.a, t.b, t.ts
          Projection: t.a, t.b, t.ts
            Projection: CAST(t.ts AS Timestamp(ms, "UTC")) AS __common_expr_2, t.a, t.b, t.ts
              Projection: t.a, t.b, t.ts
                Filter: __common_expr_3 > TimestampMillisecond(1000, Some("UTC")) AND __common_expr_3 < TimestampMillisecond(2000, Some("UTC"))
                  Projection: CAST(t.ts AS Timestamp(ms, "UTC")) AS __common_expr_3, t.a, t.b, t.ts
                    TableScan: t, partial_filters=[t.ts > TimestampNanosecond(1000000000, None), t.ts < TimestampNanosecond(2000000000, None), CAST(t.ts AS Timestamp(ms, "UTC")) > TimestampMillisecond(1000, Some("UTC")), CAST(t.ts AS Timestamp(ms, "UTC")) < TimestampMillisecond(2000, Some("UTC"))]"#
        );

        Ok(())
    }
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions