/
connection.go
128 lines (114 loc) · 3.51 KB
/
connection.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package dyndb
import (
"fmt"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/viant/dsc"
"github.com/viant/toolbox/cred"
"github.com/viant/toolbox/secret"
"net/url"
"strings"
)
const (
keyKey = "key"
secretKey = "secret"
regionKey = "region"
dbnameKey = "dbname"
)
var DbPointer = (*dynamodb.DynamoDB)(nil)
func asDatabase(connection dsc.Connection) (*dynamodb.DynamoDB, error) {
db := connection.Unwrap(DbPointer).(*dynamodb.DynamoDB)
return db, nil
}
type connection struct {
*dsc.AbstractConnection
db *dynamodb.DynamoDB
}
func (c *connection) CloseNow() error {
return nil
}
func (c *connection) Unwrap(targetType interface{}) interface{} {
if targetType == DbPointer {
return c.db
}
panic(fmt.Sprintf("unsupported targetType type %v", targetType))
}
type connectionProvider struct {
*dsc.AbstractConnectionProvider
}
func (p *connectionProvider) NewConnection() (dsc.Connection, error) {
config := p.ConnectionProvider.Config()
credConfig, err := getCredConfig(config)
if err != nil {
return nil, err
}
awsConfig := getAWSConfig(credConfig)
if awsConfig.Region == nil {
return nil, fmt.Errorf("region was empty")
}
awsConfig = p.applyOptions(awsConfig)
sess := session.Must(session.NewSession())
db := dynamodb.New(sess, awsConfig)
var connection = &connection{db: db}
var super = dsc.NewAbstractConnection(config, p.ConnectionProvider.ConnectionPool(), connection)
connection.AbstractConnection = super
return connection, nil
}
func (p *connectionProvider) applyOptions(awsConfig *aws.Config) *aws.Config {
if params, _ := url.ParseQuery(p.Config().Descriptor); len(params) > 0 {
if endpoint := params.Get("endpoint"); endpoint != "" {
if !strings.Contains(endpoint, ":") {
endpoint = endpoint + ":8000"
}
if !strings.Contains(endpoint, "http") {
endpoint = "http://" + endpoint
}
awsConfig = awsConfig.WithEndpoint(endpoint)
}
}
return awsConfig
}
func newConnectionProvider(config *dsc.Config) dsc.ConnectionProvider {
if config.MaxPoolSize == 0 {
config.MaxPoolSize = 1
}
aerospikeConnectionProvider := &connectionProvider{}
var connectionProvider dsc.ConnectionProvider = aerospikeConnectionProvider
var super = dsc.NewAbstractConnectionProvider(config, make(chan dsc.Connection, config.MaxPoolSize), connectionProvider)
aerospikeConnectionProvider.AbstractConnectionProvider = super
aerospikeConnectionProvider.AbstractConnectionProvider.ConnectionProvider = connectionProvider
return aerospikeConnectionProvider
}
func getCredConfig(config *dsc.Config) (*cred.Config, error) {
var err error
credConfig := &cred.Config{}
if config.Credentials != "" {
credConfig, err = secret.New("", false).GetCredentials(config.Credentials)
if err != nil {
return nil, err
}
}
if config.Has(keyKey) {
credConfig.Key = config.Get(keyKey)
}
if config.Has(secretKey) {
credConfig.Secret = config.Get(secretKey)
}
if config.Has(regionKey) {
credConfig.Region = config.Get(regionKey)
}
return credConfig, nil
}
//getAWSConfig returns *aws.Config for provided credential
func getAWSConfig(credConfig *cred.Config) *aws.Config {
result := &aws.Config{}
if credConfig.Key != "" {
awsCredentials := credentials.NewStaticCredentials(credConfig.Key, credConfig.Secret, "")
result = aws.NewConfig().WithRegion(credConfig.Region).WithCredentials(awsCredentials)
} else if credConfig.Region != "" {
result.Region = &credConfig.Region
}
return result
}