/
db.go
93 lines (81 loc) · 2.06 KB
/
db.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
package db
import (
"errors"
"fmt"
_ "github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
"os"
"strings"
)
type Connector struct {
config *ConnectionConfig
driver string
connection *sqlx.DB
}
type ConnectionConfig struct {
user string
password string
connection string
host string
port string
name string
}
func NewConnector() *Connector {
config := NewConfig()
return &Connector{
config: config,
driver: os.Getenv("DB_DRIVER"),
}
}
func NewConfig() *ConnectionConfig {
host := os.Getenv("DB_HOST")
port := os.Getenv("DB_PORT")
if host == "" && port == "" {
connectionParts := strings.Split(os.Getenv("DB_CONNECTION"), ":")
if len(connectionParts) == 2 {
host, port = connectionParts[0], connectionParts[1]
}
}
return &ConnectionConfig{
user: os.Getenv("DB_USER"),
password: os.Getenv("DB_PASSWORD"),
host: host,
port: port,
name: os.Getenv("DB_NAME"),
}
}
func (c *Connector) Connect() (*sqlx.DB, error) {
connectionString, ok := c.getConnectionString()
if !ok {
return nil, errors.New("driver was not provided")
}
return sqlx.Connect(os.Getenv("DB_DRIVER"), connectionString)
}
func (c *Connector) Close() error {
if c.connection != nil {
return c.connection.Close()
}
return errors.New("connection not exist")
}
func (c *Connector) ChangeDB(dbName string) {
c.config.name = dbName
}
func (c *Connector) getConnectionString() (string, bool) {
switch c.driver {
case "mysql":
return c.getMysqlConnectionString(), true
case "postgres":
return c.getPostgresConnectionString(), true
default:
return "", false
}
}
func (c *Connector) getMysqlConnectionString() string {
return fmt.Sprintf("%v:%v@tcp(%v:%v)/%v?parseTime=true",
c.config.user, c.config.password, c.config.host, c.config.port, c.config.name)
}
func (c *Connector) getPostgresConnectionString() string {
return fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable",
c.config.host, c.config.port, c.config.user, c.config.password, c.config.name)
}