forked from stellar/go
-
Notifications
You must be signed in to change notification settings - Fork 0
/
helpers.go
188 lines (165 loc) · 6.14 KB
/
helpers.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
package tickerdb
import (
"errors"
"fmt"
"reflect"
"strings"
"github.com/aiblocks/go/services/ticker/internal/utils"
)
// getDBFieldTags returns all "db" tags for a given struct, optionally excluding the "id".
func getDBFieldTags(model interface{}, excludeID bool) (fields []string) {
r := reflect.ValueOf(model)
for i := 0; i < r.Type().NumField(); i++ {
dbField := r.Type().Field(i).Tag.Get("db")
if (excludeID && dbField == "id") || dbField == "-" { // ensure fields marked with a "-" tag are ignored
continue
}
fields = append(fields, dbField)
}
fields = sanitizeFieldNames(fields)
return
}
// sanitizeFieldNames adds double quotes to each entry on a slice of field names.
func sanitizeFieldNames(fieldNames []string) (sanitizedFields []string) {
for _, v := range fieldNames {
quotedField := fmt.Sprintf("\"%s\"", v)
sanitizedFields = append(sanitizedFields, quotedField)
}
return
}
// getDBFieldValues returns all "db"-tagged values, optionally excluding the "id".
func getDBFieldValues(model interface{}, excludeID bool) (values []interface{}) {
r := reflect.ValueOf(model)
for i := 0; i < r.Type().NumField(); i++ {
dbField := r.Type().Field(i).Tag.Get("db")
dbVal := r.Field(i).Interface()
if (excludeID && dbField == "id") || dbField == "-" { // ensure fields marked with a "-" tag are ignored
continue
}
values = append(values, dbVal)
}
return
}
// createOnConflictFragment generates a ON CONFLICT sql clause for a given constraint,
// preserving the fields listed in the fields param.
func createOnConflictFragment(constraint string, fields []string) (fragment string) {
fragment = fmt.Sprintf("ON CONFLICT ON CONSTRAINT %s DO UPDATE SET ", constraint)
for i, field := range fields {
fragment += fmt.Sprintf("%s = EXCLUDED.%s", field, field)
if i != len(fields)-1 {
fragment += ","
}
}
return
}
// generatePlaceholders generates a string formatted as (?, ?, ?, ?) of length
// equal to the size of the fields param
func generatePlaceholders(fields []interface{}) (placeholder string) {
for i := range fields {
placeholder += "?"
if i != len(fields)-1 {
placeholder += ", "
}
}
return
}
// optionalVar is a simple struct to represent a query variable that should
// only be used in a statement if its value is not null
type optionalVar struct {
name string
val *string
}
// generateWhereClause generates a WHERE clause in the format:
// "WHERE x = ? AND y = ? AND ..." where the number of conditions is equal
// to the number of optVars with val != nil. It also returns the valid vals
// in the args param. This function was created to take advantage of go/sql's
// sanitization and to prevent possible SQL injections.
func generateWhereClause(optVars []optionalVar) (clause string, args []string) {
for _, ov := range optVars {
if ov.val != nil {
if clause == "" {
clause = fmt.Sprintf("WHERE %s = ?", ov.name)
} else {
clause += fmt.Sprintf(" AND %s = ?", ov.name)
}
args = append(args, *ov.val)
}
}
return
}
// generateWhereClauseWithOrs generates a WHERE clause in the format:
// "WHERE (a = ? AND b = ? ... AND c = ?) OR (x = ? AND y = ? ... AND z = ?)"
// where the number of OR conditions equals the number of optVarLists with non-zero length.
// It also returns the valid vals in the args param. This function was created to take advantage
// of go/sql's sanitization and to prevent possible SQL injections.
func generateWhereClauseWithOrs(optVarLists [][]optionalVar) (clause string, args []string) {
if len(optVarLists) == 0 {
return
}
clauses := []string{}
for _, ovl := range optVarLists {
var orClause string
for _, ov := range ovl {
if ov.val == nil {
continue
}
if orClause == "" {
orClause = fmt.Sprintf("%s = ?", ov.name)
} else {
orClause += fmt.Sprintf(" AND %s = ?", ov.name)
}
args = append(args, *ov.val)
}
clauses = append(clauses, orClause)
}
clause = fmt.Sprintf("WHERE (%s)", strings.Join(clauses, " OR "))
return
}
// getBaseAndCounterCodes takes an asset pair name string (e.g: DLO_BTC)
// and returns the parsed asset codes (e.g.: DLO, BTC). It also reverses
// the assets, according to the following rules:
// 1. DLO is always the base asset
// 2. If DLO is not in the pair, the assets should be ordered alphabetically
func getBaseAndCounterCodes(pairName string) (string, string, error) {
assets := strings.Split(pairName, "_")
if len(assets) != 2 {
return "", "", errors.New("invalid asset pair name")
}
if (assets[1] == "DLO") || (assets[0] != "DLO" && assets[0] > assets[1]) {
return assets[1], assets[0], nil
}
return assets[0], assets[1], nil
}
// normalizeBaseAndCounter takes the user-provided base and counter asset
// and issuer, and orders them according to the following rules:
// 1. DLO is always the base asset
// 2. If DLO is not in the pair, the assets should be ordered alphabetically
func orderBaseAndCounter(
baseCode *string,
baseIssuer *string,
counterCode *string,
counterIssuer *string,
) (*string, *string, *string, *string) {
if baseCode == nil || counterCode == nil {
return baseCode, baseIssuer, counterCode, counterIssuer
}
if (*counterCode == "DLO") || (*baseCode != "DLO" && *baseCode > *counterCode) {
return counterCode, counterIssuer, baseCode, baseIssuer
}
return baseCode, baseIssuer, counterCode, counterIssuer
}
// performUpsertQuery introspects a dbStruct interface{} and performs an insert query
// (if the conflictConstraint isn't broken), otherwise it updates the instance on the
// db, preserving the old values for the fields in preserveFields
func (s *TickerSession) performUpsertQuery(dbStruct interface{}, tableName string, conflictConstraint string, preserveFields []string) error {
dbFields := getDBFieldTags(dbStruct, true)
dbFieldsString := strings.Join(dbFields, ", ")
dbValues := getDBFieldValues(dbStruct, true)
cleanPreservedFields := sanitizeFieldNames(preserveFields)
toUpdateFields := utils.SliceDiff(dbFields, cleanPreservedFields)
qs := fmt.Sprintf("INSERT INTO %s (", tableName) + dbFieldsString + ")"
qs += " VALUES (" + generatePlaceholders(dbValues) + ")"
qs += " " + createOnConflictFragment(conflictConstraint, toUpdateFields) + ";"
_, err := s.ExecRaw(qs, dbValues...)
return err
}