Skip to content

Commit

Permalink
Add omit_names support to get_weights
Browse files Browse the repository at this point in the history
 - `_id='*'` now fetches all layers (but input) (`get_weights`, `get_gradients`, `get_outputs`, `weights_norm`)
 - Add `omit_names` support to `get_weights`
  • Loading branch information
OverLordGoldDragon committed Apr 30, 2020
1 parent fcc739f commit 32ad953
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 10 deletions.
47 changes: 37 additions & 10 deletions see_rnn/inspect_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def get_outputs(model, _id, input_data, layer=None, learning_phase=0,
list of str/int -> treat each str element as name, int as idx.
Ex: ['gru', 2] gets outputs of first layer with name
substring 'gru', then of layer w/ idx 2
'*': wildcard -> get outputs of all layers (except input) with
'output' attribute. Overrides `layer`.
input_data: np.ndarray & supported formats(1). Data w.r.t. which loss is
to be computed for the gradient. Only for mode=='grads'.
layer: keras.Layer/tf.keras.Layer. Layer whose outputs to return.
Expand All @@ -41,7 +43,15 @@ def _get_outs_tensors(model, names, idxs, layers):
layers = [layers]
return [l.output for l in layers]

names, idxs, layers, one_requested = _validate_args(_id, layer)
if _id != '*':
names, idxs, layers, one_requested = _validate_args(_id, layer)
else:
# exclude input layer & non-output layers
names = [l.name for l in model.layers[1:]
if getattr(l, 'output', None) is not None]
idxs, layers = None, None
one_requested = len(_id) == 1

layer_outs = _get_outs_tensors(model, names, idxs, layers)

if TF_KERAS:
Expand Down Expand Up @@ -69,6 +79,10 @@ def get_gradients(model, _id, input_data, labels, layer=None, mode='outputs',
list of str/int -> treat each str element as name, int as idx.
Ex: ['gru', 2] gets gradients of first layer with name
substring 'gru', then of layer w/ idx 2
'*': wildcard -> get outputs of all layers (except input) with:
- 'output' attribute (mode == 'outputs')
- 'weights' attribute (mode == 'weights')
Overrides `layer`.
input_data: np.ndarray & supported formats(1). Data w.r.t. which loss is
to be computed for the gradient.
labels: np.ndarray & supported formats. Labels w.r.t. which loss is
Expand Down Expand Up @@ -100,8 +114,17 @@ def _validate_args_(_id, layer, mode):
raise Exception("`mode` must be one of: 'outputs', 'weights'")
return _validate_args(_id, layer)

names, idxs, layers, one_requested = _validate_args_(_id, layer, mode)
_id = [x for var in (names, idxs) if var for x in var] or None
if _id != '*':
names, idxs, layers, one_requested = _validate_args_(_id, layer, mode)
_id = [x for var in (names, idxs) if var for x in var] or None
else:
# exclude input layer & non-output/weightless layers (`mode`-dependent)
attr = 'output' if mode == 'outputs' else 'weights'
_id = [l.name for l in model.layers[1:]
if getattr(l, attr, None) is not None]
names = _id
idxs, layers = None, None
one_requested = len(_id) == 1

if layers is None:
layers = get_layer(model, _id)
Expand Down Expand Up @@ -230,6 +253,7 @@ def get_weights(model, _id, omit_names=None, as_dict=False):
weights of first layer with name substring 'gru', then all
weights of layer w/ idx 2, then weights w/ idxs 1 and 2 of
layer w/ idx 3.
'*': wildcard -> get weights of all layers with 'weights' attribute.
omit_names: str/str list. List of names (can be substring) of weights
to omit from fetching.
as_dict: bool. True: return weight fullname-value pairs in a dict
Expand Down Expand Up @@ -285,13 +309,17 @@ def _get_by_name(model, name):
del _weights[w_name]
return _weights

weights = {}
names, idxs, *_ = _validate_args(_id)
_ids = [x for var in (names, idxs) if var for x in var] or None

if _id != '*':
names, idxs, *_ = _validate_args(_id)
_ids = [x for var in (names, idxs) if var for x in var] or None
else:
# exclude input layer & non-weight layers
_ids = [l.name for l in model.layers[1:]
if getattr(l, 'weights', None) not in (None, [])]
if not isinstance(omit_names, list):
omit_names = [omit_names] if omit_names else []

weights = {}
for _id in _ids:
weights.update(_get_weights_tensors(model, _id))

Expand Down Expand Up @@ -411,10 +439,9 @@ def _append(stats_all, l2_stats, w_idx, l_name):
for stat_idx, stat in enumerate(l2_stats):
stats_all[l_name][w_idx][stat_idx].append(stat)

weights = get_weights(model, _id, omit_names, as_dict=True)
w_names, W = zip(*weights.items())
weights = get_weights(model, _id, omit_names)

for w_idx, (w, w_name) in enumerate(zip(W, w_names)):
for w_idx, w in enumerate(weights):
l2 = _compute_norm(w, norm_fn, axis)
l2_stats = [fn(l2) for fn in stat_fns]
_append(stats_all, l2_stats, w_idx, l_name)
Expand Down
4 changes: 4 additions & 0 deletions tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,10 @@ def test_misc(): # test miscellaneous functionalities
get_weights(model, ['gru', 1, (1, 1)])
pass_on_error(get_weights, model, 'gru/goo')

get_weights(model, '*')
get_gradients(model, '*', x, y)
get_outputs(model, '*', x)

from see_rnn.utils import _filter_duplicates_by_keys
keys, data = _filter_duplicates_by_keys(list('abbc'), [1, 2, 3, 4])
assert keys == ['a', 'b', 'c']
Expand Down

0 comments on commit 32ad953

Please sign in to comment.