Skip to content

Commit

Permalink
Fix a critical bug of Splunk results reader, lack of pagination (#657)
Browse files Browse the repository at this point in the history
* fix a critical bug of splunk result reader

* typo pagenate -> paginate

* Refactored code and reformatted long lines.

Updated failing tests for new code.

---------

Co-authored-by: Ian Hellen <ianhelle@microsoft.com>
  • Loading branch information
Tatsuya-hasegawa and ianhelle committed May 10, 2023
1 parent b248797 commit 62cfaa5
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 24 deletions.
151 changes: 128 additions & 23 deletions msticpy/data/drivers/splunk_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
# license information.
# --------------------------------------------------------------------------
"""Splunk Driver class."""
from datetime import datetime
import logging
from datetime import datetime, timedelta
from time import sleep
from typing import Any, Dict, Iterable, Optional, Tuple, Union

Expand All @@ -14,6 +15,7 @@
from ..._version import VERSION
from ...common.exceptions import (
MsticpyConnectionError,
MsticpyDataQueryError,
MsticpyImportExtraError,
MsticpyUserConfigError,
)
Expand All @@ -35,6 +37,8 @@
__version__ = VERSION
__author__ = "Ashwin Patil"

logger = logging.getLogger(__name__)


SPLUNK_CONNECT_ARGS = {
"host": "(string) The host name (the default is 'localhost').",
Expand Down Expand Up @@ -73,7 +77,9 @@ def __init__(self, **kwargs):
self.service = None
self._loaded = True
self._connected = False
self._debug = kwargs.get("debug", False)
if kwargs.get("debug", False):
logger.setLevel(logging.DEBUG)

self.set_driver_property(
DriverProps.PUBLIC_ATTRS,
{
Expand Down Expand Up @@ -194,9 +200,17 @@ def query(
Other Parameters
----------------
count : int, optional
Passed to Splunk oneshot method if `oneshot` is True, by default, 0
Passed to Splunk job that indicates the maximum number
of entities to return. A value of 0 indicates no maximum,
by default, 0
oneshot : bool, optional
Set to True for oneshot (blocking) mode, by default False
page_size = int, optional
Pass to Splunk results reader in terms of fetch speed,
which sets of result amount will be got at a time,
by default, 100
timeout : int, optional
Amount of time to wait for results, by default 60
Returns
-------
Expand All @@ -212,35 +226,38 @@ def query(
# default to unlimited query unless count is specified
count = kwargs.pop("count", 0)

# Normal, oneshot or blocking searches. Defaults to non-blocking
# Oneshot is blocking a blocking HTTP call which may cause time-outs
# https://dev.splunk.com/enterprise/docs/python/sdk-python/howtousesplunkpython/howtorunsearchespython
# Get sets of N results at a time, N=100 by default
page_size = kwargs.pop("page_size", 100)

# Normal (non-blocking) searches or oneshot (blocking) searches.
# Defaults to Normal(non-blocking)

# Oneshot is a blocking search that is scheduled to run immediately.
# Instead of returning a search job, this mode returns the results
# of the search once completed.
# Because this is a blocking search, the results are not available
# until the search has finished.
# https://dev.splunk.com/enterprise/docs/python/
# sdk-python/howtousesplunkpython/howtorunsearchespython
is_oneshot = kwargs.get("oneshot", False)

if is_oneshot is True:
kwargs["output_mode"] = "json"
query_results = self.service.jobs.oneshot(query, count=count, **kwargs)
reader = sp_results.ResultsReader(query_results)

reader = sp_results.JSONResultsReader( # pylint: disable=no-member
query_results
) # due to DeprecationWarning of normal ResultsReader
resp_rows = [row for row in reader if isinstance(row, dict)]
else:
# Set mode and initialize async job
kwargs_normalsearch = {"exec_mode": "normal"}
query_job = self.service.jobs.create(query, **kwargs_normalsearch)

# Initiate progress bar and start while loop, waiting for async query to complete
progress_bar = tqdm(total=100, desc="Waiting Splunk job to complete")
while not query_job.is_done():
current_state = query_job.state
progress = float(current_state["content"]["doneProgress"]) * 100
progress_bar.update(progress)
sleep(1)

# Update progress bar indicating completion and fetch results
progress_bar.update(100)
progress_bar.close()
reader = sp_results.ResultsReader(query_job.results())
query_job = self.service.jobs.create(
query, count=count, **kwargs_normalsearch
)
resp_rows, reader = self._exec_async_search(query_job, page_size, **kwargs)

resp_rows = [row for row in reader if isinstance(row, dict)]
if not resp_rows:
if len(resp_rows) == 0 or not resp_rows:
print("Warning - query did not return any results.")
return [row for row in reader if isinstance(row, sp_results.Message)]
return pd.DataFrame(resp_rows)
Expand Down Expand Up @@ -316,6 +333,94 @@ def driver_queries(self) -> Iterable[Dict[str, Any]]:
]
return []

def _exec_async_search(self, query_job, page_size, timeout=60):
"""Execute an async search and return results."""
# Initiate progress bar and start while loop, waiting for async query to complete
progress_bar = tqdm(total=100, desc="Waiting Splunk job to complete")
prev_progress = 0
offset = 0 # Start at result 0
start_time = datetime.now()
end_time = start_time + timedelta(seconds=timeout)
while True:
while not query_job.is_ready():
sleep(1)
if self._retrieve_job_status(query_job, progress_bar, prev_progress):
break
if datetime.now() > end_time:
raise MsticpyDataQueryError(
"Timeout waiting for Splunk query to complete",
f"Job completion reported {query_job['doneProgress']}",
title="Splunk query timeout",
)
sleep(1)
# Update progress bar indicating job completion
progress_bar.update(100)
progress_bar.close()
sleep(2)

logger.info("Implicit parameter dump - 'page_size': %d", page_size)
return self._retrieve_results(query_job, offset, page_size)

@staticmethod
def _retrieve_job_status(query_job, progress_bar, prev_progress):
"""Poll the status of a job and update the progress bar."""
stats = {
"is_done": query_job["isDone"],
"done_progress": float(query_job["doneProgress"]) * 100,
"scan_count": int(query_job["scanCount"]),
"event_count": int(query_job["eventCount"]),
"result_count": int(query_job["resultCount"]),
}
status = (
"\r%(done_progress)03.1f%% %(scan_count)d scanned "
"%(event_count)d matched %(result_count)d results"
) % stats
if prev_progress == 0:
progress = stats["done_progress"]
else:
progress = stats["done_progress"] - prev_progress
prev_progress = stats["done_progress"]
progress_bar.update(progress)

if stats["is_done"] == "1":
logger.info(status)
logger.info("Splunk job completed.")
return True
return False

@staticmethod
def _retrieve_results(query_job, offset, page_size):
"""Retrieve the results of a job, decode and return them."""
# Retrieving all the results by paginate
result_count = int(
query_job["resultCount"]
) # Number of results this job returned

resp_rows = []
progress_bar_paginate = tqdm(
total=result_count, desc="Waiting Splunk result to retrieve"
)
while offset < result_count:
kwargs_paginate = {
"count": page_size,
"offset": offset,
"output_mode": "json",
}
# Get the search results and display them
search_results = query_job.results(**kwargs_paginate)
# due to DeprecationWarning of normal ResultsReader
reader = sp_results.JSONResultsReader( # pylint: disable=no-member
search_results
)
resp_rows.extend([row for row in reader if isinstance(row, dict)])
progress_bar_paginate.update(page_size)
offset += page_size
# Update progress bar indicating fetch results
progress_bar_paginate.update(result_count)
progress_bar_paginate.close()
logger.info("Retrieved %d results.", len(resp_rows))
return resp_rows, reader

@property
def _saved_searches(self) -> Union[pd.DataFrame, Any]:
"""
Expand Down
28 changes: 27 additions & 1 deletion tests/data/drivers/test_splunk_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from msticpy.common.exceptions import (
MsticpyConnectionError,
MsticpyDataQueryError,
MsticpyNotConnectedError,
MsticpyUserConfigError,
)
Expand Down Expand Up @@ -69,15 +70,35 @@ def __init__(self, name, count):


class _MockAsyncResponse:
stats = {
"isDone": "0",
"doneProgress": 0.0,
"scanCount": 1,
"eventCount": 100,
"resultCount": 100,
}

def __init__(self, query):
self.query = query

def results(self):
def __getitem__(self, key):
"""Mock method."""
return self.stats[key]

def results(self, **kwargs):
return self.query

def is_done(self):
return True

def is_ready(self):
return True

@classmethod
def set_done(cls):
cls.stats["isDone"] = "1"
cls.stats["doneProgress"] = 1


class _MockSplunkCall:
def create(query, **kwargs):
Expand Down Expand Up @@ -260,6 +281,7 @@ def test_splunk_query_success(splunk_client, splunk_results):
splunk_client.connect = cli_connect
sp_driver = SplunkDriver()
splunk_results.ResultsReader = _results_reader
splunk_results.JSONResultsReader = _results_reader

# trying to get these before connecting should throw
with pytest.raises(MsticpyNotConnectedError) as mp_ex:
Expand All @@ -279,6 +301,10 @@ def test_splunk_query_success(splunk_client, splunk_results):
check.is_not_instance(response, pd.DataFrame)
check.equal(len(response), 0)

with pytest.raises(MsticpyDataQueryError):
df_result = sp_driver.query("some query", timeout=1)

_MockAsyncResponse.set_done()
df_result = sp_driver.query("some query")
check.is_instance(df_result, pd.DataFrame)
check.equal(len(df_result), 10)
Expand Down

0 comments on commit 62cfaa5

Please sign in to comment.