diff --git a/smdebug/mxnet/hook.py b/smdebug/mxnet/hook.py index e1c85dc5a..e7f53eb96 100644 --- a/smdebug/mxnet/hook.py +++ b/smdebug/mxnet/hook.py @@ -154,9 +154,8 @@ def forward_hook(self, block, inputs, outputs): # This overwhelms the logs; turn back on if you really need it # logger.debug("Processing the global step {0} for block {1}".format(self.step, block_name)) - # Output input tensor if it is not a loss block - if isinstance(block, mx.gluon.loss.Loss) is False: - self._write_inputs(block_name, inputs) + # Output input tensor + self._write_inputs(block_name, inputs) # Output output tensors self._write_outputs(block_name, outputs) diff --git a/tests/mxnet/test_hook_loss_collection.py b/tests/mxnet/test_hook_loss_collection.py index 1a21b64f7..95307c316 100644 --- a/tests/mxnet/test_hook_loss_collection.py +++ b/tests/mxnet/test_hook_loss_collection.py @@ -33,9 +33,6 @@ def test_loss_collection_default(): loss_val = loss_tensor.value(step_num=1) assert len(loss_val) > 0 - # Assert that we are not logging the inputs to loss block. - input_loss_tensors = tr.tensor_names(regex=".*loss._input*") - assert len(input_loss_tensors) == 0 shutil.rmtree(out_dir)