From 1a20d218aa688b9bde95dfd17eed4831816dd239 Mon Sep 17 00:00:00 2001 From: Weizhen Wang Date: Tue, 14 May 2024 16:28:12 +0800 Subject: [PATCH] planner: fix wrong behavior for = all() (#52801) close pingcap/tidb#52755 --- pkg/planner/core/casetest/BUILD.bazel | 2 +- pkg/planner/core/casetest/plan_test.go | 22 +++++++++++++++++++++ pkg/planner/core/expression_rewriter.go | 26 +++++++++---------------- tests/integrationtest/r/select.result | 2 +- 4 files changed, 33 insertions(+), 19 deletions(-) diff --git a/pkg/planner/core/casetest/BUILD.bazel b/pkg/planner/core/casetest/BUILD.bazel index bf8d00caf5079..6c62171aef90d 100644 --- a/pkg/planner/core/casetest/BUILD.bazel +++ b/pkg/planner/core/casetest/BUILD.bazel @@ -12,7 +12,7 @@ go_test( ], data = glob(["testdata/**"]), flaky = True, - shard_count = 23, + shard_count = 24, deps = [ "//pkg/domain", "//pkg/parser", diff --git a/pkg/planner/core/casetest/plan_test.go b/pkg/planner/core/casetest/plan_test.go index 46347dc353c01..b1b4d3f7eeab5 100644 --- a/pkg/planner/core/casetest/plan_test.go +++ b/pkg/planner/core/casetest/plan_test.go @@ -308,3 +308,25 @@ func TestJSONPlanInExplain(t *testing.T) { } } } + +func TestHandleEQAll(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("CREATE TABLE t1 (c1 int, c2 int, UNIQUE i1 (c1, c2));") + tk.MustExec("INSERT INTO t1 VALUES (7, null),(5,1);") + tk.MustQuery("SELECT c1 FROM t1 WHERE ('m' = ALL (SELECT /*+ IGNORE_INDEX(t1, i1) */ c2 FROM t1)) IS NOT UNKNOWN; ").Check(testkit.Rows("5", "7")) + tk.MustQuery("SELECT c1 FROM t1 WHERE ('m' = ALL (SELECT /*+ use_INDEX(t1, i1) */ c2 FROM t1)) IS NOT UNKNOWN; ").Check(testkit.Rows("5", "7")) + tk.MustQuery("select (null = ALL (SELECT /*+ NO_INDEX() */ c2 FROM t1)) IS NOT UNKNOWN").Check(testkit.Rows("0")) + tk.MustExec("CREATE TABLE t2 (c1 int, c2 int, UNIQUE i1 (c1, c2));") + tk.MustExec("INSERT INTO t2 VALUES (7, null),(5,null);") + tk.MustQuery("select (null = ALL (SELECT /*+ NO_INDEX() */ c2 FROM t2)) IS NOT UNKNOWN").Check(testkit.Rows("0")) + tk.MustQuery("SELECT c1 FROM t2 WHERE ('m' = ALL (SELECT /*+ IGNORE_INDEX(t2, i1) */ c2 FROM t2)) IS NOT UNKNOWN; ").Check(testkit.Rows()) + tk.MustQuery("SELECT c1 FROM t2 WHERE ('m' = ALL (SELECT /*+ use_INDEX(t2, i1) */ c2 FROM t2)) IS NOT UNKNOWN; ").Check(testkit.Rows()) + tk.MustExec("truncate table t2") + tk.MustExec("INSERT INTO t2 VALUES (7, null),(7,null);") + tk.MustQuery("select c1 from t2 where (c1 = all (select /*+ IGNORE_INDEX(t2, i1) */ c1 from t2))").Check(testkit.Rows("7", "7")) + tk.MustQuery("select c1 from t2 where (c1 = all (select /*+ use_INDEX(t2, i1) */ c1 from t2))").Check(testkit.Rows("7", "7")) + tk.MustQuery("select c2 from t2 where (c2 = all (select /*+ IGNORE_INDEX(t2, i1) */ c2 from t2))").Check(testkit.Rows()) + tk.MustQuery("select c2 from t2 where (c2 = all (select /*+ use_INDEX(t2, i1) */ c2 from t2))").Check(testkit.Rows()) +} diff --git a/pkg/planner/core/expression_rewriter.go b/pkg/planner/core/expression_rewriter.go index e6a40ffe149a3..ea6ee0d773811 100644 --- a/pkg/planner/core/expression_rewriter.go +++ b/pkg/planner/core/expression_rewriter.go @@ -953,7 +953,9 @@ func (er *expressionRewriter) handleEQAll(planCtx *exprRewriterPlanCtx, lexpr, r intest.AssertNotNil(planCtx) sctx := planCtx.builder.ctx exprCtx := sctx.GetExprCtx() - firstRowFunc, err := aggregation.NewAggFuncDesc(exprCtx, ast.AggFuncFirstRow, []expression.Expression{rexpr}, false) + // If there is NULL in s.id column, s.id should be the value that isn't null in condition t.id == s.id. + // So use function max to filter NULL. + maxFunc, err := aggregation.NewAggFuncDesc(exprCtx, ast.AggFuncMax, []expression.Expression{rexpr}, false) if err != nil { er.err = err return @@ -964,7 +966,7 @@ func (er *expressionRewriter) handleEQAll(planCtx *exprRewriterPlanCtx, lexpr, r return } plan4Agg := LogicalAggregation{ - AggFuncs: []*aggregation.AggFuncDesc{firstRowFunc, countFunc}, + AggFuncs: []*aggregation.AggFuncDesc{maxFunc, countFunc}, }.Init(sctx, planCtx.builder.getSelectOffset()) if hintinfo := planCtx.builder.TableHints(); hintinfo != nil { plan4Agg.PreferAggType = hintinfo.PreferAggType @@ -973,29 +975,19 @@ func (er *expressionRewriter) handleEQAll(planCtx *exprRewriterPlanCtx, lexpr, r plan4Agg.SetChildren(np) plan4Agg.names = append(plan4Agg.names, types.EmptyName) - // Currently, firstrow agg function is treated like the exact representation of aggregate group key, - // so the data type is the same with group key, even if the group key is not null. - // However, the return type of firstrow should be nullable, we clear the null flag here instead of - // during invoking NewAggFuncDesc, in order to keep compatibility with the existing presumption - // that the return type firstrow does not change nullability, whatsoever. - // Cloning it because the return type is the same object with argument's data type. - newRetTp := firstRowFunc.RetTp.Clone() - newRetTp.DelFlag(mysql.NotNullFlag) - firstRowFunc.RetTp = newRetTp - - firstRowResultCol := &expression.Column{ + maxResultCol := &expression.Column{ UniqueID: sctx.GetSessionVars().AllocPlanColumnID(), - RetType: firstRowFunc.RetTp, + RetType: maxFunc.RetTp, } - firstRowResultCol.SetCoercibility(rexpr.Coercibility()) + maxResultCol.SetCoercibility(rexpr.Coercibility()) plan4Agg.names = append(plan4Agg.names, types.EmptyName) count := &expression.Column{ UniqueID: sctx.GetSessionVars().AllocPlanColumnID(), RetType: countFunc.RetTp, } - plan4Agg.SetSchema(expression.NewSchema(firstRowResultCol, count)) + plan4Agg.SetSchema(expression.NewSchema(maxResultCol, count)) leFunc := expression.NewFunctionInternal(er.sctx, ast.LE, types.NewFieldType(mysql.TypeTiny), count, expression.NewOne()) - eqCond := expression.NewFunctionInternal(er.sctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), lexpr, firstRowResultCol) + eqCond := expression.NewFunctionInternal(er.sctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), lexpr, maxResultCol) cond := expression.ComposeCNFCondition(er.sctx, leFunc, eqCond) er.buildQuantifierPlan(planCtx, plan4Agg, cond, lexpr, rexpr, true, markNoDecorrelate) } diff --git a/tests/integrationtest/r/select.result b/tests/integrationtest/r/select.result index 15800e40dcb52..b2062c732d8a0 100644 --- a/tests/integrationtest/r/select.result +++ b/tests/integrationtest/r/select.result @@ -414,7 +414,7 @@ explain format = 'brief' select a = all (select a from t t2) from t t1; id estRows task access object operator info Projection 10000.00 root or(and(and(le(Column#11, 1), eq(select.t.a, Column#10)), if(ne(Column#12, 0), , 1)), or(eq(Column#13, 0), if(isnull(select.t.a), , 0)))->Column#14 └─HashJoin 10000.00 root CARTESIAN inner join - ├─StreamAgg(Build) 1.00 root funcs:firstrow(Column#16)->Column#10, funcs:count(distinct Column#17)->Column#11, funcs:sum(Column#18)->Column#12, funcs:count(1)->Column#13 + ├─StreamAgg(Build) 1.00 root funcs:max(Column#16)->Column#10, funcs:count(distinct Column#17)->Column#11, funcs:sum(Column#18)->Column#12, funcs:count(1)->Column#13 │ └─Projection 10000.00 root select.t.a->Column#16, select.t.a->Column#17, cast(isnull(select.t.a), decimal(20,0) BINARY)->Column#18 │ └─TableReader 10000.00 root data:TableFullScan │ └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo