Skip to content

Commit

Permalink
Prevent metrics reader from reading invalid files (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Aug 14, 2020
1 parent d4d1411 commit fcf336e
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 11 deletions.
22 changes: 15 additions & 7 deletions smdebug/profiler/algorithm_metrics_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
SMProfilerEvents,
TensorboardProfilerEvents,
)
from smdebug.profiler.utils import get_node_id_from_tracefilename, get_timestamp_from_tracefilename
from smdebug.profiler.utils import (
get_node_id_from_tracefilename,
get_timestamp_from_tracefilename,
is_valid_tfprof_tracefilename,
is_valid_tracefilename,
)


class AlgorithmMetricsReader(MetricsReaderBase):
Expand Down Expand Up @@ -198,15 +203,18 @@ def parse_event_files(self, event_files):

event_data_list = S3Handler.get_objects(file_read_requests)
for event_data, event_file in zip(event_data_list, event_files):
if event_file.endswith("json.gz"):
if event_file.endswith("json.gz") and is_valid_tfprof_tracefilename(event_file):
self._get_event_parser(event_file).read_events_from_file(event_file)
self._parsed_files.add(event_file)
else:
event_string = event_data.decode("utf-8")
json_data = json.loads(event_string)
node_id = get_node_id_from_tracefilename(event_file)
self._get_event_parser(event_file).read_events_from_json_data(json_data, node_id)
self._parsed_files.add(event_file)
if is_valid_tracefilename(event_file):
event_string = event_data.decode("utf-8")
json_data = json.loads(event_string)
node_id = get_node_id_from_tracefilename(event_file)
self._get_event_parser(event_file).read_events_from_json_data(
json_data, node_id
)
self._parsed_files.add(event_file)

"""
Create a map of timestamp to filename
Expand Down
19 changes: 15 additions & 4 deletions smdebug/profiler/metrics_reader_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
from smdebug.core.logger import get_logger
from smdebug.core.utils import list_files_in_directory
from smdebug.profiler.profiler_constants import ENV_TRAIILING_DURATION, TRAILING_DURATION_DEFAULT
from smdebug.profiler.utils import TimeUnits, convert_utc_timestamp_to_microseconds
from smdebug.profiler.utils import (
TimeUnits,
convert_utc_timestamp_to_microseconds,
is_valid_tfprof_tracefilename,
is_valid_tracefilename,
validate_system_profiler_file,
)


class MetricsReaderBase:
Expand Down Expand Up @@ -168,9 +174,14 @@ def _update_start_after_prefix(self):

def _parse_event_files_local_mode(self, event_files):
for event_file in event_files:
if event_file not in self._parsed_files:
self._get_event_parser(event_file).read_events_from_file(event_file)
self._parsed_files.add(event_file)
if (
is_valid_tracefilename(event_file)
or is_valid_tfprof_tracefilename(event_file)
or validate_system_profiler_file(event_file)
):
if event_file not in self._parsed_files:
self._get_event_parser(event_file).read_events_from_file(event_file)
self._parsed_files.add(event_file)

def _get_timestamp_from_filename(self, event_file):
pass
Expand Down
10 changes: 10 additions & 0 deletions tests/profiler/core/test_algorithm_metric_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,13 @@ def test_MetricsReader_TFProfiler_timeline(use_in_memory_cache, trace_location):
assert len(events) == 798
elif trace_location == "s3":
assert len(events) >= 73000


@pytest.mark.parametrize("use_in_memory_cache", [True, False])
def test_MetricReader_all_files(use_in_memory_cache):
bucket_name = "s3://smdebug-testing/resources/pytorch_traces_with_pyinstru/profiler-output"
lt = S3AlgorithmMetricsReader(bucket_name, use_in_memory_cache=use_in_memory_cache)

events = lt.get_events(0, time.time() * CONVERT_TO_MICROSECS)

assert len(events) != 0

0 comments on commit fcf336e

Please sign in to comment.