Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Merge branch 'master' into vision
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Jul 27, 2020
2 parents 3137961 + e53d185 commit 71d7cb4
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 12 deletions.
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Removed unnecessary warning about deadlocks in `DataLoader`.
- Use slower tqdm intervals when output is being piped or redirected.
- Fixed testing models that only return a loss when they are in training mode
- Fixed testing models that only return a loss when they are in training mode.
- Fixed a bug in `FromParams` that causes silent failure in case of the parameter type being Optional[Union[...]].

### Added

- Added the option to specify `requires_grad: false` within an optimizers parameter groups.


## [v1.1.0rc1](https://github.com/allenai/allennlp/releases/tag/v1.1.0rc1) - 2020-07-14
Expand Down
5 changes: 3 additions & 2 deletions allennlp/common/from_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ def remove_optional(annotation: type):
"""
origin = getattr(annotation, "__origin__", None)
args = getattr(annotation, "__args__", ())
if origin == Union and len(args) == 2 and args[1] == type(None): # noqa
return args[0]

if origin == Union:
return Union[tuple([arg for arg in args if arg != type(None)])] # noqa: E721
else:
return annotation

Expand Down
3 changes: 3 additions & 0 deletions allennlp/modules/conditional_random_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,9 @@ def forward(

if mask is None:
mask = torch.ones(*tags.size(), dtype=torch.bool)
else:
# The code below fails in weird ways if this isn't a bool tensor, so we make sure.
mask = mask.to(torch.bool)

log_denominator = self._input_likelihood(inputs, mask)
log_numerator = self._joint_likelihood(inputs, tags, mask)
Expand Down
47 changes: 43 additions & 4 deletions allennlp/training/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,22 @@ def make_parameter_groups(
]
```
The return value in the right format to be passed directly as the `params` argument to a pytorch
`Optimizer`. If there are multiple groups specified, this is list of dictionaries, where each
All of key-value pairs specified in each of these dictionaries will passed passed as-is
to the optimizer, with the exception of a dictionaries that specify `requires_grad` to be `False`:
```
[
...
(["regex"], {"requires_grad": False})
]
```
When a parameter group has `{"requires_grad": False}`, the gradient on all matching parameters
will be disabled and that group will be dropped so that it's not actually passed to the optimizer.
Ultimately, the return value of this function is in the right format to be passed directly
as the `params` argument to a pytorch `Optimizer`.
If there are multiple groups specified, this is list of dictionaries, where each
dict contains a "parameter group" and groups specific options, e.g., {'params': [list of
parameters], 'lr': 1e-3, ...}. Any config option not specified in the additional options (e.g.
for the default group) is inherited from the top level arguments given in the constructor. See:
Expand Down Expand Up @@ -97,13 +111,38 @@ def make_parameter_groups(
parameter_groups[-1]["params"].append(param)
parameter_group_names[-1].add(name)

# log the parameter groups
# find and remove any groups with 'requires_grad = False'
no_grad_group_indices: List[int] = []
for k, (names, group) in enumerate(zip(parameter_group_names, parameter_groups)):
if group.get("requires_grad") is False:
no_grad_group_indices.append(k)
logging.info("Disabling gradient for the following parameters: %s", names)
for param in group["params"]:
param.requires_grad_(False)

# warn about any other unused options in that group.
unused_options = {
key: val for key, val in group.items() if key not in ("params", "requires_grad")
}
if unused_options:
logger.warning("Ignoring unused options %s for %s", unused_options, names)
parameter_group_names = [
names
for (k, names) in enumerate(parameter_group_names)
if k not in no_grad_group_indices
]
parameter_groups = [
group for (k, group) in enumerate(parameter_groups) if k not in no_grad_group_indices
]

# log the remaining parameter groups
logger.info("Done constructing parameter groups.")
for k in range(len(groups) + 1):
for k in range(len(parameter_groups)):
group_options = {
key: val for key, val in parameter_groups[k].items() if key != "params"
}
logger.info("Group %s: %s, %s", k, list(parameter_group_names[k]), group_options)

# check for unused regex
for regex, count in regex_use_counts.items():
if count == 0:
Expand Down
4 changes: 2 additions & 2 deletions allennlp/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,13 +1115,13 @@ def from_partial_objects(
if any(re.search(regex, name) for regex in no_grad):
parameter.requires_grad_(False)

common_util.log_frozen_and_tunable_parameter_names(model)

parameters = [[n, p] for n, p in model.named_parameters() if p.requires_grad]
optimizer_ = optimizer.construct(model_parameters=parameters)
if not optimizer_:
optimizer_ = Optimizer.default(parameters)

common_util.log_frozen_and_tunable_parameter_names(model)

batches_per_epoch: Optional[int]
try:
batches_per_epoch = len(data_loader)
Expand Down
2 changes: 1 addition & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ nr.databind.core<0.0.17
nr.interface<0.0.4

mkdocs==1.1.2
mkdocs-material==5.4.0
mkdocs-material==5.5.0
markdown-include==0.5.1

#### PACKAGE-UPLOAD PACKAGES ####
Expand Down
16 changes: 16 additions & 0 deletions tests/common/from_params_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,3 +845,19 @@ def __init__(self):

with pytest.raises(ConfigurationError, match="no registered concrete types"):
B.from_params(Params({}))

def test_from_params_raises_error_on_wrong_parameter_name_in_optional_union(self):
class NestedClass(FromParams):
def __init__(self, varname: Optional[str] = None):
self.varname = varname

class WrapperClass(FromParams):
def __init__(self, nested_class: Optional[Union[str, NestedClass]] = None):
if isinstance(nested_class, str):
nested_class = NestedClass(varname=nested_class)
self.nested_class = nested_class

with pytest.raises(ConfigurationError):
WrapperClass.from_params(
params=Params({"nested_class": {"wrong_varname": "varstring"}})
)
14 changes: 12 additions & 2 deletions tests/training/optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,23 @@ def test_optimizer_parameter_groups(self):
# NOT_A_VARIABLE_NAME displays a warning but does not raise an exception
[["weight_i", "bias_", "bias_", "NOT_A_VARIABLE_NAME"], {"lr": 2}],
[["tag_projection_layer"], {"lr": 3}],
[["^text_field_embedder.*$"], {"requires_grad": False}],
],
}
)

# Before initializing the optimizer all params in this module will still require grad.
assert all([param.requires_grad for param in self.model.text_field_embedder.parameters()])

parameters = [[n, p] for n, p in self.model.named_parameters() if p.requires_grad]
optimizer = Optimizer.from_params(model_parameters=parameters, params=optimizer_params)
param_groups = optimizer.param_groups

# After initializing the optimizer, requires_grad should be false for all params in this module.
assert not any(
[param.requires_grad for param in self.model.text_field_embedder.parameters()]
)

assert len(param_groups) == 3
assert param_groups[0]["lr"] == 2
assert param_groups[1]["lr"] == 3
Expand All @@ -63,8 +73,8 @@ def test_optimizer_parameter_groups(self):
assert len(param_groups[0]["params"]) == 6
# just the projection weight and bias
assert len(param_groups[1]["params"]) == 2
# the embedding + recurrent connections left in the default group
assert len(param_groups[2]["params"]) == 3
# the recurrent connections left in the default group
assert len(param_groups[2]["params"]) == 2


class TestDenseSparseAdam(AllenNlpTestCase):
Expand Down

0 comments on commit 71d7cb4

Please sign in to comment.