This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for prevention of parameter initialization which match th…
…e given regexes (#1405) Add support for overrided prevention of parameter initialization which match the given regexes ## Commit Summary - Makes following scenario easy to handle: continue with all other configured initialization schemes for various networks but skip parameters matching `".*transfer.*"` - Add `PreventRegexInitializer` and register it with `"prevent"`. - Allow `InitializerApplicator` to apply any `initializer` to any parameter only if prevention regexes are not matched. - Add unit tests with example usage ## Need for this feature: - If one wants to transfer modules from pretrained model, one would want to initialzie weights of new model preventing the transfered modules from this initialization. - Negative matching regex give a easy selective handle on which parameters are transferred and must not be initialized. Eg. continue with all other configured initialization schemes for various networks but skip parameters matching `".*(transfer)|(pretrained).*"`. - Also note that given regex A and regex B, in general it is not easy to write regex that matches A but not B. So bringing the required effect with plain regex can be quite tedious. ## Example Usage ```python class MyNet(torch.nn.Module): # Typically, transfer_model s would loaded with trained weights by : # eg. transfer_mode1 = load_archive(params.pop("model1.tar.gz")) # in from_params of main MyModel of which MyNet is part of. def __init__(self, transfer_model1, transfer_model2): super(Net, self).__init__() self.linear_1 = torch.nn.Linear(5, 10) self.linear_2 = torch.nn.Linear(10, 5) self.linear_3_transfer = transfer_model1.linear_3 # Note this self.linear_4_transfer = transfer_model2.linear_4 # Note this self.pretrained_conv = transfer_model2.conv # Note this def forward(self, inputs): # pylint: disable=arguments-differ pass ``` ``` // experiment.json ... {"initializer": [ [".*linear.*", {"type": "xavier_normal"}], ... [".*conv.*", {"type": "kaiming_normal"}], [".*_transfer.*", "prevent"], // Note this [".*pretrained.*",{"type": "prevent"}] // Note this ]} ... ``` The following would initialize linear and conv layers except `linear_3_transfer`, `linear_4_transfer`, `pretrained_conv` which match regexes corresponding to "prevent" type. ```python net = MyNet() # load initializer from above configs initializer(net) ```
- Loading branch information