-
Notifications
You must be signed in to change notification settings - Fork 82
/
reductions.py
67 lines (54 loc) · 2.19 KB
/
reductions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Standard Library
import re
# Third Party
import numpy as np
# First Party
from smdebug.core.reduction_config import ALLOWED_NORMS, ALLOWED_REDUCTIONS
REDUCTIONS_PREFIX = "smdebug/reductions/"
def get_numpy_reduction(reduction_name, numpy_data, abs=False):
if reduction_name not in ALLOWED_REDUCTIONS and reduction_name not in ALLOWED_NORMS:
raise ValueError("Invalid reduction type %s" % reduction_name)
if abs:
numpy_data = np.absolute(numpy_data)
return get_basic_numpy_reduction(reduction_name, numpy_data)
def get_basic_numpy_reduction(reduction_name, numpy_data):
if reduction_name in ALLOWED_REDUCTIONS:
if reduction_name in ["min", "max"]:
return getattr(np, "a" + reduction_name)(numpy_data)
elif reduction_name in ["mean", "prod", "std", "sum", "variance"]:
if reduction_name == "variance":
reduction_name = "var"
return getattr(np, reduction_name)(numpy_data)
elif reduction_name in ALLOWED_NORMS:
if reduction_name in ["l1", "l2"]:
ord = int(reduction_name[1])
else:
ord = None
if abs:
rv = np.linalg.norm(np.absolute(numpy_data), ord=ord)
else:
rv = np.linalg.norm(numpy_data, ord=ord)
return rv
return None
def get_reduction_tensor_name(tensorname, reduction_name, abs, remove_colon_index=True):
# for frameworks other than TF, it makes sense to not have trailing :0, :1
# but for TF, it makes sense to keep it consistent with TF traditional naming style
tname = f"{reduction_name}/{tensorname}"
if remove_colon_index:
tname = re.sub(r":\d+", "", tname)
if abs:
tname = "abs_" + tname
tname = REDUCTIONS_PREFIX + tname
return tname
def reverse_reduction_tensor_name(reduction_tensor_name):
rest = reduction_tensor_name.split(REDUCTIONS_PREFIX)[1]
parts = rest.split("/", 1)
reduction_name = parts[0]
if "abs_" in reduction_name:
abs = True
reduction_op_name = reduction_name.split("abs_")[1]
else:
abs = False
reduction_op_name = reduction_name
tensor_name = parts[1]
return tensor_name, reduction_op_name, abs