-
Notifications
You must be signed in to change notification settings - Fork 28
/
abstract.go
334 lines (297 loc) · 10.2 KB
/
abstract.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
/*
* Copyright 2018 The Service Manager Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package postgres
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"reflect"
"strings"
"github.com/Peripli/service-manager/pkg/query"
"github.com/jmoiron/sqlx"
sqlxtypes "github.com/jmoiron/sqlx/types"
"github.com/Peripli/service-manager/pkg/log"
"github.com/Peripli/service-manager/pkg/util"
"github.com/fatih/structs"
"github.com/lib/pq"
)
type prepareNamedContext interface {
PrepareNamedContext(ctx context.Context, query string) (*sqlx.NamedStmt, error)
}
type namedExecerContext interface {
NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error)
}
type namedQuerierContext interface {
NamedQuery(query string, arg interface{}) (*sqlx.Rows, error)
}
type selecterContext interface {
SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
}
type getterContext interface {
GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
}
//go:generate counterfeiter . pgDB
// pgDB represents a PG database API
type pgDB interface {
prepareNamedContext
namedExecerContext
namedQuerierContext
selecterContext
getterContext
sqlx.ExtContext
}
func create(ctx context.Context, db pgDB, table string, dto interface{}) (string, error) {
var lastInsertID string
setTagType := getDBTags(dto, isAutoIncrementable)
dbTags := make([]string, 0, len(setTagType))
for _, tagType := range setTagType {
dbTags = append(dbTags, tagType.Tag)
}
if len(dbTags) == 0 {
return lastInsertID, fmt.Errorf("%s insert: No fields to insert", table)
}
sqlQuery := fmt.Sprintf(
"INSERT INTO %s (%s) VALUES(:%s)",
table,
strings.Join(dbTags, ", "),
strings.Join(dbTags, ", :"),
)
id, ok := structs.New(dto).FieldOk("ID")
if ok {
queryReturningID := fmt.Sprintf("%s Returning %s", sqlQuery, id.Tag("db"))
log.C(ctx).Debugf("Executing query %s", queryReturningID)
stmt, err := db.PrepareNamedContext(ctx, queryReturningID)
if err != nil {
return "", err
}
err = stmt.GetContext(ctx, &lastInsertID, dto)
return lastInsertID, checkIntegrityViolation(ctx, checkUniqueViolation(ctx, err))
}
log.C(ctx).Debugf("Executing query %s", sqlQuery)
_, err := db.NamedExecContext(ctx, sqlQuery, dto)
return lastInsertID, checkIntegrityViolation(ctx, checkUniqueViolation(ctx, err))
}
func listWithLabelsByCriteria(ctx context.Context, db pgDB, baseEntity interface{}, label PostgresLabel, baseTableName string, criteria []query.Criterion) (*sqlx.Rows, error) {
if err := validateFieldQueryParams(getDBTags(baseEntity, nil), criteria); err != nil {
return nil, err
}
var baseQuery string
if label == nil {
baseQuery = constructBaseQueryForEntity(baseTableName)
} else {
baseQuery = constructBaseQueryForLabelable(label, baseTableName)
}
sqlQuery, queryParams, err := buildQueryWithParams(db, baseQuery, baseTableName, label, criteria, getDBTags(baseEntity, nil), " ORDER BY created_at")
if err != nil {
return nil, err
}
// Lock the rows if we are in transaction so that update operations on those rows can rely on unchanged data
// This allows us to handle concurrent updates on the same rows by executing them sequentially as
// before updating we have to anyway select the rows and can therefore lock them
if _, ok := db.(*sqlx.Tx); ok {
sqlQuery = sqlQuery[:len(sqlQuery)-1]
sqlQuery += fmt.Sprintf(" FOR SHARE of %s;", baseTableName)
}
log.C(ctx).Debugf("Executing query %s", sqlQuery)
return db.QueryxContext(ctx, sqlQuery, queryParams...)
}
func listByFieldCriteria(ctx context.Context, db pgDB, table string, criteria []query.Criterion) (*sqlx.Rows, error) {
baseQuery := constructBaseQueryForEntity(table)
sqlQuery, queryParams, err := buildQueryWithParams(db, baseQuery, table, nil, criteria, nil, " ORDER BY created_at")
if err != nil {
return nil, err
}
return db.QueryxContext(ctx, sqlQuery, queryParams...)
}
func deleteAllByFieldCriteria(ctx context.Context, extContext sqlx.ExtContext, table string, dto interface{}, criteria []query.Criterion) (*sqlx.Rows, error) {
for _, criterion := range criteria {
if criterion.Type != query.FieldQuery {
return nil, &util.UnsupportedQueryError{Message: "conditional delete is only supported for field queries"}
}
}
if err := validateFieldQueryParams(getDBTags(dto, nil), criteria); err != nil {
return nil, err
}
baseQuery := fmt.Sprintf("DELETE FROM %s", table)
sqlQuery, queryParams, err := buildQueryWithParams(extContext, baseQuery, table, nil, criteria, getDBTags(dto, isAutoIncrementable))
if err != nil {
return nil, err
}
sqlQuery = sqlQuery[:len(sqlQuery)-1] + " RETURNING *;"
return extContext.QueryxContext(ctx, sqlQuery, queryParams...)
}
func validateFieldQueryParams(tags []tagType, criteria []query.Criterion) error {
availableColumns := make(map[string]bool)
for _, dbTag := range tags {
tagValues := strings.Split(dbTag.Tag, ",")
availableColumns[tagValues[0]] = true
}
for _, criterion := range criteria {
if criterion.Type == query.FieldQuery && !availableColumns[criterion.LeftOp] {
return &util.UnsupportedQueryError{Message: fmt.Sprintf("unsupported field query key: %s", criterion.LeftOp)}
}
}
return nil
}
func constructBaseQueryForEntity(tableName string) string {
return fmt.Sprintf("SELECT * FROM %s", tableName)
}
func constructBaseQueryForLabelable(labelsEntity PostgresLabel, baseTableName string) string {
baseQuery := `SELECT %[1]s.*,`
for _, dbTag := range getDBTags(labelsEntity, isAutoIncrementable) {
baseQuery += " %[2]s." + dbTag.Tag + " " + "\"%[2]s." + dbTag.Tag + "\"" + ","
}
baseQuery = baseQuery[:len(baseQuery)-1] //remove last comma
labelsTableName := labelsEntity.LabelsTableName()
referenceKeyColumn := labelsEntity.ReferenceColumn()
primaryKeyColumn := labelsEntity.LabelsPrimaryColumn()
baseQuery += " FROM %[1]s LEFT JOIN %[2]s ON %[1]s." + primaryKeyColumn + " = %[2]s." + referenceKeyColumn
return fmt.Sprintf(baseQuery, baseTableName, labelsTableName)
}
func update(ctx context.Context, db namedExecerContext, table string, dto interface{}) error {
updateQueryString := updateQuery(table, dto)
if updateQueryString == "" {
log.C(ctx).Debugf("%s update: Nothing to update", table)
return nil
}
log.C(ctx).Debugf("Executing query %s", updateQueryString)
result, err := db.NamedExecContext(ctx, updateQueryString, dto)
if err = checkIntegrityViolation(ctx, checkUniqueViolation(ctx, err)); err != nil {
return err
}
return checkRowsAffected(ctx, result)
}
func isAutoIncrementable(tagValue string) bool {
// auto_increment states that the value will be calculated in the DB
return strings.Contains(tagValue, "auto_increment")
}
type tagType struct {
Tag string
Type reflect.Type
}
func getDBTags(structure interface{}, predicate func(string) bool) []tagType {
s := structs.New(structure)
fields := s.Fields()
set := make([]tagType, 0, len(fields))
if predicate == nil {
predicate = func(string) bool { return false }
}
getTags(fields, &set, predicate)
return set
}
func getTags(fields []*structs.Field, set *[]tagType, predicate func(string) bool) {
for _, field := range fields {
if field.Kind() == reflect.Ptr && field.IsZero() {
continue
}
if field.IsEmbedded() {
embedded := make([]tagType, 0)
getTags(field.Fields(), &embedded, predicate)
*set = append(*set, embedded...)
} else {
dbTag := field.Tag("db")
if dbTag == "-" || predicate(dbTag) {
continue
}
if dbTag == "" {
dbTag = strings.ToLower(field.Name())
}
ttype := reflect.ValueOf(field.Value()).Type()
*set = append(*set, tagType{
Tag: dbTag,
Type: ttype,
})
}
}
}
func updateQuery(tableName string, structure interface{}) string {
dbTags := getDBTags(structure, isAutoIncrementable)
set := make([]string, 0, len(dbTags))
for _, dbTag := range dbTags {
set = append(set, fmt.Sprintf("%s = :%s", dbTag.Tag, dbTag.Tag))
}
if len(set) == 0 {
return ""
}
return fmt.Sprintf("UPDATE "+tableName+" SET %s WHERE id = :id",
strings.Join(set, ", "))
}
func checkUniqueViolation(ctx context.Context, err error) error {
if err == nil {
return nil
}
sqlErr, ok := err.(*pq.Error)
if ok && sqlErr.Code.Name() == "unique_violation" {
log.C(ctx).Debug(sqlErr)
return util.ErrAlreadyExistsInStorage
}
return err
}
func checkIntegrityViolation(ctx context.Context, err error) error {
if err == nil {
return nil
}
sqlErr, ok := err.(*pq.Error)
if ok && (sqlErr.Code.Class() == "42" || sqlErr.Code.Class() == "44" || sqlErr.Code.Class() == "23") {
log.C(ctx).Debug(sqlErr)
return &util.ErrBadRequestStorage{Cause: err}
}
return err
}
func closeRows(ctx context.Context, rows *sqlx.Rows) {
if rows == nil {
return
}
if err := rows.Close(); err != nil {
log.C(ctx).WithError(err).Errorf("Could not release connection")
}
}
func checkRowsAffected(ctx context.Context, result sql.Result) error {
rowsAffected, err := result.RowsAffected()
if err != nil {
return err
}
if rowsAffected < 1 {
return util.ErrNotFoundInStorage
}
log.C(ctx).Debugf("Operation affected %d rows", rowsAffected)
return nil
}
func checkSQLNoRows(err error) error {
if err == sql.ErrNoRows {
return util.ErrNotFoundInStorage
}
return err
}
func toNullString(s string) sql.NullString {
return sql.NullString{String: s, Valid: s != ""}
}
func getJSONText(item json.RawMessage) sqlxtypes.JSONText {
if len(item) == len("null") && string(item) == "null" {
return sqlxtypes.JSONText("{}")
}
return sqlxtypes.JSONText(item)
}
func getJSONRawMessage(item sqlxtypes.JSONText) json.RawMessage {
if len(item) <= len("null") {
itemStr := string(item)
if itemStr == "{}" || itemStr == "null" {
return nil
}
}
return json.RawMessage(item)
}