Skip to content

Commit

Permalink
Fetch all pgss rows initially to avoid unnecessary filtering in the e…
Browse files Browse the repository at this point in the history
…xecution.
  • Loading branch information
amw-zero committed May 14, 2024
1 parent 58f372b commit 1af5fec
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 63 deletions.
33 changes: 19 additions & 14 deletions postgres/datadog_checks/postgres/query_calls_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,30 @@ def end_query_call_snapshot(self):
self.called_queryids = self.next_called_queryids
self.next_called_queryids = set()

def set_calls(self, queryid, calls):
def set_calls(self, rows):
"""Updates the cache of calls per query id.
Returns whether or not the number of calls changed based on
the newly updated value. The first seen update for a queryid
does not count as a change in values since that would result
in an inflated value."""
calls_changed = False
for row in rows:
queryid = row['queryid']
calls = row['calls']
calls_changed = False

if queryid in self.cache:
diff = calls - self.cache[queryid]
# Positive deltas mean the statement remained in pg_stat_statements
# between check calls. Negative deltas mean the statement was evicted
# and replaced with a new call count. Both cases should count as a call
# change.
calls_changed = diff != 0
else:
calls_changed = True
if queryid in self.cache:
diff = calls - self.cache[queryid]
# Positive deltas mean the statement remained in pg_stat_statements
# between check calls. Negative deltas mean the statement was evicted
# and replaced with a new call count. Both cases should count as a call
# change.
calls_changed = diff != 0
else:
calls_changed = True

self.next_cache[queryid] = calls
if calls_changed:
self.next_called_queryids.add(queryid)
self.next_cache[queryid] = calls
if calls_changed:
self.next_called_queryids.add(queryid)

self.end_query_call_snapshot()
104 changes: 72 additions & 32 deletions postgres/datadog_checks/postgres/statements.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from .query_calls_cache import QueryCallsCache
from .util import DatabaseConfigurationError, payload_pg_version, warning_with_tags
from .version_utils import V9_4, V14
from .version_utils import V9_4, V10, V14

try:
import datadog_agent
Expand All @@ -42,11 +42,31 @@
LEFT JOIN pg_database
ON pg_stat_statements.dbid = pg_database.oid
WHERE query NOT LIKE 'EXPLAIN %%'
AND queryid = ANY('{{ {called_queryids} }}'::bigint[])
{queryid_filter}
{filters}
{extra_clauses}
"""

def statements_query(**kwargs):
pg_stat_statements_view = kwargs.get('pg_stat_statements_view', 'pg_stat_statements')
cols = kwargs.get('cols', '*')
filters = kwargs.get('filters', '')
extra_clauses = kwargs.get('extra_clauses', '')
called_queryids = kwargs.get('called_queryids', [])

queryid_filter = ""
if len(called_queryids) > 0:
queryid_filter = f"AND queryid = ANY('{{ {called_queryids} }}'::bigint[])"

return STATEMENTS_QUERY.format(
cols=cols,
pg_stat_statements_view=pg_stat_statements_view,
filters=filters,
extra_clauses=extra_clauses,
queryid_filter=queryid_filter,
called_queryids=called_queryids,
)

BASELINE_METRICS_EXPIRY = 60 * 10 # 10 minutes

# Use pg_stat_statements(false) when available as an optimization to avoid pulling SQL text from disk
Expand Down Expand Up @@ -188,12 +208,10 @@ def _get_pg_stat_statements_columns(self):
return self._stat_column_cache

# Querying over '*' with limit 0 allows fetching only the column names from the cursor without data
query = STATEMENTS_QUERY.format(
query = statements_query(
cols='*',
pg_stat_statements_view=self._config.pg_stat_statements_view,
extra_clauses="LIMIT 0",
filters="",
called_queryids="",
)
with self._check._get_main_db() as conn:
with conn.cursor(cursor_factory=CommenterCursor) as cursor:
Expand All @@ -212,16 +230,10 @@ def _check_called_queries(self):
pgss_view_without_query_text = "pg_stat_statements(false)"

with self._check._get_main_db() as conn:
with conn.cursor(cursor_factory=CommenterCursor) as cursor:
with conn.cursor(cursor_factory=CommenterDictCursor) as cursor:
query = QUERYID_TO_CALLS_QUERY.format(pg_stat_statements_view=pgss_view_without_query_text)
rows = self._execute_query(cursor, query, params=(self._config.dbname,))

for row in rows:
queryid = row[0]
calls = row[1]
self._query_calls_cache.set_calls(queryid, calls)

self._query_calls_cache.end_query_call_snapshot()
self._query_calls_cache.set_calls(rows)
self._check.gauge(
"dd.postgresql.pg_stat_statements.calls_changed",
len(self._query_calls_cache.called_queryids),
Expand Down Expand Up @@ -267,7 +279,6 @@ def collect_per_statement_metrics(self):
@tracked_method(agent_check_getter=agent_check_getter, track_result_length=True)
def _load_pg_stat_statements(self):
try:
self._check_called_queries()
available_columns = set(self._get_pg_stat_statements_columns())
missing_columns = PG_STAT_STATEMENTS_REQUIRED_COLUMNS - available_columns
if len(missing_columns) > 0:
Expand Down Expand Up @@ -335,17 +346,27 @@ def _load_pg_stat_statements(self):
params = params + tuple(self._config.ignore_databases)
with self._check._get_main_db() as conn:
with conn.cursor(cursor_factory=CommenterDictCursor) as cursor:
return self._execute_query(
cursor,
STATEMENTS_QUERY.format(
cols=', '.join(query_columns),
pg_stat_statements_view=self._config.pg_stat_statements_view,
filters=filters,
extra_clauses="",
called_queryids=', '.join([str(i) for i in self._query_calls_cache.called_queryids]),
),
params=params,
)
if len(self._query_calls_cache.cache) > 0:
return self._execute_query(
cursor,
statements_query(
cols=', '.join(query_columns),
pg_stat_statements_view=self._config.pg_stat_statements_view,
filters=filters,
called_queryids=', '.join([str(i) for i in self._query_calls_cache.called_queryids]),
),
params=params,
)
else:
return self._execute_query(
cursor,
statements_query(
cols=', '.join(query_columns),
pg_stat_statements_view=self._config.pg_stat_statements_view,
filters=filters,
),
params=params,
)
except psycopg2.Error as e:
error_tag = "error:database-{}".format(type(e).__name__)

Expand Down Expand Up @@ -453,24 +474,27 @@ def _emit_pg_stat_statements_metrics(self):
except psycopg2.Error as e:
self._log.warning("Failed to query for pg_stat_statements count: %s", e)

# _apply_deltas expects normalized rows before any merging of duplicates.
# _apply_called_queries expects normalized rows before any merging of duplicates.
# It takes the incremental pg_stat_statements rows and constructs the full set of rows
# by adding the existing values in the baseline_metrics cache. This is equivalent to
# fetching the full set of rows from pg_stat_statements, but we avoid paying the price of
# actually querying the rows.
def _apply_deltas(self, rows):
def _apply_called_queries(self, rows):
# Apply called queries to baseline_metrics
for row in rows:
query_signature = row['query_signature']
queryid = row['queryid']
baseline_row = copy.copy(row)

# To avoid high memory usage, don't cache the query text since it can be large.
del baseline_row['query']
if query_signature not in self._baseline_metrics:
self._baseline_metrics[query_signature] = {queryid: baseline_row}
else:
self._baseline_metrics[query_signature][queryid] = baseline_row

# Apply query text, so it doesn't have to be cached.
# Apply query text for called queries since it is not cached and uncalled queries won't get result
# in sent metrics.
query_text = {row['query_signature']: row['query'] for row in rows}
applied_rows = []
for query_signature, query_sig_metrics in self._baseline_metrics.items():
Expand Down Expand Up @@ -498,10 +522,26 @@ def _collect_metrics_rows(self):
self._emit_pg_stat_statements_metrics()
self._emit_pg_stat_statements_dealloc()

self._check_baseline_metrics_expiry()
rows = self._load_pg_stat_statements()
rows = self._normalize_queries(rows)
rows = self._apply_deltas(rows)
rows = []
if self._check.version < V10:
rows = self._load_pg_stat_statements()
rows = self._normalize_queries(rows)
elif len(self._baseline_metrics) == 0:
# When we don't have baseline metrics (either on the first run or after cache expiry),
# we fetch all rows from pg_stat_statements, and update the initial state of relevant
# caches.
rows = self._load_pg_stat_statements()
rows = self._normalize_queries(rows)
self._query_calls_cache.set_calls(rows)
self._apply_called_queries(rows)
else:
# When we do have baseline metrics, use them to construct the full set of rows
# so that compute_derivative_rows can merge duplicates and calculate deltas.
self._check_baseline_metrics_expiry()
self._check_called_queries()
rows = self._load_pg_stat_statements()
rows = self._normalize_queries(rows)
rows = self._apply_called_queries(rows)

if not rows:
return []
Expand Down
27 changes: 10 additions & 17 deletions postgres/tests/test_query_calls_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,47 +10,40 @@

def test_statement_queryid_cache_same_calls_does_not_change():
cache = QueryCallsCache()
cache.set_calls(123, 1)
cache.end_query_call_snapshot()
cache.set_calls(123, 1)
cache.end_query_call_snapshot()
cache.set_calls([{'queryid': 123, 'calls': 1}])
cache.set_calls([{'queryid': 123, 'calls': 1}])

assert cache.called_queryids == set()


def test_statement_queryid_cache_multiple_calls_change():
cache = QueryCallsCache()
cache.set_calls(123, 1)
cache.end_query_call_snapshot()
cache.set_calls(123, 2)
cache.set_calls([{'queryid': 123, 'calls': 1}])
cache.set_calls([{'queryid': 123, 'calls': 2}])

assert cache.called_queryids == {123}


def test_statement_queryid_cache_after_pg_stat_statement_eviction():
cache = QueryCallsCache()
cache.set_calls(123, 100)
cache.end_query_call_snapshot()
cache.set_calls(123, 5)
cache.set_calls([{'queryid': 123, 'calls': 100}])
cache.set_calls([{'queryid': 123, 'calls': 5}])

assert cache.called_queryids == {123}


def test_statement_queryid_cache_snapshot_eviction():
cache = QueryCallsCache()
cache.set_calls(123, 100)
cache.end_query_call_snapshot()
cache.set_calls(124, 5)
cache.end_query_call_snapshot()
cache.set_calls([{'queryid': 123, 'calls': 100}])
cache.set_calls([{'queryid': 124, 'calls': 5}])

assert cache.cache.get(123, None) is None


def test_statement_queryid_multiple_inserts():
cache = QueryCallsCache()
cache.set_calls(123, 100)
cache.set_calls(124, 5)
cache.end_query_call_snapshot()
cache.set_calls([{'queryid': 123, 'calls': 100}])
cache.set_calls([{'queryid': 124, 'calls': 5}])

assert cache.cache[123] == 100
assert cache.cache[124] == 5

0 comments on commit 1af5fec

Please sign in to comment.