Skip to content

Commit

Permalink
support simple group plan which group by items equal with order by it…
Browse files Browse the repository at this point in the history
…ems (#259)

Co-authored-by: Dong Jianhui <dongjianhui03@meituan.com>
  • Loading branch information
Mulavar and Dong Jianhui committed Jul 9, 2022
1 parent 53b86bc commit c2fd6bf
Show file tree
Hide file tree
Showing 9 changed files with 260 additions and 19 deletions.
2 changes: 1 addition & 1 deletion pkg/dataset/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func Map(generateFields FieldsFunc, transform TransformFunc) Option {
}
}

func GroupReduce(groups []string, generateFields FieldsFunc, reducer func() Reducer) Option {
func GroupReduce(groups []OrderByItem, generateFields FieldsFunc, reducer func() Reducer) Option {
return func(option *pipeOption) {
*option = append(*option, func(dataset proto.Dataset) proto.Dataset {
return &GroupDataset{
Expand Down
70 changes: 67 additions & 3 deletions pkg/dataset/group_reduce.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ import (
)

import (
"github.com/arana-db/arana/pkg/merge"
"github.com/arana-db/arana/pkg/mysql/rows"
"github.com/arana-db/arana/pkg/proto"
"github.com/arana-db/arana/pkg/util/log"
)
Expand All @@ -44,9 +46,71 @@ type Reducer interface {
Row() proto.Row
}

type AggregateItem struct {
agg merge.Aggregator
idx int
}

type AggregateReducer struct {
AggItems map[int]merge.Aggregator
currentRow proto.Row
Fields []proto.Field
}

func NewGroupReducer(aggFuncMap map[int]func() merge.Aggregator, fields []proto.Field) *AggregateReducer {
aggItems := make(map[int]merge.Aggregator)
for idx, f := range aggFuncMap {
aggItems[idx] = f()
}
return &AggregateReducer{
AggItems: aggItems,
currentRow: nil,
Fields: fields,
}
}

func (gr *AggregateReducer) Reduce(next proto.Row) error {
var (
values = make([]proto.Value, len(gr.Fields))
result = make([]proto.Value, len(gr.Fields))
)
err := next.Scan(values)
if err != nil {
return err
}

for idx, aggregator := range gr.AggItems {
aggregator.Aggregate([]interface{}{values[idx]})
}

for i := 0; i < len(values); i++ {
if gr.AggItems[i] == nil {
result[i] = values[i]
} else {
aggResult, ok := gr.AggItems[i].GetResult()
if !ok {
return errors.New("can not aggregate value")
}
result[i] = aggResult
}
}

if next.IsBinary() {
gr.currentRow = rows.NewBinaryVirtualRow(gr.Fields, result)
} else {
gr.currentRow = rows.NewTextVirtualRow(gr.Fields, result)
}
return nil
}

func (gr *AggregateReducer) Row() proto.Row {
return gr.currentRow
}

type GroupDataset struct {
// Should be an orderedDataset
proto.Dataset
keys []string
keys []OrderByItem

fieldFunc FieldsFunc
actualFieldsOnce sync.Once
Expand Down Expand Up @@ -247,13 +311,13 @@ func (gd *GroupDataset) getKeyIndexes() ([]int, error) {
for _, key := range gd.keys {
idx := -1
for i := 0; i < len(fields); i++ {
if fields[i].Name() == key {
if fields[i].Name() == key.Column {
idx = i
break
}
}
if idx == -1 {
gd.keyIndexesFailure = fmt.Errorf("cannot find group field '%s'", key)
gd.keyIndexesFailure = fmt.Errorf("cannot find group field '%+v'", key)
return
}
gd.keyIndexes = append(gd.keyIndexes, idx)
Expand Down
2 changes: 1 addition & 1 deletion pkg/dataset/group_reduce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func TestGroupReduce(t *testing.T) {
}

// Simulate: SELECT gender,COUNT(*) AS amount FROM xxx WHERE ... GROUP BY gender
groups := []string{"gender"}
groups := []OrderByItem{{"gender", true}}
p := Pipe(&origin,
GroupReduce(
groups,
Expand Down
42 changes: 42 additions & 0 deletions pkg/merge/aggregator/init.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package aggregator

import (
"fmt"
)

import (
"github.com/arana-db/arana/pkg/merge"
)

var aggregatorMap = make(map[string]func() merge.Aggregator)

func init() {
aggregatorMap["MAX"] = func() merge.Aggregator { return &MaxAggregator{} }
aggregatorMap["MIN"] = func() merge.Aggregator { return &MinAggregator{} }
aggregatorMap["COUNT"] = func() merge.Aggregator { return &AddAggregator{} }
aggregatorMap["SUM"] = func() merge.Aggregator { return &AddAggregator{} }
}

func GetAggFromName(name string) func() merge.Aggregator {
if agg, ok := aggregatorMap[name]; ok {
return agg
}
panic(fmt.Errorf("aggregator %s not support yet", name))
}
44 changes: 44 additions & 0 deletions pkg/merge/aggregator/load_agg.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package aggregator

import (
"github.com/arana-db/arana/pkg/merge"
"github.com/arana-db/arana/pkg/runtime/ast"
)

func LoadAggs(fields []ast.SelectElement) map[int]func() merge.Aggregator {
var aggMap = make(map[int]func() merge.Aggregator)
enter := func(i int, n *ast.AggrFunction) {
if n == nil {
return
}
aggMap[i] = GetAggFromName(n.Name())
}

for i, field := range fields {
if field == nil {
continue
}
if f, ok := field.(*ast.SelectElementFunction); ok {
enter(i, f.Function().(*ast.AggrFunction))
}
}

return aggMap
}
35 changes: 28 additions & 7 deletions pkg/runtime/optimize/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (

import (
"github.com/arana-db/arana/pkg/dataset"
"github.com/arana-db/arana/pkg/merge/aggregator"
"github.com/arana-db/arana/pkg/proto"
"github.com/arana-db/arana/pkg/proto/rule"
"github.com/arana-db/arana/pkg/proto/schema_manager"
Expand Down Expand Up @@ -490,14 +491,34 @@ func (o optimizer) optimizeSelect(ctx context.Context, conn proto.VConn, stmt *r
}
}

// TODO: order/groupBy/aggregate
aggregate := &plan.AggregatePlan{
Plan: tmpPlan,
Combiner: transformer.NewCombinerManager(),
AggrLoader: transformer.LoadAggrs(stmt.Select),
convertOrderByItems := func(origins []*rast.OrderByItem) []dataset.OrderByItem {
var result = make([]dataset.OrderByItem, 0, len(origins))
for _, origin := range origins {
var columnName string
if cn, ok := origin.Expr.(rast.ColumnNameExpressionAtom); ok {
columnName = cn.Suffix()
}
result = append(result, dataset.OrderByItem{
Column: columnName,
Desc: origin.Desc,
})
}
return result
}
if stmt.GroupBy != nil {
return &plan.GroupPlan{
Plan: tmpPlan,
AggItems: aggregator.LoadAggs(stmt.Select),
OrderByItems: convertOrderByItems(stmt.OrderBy),
}, nil
} else {
// TODO: refactor groupby/orderby/aggregate plan to a unified plan
return &plan.AggregatePlan{
Plan: tmpPlan,
Combiner: transformer.NewCombinerManager(),
AggrLoader: transformer.LoadAggrs(stmt.Select),
}, nil
}

return aggregate, nil
}

//optimizeJoin ony support a join b in one db
Expand Down
71 changes: 71 additions & 0 deletions pkg/runtime/plan/group_plan.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package plan

import (
"context"
)

import (
"github.com/pkg/errors"
)

import (
"github.com/arana-db/arana/pkg/dataset"
"github.com/arana-db/arana/pkg/merge"
"github.com/arana-db/arana/pkg/proto"
"github.com/arana-db/arana/pkg/resultx"
)

// GroupPlan TODO now only support stmt which group by items equal with order by items, such as
// `select uid, max(score) from student group by uid order by uid`
type GroupPlan struct {
Plan proto.Plan
AggItems map[int]func() merge.Aggregator
OrderByItems []dataset.OrderByItem
}

func (g *GroupPlan) Type() proto.PlanType {
return proto.PlanTypeQuery
}

func (g *GroupPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) {
res, err := g.Plan.ExecIn(ctx, conn)
if err != nil {
return nil, errors.WithStack(err)
}

ds, err := res.Dataset()
if err != nil {
return nil, errors.WithStack(err)
}
fields, err := ds.Fields()
if err != nil {
return nil, errors.WithStack(err)
}

return resultx.New(resultx.WithDataset(dataset.Pipe(ds, dataset.GroupReduce(
g.OrderByItems,
func(fields []proto.Field) []proto.Field {
return fields
},
func() dataset.Reducer {
return dataset.NewGroupReducer(g.AggItems, fields)
},
)))), nil
}
9 changes: 4 additions & 5 deletions pkg/transformer/aggr_loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (
)

type AggrLoader struct {
Aggrs [][]string
Aggrs []string
Alias []string
Name []string
}
Expand All @@ -37,9 +37,8 @@ func LoadAggrs(fields []ast2.SelectElement) *AggrLoader {
if n == nil {
return
}
fieldAggr := make([]string, 0, 10)
fieldAggr = append(fieldAggr, n.Name())
aggrLoader.Aggrs = append(aggrLoader.Aggrs, fieldAggr)

aggrLoader.Aggrs = append(aggrLoader.Aggrs, n.Name())
for _, arg := range n.Args() {
switch arg.Value().(type) {
case ast2.ColumnNameExpressionAtom:
Expand All @@ -52,11 +51,11 @@ func LoadAggrs(fields []ast2.SelectElement) *AggrLoader {
}

for i, field := range fields {
aggrLoader.Alias[i] = field.Alias()
if field == nil {
continue
}
if f, ok := field.(*ast2.SelectElementFunction); ok {
aggrLoader.Alias[i] = field.Alias()
enter(f.Function().(*ast2.AggrFunction))
}
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/transformer/combiner.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func (c combinerManager) Merge(result proto.Result, loader *AggrLoader) (proto.R

mergeVals := make([]proto.Value, 0, len(loader.Aggrs))
for i := 0; i < len(loader.Aggrs); i++ {
switch loader.Aggrs[i][0] {
switch loader.Aggrs[i] {
case ast.AggrAvg:
mergeVals = append(mergeVals, gxbig.NewDecFromInt(0))
case ast.AggrMin:
Expand Down Expand Up @@ -118,7 +118,7 @@ func (c combinerManager) Merge(result proto.Result, loader *AggrLoader) (proto.R
return nil, errors.WithStack(err)
}

switch loader.Aggrs[aggrIdx][0] {
switch loader.Aggrs[aggrIdx] {
case ast.AggrMax:
if dummyVal.Compare(floatDecimal) < 0 {
dummyVal = floatDecimal
Expand Down

0 comments on commit c2fd6bf

Please sign in to comment.