Skip to content

Commit

Permalink
add sql context
Browse files Browse the repository at this point in the history
  • Loading branch information
Blank committed Sep 4, 2020
1 parent 5d970a6 commit 046769f
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 21 deletions.
40 changes: 25 additions & 15 deletions adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package sqlxadapter

import (
"bytes"
"context"
"database/sql"
"errors"
"fmt"
Expand Down Expand Up @@ -47,6 +48,7 @@ type CasbinRule struct {
// It can load policy lines or save policy lines from sqlx connected database.
type Adapter struct {
db *sqlx.DB
ctx context.Context
tableName string

isFiltered bool
Expand Down Expand Up @@ -80,6 +82,13 @@ type Filter struct {
// db should connected to database and controlled by user.
// If tableName == "", the Adapter will automatically create a table named 'CASBIN_RULE'.
func NewAdapter(db *sqlx.DB, tableName string) (*Adapter, error) {
return NewAdapterContext(context.Background(), db, tableName)
}

// NewAdapterContext the constructor for Adapter.
// db should connected to database and controlled by user.
// If tableName == "", the Adapter will automatically create a table named 'CASBIN_RULE'.
func NewAdapterContext(ctx context.Context, db *sqlx.DB, tableName string) (*Adapter, error) {
if db == nil {
return nil, errors.New("db is nil")
}
Expand All @@ -91,7 +100,7 @@ func NewAdapter(db *sqlx.DB, tableName string) (*Adapter, error) {
}

// check db connecting
err := db.Ping()
err := db.PingContext(ctx)
if err != nil {
return nil, err
}
Expand All @@ -102,6 +111,7 @@ func NewAdapter(db *sqlx.DB, tableName string) (*Adapter, error) {

adapter := Adapter{
db: db,
ctx: ctx,
tableName: tableName,
}

Expand Down Expand Up @@ -174,7 +184,7 @@ func (p *Adapter) genParams() {
// createTable create a not exists table.
func (p *Adapter) createTable() (err error) {
for _, query := range p.sqlCreateTable {
if _, err = p.db.Exec(query); err != nil {
if _, err = p.db.ExecContext(p.ctx, query); err != nil {
return
}
}
Expand All @@ -184,21 +194,21 @@ func (p *Adapter) createTable() (err error) {

// truncateTable clear the table.
func (p *Adapter) truncateTable() error {
_, err := p.db.Exec(p.sqlTruncateTable)
_, err := p.db.ExecContext(p.ctx, p.sqlTruncateTable)

return err
}

// isTableExist check the table exists.
func (p *Adapter) isTableExist() bool {
_, err := p.db.Query(p.sqlIsTableExist)
_, err := p.db.QueryContext(p.ctx, p.sqlIsTableExist)

return err == nil
}

// deleteRows delete eligible data.
func (p *Adapter) deleteRows(query string, args ...interface{}) error {
_, err := p.db.Exec(query, args...)
_, err := p.db.ExecContext(p.ctx, query, args...)

return err
}
Expand All @@ -209,19 +219,18 @@ func (p *Adapter) truncateAndInsertRows(args [][]interface{}) (err error) {
return
}

tx, err := p.db.Beginx()
tx, err := p.db.BeginTx(p.ctx, nil)
if err != nil {
return
}

var action string
// if _, err = tx.Exec(p.sqlDeleteAll); err != nil {
var sqlBuf bytes.Buffer
// if _, err = tx.ExecContext(p.ctx, p.sqlDeleteAll); err != nil {
// action = "delete all"
// goto ROLLBACK
// }

var sqlBuf bytes.Buffer

for _, arg := range args {
l := len(arg)
if l == 0 {
Expand All @@ -234,8 +243,8 @@ func (p *Adapter) truncateAndInsertRows(args [][]interface{}) (err error) {
sqlBuf.WriteString(" VALUES ")
sqlBuf.Write(p.placeholders[l-1])

if _, err = tx.Exec(sqlBuf.String(), arg...); err != nil {
action = "exec"
if _, err = tx.ExecContext(p.ctx, sqlBuf.String(), arg...); err != nil {
action = "exec context"
goto ROLLBACK
}

Expand All @@ -261,8 +270,9 @@ ROLLBACK:
// selectRows select all data from the table.
func (p *Adapter) selectRows(query string, args ...interface{}) (lines []CasbinRule, err error) {
// make a slice with capacity
lines = make([]CasbinRule, 0, 128)
err = p.db.Select(&lines, query, args...)
lines = make([]CasbinRule, 0, 64)

err = p.db.SelectContext(p.ctx, &lines, query, args...)

return lines, nil
}
Expand Down Expand Up @@ -344,7 +354,7 @@ func (p *Adapter) LoadPolicy(model model.Model) error {

// SavePolicy save policy rules to the storage.
func (p *Adapter) SavePolicy(model model.Model) error {
args := make([][]interface{}, 0, 32)
args := make([][]interface{}, 0, 64)

for ptype, ast := range model["p"] {
for _, rule := range ast.Policy {
Expand Down Expand Up @@ -376,7 +386,7 @@ func (p *Adapter) AddPolicy(sec string, ptype string, rule []string) error {
sqlBuf.WriteString(" VALUES ")
sqlBuf.Write(p.placeholders[idx])

_, err := p.db.Exec(sqlBuf.String(), args...)
_, err := p.db.ExecContext(p.ctx, sqlBuf.String(), args...)

return err
}
Expand Down
22 changes: 16 additions & 6 deletions adapter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ package sqlxadapter
import (
"database/sql"
"flag"
"strings"
"testing"

"github.com/casbin/casbin/v2"
"github.com/casbin/casbin/v2/util"
"github.com/jmoiron/sqlx"

_ "github.com/mattn/go-oci8"
Expand Down Expand Up @@ -148,7 +148,7 @@ func testSQL(t *testing.T, db *sqlx.DB, tableName string) {
var err error
logSQLErr := func(action string) {
if err != nil {
t.Fatalf("%s test failed, err: %v", action, err)
t.Errorf("%s test failed, err: %v", action, err)
}
}

Expand Down Expand Up @@ -288,7 +288,7 @@ func testAutoSave(t *testing.T, db *sqlx.DB, tableName string) {
var err error
logErr := func(action string) {
if err != nil {
t.Fatalf("%s test failed, err: %v", action, err)
t.Errorf("%s test failed, err: %v", action, err)
}
}

Expand Down Expand Up @@ -348,7 +348,7 @@ func testFilteredPolicy(t *testing.T, db *sqlx.DB, tableName string) {
var err error
logErr := func(action string) {
if err != nil {
t.Fatalf("%s test failed, err: %v", action, err)
t.Errorf("%s test failed, err: %v", action, err)
}
}

Expand Down Expand Up @@ -385,7 +385,17 @@ func testGetPolicy(t *testing.T, e *casbin.Enforcer, res [][]string) {
myRes := e.GetPolicy()
t.Log("Policy: ", myRes)

if !util.Array2DEquals(res, myRes) {
t.Error("Policy: ", myRes, ", supposed to be ", res)
m := make(map[string]struct{}, len(myRes))
for _, record := range myRes {
key := strings.Join(record, ",")
m[key] = struct{}{}
}

for _, record := range res {
key := strings.Join(record, ",")
if _, ok := m[key]; !ok {
t.Error("Policy: ", myRes, ", supposed to be ", res)
break
}
}
}

0 comments on commit 046769f

Please sign in to comment.