Skip to content

Commit

Permalink
planner: fix wrong behavior for = all() (pingcap#52801)
Browse files Browse the repository at this point in the history
  • Loading branch information
hawkingrei authored and RidRisR committed May 23, 2024
1 parent 6a84de6 commit 1a20d21
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 19 deletions.
2 changes: 1 addition & 1 deletion pkg/planner/core/casetest/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ go_test(
],
data = glob(["testdata/**"]),
flaky = True,
shard_count = 23,
shard_count = 24,
deps = [
"//pkg/domain",
"//pkg/parser",
Expand Down
22 changes: 22 additions & 0 deletions pkg/planner/core/casetest/plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,3 +308,25 @@ func TestJSONPlanInExplain(t *testing.T) {
}
}
}

func TestHandleEQAll(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec("CREATE TABLE t1 (c1 int, c2 int, UNIQUE i1 (c1, c2));")
tk.MustExec("INSERT INTO t1 VALUES (7, null),(5,1);")
tk.MustQuery("SELECT c1 FROM t1 WHERE ('m' = ALL (SELECT /*+ IGNORE_INDEX(t1, i1) */ c2 FROM t1)) IS NOT UNKNOWN; ").Check(testkit.Rows("5", "7"))
tk.MustQuery("SELECT c1 FROM t1 WHERE ('m' = ALL (SELECT /*+ use_INDEX(t1, i1) */ c2 FROM t1)) IS NOT UNKNOWN; ").Check(testkit.Rows("5", "7"))
tk.MustQuery("select (null = ALL (SELECT /*+ NO_INDEX() */ c2 FROM t1)) IS NOT UNKNOWN").Check(testkit.Rows("0"))
tk.MustExec("CREATE TABLE t2 (c1 int, c2 int, UNIQUE i1 (c1, c2));")
tk.MustExec("INSERT INTO t2 VALUES (7, null),(5,null);")
tk.MustQuery("select (null = ALL (SELECT /*+ NO_INDEX() */ c2 FROM t2)) IS NOT UNKNOWN").Check(testkit.Rows("0"))
tk.MustQuery("SELECT c1 FROM t2 WHERE ('m' = ALL (SELECT /*+ IGNORE_INDEX(t2, i1) */ c2 FROM t2)) IS NOT UNKNOWN; ").Check(testkit.Rows())
tk.MustQuery("SELECT c1 FROM t2 WHERE ('m' = ALL (SELECT /*+ use_INDEX(t2, i1) */ c2 FROM t2)) IS NOT UNKNOWN; ").Check(testkit.Rows())
tk.MustExec("truncate table t2")
tk.MustExec("INSERT INTO t2 VALUES (7, null),(7,null);")
tk.MustQuery("select c1 from t2 where (c1 = all (select /*+ IGNORE_INDEX(t2, i1) */ c1 from t2))").Check(testkit.Rows("7", "7"))
tk.MustQuery("select c1 from t2 where (c1 = all (select /*+ use_INDEX(t2, i1) */ c1 from t2))").Check(testkit.Rows("7", "7"))
tk.MustQuery("select c2 from t2 where (c2 = all (select /*+ IGNORE_INDEX(t2, i1) */ c2 from t2))").Check(testkit.Rows())
tk.MustQuery("select c2 from t2 where (c2 = all (select /*+ use_INDEX(t2, i1) */ c2 from t2))").Check(testkit.Rows())
}
26 changes: 9 additions & 17 deletions pkg/planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,9 @@ func (er *expressionRewriter) handleEQAll(planCtx *exprRewriterPlanCtx, lexpr, r
intest.AssertNotNil(planCtx)
sctx := planCtx.builder.ctx
exprCtx := sctx.GetExprCtx()
firstRowFunc, err := aggregation.NewAggFuncDesc(exprCtx, ast.AggFuncFirstRow, []expression.Expression{rexpr}, false)
// If there is NULL in s.id column, s.id should be the value that isn't null in condition t.id == s.id.
// So use function max to filter NULL.
maxFunc, err := aggregation.NewAggFuncDesc(exprCtx, ast.AggFuncMax, []expression.Expression{rexpr}, false)
if err != nil {
er.err = err
return
Expand All @@ -964,7 +966,7 @@ func (er *expressionRewriter) handleEQAll(planCtx *exprRewriterPlanCtx, lexpr, r
return
}
plan4Agg := LogicalAggregation{
AggFuncs: []*aggregation.AggFuncDesc{firstRowFunc, countFunc},
AggFuncs: []*aggregation.AggFuncDesc{maxFunc, countFunc},
}.Init(sctx, planCtx.builder.getSelectOffset())
if hintinfo := planCtx.builder.TableHints(); hintinfo != nil {
plan4Agg.PreferAggType = hintinfo.PreferAggType
Expand All @@ -973,29 +975,19 @@ func (er *expressionRewriter) handleEQAll(planCtx *exprRewriterPlanCtx, lexpr, r
plan4Agg.SetChildren(np)
plan4Agg.names = append(plan4Agg.names, types.EmptyName)

// Currently, firstrow agg function is treated like the exact representation of aggregate group key,
// so the data type is the same with group key, even if the group key is not null.
// However, the return type of firstrow should be nullable, we clear the null flag here instead of
// during invoking NewAggFuncDesc, in order to keep compatibility with the existing presumption
// that the return type firstrow does not change nullability, whatsoever.
// Cloning it because the return type is the same object with argument's data type.
newRetTp := firstRowFunc.RetTp.Clone()
newRetTp.DelFlag(mysql.NotNullFlag)
firstRowFunc.RetTp = newRetTp

firstRowResultCol := &expression.Column{
maxResultCol := &expression.Column{
UniqueID: sctx.GetSessionVars().AllocPlanColumnID(),
RetType: firstRowFunc.RetTp,
RetType: maxFunc.RetTp,
}
firstRowResultCol.SetCoercibility(rexpr.Coercibility())
maxResultCol.SetCoercibility(rexpr.Coercibility())
plan4Agg.names = append(plan4Agg.names, types.EmptyName)
count := &expression.Column{
UniqueID: sctx.GetSessionVars().AllocPlanColumnID(),
RetType: countFunc.RetTp,
}
plan4Agg.SetSchema(expression.NewSchema(firstRowResultCol, count))
plan4Agg.SetSchema(expression.NewSchema(maxResultCol, count))
leFunc := expression.NewFunctionInternal(er.sctx, ast.LE, types.NewFieldType(mysql.TypeTiny), count, expression.NewOne())
eqCond := expression.NewFunctionInternal(er.sctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), lexpr, firstRowResultCol)
eqCond := expression.NewFunctionInternal(er.sctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), lexpr, maxResultCol)
cond := expression.ComposeCNFCondition(er.sctx, leFunc, eqCond)
er.buildQuantifierPlan(planCtx, plan4Agg, cond, lexpr, rexpr, true, markNoDecorrelate)
}
Expand Down
2 changes: 1 addition & 1 deletion tests/integrationtest/r/select.result
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ explain format = 'brief' select a = all (select a from t t2) from t t1;
id estRows task access object operator info
Projection 10000.00 root or(and(and(le(Column#11, 1), eq(select.t.a, Column#10)), if(ne(Column#12, 0), <nil>, 1)), or(eq(Column#13, 0), if(isnull(select.t.a), <nil>, 0)))->Column#14
└─HashJoin 10000.00 root CARTESIAN inner join
├─StreamAgg(Build) 1.00 root funcs:firstrow(Column#16)->Column#10, funcs:count(distinct Column#17)->Column#11, funcs:sum(Column#18)->Column#12, funcs:count(1)->Column#13
├─StreamAgg(Build) 1.00 root funcs:max(Column#16)->Column#10, funcs:count(distinct Column#17)->Column#11, funcs:sum(Column#18)->Column#12, funcs:count(1)->Column#13
│ └─Projection 10000.00 root select.t.a->Column#16, select.t.a->Column#17, cast(isnull(select.t.a), decimal(20,0) BINARY)->Column#18
│ └─TableReader 10000.00 root data:TableFullScan
│ └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo
Expand Down

0 comments on commit 1a20d21

Please sign in to comment.