Skip to content

Commit

Permalink
opt: prepare to support multi-argument aggregate functions
Browse files Browse the repository at this point in the history
PR cockroachdb#44628 is attempting to add the corr() aggregate function, which
will be the first true multi-argument aggregate in CRDB (besides
string_agg, which is special-cased). That PR runs into problems
when dealing with aggregate DISTINCT and FILTER clauses. The main
issue is that the optimizer currently adds the AggDistinct and
AggFilter operators as *inputs* to the aggregate function. This
doesn't make sense when there are multiple inputs, since it's now
ambiguous which input to use for that purpose.

This commit flips the situation by wrapping the aggregate function
with AggDistinct and AggFilter rather than having them wrap the
input to the aggregate function:

  (AggFilter (AggDistinct (Sum (Variable 1))))

instead of:

  (Sum (AggFilter (AggDistinct (Variable 1))))

As part of this change, it was also convenient to refactor the
EliminateAggDistinctForKeys rule and also to consolidate the
ReplaceScalarMinWithLimit and ReplaceScalarMaxWithLimit rules
into one rule.

Release note: None
  • Loading branch information
andy-kimball committed Feb 8, 2020
1 parent f6bb936 commit 079bc47
Show file tree
Hide file tree
Showing 24 changed files with 636 additions and 355 deletions.
67 changes: 32 additions & 35 deletions pkg/sql/opt/exec/execbuilder/relational.go
Original file line number Diff line number Diff line change
Expand Up @@ -913,43 +913,51 @@ func (b *Builder) buildGroupBy(groupBy memo.RelExpr) (execPlan, error) {
aggInfos := make([]exec.AggInfo, len(aggregations))
for i := range aggregations {
item := &aggregations[i]
name, overload := memo.FindAggregateOverload(item.Agg)
agg := item.Agg

distinct := false
var argIdx []exec.ColumnOrdinal
var filterOrd exec.ColumnOrdinal = -1
if aggFilter, ok := agg.(*memo.AggFilterExpr); ok {
filter, ok := aggFilter.Filter.(*memo.VariableExpr)
if !ok {
return execPlan{}, errors.AssertionFailedf("only VariableOp args supported")
}
filterOrd = input.getColumnOrdinal(filter.Col)
agg = aggFilter.Input
}

if item.Agg.ChildCount() > 0 {
child := item.Agg.Child(0)
distinct := false
if aggDistinct, ok := agg.(*memo.AggDistinctExpr); ok {
distinct = true
agg = aggDistinct.Input
}

if aggFilter, ok := child.(*memo.AggFilterExpr); ok {
filter, ok := aggFilter.Filter.(*memo.VariableExpr)
if !ok {
return execPlan{}, errors.Errorf("only VariableOp args supported")
name, overload := memo.FindAggregateOverload(agg)

// Accumulate variable arguments in argCols and constant arguments in
// constArgs. Constant arguments must follow variable arguments.
var argCols []exec.ColumnOrdinal
var constArgs tree.Datums
for j, n := 0, agg.ChildCount(); j < n; j++ {
child := agg.Child(j)
if variable, ok := child.(*memo.VariableExpr); ok {
if len(constArgs) != 0 {
return execPlan{}, errors.Errorf("constant args must come after variable args")
}
filterOrd = input.getColumnOrdinal(filter.Col)
child = aggFilter.Input
}

if aggDistinct, ok := child.(*memo.AggDistinctExpr); ok {
distinct = true
child = aggDistinct.Input
}
v, ok := child.(*memo.VariableExpr)
if !ok {
return execPlan{}, errors.Errorf("only VariableOp args supported")
argCols = append(argCols, input.getColumnOrdinal(variable.Col))
} else {
if len(argCols) == 0 {
return execPlan{}, errors.Errorf("a constant arg requires at least one variable arg")
}
constArgs = append(constArgs, memo.ExtractConstDatum(child))
}
argIdx = []exec.ColumnOrdinal{input.getColumnOrdinal(v.Col)}
}

constArgs := b.extractAggregateConstArgs(item.Agg)

aggInfos[i] = exec.AggInfo{
FuncName: name,
Builtin: overload,
Distinct: distinct,
ResultType: item.Agg.DataType(),
ArgCols: argIdx,
ArgCols: argCols,
ConstArgs: constArgs,
Filter: filterOrd,
}
Expand All @@ -974,17 +982,6 @@ func (b *Builder) buildGroupBy(groupBy memo.RelExpr) (execPlan, error) {
return ep, nil
}

// extractAggregateConstArgs returns the list of constant arguments associated with a given aggregate
// expression.
func (b *Builder) extractAggregateConstArgs(agg opt.ScalarExpr) tree.Datums {
switch agg.Op() {
case opt.StringAggOp:
return tree.Datums{memo.ExtractConstDatum(agg.Child(1))}
default:
return nil
}
}

func (b *Builder) buildDistinct(distinct *memo.DistinctOnExpr) (execPlan, error) {
if distinct.GroupingCols.Empty() {
// A DistinctOn with no grouping columns should have been converted to a
Expand Down
69 changes: 38 additions & 31 deletions pkg/sql/opt/memo/extract.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ package memo
import (
"github.com/cockroachdb/cockroach/pkg/sql/opt"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/errors"
)

Expand Down Expand Up @@ -98,54 +97,62 @@ func ExtractConstDatum(e opt.Expr) tree.Datum {
panic(errors.AssertionFailedf("non-const expression: %+v", e))
}

// ExtractAggSingleInputColumn returns the input ColumnID of an aggregate
// operator that has a single input.
func ExtractAggSingleInputColumn(e opt.ScalarExpr) opt.ColumnID {
// ExtractAggFunc digs down into the given aggregate expression and returns the
// aggregate function, skipping past any AggFilter or AggDistinct operators.
func ExtractAggFunc(e opt.ScalarExpr) opt.ScalarExpr {
if filter, ok := e.(*AggFilterExpr); ok {
e = filter.Input
}

if distinct, ok := e.(*AggDistinctExpr); ok {
e = distinct.Input
}

if !opt.IsAggregateOp(e) {
panic(errors.AssertionFailedf("not an Aggregate"))
}
return ExtractVarFromAggInput(e.Child(0).(opt.ScalarExpr)).Col

return e
}

// ExtractAggInputColumns returns the set of columns the aggregate depends on.
func ExtractAggInputColumns(e opt.ScalarExpr) opt.ColSet {
if !opt.IsAggregateOp(e) {
panic(errors.AssertionFailedf("not an Aggregate"))
var res opt.ColSet
if filter, ok := e.(*AggFilterExpr); ok {
res.Add(filter.Filter.(*VariableExpr).Col)
e = filter.Input
}

if e.ChildCount() == 0 {
return opt.ColSet{}
if distinct, ok := e.(*AggDistinctExpr); ok {
e = distinct.Input
}

arg := e.Child(0)
var res opt.ColSet
if filter, ok := arg.(*AggFilterExpr); ok {
res.Add(filter.Filter.(*VariableExpr).Col)
arg = filter.Input
}
if distinct, ok := arg.(*AggDistinctExpr); ok {
arg = distinct.Input
if !opt.IsAggregateOp(e) {
panic(errors.AssertionFailedf("not an Aggregate"))
}
if variable, ok := arg.(*VariableExpr); ok {
res.Add(variable.Col)
return res

for i, n := 0, e.ChildCount(); i < n; i++ {
if variable, ok := e.Child(i).(*VariableExpr); ok {
res.Add(variable.Col)
}
}
panic(errors.AssertionFailedf("unhandled aggregate input %T", log.Safe(arg)))

return res
}

// ExtractVarFromAggInput is given an argument to an Aggregate and returns the
// inner Variable expression, stripping out modifiers like AggDistinct.
func ExtractVarFromAggInput(arg opt.ScalarExpr) *VariableExpr {
if filter, ok := arg.(*AggFilterExpr); ok {
arg = filter.Input
}
if distinct, ok := arg.(*AggDistinctExpr); ok {
arg = distinct.Input
// ExtractAggFirstVar is given an aggregate expression and returns the Variable
// expression for the first argument, skipping past modifiers like AggDistinct.
func ExtractAggFirstVar(e opt.ScalarExpr) *VariableExpr {
e = ExtractAggFunc(e)
if e.ChildCount() == 0 {
panic(errors.AssertionFailedf("aggregate does not have any arguments"))
}
if variable, ok := arg.(*VariableExpr); ok {

if variable, ok := e.Child(0).(*VariableExpr); ok {
return variable
}
panic(errors.AssertionFailedf("aggregate input not a Variable"))

panic(errors.AssertionFailedf("first aggregate input is not a Variable"))
}

// ExtractJoinEqualityColumns returns pairs of columns (one from the left side,
Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/opt/memo/testdata/stats_quality/tpcc
Original file line number Diff line number Diff line change
Expand Up @@ -636,8 +636,8 @@ scalar-group-by
│ ├── variable: s_quantity [type=int]
│ └── const: 15 [type=int]
└── aggregations
└── count [type=int, outer=(11)]
└── agg-distinct [type=int]
└── agg-distinct [type=int, outer=(11)]
└── count [type=int]
└── variable: s_i_id [type=int]

stats table=stock_level_02_scan_3
Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/opt/memo/testdata/stats_quality/tpch/q16
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ sort
│ └── filters
│ └── p_partkey = ps_partkey [type=bool, outer=(1,6), constraints=(/1: (/NULL - ]; /6: (/NULL - ]), fd=(1)==(6), (6)==(1)]
└── aggregations
└── count [type=int, outer=(2)]
└── agg-distinct [type=int]
└── agg-distinct [type=int, outer=(2)]
└── count [type=int]
└── variable: ps_suppkey [type=int]

stats table=q16_sort_1
Expand Down
31 changes: 25 additions & 6 deletions pkg/sql/opt/norm/custom_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,12 +290,21 @@ func (c *CustomFuncs) CanHaveZeroRows(input memo.RelExpr) bool {
return input.Relational().Cardinality.CanBeZero()
}

// ColsAreKey returns true if the given columns form a strict key for the given
// input expression. A strict key means that any two rows will have unique key
// column values. Nulls are treated as equal to one another (i.e. no duplicate
// nulls allowed). Having a strict key means that the set of key column values
// uniquely determine the values of all other columns in the relation.
func (c *CustomFuncs) ColsAreKey(cols opt.ColSet, input memo.RelExpr) bool {
// HasStrictKey returns true if the input expression has one or more columns
// that form a strict key (see comment for ColsAreStrictKey for definition).
func (c *CustomFuncs) HasStrictKey(input memo.RelExpr) bool {
inputFDs := &input.Relational().FuncDeps
_, hasKey := inputFDs.StrictKey()
return hasKey
}

// ColsAreStrictKey returns true if the given columns form a strict key for the
// given input expression. A strict key means that any two rows will have unique
// key column values. Nulls are treated as equal to one another (i.e. no
// duplicate nulls allowed). Having a strict key means that the set of key
// column values uniquely determine the values of all other columns in the
// relation.
func (c *CustomFuncs) ColsAreStrictKey(cols opt.ColSet, input memo.RelExpr) bool {
return input.Relational().FuncDeps.ColsAreStrictKey(cols)
}

Expand Down Expand Up @@ -357,6 +366,16 @@ func (c *CustomFuncs) DifferenceCols(left, right opt.ColSet) opt.ColSet {
return left.Difference(right)
}

// AddColToSet returns a set containing both the given set and the given column.
func (c *CustomFuncs) AddColToSet(set opt.ColSet, col opt.ColumnID) opt.ColSet {
if set.Contains(col) {
return set
}
newSet := set.Copy()
newSet.Add(col)
return newSet
}

// sharedProps returns the shared logical properties for the given expression.
// Only relational expressions and certain scalar list items (e.g. FiltersItem,
// ProjectionsItem, AggregationsItem) have shared properties.
Expand Down
107 changes: 24 additions & 83 deletions pkg/sql/opt/norm/groupby.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ package norm
import (
"github.com/cockroachdb/cockroach/pkg/sql/opt"
"github.com/cockroachdb/cockroach/pkg/sql/opt/memo"
"github.com/cockroachdb/cockroach/pkg/sql/opt/props"
"github.com/cockroachdb/cockroach/pkg/sql/opt/props/physical"
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/errors"
Expand Down Expand Up @@ -131,97 +130,39 @@ func (c *CustomFuncs) makeAggCols(
}
}

// CanRemoveAggDistinctForKeys returns true if the given aggregations contain an
// aggregation with AggDistinct where the input column together with the
// grouping columns form a key. In this case, the respective AggDistinct can be
// removed.
// CanRemoveAggDistinctForKeys returns true if the given aggregate function
// where its input column, together with the grouping columns, form a key. In
// this case, the wrapper AggDistinct can be removed.
func (c *CustomFuncs) CanRemoveAggDistinctForKeys(
aggs memo.AggregationsExpr, private *memo.GroupingPrivate, input memo.RelExpr,
input memo.RelExpr, private *memo.GroupingPrivate, agg opt.ScalarExpr,
) bool {
inputFDs := &input.Relational().FuncDeps
if _, hasKey := inputFDs.StrictKey(); !hasKey {
// Fast-path for the case when the input has no keys.
if agg.ChildCount() == 0 {
return false
}

for i := range aggs {
if ok, _ := c.hasRemovableAggDistinct(aggs[i].Agg, private.GroupingCols, inputFDs); ok {
return true
}
}
return false
inputFDs := &input.Relational().FuncDeps
variable := agg.Child(0).(*memo.VariableExpr)
cols := c.AddColToSet(private.GroupingCols, variable.Col)
return inputFDs.ColsAreStrictKey(cols)
}

// RemoveAggDistinctForKeys rewrites aggregations to remove AggDistinct when
// the input column together with the grouping columns form a key. Returns the
// new Aggregation expression.
func (c *CustomFuncs) RemoveAggDistinctForKeys(
aggs memo.AggregationsExpr, private *memo.GroupingPrivate, input memo.RelExpr,
// ReplaceAggregationsItem returns a new list that is a copy of the given list,
// except that the given search item has been replaced by the given replace
// item. If the list contains the search item multiple times, then only the
// first instance is replaced. If the list does not contain the item, then the
// method panics.
func (c *CustomFuncs) ReplaceAggregationsItem(
aggs memo.AggregationsExpr, search *memo.AggregationsItem, replace opt.ScalarExpr,
) memo.AggregationsExpr {
inputFDs := &input.Relational().FuncDeps

newAggs := make(memo.AggregationsExpr, len(aggs))
newAggs := make([]memo.AggregationsItem, len(aggs))
for i := range aggs {
item := &aggs[i]
if ok, v := c.hasRemovableAggDistinct(item.Agg, private.GroupingCols, inputFDs); ok {
// Remove AggDistinct. We rely on the fact that AggDistinct must be
// directly "under" the Aggregate.
// TODO(radu): this will need to be revisited when we add more modifiers.
newAggs[i] = c.f.ConstructAggregationsItem(
c.replaceAggInputVar(item.Agg, v),
aggs[i].Col,
)
} else {
newAggs[i] = *item
if search == &aggs[i] {
copy(newAggs, aggs[:i])
newAggs[i] = c.f.ConstructAggregationsItem(replace, search.Col)
copy(newAggs[i+1:], aggs[i+1:])
return newAggs
}
}

return newAggs
}

// replaceAggInputVar swaps out the aggregated variable in an aggregate with v. In
// the case of aggregates with multiple arguments (like string_agg) the other arguments
// are kept the same.
func (c *CustomFuncs) replaceAggInputVar(agg opt.ScalarExpr, v opt.ScalarExpr) opt.ScalarExpr {
switch agg.ChildCount() {
case 1:
return c.f.DynamicConstruct(agg.Op(), v).(opt.ScalarExpr)
case 2:
return c.f.DynamicConstruct(agg.Op(), v, agg.Child(1)).(opt.ScalarExpr)
default:
panic(errors.AssertionFailedf("unhandled number of aggregate children"))
}
}

// hasRemovableAggDistinct is called with an aggregation element and returns
// true if the aggregation has AggDistinct and the grouping columns along with
// the aggregation input column form a key in the input (in which case
// AggDistinct can be elided).
// On success, the input expression to AggDistinct is also returned.
func (c *CustomFuncs) hasRemovableAggDistinct(
agg opt.ScalarExpr, groupingCols opt.ColSet, inputFDs *props.FuncDepSet,
) (ok bool, aggDistinctVar *memo.VariableExpr) {
if agg.ChildCount() == 0 {
return false, nil
}

distinct, ok := agg.Child(0).(*memo.AggDistinctExpr)
if !ok {
return false, nil
}

v, ok := distinct.Input.(*memo.VariableExpr)
if !ok {
return false, nil
}

cols := groupingCols.Copy()
cols.Add(v.Col)
if !inputFDs.ColsAreStrictKey(cols) {
return false, nil
}

return true, v
panic(errors.AssertionFailedf("item to replace is not in the list: %v", search))
}

// HasNoGroupingCols returns true if the GroupingCols in the private are empty.
Expand All @@ -243,7 +184,7 @@ func (c *CustomFuncs) ConstructProjectionFromDistinctOn(
var passthrough opt.ColSet
var projections memo.ProjectionsExpr
for i := range aggs {
varExpr := memo.ExtractVarFromAggInput(aggs[i].Agg.Child(0).(opt.ScalarExpr))
varExpr := memo.ExtractAggFirstVar(aggs[i].Agg)
inputCol := varExpr.Col
outputCol := aggs[i].Col
if inputCol == outputCol {
Expand Down
Loading

0 comments on commit 079bc47

Please sign in to comment.