Skip to content

Commit

Permalink
specify table name and not to set unique index (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
JunNishimura committed Mar 30, 2024
1 parent 2e3deed commit fef1a94
Showing 1 changed file with 10 additions and 58 deletions.
68 changes: 10 additions & 58 deletions adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@ import (
"github.com/uptrace/bun/extra/bundebug"
)

const (
defaultTableName = "casbin_policy"
)

var (
// check if the bunAdapter implements the Adapter interface
_ persist.Adapter = (*bunAdapter)(nil)
Expand All @@ -34,18 +30,11 @@ var (

type bunAdapter struct {
db *bun.DB
tableName string
debugMode bool
}

type adapterOption func(*bunAdapter)

func WithTableName(tableName string) adapterOption {
return func(a *bunAdapter) {
a.tableName = tableName
}
}

func WithDebugMode() adapterOption {
return func(a *bunAdapter) {
a.debugMode = true
Expand Down Expand Up @@ -80,8 +69,7 @@ func newAdapter(driverName, dataSourceName string) (*bunAdapter, error) {
}

return &bunAdapter{
db: db,
tableName: defaultTableName,
db: db,
}, nil
}

Expand Down Expand Up @@ -114,41 +102,20 @@ func connectDB(driverName, dataSourceName string) (*bun.DB, error) {
}

func (a *bunAdapter) createTable() error {
_, err := a.db.NewCreateTable().
if _, err := a.db.NewCreateTable().
Model((*CasbinPolicy)(nil)).
ModelTableExpr(a.tableName).
IfNotExists().
Exec(context.Background())
tableNameForHook = a.tableName // pass the tableName field to the hook function
return err
}

var (
_ bun.AfterCreateTableHook = (*CasbinPolicy)(nil)
// TODO: find a better way to pass the tableName field to the hook function
// Originally, we want to use the value of the tableName field of the bunAdapter
// but hook function cannot access the field of the struct `bunAdapter`
// so we use a global variable to store the value of the tableName field
tableNameForHook = defaultTableName
)

func (*CasbinPolicy) AfterCreateTable(ctx context.Context, query *bun.CreateTableQuery) error {
_, err := query.DB().NewCreateIndex().
Model((*CasbinPolicy)(nil)).
ModelTableExpr(tableNameForHook).
Unique().
Index("idx_ptype_v0_v1_v2_v3_v4_v5").
Column("ptype", "v0", "v1", "v2", "v3", "v4", "v5").
Exec(ctx)
return err
Exec(context.Background()); err != nil {
return err
}
return nil
}

// LoadPolicy loads all policy rules from the storage.
func (a *bunAdapter) LoadPolicy(model model.Model) error {
var policies []CasbinPolicy
err := a.db.NewSelect().
Model(&policies).
ModelTableExpr(a.tableName).
Scan(context.Background())
if err != nil {
return err
Expand Down Expand Up @@ -207,26 +174,21 @@ func (a *bunAdapter) savePolicyRecords(policies []CasbinPolicy) error {
// bulk insert new policies
if _, err := a.db.NewInsert().
Model(&policies).
ModelTableExpr(a.tableName).
Exec(context.Background()); err != nil {
return err
}

return nil
}

// drop and recreate the table
// truncate tables
func (a *bunAdapter) refreshTable() error {
// just truncate the table could be a better choice
// but NewTruncateTable() does not support ModelTableExpr
// so we drop and recreate the table instead
if _, err := a.db.NewDropTable().
ModelTableExpr(a.tableName).
IfExists().
if _, err := a.db.NewTruncateTable().
Model((*CasbinPolicy)(nil)).
Exec(context.Background()); err != nil {
return err
}
return a.createTable()
return nil
}

// AddPolicy adds a policy rule to the storage.
Expand All @@ -235,7 +197,6 @@ func (a *bunAdapter) AddPolicy(sec string, ptype string, rule []string) error {
newPolicy := newCasbinPolicy(ptype, rule)
if _, err := a.db.NewInsert().
Model(&newPolicy).
ModelTableExpr(a.tableName).
Exec(context.Background()); err != nil {
return err
}
Expand All @@ -251,7 +212,6 @@ func (a *bunAdapter) AddPolicies(sec string, ptype string, rules [][]string) err
}
if _, err := a.db.NewInsert().
Model(&policies).
ModelTableExpr(a.tableName).
Exec(context.Background()); err != nil {
return err
}
Expand Down Expand Up @@ -284,7 +244,6 @@ func (a *bunAdapter) RemovePolicies(sec string, ptype string, rules [][]string)

func (a *bunAdapter) deleteRecord(existingPolicy CasbinPolicy) error {
query := a.db.NewDelete().
ModelTableExpr(a.tableName).
Where("ptype = ?", existingPolicy.PType)

values := existingPolicy.filterValuesWithKey()
Expand All @@ -294,7 +253,6 @@ func (a *bunAdapter) deleteRecord(existingPolicy CasbinPolicy) error {

func (a *bunAdapter) deleteRecordInTx(tx bun.Tx, existingPolicy CasbinPolicy) error {
query := tx.NewDelete().
ModelTableExpr(a.tableName).
Where("ptype = ?", existingPolicy.PType)

values := existingPolicy.filterValuesWithKey()
Expand Down Expand Up @@ -327,7 +285,6 @@ func (a *bunAdapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex i

func (a *bunAdapter) deleteFilteredPolicy(ptype string, fieldIndex int, fieldValues ...string) error {
query := a.db.NewDelete().
ModelTableExpr(a.tableName).
Where("ptype = ?", ptype)

// Note that empty string in fieldValues could be any word.
Expand Down Expand Up @@ -398,7 +355,6 @@ func (a *bunAdapter) UpdatePolicy(sec string, ptype string, oldRule, newRule []s
func (a *bunAdapter) updateRecord(oldPolicy, newPolicy CasbinPolicy) error {
query := a.db.NewUpdate().
Model(&newPolicy).
ModelTableExpr(a.tableName).
Where("ptype = ?", oldPolicy.PType)

values := oldPolicy.filterValuesWithKey()
Expand All @@ -409,7 +365,6 @@ func (a *bunAdapter) updateRecord(oldPolicy, newPolicy CasbinPolicy) error {
func (a *bunAdapter) updateRecordInTx(tx bun.Tx, oldPolicy, newPolicy CasbinPolicy) error {
query := tx.NewUpdate().
Model(&newPolicy).
ModelTableExpr(a.tableName).
Where("ptype = ?", oldPolicy.PType)

values := oldPolicy.filterValuesWithKey()
Expand Down Expand Up @@ -463,10 +418,8 @@ func (a *bunAdapter) UpdateFilteredPolicies(sec string, ptype string, newRules [
}

selectQuery := tx.NewSelect().
ModelTableExpr(a.tableName).
Where("ptype = ?", ptype)
deleteQuery := tx.NewDelete().
ModelTableExpr(a.tableName).
Where("ptype = ?", ptype)

// Note that empty string in fieldValues could be any word.
Expand Down Expand Up @@ -551,7 +504,6 @@ func (a *bunAdapter) UpdateFilteredPolicies(sec string, ptype string, newRules [
// create new policies
if _, err := tx.NewInsert().
Model(&newPolicies).
ModelTableExpr(a.tableName).
Exec(context.Background()); err != nil {
if err := tx.Rollback(); err != nil {
return nil, err
Expand Down

0 comments on commit fef1a94

Please sign in to comment.