-
Notifications
You must be signed in to change notification settings - Fork 0
/
cloudwatch_api.py
144 lines (124 loc) · 5.22 KB
/
cloudwatch_api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""Class for interacting with Cloudwatch API."""
import os
from collections import deque
from datetime import datetime, timedelta, timezone
from math import ceil
from typing import Deque
import boto3
from tap_cloudwatch.exception import InvalidQueryException
from tap_cloudwatch.subquery import Subquery
class CloudwatchAPI:
"""Cloudwatch class for interacting with the API."""
def __init__(self, logger):
"""Initialize CloudwatchAPI."""
self._client = None
self.logger = logger
self.max_concurrent_queries = 20
@property
def client(self):
"""Property to access client object."""
if not self._client:
raise Exception("Client not yet initialized")
return self._client
def authenticate(self, config):
"""Authenticate the AWS client."""
self._client = self._create_client(config)
def _create_client(self, config):
aws_access_key_id = config.get("aws_access_key_id") or os.environ.get(
"AWS_ACCESS_KEY_ID"
)
aws_secret_access_key = config.get("aws_secret_access_key") or os.environ.get(
"AWS_SECRET_ACCESS_KEY"
)
aws_session_token = config.get("aws_session_token") or os.environ.get(
"AWS_SESSION_TOKEN"
)
aws_profile = config.get("aws_profile") or os.environ.get("AWS_PROFILE")
aws_endpoint_url = config.get("aws_endpoint_url")
aws_region_name = config.get("aws_region_name")
# AWS credentials based authentication
if aws_access_key_id and aws_secret_access_key:
aws_session = boto3.session.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
region_name=aws_region_name,
)
# AWS Profile based authentication
else:
aws_session = boto3.session.Session(profile_name=aws_profile)
if aws_endpoint_url:
logs = aws_session.client("logs", endpoint_url=aws_endpoint_url)
else:
logs = aws_session.client("logs")
return logs
def _split_batch_into_windows(self, start_time, end_time, batch_increment_s):
start_time_epoch = start_time.timestamp()
end_time_epoch = end_time.timestamp()
diff_s = end_time_epoch - start_time_epoch
total_batches = ceil(diff_s / batch_increment_s)
batch_windows = []
for batch_num in range(total_batches):
if batch_num != 0:
# Inclusive start and end date, so on second iteration
# we can skip one second.
query_start = int(
start_time_epoch + (batch_increment_s * batch_num) + 1
)
else:
query_start = int(start_time_epoch + (batch_increment_s * batch_num))
# Never exceed the end_time
query_end = min(
int(start_time_epoch + (batch_increment_s * (batch_num + 1))),
int(end_time_epoch),
)
batch_windows.append((query_start, query_end))
return batch_windows
def _validate_query(self, query):
if "|sort" in query.replace(" ", ""):
raise InvalidQueryException("sort not allowed")
if "|limit" in query.replace(" ", ""):
raise InvalidQueryException("limit not allowed")
if "stats" in query:
raise InvalidQueryException("stats not allowed")
if "@timestamp" not in query.split("|")[0]:
raise InvalidQueryException(
"@timestamp field is used as the replication key so it must be selected"
)
def _queue_is_full(self, queue):
return len(queue) >= self.max_concurrent_queries
@staticmethod
def _get_completed_query(queue):
return queue.popleft()
def _iterate_batches(self, batch_windows, log_group, query):
queue: Deque["Subquery"] = deque()
for start_ts, end_ts in batch_windows:
if self._queue_is_full(queue):
query_obj = self._get_completed_query(queue)
queue.append(
Subquery(self.client, start_ts, end_ts, log_group, query).execute()
)
yield query_obj.get_results()
else:
queue.append(
Subquery(self.client, start_ts, end_ts, log_group, query).execute()
)
# Clear queue to complete
while len(queue) > 0:
query_obj = self._get_completed_query(queue)
yield query_obj.get_results()
def _alter_end_ts(self, end_time):
default_end_time = datetime.now(timezone.utc) - timedelta(minutes=5)
if end_time:
return min([end_time, default_end_time])
else:
return default_end_time
def get_records_iterator(
self, bookmark, log_group, query, batch_increment_s, end_time
):
"""Retrieve records from Cloudwatch."""
self._validate_query(query)
batch_windows = self._split_batch_into_windows(
bookmark, self._alter_end_ts(end_time), batch_increment_s
)
yield from self._iterate_batches(batch_windows, log_group, query)