forked from slicebit/qb
/
dialect_postgres.go
140 lines (120 loc) · 3.65 KB
/
dialect_postgres.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
package qb
import (
"fmt"
"strings"
)
// PostgresDialect is a type of dialect that can be used with postgres driver
type PostgresDialect struct {
bindingIndex int
escaping bool
}
// NewPostgresDialect returns a new PostgresDialect
func NewPostgresDialect() Dialect {
return &PostgresDialect{escaping: false, bindingIndex: 0}
}
func init() {
RegisterDialect("postgres", NewPostgresDialect)
}
// CompileType compiles a type into its DDL
func (d *PostgresDialect) CompileType(t TypeElem) string {
if t.Name == "BLOB" {
return "bytea"
}
return DefaultCompileType(t, d.SupportsUnsigned())
}
// Escape wraps the string with escape characters of the dialect
func (d *PostgresDialect) Escape(str string) string {
if d.escaping {
return fmt.Sprintf("\"%s\"", str)
}
return str
}
// EscapeAll wraps all elements of string array
func (d *PostgresDialect) EscapeAll(strings []string) []string {
return escapeAll(d, strings[0:])
}
// SetEscaping sets the escaping parameter of dialect
func (d *PostgresDialect) SetEscaping(escaping bool) {
d.escaping = escaping
}
// Escaping gets the escaping parameter of dialect
func (d *PostgresDialect) Escaping() bool {
return d.escaping
}
// AutoIncrement generates auto increment sql of current dialect
func (d *PostgresDialect) AutoIncrement(column *ColumnElem) string {
var colSpec string
if column.Type.Name == "BIGINT" {
colSpec = "BIGSERIAL"
} else if column.Type.Name == "SMALLINT" {
colSpec = "SMALLSERIAL"
} else {
colSpec = "SERIAL"
}
if column.Options.InlinePrimaryKey {
colSpec += " PRIMARY KEY"
}
return colSpec
}
// SupportsUnsigned returns whether driver supports unsigned type mappings or not
func (d *PostgresDialect) SupportsUnsigned() bool { return false }
// Driver returns the current driver of dialect
func (d *PostgresDialect) Driver() string {
return "postgres"
}
// GetCompiler returns a PostgresCompiler
func (d *PostgresDialect) GetCompiler() Compiler {
return PostgresCompiler{NewSQLCompiler(d)}
}
// PostgresCompiler is a SQLCompiler specialised for PostgreSQL
type PostgresCompiler struct {
SQLCompiler
}
// VisitBind renders a bounded value
func (PostgresCompiler) VisitBind(context *CompilerContext, bind BindClause) string {
context.Binds = append(context.Binds, bind.Value)
return fmt.Sprintf("$%d", len(context.Binds))
}
// VisitUpsert generates INSERT INTO ... VALUES ... ON CONFLICT(...) DO UPDATE SET ...
func (PostgresCompiler) VisitUpsert(context *CompilerContext, upsert UpsertStmt) string {
var (
colNames []string
values []string
)
for k, v := range upsert.ValuesMap {
colNames = append(colNames, context.Compiler.VisitLabel(context, k))
context.Binds = append(context.Binds, v)
values = append(values, fmt.Sprintf("$%d", len(context.Binds)))
}
var updates []string
for k, v := range upsert.ValuesMap {
context.Binds = append(context.Binds, v)
updates = append(updates, fmt.Sprintf(
"%s = %s",
context.Dialect.Escape(k),
fmt.Sprintf("$%d", len(context.Binds)),
))
}
var uniqueCols []string
for _, c := range upsert.Table.PrimaryCols() {
uniqueCols = append(uniqueCols, context.Compiler.VisitLabel(context, c.Name))
}
sql := fmt.Sprintf(
"INSERT INTO %s(%s)\nVALUES(%s)\nON CONFLICT (%s) DO UPDATE SET %s",
context.Compiler.VisitLabel(context, upsert.Table.Name),
strings.Join(colNames, ", "),
strings.Join(values, ", "),
strings.Join(uniqueCols, ", "),
strings.Join(updates, ", "))
var returning []string
for _, r := range upsert.ReturningCols {
returning = append(returning, context.Compiler.VisitLabel(context, r.Name))
}
if len(returning) > 0 {
sql += fmt.Sprintf(
"RETURNING %s",
strings.Join(returning, ", "),
)
}
return sql
}