/
connection_postgres.go
executable file
·131 lines (106 loc) · 2.72 KB
/
connection_postgres.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
package database
import (
"flag"
"fmt"
"log"
"os"
"github.com/agubarev/hometown/pkg/util"
"github.com/jackc/pgx"
"github.com/jackc/pgx/log/zapadapter"
"github.com/pkg/errors"
"go.uber.org/zap"
)
var postgresConn *pgx.Conn
// PostgreSQLConnection returns data singleton instance
func PostgreSQLConnection(logger *zap.Logger) *pgx.Conn {
// using a package global variable
if postgresConn == nil {
// checking whether it's called during `go test`
testMode := flag.Lookup("test.v") != nil
dsn := os.Getenv("HOMETOWN_DATABASE")
// better safe than sorry
if testMode {
dsn = os.Getenv("HOMETOWN_TEST_DATABASE")
}
// mysqlConn config
conf, err := pgx.ParseDSN(dsn)
if err != nil {
log.Fatalf("failed to parse DSN: %s", err)
}
// injecting logger into data instance
if logger != nil {
conf.Logger = zapadapter.NewLogger(logger)
conf.LogLevel = pgx.LogLevelDebug
}
// initializing connection to postgres data
conn, err := pgx.Connect(conf)
if err != nil {
log.Fatalf("failed to connect to data: %s", err)
}
postgresConn = conn
}
return postgresConn
}
// PostgreSQLForTesting simply returns a data mysqlConn
func PostgreSQLForTesting(logger *zap.Logger) (conn *pgx.Conn) {
if !util.IsTestMode() {
log.Fatal("TruncateTestDatabase() can only be called during testing")
}
// checking whether it's called during `go test`
testMode := flag.Lookup("test.v") != nil
dsn := os.Getenv("HOMETOWN_DATABASE")
// better safe than sorry
if testMode {
dsn = os.Getenv("HOMETOWN_TEST_DATABASE")
}
// mysqlConn config
conf, err := pgx.ParseDSN(dsn)
if err != nil {
log.Fatalf("failed to parse DSN: %s", err)
}
// injecting logger into data instance
if logger != nil {
conf.Logger = zapadapter.NewLogger(logger)
conf.LogLevel = pgx.LogLevelDebug
}
// initializing connection to postgres data
conn, err = pgx.Connect(conf)
if err != nil {
log.Fatalf("failed to connect to data: %s", err)
}
postgresConn = conn
tx, err := conn.Begin()
if err != nil {
log.Fatalf("failed to begin transaction: %s", err)
}
defer func() {
if p := recover(); p != nil {
err = errors.Wrap(err, "recovering from panic after TruncateDatabaseForTesting")
}
}()
tables := []string{
"group",
"group_assets",
"accesspolicy",
"accesspolicy_roster",
"password",
"token",
"user",
"user_email",
"user_phone",
"user_profile",
"auth_session",
"auth_refresh_token",
"auth_code_exchange",
}
// truncating tables
for _, tableName := range tables {
if _, err := tx.Exec(fmt.Sprintf(`TRUNCATE TABLE "%s" RESTART IDENTITY CASCADE`, tableName)); err != nil {
panic(errors.Wrap(err, tx.Rollback().Error()))
}
}
if err := tx.Commit(); err != nil {
panic(err)
}
return conn
}