Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Add support for prevention of parameter initialization which match th…
Browse files Browse the repository at this point in the history
…e given regexes (#1405)

Add support for overrided prevention of parameter initialization which match the given regexes

## Commit Summary

- Makes following scenario easy to handle: continue with all other configured initialization schemes for various networks but skip parameters matching `".*transfer.*"`
- Add `PreventRegexInitializer` and register it with `"prevent"`. 
- Allow `InitializerApplicator` to apply any `initializer` to any parameter only if prevention regexes are not matched.
- Add unit tests with example usage

## Need for this feature:

- If one wants to transfer modules from pretrained model, one would want to initialzie weights of new model preventing the transfered modules from this initialization. 
- Negative matching regex give a easy selective handle on which parameters are transferred and must not be initialized. Eg. continue with all other configured initialization schemes for various networks but skip parameters matching `".*(transfer)|(pretrained).*"`.
- Also note that given regex A and regex B, in general it is not easy to write regex that matches A but not B. So bringing the required effect with plain regex can be quite tedious.


## Example Usage


```python
class MyNet(torch.nn.Module):
    # Typically, transfer_model s would loaded with trained weights by : 
    # eg. transfer_mode1 = load_archive(params.pop("model1.tar.gz"))  
    # in from_params of main MyModel of which MyNet is part of.
    def __init__(self, transfer_model1, transfer_model2):
        super(Net, self).__init__()
        self.linear_1 = torch.nn.Linear(5, 10)
        self.linear_2 = torch.nn.Linear(10, 5)
        self.linear_3_transfer = transfer_model1.linear_3 	 # Note this
        self.linear_4_transfer = transfer_model2.linear_4        # Note this
        self.pretrained_conv = transfer_model2.conv 	         # Note this
    def forward(self, inputs):  # pylint: disable=arguments-differ
        pass
```

```
// experiment.json
...
    {"initializer": [
        [".*linear.*", {"type": "xavier_normal"}], ...
        [".*conv.*", {"type": "kaiming_normal"}],
        [".*_transfer.*", "prevent"],                           // Note this
        [".*pretrained.*",{"type": "prevent"}]                 // Note this
    ]}
...
```
The following would initialize linear and conv layers except `linear_3_transfer`, 
`linear_4_transfer`, `pretrained_conv` which match regexes corresponding to "prevent" type.

```python
net = MyNet()
# load initializer from above configs
initializer(net)
```
  • Loading branch information
HarshTrivedi authored and DeNeutoy committed Jun 21, 2018
1 parent 76deabb commit 4bd8e7f
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 5 deletions.
23 changes: 18 additions & 5 deletions allennlp/nn/initializers.py
Expand Up @@ -180,22 +180,29 @@ def from_params(cls, params: Params):
}



class InitializerApplicator:
"""
Applies initializers to the parameters of a Module based on regex matches. Any parameter not
explicitly matching a regex will not be initialized, instead using whatever the default
initialization was in the module's code.
"""
def __init__(self, initializers: List[Tuple[str, Initializer]] = None) -> None:
def __init__(self,
initializers: List[Tuple[str, Initializer]] = None,
prevent_regexes: List[str] = None) -> None:
"""
Parameters
----------
initializers : ``List[Tuple[str, Initializer]]``, optional (default = [])
A list mapping parameter regexes to initializers. We will check each parameter against
each regex in turn, and apply the initializer paired with the first matching regex, if
any.
any. If "prevent" is assigned to any regex, then it will override and prevent the matched
parameters to be initialzed.
"""
self._initializers = initializers or []
self._prevent_regex = None
if prevent_regexes:
self._prevent_regex = "(" + ")|(".join(prevent_regexes) + ")"

def __call__(self, module: torch.nn.Module) -> None:
"""
Expand All @@ -213,7 +220,8 @@ def __call__(self, module: torch.nn.Module) -> None:
# Store which initialisers were applied to which parameters.
for name, parameter in module.named_parameters():
for initializer_regex, initializer in self._initializers:
if re.search(initializer_regex, name):
allow = self._prevent_regex is None or not bool(re.search(self._prevent_regex, name))
if allow and re.search(initializer_regex, name):
logger.info("Initializing %s using %s intitializer", name, initializer_regex)
initializer(parameter)
unused_regexes.discard(initializer_regex)
Expand Down Expand Up @@ -244,6 +252,7 @@ def from_params(cls, params: List[Tuple[str, Params]]) -> "InitializerApplicator
}
],
["parameter_regex_match2", "uniform"]
["prevent_init_regex", "prevent"]
]
where the first item in each tuple is the regex that matches to parameters, and the second
Expand All @@ -252,11 +261,15 @@ def from_params(cls, params: List[Tuple[str, Params]]) -> "InitializerApplicator
or dictionaries, in which case they must contain the "type" key, corresponding to the name
of an initializer. In addition, they may contain auxiliary named parameters which will be
fed to the initializer itself. To determine valid auxiliary parameters, please refer to the
torch.nn.init documentation.
torch.nn.init documentation. Only "prevent" is a special type which does not have corresponding
initializer. Any parameter matching its corresponding regex will be overriden to NOT initialize.
Returns
-------
An InitializerApplicator containing the specified initializers.
"""
is_prevent = lambda item: item == "prevent" or item == {"type": "prevent"}
prevent_regexes = [param[0] for param in params if is_prevent(param[1])]
params = [param for param in params if param[1] if not is_prevent(param[1])]
initializers = [(name, Initializer.from_params(init_params)) for name, init_params in params]
return InitializerApplicator(initializers)
return InitializerApplicator(initializers, prevent_regexes)
38 changes: 38 additions & 0 deletions allennlp/tests/nn/initializers_test.py
Expand Up @@ -81,3 +81,41 @@ def test_uniform_unit_scaling_can_initialize(self):
uniform_unit_scaling(tensor, "relu")
assert tensor.data.max() < math.sqrt(3/10) * 1.43
assert tensor.data.min() > -math.sqrt(3/10) * 1.43

def test_regex_match_prevention_prevents_and_overrides(self):

class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.linear_1 = torch.nn.Linear(5, 10)
self.linear_2 = torch.nn.Linear(10, 5)
# typical actual usage: modules loaded from allenlp.model.load(..)
self.linear_3_transfer = torch.nn.Linear(5, 10)
self.linear_4_transfer = torch.nn.Linear(10, 5)
self.pretrained_conv = torch.nn.Conv1d(5, 5, 5)
def forward(self, inputs): # pylint: disable=arguments-differ
pass

json_params = """{"initializer": [
[".*linear.*", {"type": "constant", "val": 10}],
[".*conv.*", {"type": "constant", "val": 10}],
[".*_transfer.*", "prevent"],
[".*pretrained.*",{"type": "prevent"}]
]}
"""
params = Params(pyhocon.ConfigFactory.parse_string(json_params))
initializers = InitializerApplicator.from_params(params['initializer'])
model = Net()
initializers(model)

for module in [model.linear_1, model.linear_2]:
for parameter in module.parameters():
assert torch.equal(parameter.data, torch.ones(parameter.size())*10)

transfered_modules = [model.linear_3_transfer,
model.linear_4_transfer,
model.pretrained_conv]

for module in transfered_modules:
for parameter in module.parameters():
assert not torch.equal(parameter.data, torch.ones(parameter.size())*10)

0 comments on commit 4bd8e7f

Please sign in to comment.