/
sql_connection_registrar.go
176 lines (136 loc) · 6 KB
/
sql_connection_registrar.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
package connection_repository
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"time"
"github.com/RedHatInsights/cloud-connector/internal/config"
"github.com/RedHatInsights/cloud-connector/internal/domain"
"github.com/RedHatInsights/cloud-connector/internal/platform/logger"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
)
type SqlConnectionRegistrar struct {
database *sql.DB
queryTimeout time.Duration
}
func NewSqlConnectionRegistrar(cfg *config.Config, database *sql.DB) (*SqlConnectionRegistrar, error) {
return &SqlConnectionRegistrar{
database: database,
queryTimeout: cfg.ConnectionDatabaseQueryTimeout,
}, nil
}
func (scm *SqlConnectionRegistrar) Register(ctx context.Context, rhcClient domain.ConnectorClientState) error {
callDurationTimer := prometheus.NewTimer(metrics.sqlConnectionRegistrationDuration)
defer callDurationTimer.ObserveDuration()
account := rhcClient.Account
org_id := rhcClient.OrgID
client_id := rhcClient.ClientID
ctx, cancel := context.WithTimeout(ctx, scm.queryTimeout)
defer cancel()
logger := logger.Log.WithFields(logrus.Fields{"account": account, "org_id": org_id, "client_id": client_id})
update := "UPDATE connections SET dispatchers=$1, tags = $2, updated_at = NOW(), message_id = $3, message_sent = $4 WHERE account=$5 AND client_id=$6"
insert := "INSERT INTO connections (account, org_id, client_id, dispatchers, canonical_facts, tags, message_id, message_sent) SELECT $7, $8, $9, $10, $11, $12, $13, $14"
insertOrUpdate := fmt.Sprintf("WITH upsert AS (%s RETURNING *) %s WHERE NOT EXISTS (SELECT * FROM upsert)", update, insert)
statement, err := scm.database.Prepare(insertOrUpdate)
if err != nil {
logger.WithFields(logrus.Fields{"error": err}).Error("Prepare failed")
return FatalError{err}
}
defer statement.Close()
dispatchersString, err := json.Marshal(rhcClient.Dispatchers)
if err != nil {
logger.WithFields(logrus.Fields{"error": err, "dispatchers": rhcClient.Dispatchers}).Error("Unable to marshal dispatchers")
return err
}
canonicalFactsString, err := json.Marshal(rhcClient.CanonicalFacts)
if err != nil {
logger.WithFields(logrus.Fields{"error": err, "canonical_facts": rhcClient.CanonicalFacts}).Error("Unable to marshal canonicalfacts")
return err
}
tagsString, err := json.Marshal(rhcClient.Tags)
if err != nil {
logger.WithFields(logrus.Fields{"error": err, "tags": rhcClient.CanonicalFacts}).Error("Unable to marshal tags")
return err
}
_, err = statement.ExecContext(ctx, dispatchersString, tagsString, rhcClient.MessageMetadata.LatestMessageID, rhcClient.MessageMetadata.LatestTimestamp, account, client_id, account, org_id, client_id, dispatchersString, canonicalFactsString, tagsString, rhcClient.MessageMetadata.LatestMessageID, rhcClient.MessageMetadata.LatestTimestamp)
if err != nil {
logger.WithFields(logrus.Fields{"error": err}).Error("Insert/update failed")
return FatalError{err}
}
logger.Debug("Registered a connection")
return nil
}
func (scm *SqlConnectionRegistrar) Unregister(ctx context.Context, client_id domain.ClientID) error {
callDurationTimer := prometheus.NewTimer(metrics.sqlConnectionUnregistrationDuration)
defer callDurationTimer.ObserveDuration()
ctx, cancel := context.WithTimeout(ctx, scm.queryTimeout)
defer cancel()
logger := logger.Log.WithFields(logrus.Fields{"client_id": client_id})
statement, err := scm.database.Prepare("DELETE FROM connections WHERE client_id = $1")
if err != nil {
logger.WithFields(logrus.Fields{"error": err}).Error("Prepare failed")
return FatalError{err}
}
defer statement.Close()
_, err = statement.ExecContext(ctx, client_id)
if err != nil {
logger.WithFields(logrus.Fields{"error": err}).Error("Delete failed")
return FatalError{err}
}
logger.Debug("Unregistered a connection")
return nil
}
func (scm *SqlConnectionRegistrar) FindConnectionByClientID(ctx context.Context, client_id domain.ClientID) (domain.ConnectorClientState, error) {
var connectorClient domain.ConnectorClientState
var err error
logger := logger.Log.WithFields(logrus.Fields{"client_id": client_id})
callDurationTimer := prometheus.NewTimer(metrics.sqlConnectionLookupByClientIDDuration)
defer callDurationTimer.ObserveDuration()
ctx, cancel := context.WithTimeout(ctx, scm.queryTimeout)
defer cancel()
statement, err := scm.database.Prepare("SELECT account, org_id, client_id, dispatchers, canonical_facts, tags, message_id, message_sent FROM connections WHERE client_id = $1")
if err != nil {
logger.WithFields(logrus.Fields{"error": err}).Error("SQL prepare failed")
return connectorClient, FatalError{err}
}
defer statement.Close()
var account sql.NullString
var orgID domain.OrgID
var serializedCanonicalFacts sql.NullString
var serializedDispatchers sql.NullString
var serializedTags sql.NullString
var latestMessageID sql.NullString
err = statement.QueryRowContext(ctx, client_id).Scan(&account,
&orgID,
&connectorClient.ClientID,
&serializedDispatchers,
&serializedCanonicalFacts,
&serializedTags,
&latestMessageID,
&connectorClient.MessageMetadata.LatestTimestamp)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
logger.Debug("No connection found!")
return connectorClient, NotFoundError
} else if errors.Is(err, sql.ErrNoRows) == false {
logger.WithFields(logrus.Fields{"error": err}).Error("SQL query failed")
err = FatalError{err}
}
return connectorClient, err
}
connectorClient.OrgID = domain.OrgID(orgID)
if account.Valid {
connectorClient.Account = domain.AccountID(account.String)
}
logger = logger.WithFields(logrus.Fields{"account": connectorClient.Account, "org_id": connectorClient.OrgID})
connectorClient.CanonicalFacts = deserializeCanonicalFacts(logger, serializedCanonicalFacts)
connectorClient.Dispatchers = deserializeDispatchers(logger, serializedDispatchers)
connectorClient.Tags = deserializeTags(logger, serializedTags)
if latestMessageID.Valid {
connectorClient.MessageMetadata.LatestMessageID = latestMessageID.String
}
return connectorClient, nil
}