Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: count_wildcard_to_time_index_rule doesn't handle table reference properly #3847

Merged
merged 3 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 70 additions & 9 deletions src/query/src/optimizer/count_wildcard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ use datafusion::datasource::DefaultTableSource;
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor,
};
use datafusion_common::Result as DataFusionResult;
use datafusion_common::{Column, Result as DataFusionResult};
use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition, WindowFunction};
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_expr::{col, lit, Expr, LogicalPlan, WindowFunctionDefinition};
use datafusion_optimizer::utils::NamePreserver;
use datafusion_optimizer::AnalyzerRule;
use datafusion_sql::TableReference;
use table::table::adapter::DfTableProviderAdapter;

/// A replacement to DataFusion's [`CountWildcardRule`]. This rule
Expand Down Expand Up @@ -77,11 +78,27 @@ impl CountWildcardToTimeIndexRule {
})
}

fn try_find_time_index_col(plan: &LogicalPlan) -> Option<String> {
fn try_find_time_index_col(plan: &LogicalPlan) -> Option<Column> {
let mut finder = TimeIndexFinder::default();
// Safety: `TimeIndexFinder` won't throw error.
plan.visit(&mut finder).unwrap();
finder.time_index
let col = finder.into_column();

// check if the time index is a valid column as for current plan
if let Some(col) = &col {
let mut is_valid = false;
for input in plan.inputs() {
if input.schema().has_column(col) {
is_valid = true;
break;
}
}
if !is_valid {
return None;
}
}

col
}
}

Expand Down Expand Up @@ -114,16 +131,16 @@ impl CountWildcardToTimeIndexRule {

#[derive(Default)]
struct TimeIndexFinder {
time_index: Option<String>,
table_alias: Option<String>,
time_index_col: Option<String>,
table_alias: Option<TableReference>,
}

impl TreeNodeVisitor for TimeIndexFinder {
type Node = LogicalPlan;

fn f_down(&mut self, node: &Self::Node) -> DataFusionResult<TreeNodeRecursion> {
if let LogicalPlan::SubqueryAlias(subquery_alias) = node {
self.table_alias = Some(subquery_alias.alias.to_string());
self.table_alias = Some(subquery_alias.alias.clone());
}

if let LogicalPlan::TableScan(table_scan) = &node {
Expand All @@ -138,9 +155,13 @@ impl TreeNodeVisitor for TimeIndexFinder {
.downcast_ref::<DfTableProviderAdapter>()
{
let table_info = adapter.table().table_info();
let col_name = table_info.meta.schema.timestamp_column().map(|c| &c.name);
let table_name = self.table_alias.as_ref().unwrap_or(&table_info.name);
self.time_index = col_name.map(|s| format!("{}.{}", table_name, s));
self.table_alias
.get_or_insert(TableReference::bare(table_info.name.clone()));
self.time_index_col = table_info
.meta
.schema
.timestamp_column()
.map(|c| c.name.clone());

return Ok(TreeNodeRecursion::Stop);
}
Expand All @@ -154,3 +175,43 @@ impl TreeNodeVisitor for TimeIndexFinder {
Ok(TreeNodeRecursion::Stop)
}
}

impl TimeIndexFinder {
fn into_column(self) -> Option<Column> {
self.time_index_col
.map(|c| Column::new(self.table_alias, c))
}
}

#[cfg(test)]
mod test {
use std::sync::Arc;

use datafusion_expr::{count, wildcard, LogicalPlanBuilder};
use table::table::numbers::NumbersTable;

use super::*;

#[test]
fn uppercase_table_name() {
let numbers_table = NumbersTable::table_with_name(0, "AbCdE".to_string());
let table_source = Arc::new(DefaultTableSource::new(Arc::new(
DfTableProviderAdapter::new(numbers_table),
)));

let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
.unwrap()
.aggregate(Vec::<Expr>::new(), vec![count(wildcard())])
.unwrap()
.alias(r#""FgHiJ""#)
.unwrap()
.build()
.unwrap();

let mut finder = TimeIndexFinder::default();
plan.visit(&mut finder).unwrap();

assert_eq!(finder.table_alias, Some(TableReference::bare("FgHiJ")));
assert!(finder.time_index_col.is_none());
}
}
56 changes: 56 additions & 0 deletions tests/cases/standalone/common/aggregate/count.result
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
create table "HelloWorld" (a string, b timestamp time index);

Affected Rows: 0

insert into "HelloWorld" values ("a", 1) ,("b", 2);

Affected Rows: 2

select count(*) from "HelloWorld";

+----------+
| COUNT(*) |
+----------+
| 2 |
+----------+

create table test (a string, "BbB" timestamp time index);

Affected Rows: 0

insert into test values ("c", 1) ;

Affected Rows: 1

select count(*) from test;

+----------+
| COUNT(*) |
+----------+
| 1 |
+----------+

select count(*) from (select count(*) from test where a = 'a');

+----------+
| COUNT(*) |
+----------+
| 1 |
+----------+

select count(*) from (select * from test cross join "HelloWorld");

+----------+
| COUNT(*) |
+----------+
| 2 |
+----------+

drop table "HelloWorld";

Affected Rows: 0

drop table test;

Affected Rows: 0

19 changes: 19 additions & 0 deletions tests/cases/standalone/common/aggregate/count.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
create table "HelloWorld" (a string, b timestamp time index);

insert into "HelloWorld" values ("a", 1) ,("b", 2);

select count(*) from "HelloWorld";

create table test (a string, "BbB" timestamp time index);

insert into test values ("c", 1) ;

select count(*) from test;

select count(*) from (select count(*) from test where a = 'a');

select count(*) from (select * from test cross join "HelloWorld");

drop table "HelloWorld";

drop table test;
Loading