diff --git a/main.go b/main.go
index 6851902..397af80 100644
--- a/main.go
+++ b/main.go
@@ -9,6 +9,7 @@ import (
"os"
"os/signal"
"syscall"
+ "time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
@@ -18,7 +19,8 @@ import (
)
const (
- indexHTML = `
+ shutdownTimeout = 10 * time.Second
+ indexHTML = `
Postgresql Exporter
@@ -56,7 +58,9 @@ func main() {
if err := cfg.Load(); err != nil {
log.Fatalf("could not load config: %v", err)
}
- collector := pgcollector.New()
+ ctx, cancel := context.WithCancel(context.Background())
+
+ collector := pgcollector.New(ctx)
collector.LoadConfig(cfg)
if err := prometheus.Register(collector); err != nil {
@@ -98,7 +102,10 @@ loop:
log.Printf("received signal: %v", sig)
}
}
- if err := srv.Shutdown(context.Background()); err != nil {
+ cancel()
+
+ shutdownCtx, _ := context.WithTimeout(context.Background(), shutdownTimeout)
+ if err := srv.Shutdown(shutdownCtx); err != nil {
log.Printf("could not shutdown http server: %v", err)
}
diff --git a/pkg/db/db.go b/pkg/db/db.go
index edd22a8..5ed7969 100644
--- a/pkg/db/db.go
+++ b/pkg/db/db.go
@@ -15,8 +15,8 @@ import (
"github.com/adjust/postgresql_exporter/pkg/config"
)
-//DbInterface describes Db methods
-type DbInterface interface {
+//Interface describes Db methods
+type Interface interface {
SetStatementTimeout(time.Duration) error
Exec(string) ([]map[string]interface{}, error)
PgVersion() config.PgVersion
@@ -30,12 +30,13 @@ var ErrQueryTimeout = errors.New("canceled due to statement timeout")
// Db describes database
type Db struct {
+ ctx context.Context
version config.PgVersion
db *pgx.Conn
}
// New creates new instance of database connection
-func New(dbConfig config.DbConfig) (*Db, error) {
+func New(ctx context.Context, dbConfig config.DbConfig) (*Db, error) {
var version config.PgVersion
cfg := pgx.ConnConfig{
@@ -76,10 +77,6 @@ func New(dbConfig config.DbConfig) (*Db, error) {
version = config.NoVersion
}
- if err != nil {
- return nil, fmt.Errorf("could not open connection: %v", err)
- }
-
if !dbConfig.IsNotPg {
if err := dbConn.Ping(context.Background()); err != nil {
return nil, fmt.Errorf("could not ping db: %v", err)
@@ -87,6 +84,7 @@ func New(dbConfig config.DbConfig) (*Db, error) {
}
return &Db{
+ ctx: ctx,
db: dbConn,
version: version,
}, nil
@@ -96,7 +94,7 @@ func New(dbConfig config.DbConfig) (*Db, error) {
func (d *Db) Exec(query string) ([]map[string]interface{}, error) {
values := make([]map[string]interface{}, 0)
- rows, err := d.db.Query(query)
+ rows, err := d.db.QueryEx(d.ctx, query, nil)
if err != nil {
return nil, fmt.Errorf("query error: %v", err)
}
diff --git a/pkg/pgcollector/pgcollector.go b/pkg/pgcollector/pgcollector.go
index 1161ad6..d9e9ce1 100644
--- a/pkg/pgcollector/pgcollector.go
+++ b/pkg/pgcollector/pgcollector.go
@@ -1,6 +1,7 @@
package pgcollector
import (
+ "context"
"fmt"
"log"
"sync"
@@ -32,6 +33,7 @@ type PgCollector struct {
config config.Interface
timeOuts uint32
errors uint32
+ ctx context.Context
}
type workerJob struct {
@@ -40,8 +42,10 @@ type workerJob struct {
}
// New create new instance of the PostgreSQL metrics collector
-func New() *PgCollector {
- return &PgCollector{}
+func New(ctx context.Context) *PgCollector {
+ return &PgCollector{
+ ctx: ctx,
+ }
}
// LoadConfig loads config
@@ -84,7 +88,7 @@ func createMetric(job *workerJob, name string, constLabels prometheus.Labels, ra
return nil, nil
}
-func (p *PgCollector) worker(conn db.DbInterface, jobs chan *workerJob, res chan<- prometheus.Metric, wg *sync.WaitGroup) {
+func (p *PgCollector) worker(conn db.Interface, jobs chan *workerJob, res chan<- prometheus.Metric, wg *sync.WaitGroup) {
defer wg.Done()
jobs:
@@ -204,16 +208,16 @@ func (p *PgCollector) Collect(metricsCh chan<- prometheus.Metric) {
wg := &sync.WaitGroup{}
- dbPool := make(map[string][]db.DbInterface)
+ dbPool := make(map[string][]db.Interface)
dbJobs := make(map[string]chan *workerJob)
for _, dbName := range p.config.DbList() {
dbConf := p.config.Db(dbName)
workersCnt := dbConf.Workers()
- dbPool[dbName] = make([]db.DbInterface, 0)
+ dbPool[dbName] = make([]db.Interface, 0)
for i := 0; i < workersCnt; i++ {
- conn, err := db.New(dbConf)
+ conn, err := db.New(p.ctx, dbConf)
if err != nil {
log.Printf("could not create db instance %q: %v", dbName, err)
atomic.AddUint32(&p.errors, 1)