Skip to content

Commit

Permalink
Merge pull request #4 from Hillside-Labs/ionrock/support-sslmode
Browse files Browse the repository at this point in the history
Support sslmode in the dsn
  • Loading branch information
ionrock committed Nov 29, 2023
2 parents 3504aba + 5a9e508 commit ad8787e
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 2 deletions.
11 changes: 9 additions & 2 deletions models/db.go
Expand Up @@ -81,13 +81,14 @@ func Connect(fn string) (*gorm.DB, error) {
return nil, err
}
err = db.AutoMigrate(&Record{}, &LocalConfig{})

return db, err
}

func (dsn *DSN) String() string {
return fmt.Sprintf(
"host=%s user=%s password=%s dbname=%s port=%d sslmode=disable",
dsn.Host, dsn.User, dsn.Password, dsn.DBName, dsn.Port,
"host=%s user=%s password=%s dbname=%s port=%d sslmode=%s",
dsn.Host, dsn.User, dsn.Password, dsn.DBName, dsn.Port, dsn.SSLMode,
)
}

Expand All @@ -114,11 +115,17 @@ func NewDSN(dburi string) (*DSN, error) {

dbname := path[0]

sslmode := url.Query().Get("sslmode")
if sslmode == "" {
sslmode = "disable"
}

return &DSN{
Host: url.Hostname(),
User: url.User.Username(),
Password: pw,
DBName: dbname,
Port: port,
SSLMode: sslmode,
}, nil
}
51 changes: 51 additions & 0 deletions models/db_test.go
@@ -0,0 +1,51 @@
package models

import (
"fmt"
"strings"
"testing"
)

func TestDSNWithSSLEnabled(t *testing.T) {
dburi := "postgresql://user:pw@pgpool.xample.com/chetapp?sslmode=require"
dsn, err := NewDSN(dburi)
if err != nil {
t.Fatal(err)
}

fields := strings.Fields(dsn.String())

for _, field := range fields {
f := strings.SplitN(field, "=", 2)
k := f[0]
v := f[1]

switch k {
case "host":
if v != dsn.Host {
t.Errorf("dsn host: expected %s got %s", dsn.Host, v)
}
case "user":
if v != dsn.User {
t.Errorf("dsn user: expected %s got %s", dsn.User, v)
}
case "password":
if v != dsn.Password {
t.Errorf("dsn password: expected %s got %s", dsn.Password, v)
}
case "dbname":
if v != dsn.DBName {
t.Errorf("dsn dbname: expected %s got %s", dsn.DBName, v)
}
case "port":
if v != fmt.Sprintf("%d", dsn.Port) {
t.Errorf("dsn port: expected %d got %s", dsn.Port, v)
}
case "sslmode":
if v != dsn.SSLMode {
t.Errorf("dsn sslmode: expected %s got %s", dsn.SSLMode, v)
}
}

}
}

0 comments on commit ad8787e

Please sign in to comment.