Permalink
Browse files

Escape query (#23)

* use shared config struct

* escape query for insert update and delete

* escape field on query builder

* wip aggregate function

* fix count on postgres

* test for query

* escape function argument and cache escaped field

* combine limit and offset

* added test for aggregate

* fix tests
  • Loading branch information...
Fs02 committed Jun 2, 2018
1 parent f1bb7a1 commit 77eb2daf614a99a9a98a396296070981e06d4edb
@@ -2,12 +2,12 @@ package grimoire

// Adapter interface
type Adapter interface {
Count(Query, ...Logger) (int, error)
All(Query, interface{}, ...Logger) (int, error)
Delete(Query, ...Logger) error
Aggregate(Query, interface{}, ...Logger) error
Insert(Query, map[string]interface{}, ...Logger) (interface{}, error)
InsertAll(Query, []string, []map[string]interface{}, ...Logger) ([]interface{}, error)
Update(Query, map[string]interface{}, ...Logger) error
Delete(Query, ...Logger) error

Begin() (Adapter, error)
Commit() error
@@ -34,7 +34,16 @@ var _ grimoire.Adapter = (*Adapter)(nil)
func Open(dsn string) (*Adapter, error) {
var err error

adapter := &Adapter{sql.New(errorFunc, incrementFunc, sql.Placeholder("?"))}
adapter := &Adapter{
Adapter: &sql.Adapter{
Config: &sql.Config{
Placeholder: "?",
EscapeChar: "`",
IncrementFunc: incrementFunc,
ErrorFunc: errorFunc,
},
},
}
adapter.DB, err = db.Open("mysql", dsn)

return adapter, err
@@ -61,7 +61,7 @@ func dsn() string {
return "root@(127.0.0.1:3306)/grimoire_test?charset=utf8&parseTime=True&loc=Local"
}

func TestAdapter__specs(t *testing.T) {
func TestAdapter_specs(t *testing.T) {
adapter, err := Open(dsn())
paranoid.Panic(err, "failed to open database connection")
defer adapter.Close()
@@ -76,8 +76,8 @@ func TestAdapter__specs(t *testing.T) {
// Preload specs
specs.Preload(t, repo)

// Count Specs
specs.Count(t, repo)
// Aggregate Specs
specs.Aggregate(t, repo)

// Insert Specs
specs.Insert(t, repo)
@@ -33,10 +33,16 @@ var _ grimoire.Adapter = (*Adapter)(nil)
func Open(dsn string) (*Adapter, error) {
var err error

adapter := &Adapter{sql.New(errorFunc, nil,
sql.Placeholder("$"),
sql.Ordinal(true),
sql.InsertDefaultValues(true)),
adapter := &Adapter{
Adapter: &sql.Adapter{
Config: &sql.Config{
Placeholder: "$",
EscapeChar: "\"",
Ordinal: true,
InsertDefaultValues: true,
ErrorFunc: errorFunc,
},
},
}
adapter.DB, err = db.Open("postgres", dsn)

@@ -45,9 +51,7 @@ func Open(dsn string) (*Adapter, error) {

// Insert inserts a record to database and returns its id.
func (adapter *Adapter) Insert(query grimoire.Query, changes map[string]interface{}, loggers ...grimoire.Logger) (interface{}, error) {
statement, args := sql.NewBuilder(adapter.Placeholder, adapter.Ordinal, adapter.InsertDefaultValues).
Returning("id").
Insert(query.Collection, changes)
statement, args := sql.NewBuilder(adapter.Config).Returning("id").Insert(query.Collection, changes)

var result struct {
ID int64
@@ -59,7 +63,7 @@ func (adapter *Adapter) Insert(query grimoire.Query, changes map[string]interfac

// InsertAll inserts multiple records to database and returns its ids.
func (adapter *Adapter) InsertAll(query grimoire.Query, fields []string, allchanges []map[string]interface{}, loggers ...grimoire.Logger) ([]interface{}, error) {
statement, args := sql.NewBuilder(adapter.Placeholder, adapter.Ordinal, adapter.InsertDefaultValues).Returning("id").InsertAll(query.Collection, fields, allchanges)
statement, args := sql.NewBuilder(adapter.Config).Returning("id").InsertAll(query.Collection, fields, allchanges)

var result []struct {
ID int64
@@ -81,11 +85,8 @@ func (adapter *Adapter) Begin() (grimoire.Adapter, error) {

return &Adapter{
&sql.Adapter{
Placeholder: adapter.Placeholder,
Ordinal: adapter.Ordinal,
IncrementFunc: adapter.IncrementFunc,
ErrorFunc: adapter.ErrorFunc,
Tx: Tx,
Config: adapter.Config,
Tx: Tx,
},
}, err
}
@@ -61,7 +61,7 @@ func dsn() string {
return "postgres://postgres@localhost/grimoire_test?sslmode=disable"
}

func TestAdapter__specs(t *testing.T) {
func TestAdapter_specs(t *testing.T) {
adapter, err := Open(dsn())
paranoid.Panic(err, "failed to open database connection")
defer adapter.Close()
@@ -76,8 +76,8 @@ func TestAdapter__specs(t *testing.T) {
// Preload specs
specs.Preload(t, repo)

// Count Specs
specs.Count(t, repo)
// Aggregate Specs
specs.Aggregate(t, repo)

// Insert Specs
specs.Insert(t, repo)
@@ -8,8 +8,8 @@ import (
"github.com/stretchr/testify/assert"
)

// Count tests count specifications.
func Count(t *testing.T, repo grimoire.Repo) {
// Aggregate tests count specifications.
func Aggregate(t *testing.T, repo grimoire.Repo) {
// preparte tests data
user := User{Name: "name1", Gender: "male", Age: 10}
repo.From(users).MustSave(&user)
@@ -36,12 +36,24 @@ func Count(t *testing.T, repo grimoire.Repo) {
repo.From(users).Where(c.NotLike(name, "noname%")),
repo.From(users).Where(c.Fragment("id > 0")),
repo.From(users).Where(c.Not(c.Eq(id, 1), c.Eq(name, "name1"), c.Eq(age, 10))),
repo.From(users).Group("gender"),
repo.From(users).Group("age").Having(c.Gt(age, 10)),
}

for _, query := range tests {
statement, _ := builder.Find(query.Select("COUNT(*) AS count"))
t.Run("Count|"+statement, func(t *testing.T) {
_, err := query.Count()
field := "*"
if len(query.GroupFields) != 0 {
field = query.GroupFields[0]
}

statement, _ := builder.Find(query.Select(field, "count("+field+") AS sum"))
t.Run("Aggregate|"+statement, func(t *testing.T) {
var out []struct {
Count int
}

err := query.Aggregate("count", field, &out)
assert.True(t, len(out) > 0)
assert.Nil(t, err)
})
}
@@ -55,7 +55,10 @@ const (
address = c.I("address")
)

var builder = sql.NewBuilder("?", false, false)
var builder = sql.NewBuilder(&sql.Config{
Placeholder: "?",
EscapeChar: "`",
})

func assertConstraint(t *testing.T, err error, kind errors.Kind, field string) {
assert.NotNil(t, err)
Oops, something went wrong.

0 comments on commit 77eb2da

Please sign in to comment.