-
Notifications
You must be signed in to change notification settings - Fork 3
/
postgresql.go
132 lines (111 loc) · 3.09 KB
/
postgresql.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
/*
Copyright 2020 Adevinta
*/
package postgresql
import (
"bytes"
"fmt"
"strings"
"text/template"
"github.com/adevinta/vulnerability-db-api/pkg/storage"
"github.com/jmoiron/sqlx"
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/log"
_ "github.com/lib/pq" // Import the PostgreSQL driver.
)
const (
// dateUpBoundOffset specifies the time offset to define
// the upper bound from a date type for a timestamp field.
dateUpBoundOffset = " 23:59:59"
)
type (
// DB holds the database connection.
DB struct {
DB *sqlx.DB
DBRw *sqlx.DB
Logger echo.Logger
}
// ConnStr holds the PostgreSQL connection information.
ConnStr struct {
Host string `toml:"host"`
Port string `toml:"port"`
User string `toml:"user"`
Pass string `toml:"pass"`
DB string `toml:"db"`
SSLMode string `toml:"sslmode"`
}
)
// NewDB instantiates a new PosgreSQL connection.
func NewDB(cs ConnStr, csRead ConnStr, logger echo.Logger) (*DB, error) {
if cs.SSLMode == "" {
cs.SSLMode = "disable"
}
if csRead.SSLMode == "" {
csRead.SSLMode = "disable"
}
connStrRead := fmt.Sprintf("host=%s port=%s user=%s "+
"password=%s dbname=%s sslmode=%s",
csRead.Host, csRead.Port, csRead.User, csRead.Pass, csRead.DB, csRead.SSLMode)
dbRead, err := sqlx.Connect("postgres", connStrRead)
if err != nil {
return nil, err
}
var dbReadWrite *sqlx.DB
// If the host and the port of the read and the read-write connection
// string is the same, the database connection is also the same.
if cs.Host != csRead.Host || cs.Port != csRead.Port {
connStr := fmt.Sprintf("host=%s port=%s user=%s "+
"password=%s dbname=%s sslmode=%s",
cs.Host, cs.Port, cs.User, cs.Pass, cs.DB, cs.SSLMode)
dbReadWrite, err = sqlx.Connect("postgres", connStr)
if err != nil {
return nil, err
}
} else {
dbReadWrite = dbRead
}
return &DB{DB: dbRead, DBRw: dbReadWrite, Logger: logger}, nil
}
// filterTemplate builds an SQL query with the clauses required by a given filter.
func filterTemplate(queryTemplate string, filter storage.Filter) (string, error) {
tmpl, err := template.New("query").Parse(queryTemplate)
if err != nil {
return "", err
}
var queryBuf bytes.Buffer
err = tmpl.Execute(&queryBuf, filter)
if err != nil {
return "", err
}
return queryBuf.String(), nil
}
func logQuery(logger echo.Logger, name, query string, args ...interface{}) {
if logger.Level() != log.DEBUG {
return
}
query = strings.ReplaceAll(query, "\t", "")
query = strings.ReplaceAll(query, "\n", " ")
query = buildQueryWithArgs(query, args)
logger.Debugf("%s query: %s", name, query)
}
func buildQueryWithArgs(query string, args []interface{}) string {
if len(args) == 0 {
return query
}
if v, ok := args[0].(map[string]interface{}); ok {
// args as map
for k, v := range v {
tag := fmt.Sprintf(":%s", k)
value := fmt.Sprintf("%v", v)
query = strings.ReplaceAll(query, tag, value)
}
} else {
// args as variadic args list
for i, v := range args {
tag := fmt.Sprintf("$%d", i+1)
value := fmt.Sprintf("%v", v)
query = strings.Replace(query, tag, value, 1)
}
}
return query
}