Skip to content

Commit

Permalink
TF 2.3.1 compatibility
Browse files Browse the repository at this point in the history
Fixes 'L1' object has no attribute 'l2'
  • Loading branch information
OverLordGoldDragon committed Oct 26, 2020
1 parent a9c664a commit f4660d4
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
20 changes: 12 additions & 8 deletions keras_adamw/utils.py
Expand Up @@ -164,11 +164,8 @@ def _get_layer_penalties(layer, zero_penalties=False):
for weight_name in ['kernel', 'bias']:
_lambda = getattr(layer, weight_name + '_regularizer', None)
if _lambda is not None:
l1l2 = (float(_lambda.l1), float(_lambda.l2))
l1l2 = _get_and_maybe_zero_penalties(_lambda, zero_penalties)
penalties.append([getattr(layer, weight_name).name, l1l2])
if zero_penalties:
_lambda.l1 = np.array(0., dtype=_lambda.l1.dtype)
_lambda.l2 = np.array(0., dtype=_lambda.l2.dtype)
return penalties


Expand All @@ -190,14 +187,21 @@ def _cell_penalties(rnn_cell, zero_penalties=False):
_lambda = getattr(cell, weight_type + '_regularizer', None)
if _lambda is not None:
weight_name = cell.weights[weight_idx].name
l1l2 = (float(_lambda.l1), float(_lambda.l2))
l1l2 = _get_and_maybe_zero_penalties(_lambda, zero_penalties)
penalties.append([weight_name, l1l2])
if zero_penalties:
_lambda.l1 = np.array(0., dtype=_lambda.l1.dtype)
_lambda.l2 = np.array(0., dtype=_lambda.l2.dtype)
return penalties


def _get_and_maybe_zero_penalties(_lambda, zero_penalties):
if zero_penalties:
if hasattr(_lambda, 'l1'):
_lambda.l1 = np.array(0., dtype=_lambda.l1.dtype)
if hasattr(_lambda, 'l2'):
_lambda.l2 = np.array(0., dtype=_lambda.l2.dtype)
return (float(getattr(_lambda, 'l1', 0.)),
float(getattr(_lambda, 'l2', 0.)))


def fill_dict_in_order(_dict, values_list):
for idx, key in enumerate(_dict.keys()):
_dict[key] = values_list[idx]
Expand Down
4 changes: 2 additions & 2 deletions tests/test_optimizers.py
Expand Up @@ -90,8 +90,8 @@ def test_misc(): # tests of non-main features to improve coverage
embed_input_dim = 5

# arbitrarily select SGDW for coverage testing
l1_reg = 1e-4 if optimizer_name == 'SGDW' else 0
l2_reg = 1e-4 if optimizer_name != 'SGDW' else 0
l1_reg = 1e-4 if optimizer_name == 'SGDW' else None
l2_reg = 1e-4 if optimizer_name != 'SGDW' else None
if optimizer_name == 'SGDW':
optimizer_kw.update(dict(zero_penalties=False,
weight_decays={},
Expand Down

0 comments on commit f4660d4

Please sign in to comment.