Skip to content

Commit

Permalink
Remove only_shape from ReductionConfig (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
jarednielsen committed Nov 20, 2019
1 parent 7d91649 commit b63c151
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 27 deletions.
31 changes: 5 additions & 26 deletions smdebug/core/reduction_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,7 @@
ALLOWED_REDUCTIONS = ["min", "max", "mean", "std", "variance", "sum", "prod"]
ALLOWED_NORMS = ["l1", "l2"]
REDUCTION_CONFIG_VERSION_NUM = "v0"
ALLOWED_PARAMS = [
"only_shape",
"reductions",
"abs_reductions",
"norms",
"abs_norms",
"save_raw_tensor",
]
ALLOWED_PARAMS = ["reductions", "abs_reductions", "norms", "abs_norms", "save_raw_tensor"]


class ReductionConfig:
Expand All @@ -31,10 +24,6 @@ class ReductionConfig:
Attributes
----------
only_shape: bool
If this is set, only the shape of tensor is saved.
Not yet supported.
reductions: list of str
takes list of names of reductions to be computed.
should be one of 'min', 'max', 'median', 'mean', 'std', 'variance', 'sum', 'prod'
Expand All @@ -55,14 +44,12 @@ class ReductionConfig:

def __init__(
self,
only_shape=False,
reductions=None,
abs_reductions=None,
norms=None,
abs_norms=None,
save_raw_tensor=False,
):
self.only_shape = only_shape
self.reductions = reductions if reductions is not None else []
self.abs_reductions = abs_reductions if abs_reductions is not None else []
self.norms = norms if norms is not None else []
Expand Down Expand Up @@ -91,13 +78,12 @@ def _check(self):

@classmethod
def from_dict(cls, params: Dict[str, Any]) -> "ReductionConfig":
"""Parses a flattened dict with two keys: `only_shape` and `reductions`."""
"""Parses a flattened dict with two keys: `save_raw_tensor` and `reductions`."""
if params is None:
return None
if not isinstance(params, dict):
raise ValueError(f"params={params} must be dict")

only_shape = params.get("only_shape", False)
save_raw_tensor = params.get("save_raw_tensor", False)
# Parse comma-separated string into array
all_reductions = split(params.get("reductions", ""))
Expand All @@ -117,7 +103,6 @@ def from_dict(cls, params: Dict[str, Any]) -> "ReductionConfig":
reductions.append(red) # mean -> mean

return cls(
only_shape=only_shape,
reductions=reductions,
abs_reductions=abs_reductions,
norms=norms,
Expand All @@ -131,7 +116,6 @@ def from_json(cls, json_str: str) -> "ReductionConfig":
return cls.from_dict(d)

def to_json_dict(self) -> Dict[str, Any]:
only_shape = self.only_shape
save_raw_tensor = self.save_raw_tensor
# Convert reductions from various arrays into single comma-separated string
all_reductions = []
Expand All @@ -145,11 +129,7 @@ def to_json_dict(self) -> Dict[str, Any]:
all_reductions.append(f"abs_{red}_norm")
all_reductions_str = ",".join(all_reductions)
# Return the dict
return {
"only_shape": only_shape,
"save_raw_tensor": save_raw_tensor,
"reductions": all_reductions_str,
}
return {"save_raw_tensor": save_raw_tensor, "reductions": all_reductions_str}

def to_json(self) -> str:
return json.dumps(self.to_json_dict())
Expand All @@ -159,8 +139,7 @@ def __eq__(self, other):
return NotImplemented

return (
self.only_shape == other.only_shape
and self.reductions == other.reductions
self.reductions == other.reductions
and self.abs_reductions == other.abs_reductions
and self.norms == other.norms
and self.abs_norms == other.abs_norms
Expand All @@ -169,6 +148,6 @@ def __eq__(self, other):

def __repr__(self):
return (
f"<class ReductionConfig: only_shape={self.only_shape}, reductions={self.reductions}, "
f"<class ReductionConfig: reductions={self.reductions}, "
f"abs_reductions={self.abs_reductions}, norms={self.norms}, abs_norms={self.abs_norms}>"
)
2 changes: 1 addition & 1 deletion tests/core/test_reduction_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


def test_export_load():
r1 = ReductionConfig(only_shape=True, reductions=["min"], norms=["l2"])
r1 = ReductionConfig(reductions=["min"], norms=["l2"])
r2 = ReductionConfig.from_json(r1.to_json())
assert r1 == r2
assert r1.to_json() == r2.to_json()
Expand Down

0 comments on commit b63c151

Please sign in to comment.