Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ability to save shapes #341

Merged
merged 41 commits into from
Sep 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
ac3c250
WIP saveshape
rahul003 Aug 10, 2020
34084cc
Add shape writer
rahul003 Aug 10, 2020
e8a6d98
Add pytorch test
rahul003 Aug 11, 2020
907cf64
Add untested keras test
rahul003 Aug 11, 2020
86842e6
fix syntax
rahul003 Aug 11, 2020
651c440
fix syntax
rahul003 Aug 11, 2020
fc25940
Import
rahul003 Aug 11, 2020
1357f5d
Import
rahul003 Aug 11, 2020
44358ee
Add tests for TF
rahul003 Aug 12, 2020
f146c77
Simplify read code
rahul003 Aug 13, 2020
5906e5a
Add read API and tests
rahul003 Aug 13, 2020
681e35c
Add mxnet test
rahul003 Aug 14, 2020
5dc47ff
Add s3 and json tests
rahul003 Aug 14, 2020
c775942
lint
NihalHarish Aug 14, 2020
355be0b
Fix payload
rahul003 Aug 17, 2020
3eb0202
fix import
rahul003 Aug 17, 2020
c14a67e
Handle different num tensors for losses
rahul003 Aug 17, 2020
d12b824
Fix exact equal condition
rahul003 Aug 17, 2020
972d95a
Fix mode bug
rahul003 Aug 17, 2020
850cc44
trigger CI
rahul003 Aug 18, 2020
2c44796
Add support for distributed training with writer map
rahul003 Aug 19, 2020
1b09b8e
Check that value throws exception
rahul003 Aug 19, 2020
f4106f3
Fix tests to make them more resilient
rahul003 Aug 19, 2020
78b67d6
Fix mxnet and pytorch tests
rahul003 Aug 19, 2020
2515a2d
Remove tensor names
rahul003 Aug 19, 2020
7f3ea4e
pre-commmit
NihalHarish Aug 19, 2020
cdf6578
Fix get_mode
rahul003 Aug 19, 2020
d16d1de
Fix bug with old index files
rahul003 Aug 19, 2020
384b71c
Fix keras test with names of tensors
rahul003 Aug 20, 2020
cd8a4d1
Set original name to None if tf_obj is None
rahul003 Aug 20, 2020
c4881b7
Fix mirrored test for cpu
rahul003 Aug 20, 2020
b5fc689
Merge branch 'master' of https://github.com/awslabs/sagemaker-debugge…
rahul003 Aug 31, 2020
dd434c6
Add docs
rahul003 Sep 1, 2020
4fe8df0
trigger CI
rahul003 Sep 1, 2020
fa664d3
Fix shape writer get
rahul003 Sep 1, 2020
b5b29b1
Simplify by removing shape writer
rahul003 Sep 2, 2020
131ec44
Cleanup
rahul003 Sep 2, 2020
dee7106
Fix name of writer
rahul003 Sep 2, 2020
5c89fa4
Addressed review comments
rahul003 Sep 8, 2020
a893d91
trigger ci
rahul003 Sep 8, 2020
1f94933
retrigger CI
NihalHarish Sep 8, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions docs/analysis.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ This page describes the programming model that SageMaker Debugger provides for y
* [steps](#steps-1)
* [value](#value)
* [reduction_value](#reduction_value)
* [reduction_values](#reduction_values)
* [shape](#shape)
* [values](#values)
* [reduction_values](#reduction_values)
* [shapes](#shapes)
* [workers](#workers-1)
* [prev_steps](#prev_steps)
* [Rules](#Rules)
Expand Down Expand Up @@ -356,6 +358,34 @@ trial.tensor(name).reduction_value(step_num, reduction_name,
###### Returns
`numpy.ndarray` The reduction value of tensor at the given step and worker (if the training job saved data from multiple workers) as a 1x1 numpy array. If this reduction was saved for the tensor during training as part of specification through reduction config, it will be loaded and returned. If the given reduction was not saved then, but the full tensor was saved, the reduction will be computed on the fly and returned. If both the chosen reduction and full tensor are not available, this method raises `TensorUnavailableForStep` exception.

#### shape
Get the shape of the chosen tensor at a particular step.

```python
trial.tensor(name).shape(step_num, mode=modes.GLOBAL, worker=None)

```
###### Arguments
- `step_num (int)` The step number whose value is to be returned for the mode passed through the next parameter.
- `mode (smdebug.modes enum value)` The mode applicable for the step number passed above. Defaults to `modes.GLOBAL`
- `worker (str)` This parameter is only applicable for distributed training. You can retrieve the value of the tensor from a specific worker by passing the worker name. You can query all the workers seen by the trial with the `trial.workers()` method. You might also be interested in querying the workers which saved a value for the tensor at a specific step, this is possible with the method: `trial.tensor(name).workers(step, mode)`

###### Returns
`tuple(int)` If only the shape of this tensor was saved through `save_shape` configuration in ReductionConfig, it will be returned. If the full tensor was saved, then shape will be computed and returned today. If both the shape and full tensor are not available, this method raises `TensorUnavailableForStep` exception.

#### values
Get the values of the tensor for all steps of a given mode.

```python
trial.tensor(name).values(mode=modes.GLOBAL, worker=None)
```

###### Arguments
- `mode (smdebug.modes enum value)` The mode applicable for the step number passed above. Defaults to `modes.GLOBAL`
- `worker (str)` This parameter is only applicable for distributed training. You can retrieve the value of the tensor from a specific worker by passing the worker name. You can query all the workers seen by the trial with the `trial.workers()` method. You might also be interested in querying the workers which saved a value for the tensor at a specific step, this is possible with the method: `trial.tensor(name).workers(step, mode)`

###### Returns
`dict[int -> numpy.ndarray]` A dictionary with step numbers as keys and numpy arrays representing the value of the tensor as values.

#### reduction_values
Get all reduction values saved for the chosen tensor at a particular step. A reduction value is a tensor reduced to a single value through reduction or aggregation operations. Please go through the description of the method `reduction_value` for more details.
Expand All @@ -372,19 +402,19 @@ trial.tensor(name).reduction_values(step_num, mode=modes.GLOBAL, worker=None)
###### Returns
`dict[(str, bool) -> numpy.ndarray]` A dictionary with keys being tuples of the form `(reduction_name, abs)` to a 1x1 numpy ndarray value. `abs` here is a boolean that denotes whether the reduction was performed on the absolute value of the tensor or not. Note that this method only returns the reductions which were saved from the training job. It does not compute all known reductions and return them if only the raw tensor was saved.

#### values
Get the values of the tensor for all steps of a given mode.
#### shapes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this diff incorrect?
If not, why have you changed the API from values to shapes?
Is this not a breaking change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This diff looks weird. I copied values contents and modified it become shapes. And moved them to group them appropriately. I didnt remove anything

Get the shapes of the tensor for all steps of a given mode.

```python
trial.tensor(name).values(mode=modes.GLOBAL, worker=None)
trial.tensor(name).shapes(mode=modes.GLOBAL, worker=None)
```

###### Arguments
- `mode (smdebug.modes enum value)` The mode applicable for the step number passed above. Defaults to `modes.GLOBAL`
- `worker (str)` This parameter is only applicable for distributed training. You can retrieve the value of the tensor from a specific worker by passing the worker name. You can query all the workers seen by the trial with the `trial.workers()` method. You might also be interested in querying the workers which saved a value for the tensor at a specific step, this is possible with the method: `trial.tensor(name).workers(step, mode)`

###### Returns
`dict[int -> numpy.ndarray]` A dictionary with step numbers as keys and numpy arrays representing the value of the tensor as values.
`dict[int -> tuple(int)]` A dictionary with step numbers as keys and tuples of ints representing the shapes of the tensor as values.

#### workers
Get all the workers for which this tensor was saved at a given step
Expand Down
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ include_workers
include_regex
reductions
save_raw_tensor
save_shape
save_interval
save_steps
start_step
Expand Down
56 changes: 45 additions & 11 deletions smdebug/core/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,17 @@ def _prepare_collections(self):
self.prepared_collections = True

#### End of Save Manager methods ####
@staticmethod
def _close_given_writer_map(writer_dict):
NihalHarish marked this conversation as resolved.
Show resolved Hide resolved
# Delete all the dist training writers
to_delete_writers = []
for key, writer in writer_dict.items():
# close calls flush
writer.close()
to_delete_writers.append(key)

for key in to_delete_writers:
del writer_dict[key]

def _close_writers(self) -> None:
if self.dry_run:
Expand All @@ -433,16 +444,7 @@ def _close_writers(self) -> None:
self.writer.close()
self.writer = None

to_delete_writers = []

# Delete all the tb writers
for mode, writer in self.tb_writers.items():
if writer is not None:
writer.flush()
writer.close()
to_delete_writers.append(mode)
for mode in to_delete_writers:
del self.tb_writers[mode]
self._close_given_writer_map(self.tb_writers)

def _initialize_writers(self, only_initialize_if_missing=False) -> None:
# Function is overridden in smdebug/tensorflow/base_hook.py
Expand Down Expand Up @@ -470,8 +472,12 @@ def _initialize_writers(self, only_initialize_if_missing=False) -> None:
if self.save_all_workers is False:
if self.worker != self.chief_worker:
return

self.writer = FileWriter(trial_dir=self.out_dir, step=self.step, worker=self.worker)

def _get_main_writer(self) -> List[FileWriter]:
return [self.writer] if self.writer else []

def _get_writers(self, tensor_name, tensor_ref=None) -> List[FileWriter]:
"""
:param tensor_name:
Expand All @@ -480,7 +486,7 @@ def _get_writers(self, tensor_name, tensor_ref=None) -> List[FileWriter]:
"""
if self.save_all_workers is False and self.worker != self.chief_worker:
return []
return [self.writer] if self.writer else []
return self._get_main_writer()

def _maybe_get_tb_writer(self) -> Optional[FileWriter]:
""" Returns a FileWriter object if `hook.tensorboard_dir` has been specified, else None.
Expand Down Expand Up @@ -749,6 +755,31 @@ def _write_raw_tensor(self, tensor_name, tensor_value, save_collections, tensor_
self._write_raw_tensor_simple(tensor_name, tensor_value, tensor_ref=tensor_ref)
break

def _write_shape(self, tensor_name, tensor_value, save_collections, tensor_ref=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if tensor_value is always going to be a tuple, can we add type annotations to this function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensor value is a framework data format, it's not a tuple here. It becomes a tuple in this function.

writers = self._get_writers(tensor_name, tensor_ref=tensor_ref)
for s_col in save_collections:
reduction_config = s_col.reduction_config
if self.dry_run is False and reduction_config.save_shape is True:
numpy_tensor_value = self._make_numpy_array(tensor_value)
this_size, this_shape = size_and_shape(numpy_tensor_value)
# In TF Keras and Variables in all interfaces of TF, sometimes we output tensors with
# more meaningful names than the origina name. Outputting
# both Smdebug given name and original name in such cases
if tensor_ref is not None and tensor_ref.tf_obj is not None:
original_name = tensor_ref.tf_obj.name
else:
original_name = None
Comment on lines +768 to +771
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add comments explaining the need for this if-else block?

Which framework and which mode of execution requires this check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added


for writer in writers:
writer.write_shape(
tensor_name,
this_shape,
self.mode,
self.mode_steps[self.mode],
original_name=original_name,
)
break

def _write_raw_tensor_simple(self, tensor_name, tensor_value, tensor_ref=None, timestamp=None):
# tensor_ref is used by TF
# todo: if fp16, check perf of saving as fp16 in proto vs as fp32
Expand Down Expand Up @@ -828,6 +859,9 @@ def _write_for_tensor(self, tensor_name, tensor_value, save_collections, tensor_
:param save_collections: list of collections which are being saved for this step
"""
self._log_save(tensor_name, save_collections)

self._write_shape(tensor_name, tensor_value, save_collections, tensor_ref=tensor_ref)

# write reductions defined for collections this tensor may be part of
self._write_reductions(tensor_name, tensor_value, save_collections, tensor_ref=tensor_ref)

Expand Down
100 changes: 47 additions & 53 deletions smdebug/core/index_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
MISSING_EVENT_FILE_RETRY_LIMIT,
MISSING_EVENT_FILE_RETRY_LIMIT_KEY,
)
from smdebug.core.locations import IndexFileLocationUtils, TensorLocation
from smdebug.core.locations import IndexFileLocationUtils, TensorLocation, TensorShape
from smdebug.core.logger import get_logger
from smdebug.core.modes import ModeKeys
from smdebug.core.s3_utils import list_s3_objects
Expand Down Expand Up @@ -120,12 +120,22 @@ def fetch_tensor_value(self, tensor_location: TensorLocation):
def list_event_files(self, start_after_prefix):
pass

@abstractmethod
NihalHarish marked this conversation as resolved.
Show resolved Hide resolved
def load_tensor_data_from_index_files(
self, start_after_key=None, range_steps=None
) -> Tuple[Dict[str, Dict[int, Dict[str, TensorLocation]]], str]:
"""Return a triply nested dict referring to tensor data."""

responses, steps, last_index_token, workers = self.read_index_files(
start_after_key, range_steps
)

tensor_data = {}
for step, response, worker in zip(steps, responses, workers):
tensor_data = self._update_tensors_from_json(
tensor_data, step, response, self.path, worker
)
return tensor_data, last_index_token

@abstractmethod
def _is_event_file_present(self, file_name) -> bool:
pass
Expand Down Expand Up @@ -203,8 +213,10 @@ def _validate(index_dict):
raise IndexReaderException("meta section is not present")
if len(index_dict["meta"]) == 0:
raise IndexReaderException("meta section is empty")
if "tensor_payload" not in index_dict:
raise IndexReaderException("tensor_payload section is not present")
if "tensor_payload" not in index_dict and "shape_payload" not in index_dict:
raise IndexReaderException(
"neither tensor_payload nor shape_payload sections are present"
)

def _update_tensors_from_json(
self, index_tensors_dict, step, response: bytes, path, worker
Expand Down Expand Up @@ -233,28 +245,41 @@ def _update_tensors_from_json(
mode = index_meta["mode"]
mode = ModeKeys[mode.strip()]
mode_step = index_meta["mode_step"]
event_file_name = os.path.join(path, index_meta["event_file_name"])
tensors = index_dict["tensor_payload"]
for tensor in tensors:
tensor_name = tensor["tensorname"]
start_idx = tensor["start_idx"]
length = tensor["length"]
tensor_location = TensorLocation(
tensor_name, mode, mode_step, event_file_name, start_idx, length, worker
)

to_update_index_dict = []

if "tensor_payload" in index_dict and len(index_dict["tensor_payload"]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: len(index_dict["tensor_payload"]) seems like a redundant check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There can be empty payload

event_file_name = os.path.join(path, index_meta["event_file_name"])
for tensor in index_dict["tensor_payload"]:
tensor_name = tensor["tensorname"]
start_idx = tensor["start_idx"]
length = tensor["length"]
tensor_location = TensorLocation(
tensor_name, mode, mode_step, event_file_name, start_idx, length, worker
)
to_update_index_dict.append((tensor_name, step, tensor_location))

if "shape_payload" in index_dict and len(index_dict["shape_payload"]):
for tensor in index_dict["shape_payload"]:
tensor_name = tensor["tensorname"]
original_name = tensor["originalname"]
shape = tensor["shape"]
ts = TensorShape(tensor_name, mode, mode_step, shape, original_name)
to_update_index_dict.append((tensor_name, step, ts))

for tu in to_update_index_dict:
tensor_name, step, obj = tu
if isinstance(obj, TensorLocation):
obj_dict = {"tensor_location": obj}
elif isinstance(obj, TensorShape):
obj_dict = {"tensor_shape": obj}
if tensor_name in index_tensors_dict:
if step in index_tensors_dict[tensor_name]:
index_tensors_dict[tensor_name][step].update(
{worker: {"tensor_location": tensor_location}}
)
index_tensors_dict[tensor_name][step].update({worker: obj_dict})
else:
index_tensors_dict[tensor_name].update(
{step: {worker: {"tensor_location": tensor_location}}}
)
index_tensors_dict[tensor_name].update({step: {worker: obj_dict}})
else:
index_tensors_dict[tensor_name] = {
step: {worker: {"tensor_location": tensor_location}}
}
index_tensors_dict[tensor_name] = {step: {worker: obj_dict}}
Comment on lines +278 to +282
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this simply a lint change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just made the innermost dict a variable

return index_tensors_dict


Expand Down Expand Up @@ -285,22 +310,6 @@ def fetch_tensor_value(self, tensor_location: TensorLocation) -> np.ndarray:
tensor_name, step, tensor_data, mode, mode_step = tensor_tuple
return tensor_data

def load_tensor_data_from_index_files(
self, start_after_key=None, range_steps=None
) -> Tuple[Dict[str, Dict[int, Dict[str, TensorLocation]]], str]:
"""Return a triply nested dict referring to tensor data."""

responses, steps, last_index_token, workers = self.read_index_files(
start_after_key, range_steps
)

tensor_data = {}
for step, response, worker in zip(steps, responses, workers):
tensor_data = self._update_tensors_from_json(
tensor_data, step, response, self.path, worker
)
return tensor_data, last_index_token

def read_index_files(
self, start_after_key: str, range_steps=None
) -> Tuple[List[bytes], list, str, List[str]]:
Expand Down Expand Up @@ -398,21 +407,6 @@ def fetch_tensor_value(self, tensor_location: TensorLocation) -> np.ndarray:
tensor_name, step, tensor_data, mode, mode_step = tensor_tuple
return tensor_data

def load_tensor_data_from_index_files(
self, start_after_key=None, range_steps=None
) -> Tuple[Dict[str, Dict[int, Dict[str, TensorLocation]]], str]:
"""Return a triply nested dict referring to tensor data."""

responses, steps, last_index_token, workers = self.read_index_files(
start_after_key, range_steps
)
tensor_data = {}
for step, response, worker in zip(steps, responses, workers):
tensor_data = self._update_tensors_from_json(
tensor_data, step, response, self.path, worker
)
return tensor_data, last_index_token

def read_index_files(
self, start_after_key: str, range_steps=None
) -> Tuple[List[bytes], list, str, List[str]]:
Expand Down
12 changes: 12 additions & 0 deletions smdebug/core/locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@ def to_dict(self):
return {"tensorname": self.tensorname, "start_idx": self.start_idx, "length": self.length}


class TensorShape:
def __init__(self, name, mode, mode_step, shape, original_name=None):
self.name = name
self.original_name = original_name if original_name is not None else name
self.mode = mode
self.mode_step = mode_step
self.shape = tuple(shape)

def to_dict(self):
return {"tensorname": self.name, "originalname": self.original_name, "shape": self.shape}


STEP_NUMBER_FORMATTING_LENGTH = "012"


Expand Down
Loading