Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
ensure regularizers are only computed for parameters requiring gradie…
Browse files Browse the repository at this point in the history
…nts (#2887)
  • Loading branch information
codedecde authored and matt-gardner committed Jun 14, 2019
1 parent acfbb8c commit 6a3d3a8
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 7 deletions.
15 changes: 8 additions & 7 deletions allennlp/nn/regularizers/regularizer_applicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,15 @@ def __call__(self, module: torch.nn.Module) -> torch.Tensor:
The module to regularize.
"""
accumulator = 0.0
# For each parameter find the first matching regex.
for name, parameter in module.named_parameters():
for regex, regularizer in self._regularizers:
if re.search(regex, name):
penalty = regularizer(parameter)
accumulator = accumulator + penalty
break

# We first check if the parameter needs gradient updates or not
if parameter.requires_grad:
# For each parameter find the first matching regex.
for regex, regularizer in self._regularizers:
if re.search(regex, name):
penalty = regularizer(parameter)
accumulator = accumulator + penalty
break
return accumulator

# Requires custom from_params because of complex logic.
Expand Down
18 changes: 18 additions & 0 deletions allennlp/tests/nn/regularizers_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# pylint: disable=no-self-use,invalid-name
import re
import torch

from allennlp.common.params import Params
from allennlp.nn import InitializerApplicator, Initializer
from allennlp.nn.regularizers import L1Regularizer, L2Regularizer, RegularizerApplicator
Expand Down Expand Up @@ -57,3 +59,19 @@ def test_from_params(self):
assert isinstance(conv, L1Regularizer)
assert isinstance(linear, L2Regularizer)
assert linear.alpha == 10

def test_frozen_params(self):
model = torch.nn.Sequential(
torch.nn.Linear(5, 10),
torch.nn.Linear(10, 5)
)
constant_init = Initializer.from_params(Params({"type": "constant", "val": -1}))
initializer = InitializerApplicator([(".*", constant_init)])
initializer(model)
# freeze the parameters of the first linear
for name, param in model.named_parameters():
if re.search(r"0.*$", name):
param.requires_grad = False
value = RegularizerApplicator([("", L1Regularizer(1.0))])(model)
# 55 because of bias (5*10 + 5)
assert value.data.numpy() == 55

0 comments on commit 6a3d3a8

Please sign in to comment.