Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replay stored procedures #703

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 80 additions & 8 deletions src/SimpleReplay/extract/extractor/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,22 @@
import pathlib
import re
from collections import OrderedDict

import dateutil.parser
import datetime


import redshift_connector
from boto3 import client
from tqdm import tqdm
import asyncio



from audit_logs_parsing import (
ConnectionLog,
)
from helper import aws_service as aws_service_helper
import util
from log_validation import remove_line_comments
from .cloudwatch_extractor import CloudwatchExtractor
from .s3_extractor import S3Extractor
Expand Down Expand Up @@ -75,6 +81,63 @@ def get_extract(self, log_location, start_time, end_time):
else:
return self.local_extractor.get_extract_locally(log_location, start_time, end_time)

async def get_stored_procedures(self, start_time, end_time, username, stored_procedures_map, sql_json):
'''
Handled bind varibales for stored procedures by querying the sys_query_history table
'''
parse_start_time = "'" + start_time + "'"
parse_end_time = "'" + end_time + "'"
query_results = []
async def fetch_procedure_data(transaction_ids):
transaction_ids_str = ','.join(str(id) for id in transaction_ids)
sys_query_history = f'select transaction_id, query_text \
from sys_query_history \
WHERE user_id > 1 \
AND transaction_id IN ({transaction_ids_str}) \
AND start_time >= {parse_start_time} \
AND end_time <= {parse_end_time}\
ORDER BY 1;'
cluster_object = util.cluster_dict(endpoint=self.config["source_cluster_endpoint"])
result = await aws_service_helper.redshift_execute_query_async(
redshift_cluster_id=cluster_object['id'],
redshift_database_name=cluster_object['database'],
redshift_user=username,
region=self.config['region'],
query=sys_query_history,
)
query_results.append(result)
return query_results
transaction_ids = list(stored_procedures_map.keys())
tasks = []
batch_size = 100
for i in range(0, len(transaction_ids),batch_size):
# fetching a portion(100) of the transaction ids
batch = transaction_ids[i:i + batch_size]
tasks.append(fetch_procedure_data(batch))

# Gather all the results when the tasks are completed
batch_results = await asyncio.gather(*tasks)

#Combine results from all batches
results = [result for batch_result in batch_results for result in batch_results]

#Process the results and update the sql_json
for result in results:
if 'ColumnMetadata' in result and 'Records' in result:
for row in result['Records']:
for field in row:
if 'stringValue' in field and field['stringValue'].startswith('call'):
modified_string_value = field['stringValue'].rsplit('--')[0]
transaction_id = row[0]['longValue']
stored_procedures_map[transaction_id] = modified_string_value

# Update sql_json with the modified query text
for xid, query_text in stored_procedures_map.items():
if xid in sql_json['transactions'] and sql_json['transactions'][xid]['queries']:
sql_json['transactions'][xid]['queries'][0]['text'] = query_text
return sql_json


def save_logs(self, logs, last_connections, output_directory, connections, start_time, end_time):
"""
saving the extracted logs in S3 location in the following format:
Expand Down Expand Up @@ -111,11 +174,16 @@ def save_logs(self, logs, last_connections, output_directory, connections, start
)
pathlib.Path(output_directory).mkdir(parents=True, exist_ok=True)

sql_json, missing_audit_log_connections, replacements = self.get_sql_connections_replacements(last_connections,
log_items)

with gzip.open(archive_filename, "wb") as f:
f.write(json.dumps(sql_json, indent=2).encode("utf-8"))
sql_json, missing_audit_log_connections, replacements, stored_procedures_map = self.get_sql_connections_replacements(last_connections,
log_items)
if self.config.get('replay_stored_procedures'):
logger.info(f'The total length of stored procedures found are : {len(stored_procedures_map)}')
sql_json_with_stored_procedure = asyncio.run(Extractor.get_stored_procedures(self,start_time=self.config['start_time'],end_time=self.config['end_time'],username=self.config['master_username'],stored_procedures_map=stored_procedures_map,sql_json=sql_json))
with gzip.open(archive_filename, "wb") as f:
f.write(json.dumps(sql_json_with_stored_procedure, indent=2).encode("utf-8"))
else:
with gzip.open(archive_filename, "wb") as f:
f.write(json.dumps(sql_json, indent=2).encode("utf-8"))

if is_s3:
dest = output_prefix + "/SQLs.json.gz"
Expand Down Expand Up @@ -169,6 +237,7 @@ def get_sql_connections_replacements(self, last_connections, log_items):
sql_json = {"transactions": OrderedDict()}
missing_audit_log_connections = set()
replacements = set()
stored_procedures_map = {}
for filename, queries in tqdm(
log_items,
disable=self.disable_progress_bar,
Expand Down Expand Up @@ -220,7 +289,10 @@ def get_sql_connections_replacements(self, last_connections, log_items):
query.text,
flags=re.IGNORECASE,
)

if self.config.get("replay_stored_procedures", None):
if query.text.lower().startswith("call"):
stored_procedures_map[query.xid] = query.text

query.text = f"{query.text.strip()}"
if not len(query.text) == 0:
if not query.text.endswith(";"):
Expand All @@ -231,7 +303,7 @@ def get_sql_connections_replacements(self, last_connections, log_items):

if not hash((query.database_name, query.username, query.pid)) in last_connections:
missing_audit_log_connections.add((query.database_name, query.username, query.pid))
return sql_json, missing_audit_log_connections, replacements
return sql_json, missing_audit_log_connections, replacements,stored_procedures_map

def unload_system_table(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/SimpleReplay/extract/extractor/s3_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_extract_from_s3(self, log_bucket, log_prefix, start_time, end_time):
last_connections = {}
databases = set()

bucket_objects = aws_service_helper.s3_get_bucket_contents(log_bucket, log_prefix)
bucket_objects = aws_service_helper.sync_s3_get_bucket_contents(log_bucket, log_prefix)

s3_connection_logs = []
s3_user_activity_logs = []
Expand Down
Loading