Skip to content

Commit

Permalink
feat: audit log filter (#155)
Browse files Browse the repository at this point in the history
  • Loading branch information
dk-lockdown committed Jun 23, 2022
1 parent be37d88 commit e7f819d
Show file tree
Hide file tree
Showing 16 changed files with 310 additions and 16 deletions.
1 change: 1 addition & 0 deletions cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
"github.com/cectc/dbpack/pkg/dt/storage/etcd"
"github.com/cectc/dbpack/pkg/executor"
"github.com/cectc/dbpack/pkg/filter"
_ "github.com/cectc/dbpack/pkg/filter/audit_log"
_ "github.com/cectc/dbpack/pkg/filter/dt"
_ "github.com/cectc/dbpack/pkg/filter/metrics"
dbpackHttp "github.com/cectc/dbpack/pkg/http"
Expand Down
18 changes: 16 additions & 2 deletions docker/conf/config_sdb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,30 @@ data_source_cluster:
filters:
- metricFilter
- mysqlDTFilter
- auditLogFilter

filters:
- name: metricFilter
kind: ConnectionMetricFilter
- name: mysqlDTFilter
kind: MysqlDistributedTransaction
conf:
appid: svc
lock_retry_interval: 50ms
lock_retry_times: 30
- name: metricFilter
kind: ConnectionMetricFilter
- name: auditLogFilter
kind: AuditLogFilter
conf:
audit_log_dir: /var/log/dbpack/
# unit MB
max_size: 300
# unit Day
max_age: 28
# maximum number of old log files to retain
max_backups: 1
# determines if the rotated log files should be compressed using gzip
compress: true
record_before: true

distributed_transaction:
appid: svc
Expand Down
1 change: 1 addition & 0 deletions docker/docker-compose-sdb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ services:
volumes:
- ./conf/config_sdb.yaml:/config.yaml
- ./scripts/wait-for-mysql.sh:/wait-for-mysql.sh
- /root/:/var/log/:rw
depends_on:
- etcd
- mysql
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,6 @@ require (
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a // indirect
golang.org/x/tools v0.1.10 // indirect
google.golang.org/protobuf v1.27.1
gopkg.in/natefinch/lumberjack.v2 v2.0.0
k8s.io/apimachinery v0.23.5
)
4 changes: 3 additions & 1 deletion pkg/driver/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,9 @@ func (conn *BackendConnection) ReadQueryResult(wantFields bool) (result *mysql.R
}

result = &mysql.Result{
Fields: make([]*mysql.Field, colNumber),
AffectedRows: affectedRows,
InsertId: lastInsertID,
Fields: make([]*mysql.Field, colNumber),
}

// Read column headers. One packet per column.
Expand Down
2 changes: 1 addition & 1 deletion pkg/executor/single_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ func (executor *SingleDBExecutor) ExecutorComQuery(ctx context.Context, sql stri

func (executor *SingleDBExecutor) ExecutorComStmtExecute(ctx context.Context, stmt *proto.Stmt) (proto.Result, uint16, error) {
connectionID := proto.ConnectionID(ctx)
log.Debugf("connectionID: %d, prepare: %s", connectionID, stmt.PrepareStmt)
log.Debugf("connectionID: %d, prepare: %s", connectionID, stmt.SqlText)
txi, ok := executor.localTransactionMap.Load(connectionID)
if ok {
tx := txi.(proto.Tx)
Expand Down
241 changes: 241 additions & 0 deletions pkg/filter/audit_log/filter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
/*
* Copyright 2022 CECTC, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package audit_log

import (
"context"
"encoding/json"
"fmt"
"path/filepath"
"strings"

"github.com/golang-module/carbon"
"github.com/pkg/errors"
"gopkg.in/natefinch/lumberjack.v2"

"github.com/cectc/dbpack/pkg/constant"
"github.com/cectc/dbpack/pkg/filter"
"github.com/cectc/dbpack/pkg/log"
"github.com/cectc/dbpack/pkg/proto"
"github.com/cectc/dbpack/third_party/parser/ast"
)

const (
auditLogFilter = "AuditLogFilter"
defaultMaxSize = 500
defaultMaxBackups = 1
defaultMaxAge = 30
)

type _factory struct {
}

func (factory *_factory) NewFilter(config map[string]interface{}) (proto.Filter, error) {
var (
err error
content []byte
filterConfig *AuditLogFilterConfig
)

if content, err = json.Marshal(config); err != nil {
return nil, errors.Wrap(err, "marshal audit log filter config failed.")
}
if err = json.Unmarshal(content, &filterConfig); err != nil {
log.Errorf("unmarshal audit log filter failed, %s", err)
return nil, err
}
if filterConfig.MaxSize == 0 {
filterConfig.MaxSize = defaultMaxSize
}
if filterConfig.MaxBackups == 0 {
filterConfig.MaxBackups = defaultMaxBackups
}
if filterConfig.MaxAge == 0 {
filterConfig.MaxAge = defaultMaxAge
}
logger := &lumberjack.Logger{
Filename: auditLogFile(filterConfig.AuditLogDir),
MaxSize: filterConfig.MaxSize,
MaxBackups: filterConfig.MaxBackups,
MaxAge: filterConfig.MaxAge,
Compress: filterConfig.Compress,
}
return &_filter{recordBefore: filterConfig.RecordBefore, log: logger}, nil
}

type AuditLogFilterConfig struct {
AuditLogDir string `json:"audit_log_dir" yaml:"audit_log_dir"`
// MaxSize is the maximum size in megabytes of the log file before it gets rotated
MaxSize int `json:"max_size" yaml:"max_size"`
// MaxAge is the maximum number of days to retain old log files
MaxAge int `json:"max_age" yaml:"max_age"`
// MaxBackups maximum number of old log files to retain
MaxBackups int `json:"max_backups" yaml:"max_backups"`
// Compress determines if the rotated log files should be compressed using gzip
Compress bool `json:"compress" yaml:"compress"`
// RecordBefore define whether to log before or after sql execution
RecordBefore bool `json:"record_before" yaml:"record_before"`
}

type _filter struct {
recordBefore bool
log *lumberjack.Logger
}

func (f *_filter) GetKind() string {
return auditLogFilter
}

func (f *_filter) PreHandle(ctx context.Context, conn proto.Connection) error {
if !f.recordBefore {
return nil
}
userName := proto.UserName(ctx)
remoteAddr := proto.RemoteAddr(ctx)
connectionID := proto.ConnectionID(ctx)
commandType := proto.CommandType(ctx)
sqlText := proto.SqlText(ctx)

var (
commandTypeStr string
args strings.Builder
stmt ast.Node
)
args.WriteByte('[')
switch commandType {
case constant.ComQuery:
commandTypeStr = "COM_QUERY"
stmt = proto.QueryStmt(ctx)
case constant.ComStmtExecute:
commandTypeStr = "COM_STMT_EXECUTE"
statement := proto.PrepareStmt(ctx)
stmt = statement.StmtNode
for i := 0; i < len(statement.BindVars); i++ {
parameterID := fmt.Sprintf("v%d", i+1)
param := statement.BindVars[parameterID]
switch arg := param.(type) {
case []byte, string:
args.WriteString(fmt.Sprintf("'%s'", arg))
case nil:
args.WriteString("NULL")
default:
args.WriteString(fmt.Sprintf("'%v'", arg))
}
if i < len(statement.BindVars)-1 {
args.WriteByte(' ')
}
}
default:
return nil
}
args.WriteByte(']')

var command string
switch stmt.(type) {
case *ast.DeleteStmt:
command = "DELETE"
case *ast.InsertStmt:
command = "INSERT"
case *ast.UpdateStmt:
command = "UPDATE"
case *ast.SelectStmt:
command = "SELECT"
default:
}

if _, err := f.log.Write([]byte(fmt.Sprintf("%s,%s,%s,%v,%s,%s,%s,%s,0\n", carbon.Now(), userName, remoteAddr, connectionID,
commandTypeStr, command, sqlText, args.String()))); err != nil {
return err
}
return nil
}

func (f *_filter) PostHandle(ctx context.Context, result proto.Result, conn proto.Connection) error {
if f.recordBefore {
return nil
}
userName := proto.UserName(ctx)
remoteAddr := proto.RemoteAddr(ctx)
connectionID := proto.ConnectionID(ctx)
commandType := proto.CommandType(ctx)
sqlText := proto.SqlText(ctx)

var (
commandTypeStr string
args strings.Builder
stmt ast.Node
)
args.WriteByte('[')
switch commandType {
case constant.ComQuery:
commandTypeStr = "COM_QUERY"
stmt = proto.QueryStmt(ctx)
case constant.ComStmtExecute:
commandTypeStr = "COM_STMT_EXECUTE"
statement := proto.PrepareStmt(ctx)
stmt = statement.StmtNode
for i := 0; i < len(statement.BindVars); i++ {
parameterID := fmt.Sprintf("v%d", i+1)
param := statement.BindVars[parameterID]
switch arg := param.(type) {
case []byte, string:
args.WriteString(fmt.Sprintf("'%s'", arg))
case nil:
args.WriteString("NULL")
default:
args.WriteString(fmt.Sprintf("'%v'", arg))
}
if i < len(statement.BindVars)-1 {
args.WriteByte(' ')
}
}
default:
return nil
}
args.WriteByte(']')

var command string
switch stmt.(type) {
case *ast.DeleteStmt:
command = "DELETE"
case *ast.InsertStmt:
command = "INSERT"
case *ast.UpdateStmt:
command = "UPDATE"
case *ast.SelectStmt:
command = "SELECT"
default:
}

affected, err := result.RowsAffected()
if err != nil {
return err
}
if _, err := f.log.Write([]byte(fmt.Sprintf("%s,%s,%s,%v,%s,%s,%s,%s,%v\n", carbon.Now(), userName, remoteAddr, connectionID,
commandTypeStr, command, sqlText, args.String(), affected))); err != nil {
return err
}
return nil
}

func auditLogFile(dir string) string {
return filepath.Join(dir, "audit.log")
}

func init() {
filter.RegistryFilterFactory(auditLogFilter, &_factory{})
}
2 changes: 1 addition & 1 deletion pkg/filter/dt/exec/prepare_global_lock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ func TestPrepareGlobalLock(t *testing.T) {
ctx := proto.WithCommandType(context.Background(), constant.ComStmtExecute)
protoStmt := &proto.Stmt{
StatementID: 1,
PrepareStmt: c.sql,
SqlText: c.sql,
ParamsCount: 1,
ParamData: nil,
ParamsType: nil,
Expand Down
2 changes: 1 addition & 1 deletion pkg/filter/dt/exec/prepare_select_for_update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func TestPrepareSelectForUpdate(t *testing.T) {
ctx := proto.WithCommandType(context.Background(), constant.ComStmtExecute)
protoStmt := &proto.Stmt{
StatementID: 1,
PrepareStmt: c.sql,
SqlText: c.sql,
ParamsCount: 1,
ParamData: nil,
ParamsType: nil,
Expand Down
2 changes: 1 addition & 1 deletion pkg/filter/dt/exec/query_global_lock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func TestQueryGlobalLock(t *testing.T) {
ctx := proto.WithCommandType(context.Background(), constant.ComStmtExecute)
protoStmt := &proto.Stmt{
StatementID: 1,
PrepareStmt: c.sql,
SqlText: c.sql,
ParamsCount: 1,
ParamData: nil,
ParamsType: nil,
Expand Down
2 changes: 1 addition & 1 deletion pkg/filter/dt/exec/query_select_for_update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func TestQuerySelectForUpdate(t *testing.T) {
ctx := proto.WithCommandType(context.Background(), constant.ComStmtExecute)
protoStmt := &proto.Stmt{
StatementID: 1,
PrepareStmt: c.sql,
SqlText: c.sql,
ParamsCount: 1,
ParamData: nil,
ParamsType: nil,
Expand Down
7 changes: 5 additions & 2 deletions pkg/listener/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ func (l *MysqlListener) handle(conn net.Conn, connectionID uint32) {
ctx := proto.WithVariableMap(context.Background())
ctx = proto.WithConnectionID(ctx, connectionID)
ctx = proto.WithUserName(ctx, c.UserName())
ctx = proto.WithRemoteAddr(ctx, c.RemoteAddr().String())
ctx = proto.WithSchema(ctx, l.schemaName)
newCtx, span := tracing.GetTraceSpan(ctx, "mysql_handle")
err = l.ExecuteCommand(newCtx, c, content)
Expand Down Expand Up @@ -571,6 +572,7 @@ func (l *MysqlListener) ExecuteCommand(ctx context.Context, c *mysql.Conn, data

ctx = proto.WithCommandType(newCtx, commandType)
ctx = proto.WithQueryStmt(ctx, stmt)
ctx = proto.WithSqlText(ctx, query)
result, warn, err := l.executor.ExecutorComQuery(ctx, query)
if err != nil {
if writeErr := c.WriteErrorPacketFromError(err); writeErr != nil {
Expand Down Expand Up @@ -680,10 +682,10 @@ func (l *MysqlListener) ExecuteCommand(ctx context.Context, c *mysql.Conn, data
l.statementID.Inc()
stmt := &proto.Stmt{
StatementID: l.statementID.Load(),
PrepareStmt: query,
SqlText: query,
}
p := parser.New()
act, err := p.ParseOneStmt(stmt.PrepareStmt, "", "")
act, err := p.ParseOneStmt(stmt.SqlText, "", "")
if err != nil {
log.Errorf("Conn %v: Error parsing prepared statement: %v", c, err)
if writeErr := c.WriteErrorPacketFromError(err); writeErr != nil {
Expand Down Expand Up @@ -748,6 +750,7 @@ func (l *MysqlListener) ExecuteCommand(ctx context.Context, c *mysql.Conn, data

ctx = proto.WithCommandType(newCtx, commandType)
ctx = proto.WithPrepareStmt(ctx, stmt)
ctx = proto.WithSqlText(ctx, stmt.SqlText)
result, warn, err := l.executor.ExecutorComStmtExecute(ctx, stmt)
if err != nil {
if writeErr := c.WriteErrorPacketFromError(err); writeErr != nil {
Expand Down

0 comments on commit e7f819d

Please sign in to comment.