forked from dexidp/dex
/
postgres.go
154 lines (124 loc) · 3.32 KB
/
postgres.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
package ent
import (
"context"
"crypto/sha256"
"database/sql"
"fmt"
"net"
"regexp"
"strconv"
"strings"
"time"
entSQL "entgo.io/ent/dialect/sql"
_ "github.com/lib/pq" // Register postgres driver.
"github.com/dexidp/dex/pkg/log"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/ent/client"
"github.com/dexidp/dex/storage/ent/db"
)
//nolint
const (
// postgres SSL modes
pgSSLDisable = "disable"
pgSSLRequire = "require"
pgSSLVerifyCA = "verify-ca"
pgSSLVerifyFull = "verify-full"
)
// Postgres options for creating an SQL db.
type Postgres struct {
NetworkDB
SSL SSL `json:"ssl"`
}
// Open always returns a new in sqlite3 storage.
func (p *Postgres) Open(logger log.Logger) (storage.Storage, error) {
logger.Debug("experimental ent-based storage driver is enabled")
drv, err := p.driver()
if err != nil {
return nil, err
}
databaseClient := client.NewDatabase(
client.WithClient(db.NewClient(db.Driver(drv))),
client.WithHasher(sha256.New),
// The default behavior for Postgres transactions is consistent reads, not consistent writes.
// For each transaction opened, ensure it has the correct isolation level.
//
// See: https://www.postgresql.org/docs/9.3/static/sql-set-transaction.html
client.WithTxIsolationLevel(sql.LevelSerializable),
)
if err := databaseClient.Schema().Create(context.TODO()); err != nil {
return nil, err
}
return databaseClient, nil
}
func (p *Postgres) driver() (*entSQL.Driver, error) {
drv, err := entSQL.Open("postgres", p.dsn())
if err != nil {
return nil, err
}
// set database/sql tunables if configured
if p.ConnMaxLifetime != 0 {
drv.DB().SetConnMaxLifetime(time.Duration(p.ConnMaxLifetime) * time.Second)
}
if p.MaxIdleConns == 0 {
drv.DB().SetMaxIdleConns(5)
} else {
drv.DB().SetMaxIdleConns(p.MaxIdleConns)
}
if p.MaxOpenConns == 0 {
drv.DB().SetMaxOpenConns(5)
} else {
drv.DB().SetMaxOpenConns(p.MaxOpenConns)
}
return drv, nil
}
func (p *Postgres) dsn() string {
// detect host:port for backwards-compatibility
host, port, err := net.SplitHostPort(p.Host)
if err != nil {
// not host:port, probably unix socket or bare address
host = p.Host
if p.Port != 0 {
port = strconv.Itoa(int(p.Port))
}
}
var parameters []string
addParam := func(key, val string) {
parameters = append(parameters, fmt.Sprintf("%s=%s", key, val))
}
addParam("connect_timeout", strconv.Itoa(p.ConnectionTimeout))
if host != "" {
addParam("host", dataSourceStr(host))
}
if port != "" {
addParam("port", port)
}
if p.User != "" {
addParam("user", dataSourceStr(p.User))
}
if p.Password != "" {
addParam("password", dataSourceStr(p.Password))
}
if p.Database != "" {
addParam("dbname", dataSourceStr(p.Database))
}
if p.SSL.Mode == "" {
// Assume the strictest mode if unspecified.
addParam("sslmode", dataSourceStr(pgSSLVerifyFull))
} else {
addParam("sslmode", dataSourceStr(p.SSL.Mode))
}
if p.SSL.CAFile != "" {
addParam("sslrootcert", dataSourceStr(p.SSL.CAFile))
}
if p.SSL.CertFile != "" {
addParam("sslcert", dataSourceStr(p.SSL.CertFile))
}
if p.SSL.KeyFile != "" {
addParam("sslkey", dataSourceStr(p.SSL.KeyFile))
}
return strings.Join(parameters, " ")
}
var strEsc = regexp.MustCompile(`([\\'])`)
func dataSourceStr(str string) string {
return "'" + strEsc.ReplaceAllString(str, `\$1`) + "'"
}