Skip to content

Commit

Permalink
Add ability to only save shapes of tensors (#328)
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul003 committed Aug 26, 2020
1 parent 47ceaf0 commit c9eb769
Show file tree
Hide file tree
Showing 22 changed files with 681 additions and 205 deletions.
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ include_workers
include_regex
reductions
save_raw_tensor
save_shape
save_interval
save_steps
start_step
Expand Down
78 changes: 65 additions & 13 deletions smdebug/core/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
size_and_shape,
validate_custom_tensor_value,
)
from smdebug.core.writer import FileWriter
from smdebug.core.writer import FileWriter, ShapeWriter
from smdebug.exceptions import InvalidCollectionConfiguration

try:
Expand Down Expand Up @@ -222,7 +222,7 @@ def __init__(
self.mode = ModeKeys.GLOBAL
self.mode_steps = {ModeKeys.GLOBAL: init_step}
self.writer = None

self.shape_writer = None
if is_sagemaker_job() and SageMakerFileMetricsWriter is not None:
self.metrics_writer = SageMakerFileMetricsWriter()
else:
Expand Down Expand Up @@ -343,6 +343,12 @@ def _get_collections_to_save_for_step(self) -> Set["Collection"]:
)
return self._collections_to_save_for_step

def _saving_shapes_in_step(self) -> bool:
for coll in self._get_collections_to_save_for_step():
if coll.reduction_config.save_shape is True:
return True
return False

def _get_collections_with_tensor(self, tensor_name) -> Set["Collection"]:
self._assert_prep()
# for tf this will be prepopulated in check_and_add_tensor
Expand Down Expand Up @@ -404,6 +410,17 @@ def _prepare_collections(self):
self.prepared_collections = True

#### End of Save Manager methods ####
@staticmethod
def _close_given_writer_map(writer_dict):
# Delete all the dist training writers
to_delete_writers = []
for key, writer in writer_dict.items():
# close calls flush
writer.close()
to_delete_writers.append(key)

for key in to_delete_writers:
del writer_dict[key]

def _close_writers(self) -> None:
if self.dry_run:
Expand All @@ -417,16 +434,11 @@ def _close_writers(self) -> None:
self.writer.close()
self.writer = None

to_delete_writers = []
self._close_given_writer_map(self.tb_writers)

# Delete all the tb writers
for mode, writer in self.tb_writers.items():
if writer is not None:
writer.flush()
writer.close()
to_delete_writers.append(mode)
for mode in to_delete_writers:
del self.tb_writers[mode]
if self.shape_writer is not None:
self.shape_writer.close()
self.shape_writer = None

def _initialize_writers(self, only_initialize_if_missing=False) -> None:
# Function is overridden in smdebug/tensorflow/base_hook.py
Expand Down Expand Up @@ -454,17 +466,32 @@ def _initialize_writers(self, only_initialize_if_missing=False) -> None:
if self.save_all_workers is False:
if self.worker != self.chief_worker:
return

self.writer = FileWriter(trial_dir=self.out_dir, step=self.step, worker=self.worker)

def _get_writers(self, tensor_name, tensor_ref=None) -> List[FileWriter]:
if self._saving_shapes_in_step():
self.shape_writer = ShapeWriter(
trial_dir=self.out_dir,
step=self.step,
worker=self.worker,
index_writer=self.writer.index_writer,
)

def _get_single_process_writers(self, shape_writers=False) -> List[FileWriter]:
if shape_writers is False:
return [self.writer] if self.writer else []
else:
return [self.shape_writer] if self.shape_writer else []

def _get_writers(self, tensor_name, tensor_ref=None, shape_writers=False) -> List[FileWriter]:
"""
:param tensor_name:
:param tensor_ref: used by TF
:return: List[FileWriter]
"""
if self.save_all_workers is False and self.worker != self.chief_worker:
return []
return [self.writer] if self.writer else []
return self._get_single_process_writers(shape_writers)

def _maybe_get_tb_writer(self) -> Optional[FileWriter]:
""" Returns a FileWriter object if `hook.tensorboard_dir` has been specified, else None.
Expand Down Expand Up @@ -726,6 +753,28 @@ def _write_raw_tensor(self, tensor_name, tensor_value, save_collections, tensor_
self._write_raw_tensor_simple(tensor_name, tensor_value, tensor_ref=tensor_ref)
break

def _write_shape(self, tensor_name, tensor_value, save_collections, tensor_ref=None):
shape_writers = self._get_writers(tensor_name, tensor_ref=tensor_ref, shape_writers=True)
for s_col in save_collections:
reduction_config = s_col.reduction_config
if self.dry_run is False and reduction_config.save_shape is True:
numpy_tensor_value = self._make_numpy_array(tensor_value)
this_size, this_shape = size_and_shape(numpy_tensor_value)
if tensor_ref is not None and tensor_ref.tf_obj is not None:
original_name = tensor_ref.tf_obj.name
else:
original_name = None

for writer in shape_writers:
writer.write_shape(
tensor_name,
this_shape,
self.mode,
self.mode_steps[self.mode],
original_name=original_name,
)
break

def _write_raw_tensor_simple(self, tensor_name, tensor_value, tensor_ref=None, timestamp=None):
# tensor_ref is used by TF
# todo: if fp16, check perf of saving as fp16 in proto vs as fp32
Expand Down Expand Up @@ -805,6 +854,9 @@ def _write_for_tensor(self, tensor_name, tensor_value, save_collections, tensor_
:param save_collections: list of collections which are being saved for this step
"""
self._log_save(tensor_name, save_collections)

self._write_shape(tensor_name, tensor_value, save_collections, tensor_ref=tensor_ref)

# write reductions defined for collections this tensor may be part of
self._write_reductions(tensor_name, tensor_value, save_collections, tensor_ref=tensor_ref)

Expand Down
100 changes: 47 additions & 53 deletions smdebug/core/index_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
MISSING_EVENT_FILE_RETRY_LIMIT,
MISSING_EVENT_FILE_RETRY_LIMIT_KEY,
)
from smdebug.core.locations import IndexFileLocationUtils, TensorLocation
from smdebug.core.locations import IndexFileLocationUtils, TensorLocation, TensorShape
from smdebug.core.logger import get_logger
from smdebug.core.modes import ModeKeys
from smdebug.core.s3_utils import list_s3_objects
Expand Down Expand Up @@ -120,12 +120,22 @@ def fetch_tensor_value(self, tensor_location: TensorLocation):
def list_event_files(self, start_after_prefix):
pass

@abstractmethod
def load_tensor_data_from_index_files(
self, start_after_key=None, range_steps=None
) -> Tuple[Dict[str, Dict[int, Dict[str, TensorLocation]]], str]:
"""Return a triply nested dict referring to tensor data."""

responses, steps, last_index_token, workers = self.read_index_files(
start_after_key, range_steps
)

tensor_data = {}
for step, response, worker in zip(steps, responses, workers):
tensor_data = self._update_tensors_from_json(
tensor_data, step, response, self.path, worker
)
return tensor_data, last_index_token

@abstractmethod
def _is_event_file_present(self, file_name) -> bool:
pass
Expand Down Expand Up @@ -203,8 +213,10 @@ def _validate(index_dict):
raise IndexReaderException("meta section is not present")
if len(index_dict["meta"]) == 0:
raise IndexReaderException("meta section is empty")
if "tensor_payload" not in index_dict:
raise IndexReaderException("tensor_payload section is not present")
if "tensor_payload" not in index_dict and "shape_payload" not in index_dict:
raise IndexReaderException(
"neither tensor_payload nor shape_payload sections are present"
)

def _update_tensors_from_json(
self, index_tensors_dict, step, response: bytes, path, worker
Expand Down Expand Up @@ -233,28 +245,41 @@ def _update_tensors_from_json(
mode = index_meta["mode"]
mode = ModeKeys[mode.strip()]
mode_step = index_meta["mode_step"]
event_file_name = os.path.join(path, index_meta["event_file_name"])
tensors = index_dict["tensor_payload"]
for tensor in tensors:
tensor_name = tensor["tensorname"]
start_idx = tensor["start_idx"]
length = tensor["length"]
tensor_location = TensorLocation(
tensor_name, mode, mode_step, event_file_name, start_idx, length, worker
)

to_update_index_dict = []

if "tensor_payload" in index_dict and len(index_dict["tensor_payload"]):
event_file_name = os.path.join(path, index_meta["event_file_name"])
for tensor in index_dict["tensor_payload"]:
tensor_name = tensor["tensorname"]
start_idx = tensor["start_idx"]
length = tensor["length"]
tensor_location = TensorLocation(
tensor_name, mode, mode_step, event_file_name, start_idx, length, worker
)
to_update_index_dict.append((tensor_name, step, tensor_location))

if "shape_payload" in index_dict and len(index_dict["shape_payload"]):
for tensor in index_dict["shape_payload"]:
tensor_name = tensor["tensorname"]
original_name = tensor["originalname"]
shape = tensor["shape"]
ts = TensorShape(tensor_name, mode, mode_step, shape, original_name)
to_update_index_dict.append((tensor_name, step, ts))

for tu in to_update_index_dict:
tensor_name, step, obj = tu
if isinstance(obj, TensorLocation):
obj_dict = {"tensor_location": obj}
elif isinstance(obj, TensorShape):
obj_dict = {"tensor_shape": obj}
if tensor_name in index_tensors_dict:
if step in index_tensors_dict[tensor_name]:
index_tensors_dict[tensor_name][step].update(
{worker: {"tensor_location": tensor_location}}
)
index_tensors_dict[tensor_name][step].update({worker: obj_dict})
else:
index_tensors_dict[tensor_name].update(
{step: {worker: {"tensor_location": tensor_location}}}
)
index_tensors_dict[tensor_name].update({step: {worker: obj_dict}})
else:
index_tensors_dict[tensor_name] = {
step: {worker: {"tensor_location": tensor_location}}
}
index_tensors_dict[tensor_name] = {step: {worker: obj_dict}}
return index_tensors_dict


Expand Down Expand Up @@ -285,22 +310,6 @@ def fetch_tensor_value(self, tensor_location: TensorLocation) -> np.ndarray:
tensor_name, step, tensor_data, mode, mode_step = tensor_tuple
return tensor_data

def load_tensor_data_from_index_files(
self, start_after_key=None, range_steps=None
) -> Tuple[Dict[str, Dict[int, Dict[str, TensorLocation]]], str]:
"""Return a triply nested dict referring to tensor data."""

responses, steps, last_index_token, workers = self.read_index_files(
start_after_key, range_steps
)

tensor_data = {}
for step, response, worker in zip(steps, responses, workers):
tensor_data = self._update_tensors_from_json(
tensor_data, step, response, self.path, worker
)
return tensor_data, last_index_token

def read_index_files(
self, start_after_key: str, range_steps=None
) -> Tuple[List[bytes], list, str, List[str]]:
Expand Down Expand Up @@ -398,21 +407,6 @@ def fetch_tensor_value(self, tensor_location: TensorLocation) -> np.ndarray:
tensor_name, step, tensor_data, mode, mode_step = tensor_tuple
return tensor_data

def load_tensor_data_from_index_files(
self, start_after_key=None, range_steps=None
) -> Tuple[Dict[str, Dict[int, Dict[str, TensorLocation]]], str]:
"""Return a triply nested dict referring to tensor data."""

responses, steps, last_index_token, workers = self.read_index_files(
start_after_key, range_steps
)
tensor_data = {}
for step, response, worker in zip(steps, responses, workers):
tensor_data = self._update_tensors_from_json(
tensor_data, step, response, self.path, worker
)
return tensor_data, last_index_token

def read_index_files(
self, start_after_key: str, range_steps=None
) -> Tuple[List[bytes], list, str, List[str]]:
Expand Down
14 changes: 14 additions & 0 deletions smdebug/core/locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@ def to_dict(self):
return {"tensorname": self.tensorname, "start_idx": self.start_idx, "length": self.length}


class TensorShape:
def __init__(self, name, mode, mode_step, shape, original_name=None):
if original_name is None:
original_name = name
self.name = name
self.original_name = original_name
self.mode = mode
self.mode_step = mode_step
self.shape = tuple(shape)

def to_dict(self):
return {"tensorname": self.name, "originalname": self.original_name, "shape": self.shape}


STEP_NUMBER_FORMATTING_LENGTH = "012"


Expand Down
Loading

0 comments on commit c9eb769

Please sign in to comment.