Skip to content

Commit

Permalink
feat: support rate limit filter (#224)
Browse files Browse the repository at this point in the history
* feat: support rate limit filter

* fix: format error log
  • Loading branch information
dk-lockdown committed Aug 12, 2022
1 parent 606ebc7 commit 8d6877d
Show file tree
Hide file tree
Showing 9 changed files with 264 additions and 5 deletions.
1 change: 1 addition & 0 deletions cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import (
_ "github.com/cectc/dbpack/pkg/filter/crypto"
_ "github.com/cectc/dbpack/pkg/filter/dt"
_ "github.com/cectc/dbpack/pkg/filter/metrics"
_ "github.com/cectc/dbpack/pkg/filter/rate"
dbpackHttp "github.com/cectc/dbpack/pkg/http"
"github.com/cectc/dbpack/pkg/listener"
"github.com/cectc/dbpack/pkg/log"
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ require (
go.etcd.io/etcd/client/v3 v3.5.0-alpha.0
go.uber.org/atomic v1.9.0
go.uber.org/goleak v1.1.11
go.uber.org/ratelimit v0.2.1-0.20220713224938-b62b799bc9a5
go.uber.org/zap v1.21.0
golang.org/x/net v0.0.0-20220225172249-27dd8689420f
golang.org/x/text v0.3.7
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1273,6 +1273,8 @@ go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKY
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
go.uber.org/multierr v1.7.0 h1:zaiO/rmgFjbmCXdSYJWQcdvOCsthmdaHfr3Gm2Kx4Ec=
go.uber.org/multierr v1.7.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak=
go.uber.org/ratelimit v0.2.1-0.20220713224938-b62b799bc9a5 h1:ifl9jpZtYB7ZNbdfq0Ac+gwZQ2I+/dQUie3qye1/pIo=
go.uber.org/ratelimit v0.2.1-0.20220713224938-b62b799bc9a5/go.mod h1:So5LG7CV1zWpY1sHe+DXTJqQvOx+FFPFaAs2SnoyBaI=
go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA=
go.uber.org/zap v1.8.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
Expand Down
2 changes: 1 addition & 1 deletion pkg/filter/audit_log/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (factory *_factory) NewFilter(_ string, config map[string]interface{}) (pro
return nil, errors.Wrap(err, "marshal audit log filter config failed.")
}
if err = json.Unmarshal(content, &filterConfig); err != nil {
log.Errorf("unmarshal audit log filter failed, %s", err)
log.Errorf("unmarshal audit log filter failed, %v", err)
return nil, err
}
if filterConfig.MaxSize == 0 {
Expand Down
4 changes: 2 additions & 2 deletions pkg/filter/crypto/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
"fmt"
"strings"

"github.com/pingcap/errors"
"github.com/pkg/errors"

"github.com/cectc/dbpack/pkg/constant"
"github.com/cectc/dbpack/pkg/filter"
Expand Down Expand Up @@ -55,7 +55,7 @@ func (factory *_factory) NewFilter(_ string, config map[string]interface{}) (pro
ColumnCryptoList []*ColumnCrypto `yaml:"column_crypto_list" json:"column_crypto_list"`
}{}
if err = json.Unmarshal(content, &v); err != nil {
log.Errorf("unmarshal crypto filter failed, %s", err)
log.Errorf("unmarshal crypto filter failed, %v", err)
return nil, err
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/filter/dt/filter_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func (factory *_httpFactory) NewFilter(appid string, config map[string]interface
return nil, errors.Wrap(err, "marshal http distributed transaction filter config failed.")
}
if err = json.Unmarshal(content, &filterConfig); err != nil {
log.Errorf("unmarshal http distributed transaction filter failed, %s", err)
log.Errorf("unmarshal http distributed transaction filter failed, %v", err)
return nil, err
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/filter/dt/filter_mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (factory *_mysqlFactory) NewFilter(appid string, config map[string]interfac
LockRetryTimes int `yaml:"lock_retry_times" json:"lock_retry_times"`
}{}
if err = json.Unmarshal(content, v); err != nil {
log.Errorf("unmarshal mysql distributed transaction filter config failed, %s", err)
log.Errorf("unmarshal mysql distributed transaction filter config failed, %v", err)
return nil, err
}
if v.LockRetryInterval, err = time.ParseDuration(v.LockRetryIntervalStr); err != nil {
Expand Down
151 changes: 151 additions & 0 deletions pkg/filter/rate/limiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Copyright 2022 CECTC, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package rate

import (
"context"
"encoding/json"

"github.com/pkg/errors"
"go.uber.org/ratelimit"

"github.com/cectc/dbpack/pkg/constant"
"github.com/cectc/dbpack/pkg/filter"
"github.com/cectc/dbpack/pkg/log"
"github.com/cectc/dbpack/pkg/proto"
"github.com/cectc/dbpack/third_party/parser/ast"
)

const (
rateLimiterFilter = "RateLimiterFilter"
)

type _factory struct{}

func (factory *_factory) NewFilter(config map[string]interface{}) (proto.Filter, error) {
var (
err error
content []byte
conf *LimiterFilterConfig
insertLimiter ratelimit.Limiter
updateLimiter ratelimit.Limiter
deleteLimiter ratelimit.Limiter
selectLimiter ratelimit.Limiter
)
if content, err = json.Marshal(config); err != nil {
return nil, errors.Wrap(err, "marshal rate limit filter config failed.")
}
if err = json.Unmarshal(content, &conf); err != nil {
log.Errorf("unmarshal rate limit filter failed, %v", err)
return nil, err
}
if conf.InsertLimit != 0 {
insertLimiter = ratelimit.New(conf.InsertLimit)
}
if conf.UpdateLimit != 0 {
updateLimiter = ratelimit.New(conf.UpdateLimit)
}
if conf.DeleteLimit != 0 {
deleteLimiter = ratelimit.New(conf.DeleteLimit)
}
if conf.SelectLimit != 0 {
selectLimiter = ratelimit.New(conf.SelectLimit)
}
return &_filter{
insertLimiter: insertLimiter,
updateLimiter: updateLimiter,
deleteLimiter: deleteLimiter,
selectLimiter: selectLimiter,
}, nil
}

type LimiterFilterConfig struct {
InsertLimit int `yaml:"insert_limit" json:"insert_limit"`
UpdateLimit int `yaml:"update_limit" json:"update_limit"`
DeleteLimit int `yaml:"delete_limit" json:"delete_limit"`
SelectLimit int `yaml:"select_limit" json:"select_limit"`
}

type _filter struct {
insertLimiter ratelimit.Limiter
updateLimiter ratelimit.Limiter
deleteLimiter ratelimit.Limiter
selectLimiter ratelimit.Limiter
}

func (f *_filter) GetKind() string {
return rateLimiterFilter
}

func (f *_filter) PreHandle(ctx context.Context) error {
commandType := proto.CommandType(ctx)
switch commandType {
case constant.ComQuery:
stmt := proto.QueryStmt(ctx)
switch stmt.(type) {
case *ast.InsertStmt:
if f.insertLimiter != nil {
f.insertLimiter.Take()
}
case *ast.UpdateStmt:
if f.updateLimiter != nil {
f.updateLimiter.Take()
}
case *ast.DeleteStmt:
if f.deleteLimiter != nil {
f.deleteLimiter.Take()
}
case *ast.SelectStmt:
if f.selectLimiter != nil {
f.selectLimiter.Take()
}
default:
return nil
}

case constant.ComStmtExecute:
stmt := proto.PrepareStmt(ctx)
if stmt == nil {
return errors.New("prepare stmt should not be nil")
}
switch stmt.StmtNode.(type) {
case *ast.InsertStmt:
if f.insertLimiter != nil {
f.insertLimiter.Take()
}
case *ast.UpdateStmt:
if f.updateLimiter != nil {
f.updateLimiter.Take()
}
case *ast.DeleteStmt:
if f.deleteLimiter != nil {
f.deleteLimiter.Take()
}
case *ast.SelectStmt:
if f.selectLimiter != nil {
f.selectLimiter.Take()
}
default:
return nil
}
}
return nil
}

func init() {
filter.RegistryFilterFactory(rateLimiterFilter, &_factory{})
}
104 changes: 104 additions & 0 deletions pkg/filter/rate/limiter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright 2022 CECTC, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package rate

import (
"context"
"strings"
"testing"
"time"

"github.com/stretchr/testify/assert"

"github.com/cectc/dbpack/pkg/constant"
"github.com/cectc/dbpack/pkg/proto"
"github.com/cectc/dbpack/pkg/visitor"
"github.com/cectc/dbpack/third_party/parser"
)

func TestRateLimiter(t *testing.T) {
testCases := []string{
"insert_limit",
"update_limit",
"delete_limit",
"select_limit",
}
testSQLs := []string{
"insert into student (id, name, age) values (1, 'scott', 28)",
"update student set age = 30 where id = 1",
"delete from student where id = 1",
"select id, name, age from student where id = 1",
}
for _, tc := range testCases {
for _, sql := range testSQLs {
t.Run(strings.Join([]string{tc, sql}, "_"), func(t *testing.T) {
p := parser.New()
stmt, err := p.ParseOneStmt(sql, "", "")
assert.Nil(t, err)
stmt.Accept(&visitor.ParamVisitor{})

ctx := proto.WithCommandType(context.Background(), constant.ComQuery)
ctx = proto.WithQueryStmt(ctx, stmt)

filter, err := (&_factory{}).NewFilter(map[string]interface{}{
tc: 1,
})
assert.Nil(t, err)
f := filter.(proto.DBPreFilter)

for i := 0; i < 10; i++ {
err := f.PreHandle(ctx)
assert.Nil(t, err)
t.Log(time.Now())
}
})
}
}

testSQLs = []string{
"insert into student (id, name, age) values (?, ?, ?)",
"update student set age = ? where id = ?",
"delete from student where id = ?",
"select id, name, age from student where id = ?",
}
for _, tc := range testCases {
for _, sql := range testSQLs {
t.Run(strings.Join([]string{tc, sql}, "_"), func(t *testing.T) {
p := parser.New()
stmt, err := p.ParseOneStmt(sql, "", "")
assert.Nil(t, err)
stmt.Accept(&visitor.ParamVisitor{})

protoStmt := &proto.Stmt{StmtNode: stmt}
ctx := proto.WithCommandType(context.Background(), constant.ComStmtExecute)
ctx = proto.WithPrepareStmt(ctx, protoStmt)

filter, err := (&_factory{}).NewFilter(map[string]interface{}{
tc: 1,
})
assert.Nil(t, err)
f := filter.(proto.DBPreFilter)

for i := 0; i < 10; i++ {
err := f.PreHandle(ctx)
assert.Nil(t, err)
t.Log(time.Now())
}
})
}
}
}

0 comments on commit 8d6877d

Please sign in to comment.