Skip to content

Commit

Permalink
Exports Upsert attributes
Browse files Browse the repository at this point in the history
When dialects are in separate packages, they will need to access the Upsert attributes
  • Loading branch information
cdevienne committed Dec 20, 2016
1 parent 7fbe2ae commit b0d64ed
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 20 deletions.
6 changes: 3 additions & 3 deletions dialect_mysql.go
Expand Up @@ -85,14 +85,14 @@ func (MysqlCompiler) VisitUpsert(context *CompilerContext, upsert UpsertStmt) st
values []string
)

for k, v := range upsert.values {
for k, v := range upsert.ValuesMap {
colNames = append(colNames, context.Compiler.VisitLabel(context, k))
context.Binds = append(context.Binds, v)
values = append(values, "?")
}

updates := []string{}
for k, v := range upsert.values {
for k, v := range upsert.ValuesMap {
updates = append(updates, fmt.Sprintf(
"%s = %s",
context.Dialect.Escape(k),
Expand All @@ -103,7 +103,7 @@ func (MysqlCompiler) VisitUpsert(context *CompilerContext, upsert UpsertStmt) st

sql := fmt.Sprintf(
"INSERT INTO %s(%s)\nVALUES(%s)\nON DUPLICATE KEY UPDATE %s",
context.Dialect.Escape(upsert.table.Name),
context.Dialect.Escape(upsert.Table.Name),
strings.Join(colNames, ", "),
strings.Join(values, ", "),
strings.Join(updates, ", "),
Expand Down
12 changes: 6 additions & 6 deletions dialect_postgres.go
Expand Up @@ -97,14 +97,14 @@ func (PostgresCompiler) VisitUpsert(context *CompilerContext, upsert UpsertStmt)
colNames []string
values []string
)
for k, v := range upsert.values {
for k, v := range upsert.ValuesMap {
colNames = append(colNames, context.Compiler.VisitLabel(context, k))
context.Binds = append(context.Binds, v)
values = append(values, fmt.Sprintf("$%d", len(context.Binds)))
}

var updates []string
for k, v := range upsert.values {
for k, v := range upsert.ValuesMap {
context.Binds = append(context.Binds, v)
updates = append(updates, fmt.Sprintf(
"%s = %s",
Expand All @@ -114,23 +114,23 @@ func (PostgresCompiler) VisitUpsert(context *CompilerContext, upsert UpsertStmt)
}

var uniqueCols []string
for _, c := range upsert.table.PrimaryCols() {
for _, c := range upsert.Table.PrimaryCols() {
uniqueCols = append(uniqueCols, context.Compiler.VisitLabel(context, c.Name))
}

sql := fmt.Sprintf(
"INSERT INTO %s(%s)\nVALUES(%s)\nON CONFLICT (%s) DO UPDATE SET %s",
context.Compiler.VisitLabel(context, upsert.table.Name),
context.Compiler.VisitLabel(context, upsert.Table.Name),
strings.Join(colNames, ", "),
strings.Join(values, ", "),
strings.Join(uniqueCols, ", "),
strings.Join(updates, ", "))

var returning []string
for _, r := range upsert.returning {
for _, r := range upsert.ReturningCols {
returning = append(returning, context.Compiler.VisitLabel(context, r.Name))
}
if len(upsert.returning) > 0 {
if len(returning) > 0 {
sql += fmt.Sprintf(
"RETURNING %s",
strings.Join(returning, ", "),
Expand Down
4 changes: 2 additions & 2 deletions dialect_sqlite.go
Expand Up @@ -83,15 +83,15 @@ func (SqliteCompiler) VisitUpsert(context *CompilerContext, upsert UpsertStmt) s
colNames []string
values []string
)
for k, v := range upsert.values {
for k, v := range upsert.ValuesMap {
colNames = append(colNames, context.Compiler.VisitLabel(context, k))
context.Binds = append(context.Binds, v)
values = append(values, "?")
}

sql := fmt.Sprintf(
"REPLACE INTO %s(%s)\nVALUES(%s)",
context.Compiler.VisitLabel(context, upsert.table.Name),
context.Compiler.VisitLabel(context, upsert.Table.Name),
strings.Join(colNames, ", "),
strings.Join(values, ", "),
)
Expand Down
2 changes: 1 addition & 1 deletion table_test.go
Expand Up @@ -148,7 +148,7 @@ func (suite *TableTestSuite) TestTableStarters() {
assert.Contains(suite.T(), ins.Bindings(), "al@pacino.com")

ups := users.Upsert()
assert.Equal(suite.T(), users, ups.table)
assert.Equal(suite.T(), users, ups.Table)

upd := users.
Update().
Expand Down
16 changes: 8 additions & 8 deletions upsert.go
Expand Up @@ -3,23 +3,23 @@ package qb
// Upsert generates an insert ... on (duplicate key/conflict) update statement
func Upsert(table TableElem) UpsertStmt {
return UpsertStmt{
table: table,
values: map[string]interface{}{},
returning: []ColumnElem{},
Table: table,
ValuesMap: map[string]interface{}{},
ReturningCols: []ColumnElem{},
}
}

// UpsertStmt is the base struct for any insert ... on conflict/duplicate key ... update ... statements
type UpsertStmt struct {
table TableElem
values map[string]interface{}
returning []ColumnElem
Table TableElem
ValuesMap map[string]interface{}
ReturningCols []ColumnElem
}

// Values accepts map[string]interface{} and forms the values map of insert statement
func (s UpsertStmt) Values(values map[string]interface{}) UpsertStmt {
for k, v := range values {
s.values[k] = v
s.ValuesMap[k] = v
}
return s
}
Expand All @@ -28,7 +28,7 @@ func (s UpsertStmt) Values(values map[string]interface{}) UpsertStmt {
// NOTE: Please use it in only postgres dialect, otherwise it'll crash
func (s UpsertStmt) Returning(cols ...ColumnElem) UpsertStmt {
for _, c := range cols {
s.returning = append(s.returning, c)
s.ReturningCols = append(s.ReturningCols, c)
}
return s
}
Expand Down

0 comments on commit b0d64ed

Please sign in to comment.