Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Add regularization parameter to Models (#3120)
Browse files Browse the repository at this point in the history
  • Loading branch information
pensono authored and DeNeutoy committed Aug 13, 2019
1 parent bf968c6 commit 0bd3319
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 16 deletions.
11 changes: 7 additions & 4 deletions allennlp/models/basic_classifier.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Dict
from typing import Dict, Optional

from overrides import overrides
import torch

from allennlp.data import Vocabulary
from allennlp.models.model import Model
from allennlp.modules import Seq2SeqEncoder, Seq2VecEncoder, TextFieldEmbedder
from allennlp.nn import InitializerApplicator
from allennlp.nn import InitializerApplicator, RegularizerApplicator
from allennlp.nn.util import get_text_field_mask
from allennlp.training.metrics import CategoricalAccuracy

Expand Down Expand Up @@ -41,6 +41,8 @@ class BasicClassifier(Model):
Vocabulary namespace corresponding to labels. By default, we use the "labels" namespace.
initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
If provided, will be used to initialize the model parameters.
regularizer : ``RegularizerApplicator``, optional (default=``None``)
If provided, will be used to calculate the regularization penalty during training.
"""
def __init__(self,
vocab: Vocabulary,
Expand All @@ -50,9 +52,10 @@ def __init__(self,
dropout: float = None,
num_labels: int = None,
label_namespace: str = "labels",
initializer: InitializerApplicator = InitializerApplicator()) -> None:
initializer: InitializerApplicator = InitializerApplicator(),
regularizer: Optional[RegularizerApplicator] = None) -> None:

super().__init__(vocab)
super().__init__(vocab, regularizer)
self._text_field_embedder = text_field_embedder

if seq2seq_encoder:
Expand Down
10 changes: 7 additions & 3 deletions allennlp/models/bert_for_classification.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Union
from typing import Dict, Union, Optional

from overrides import overrides
import torch
Expand All @@ -8,6 +8,7 @@
from allennlp.models.model import Model
from allennlp.modules.token_embedders.bert_token_embedder import PretrainedBertModel
from allennlp.nn.initializers import InitializerApplicator
from allennlp.nn import RegularizerApplicator
from allennlp.training.metrics import CategoricalAccuracy


Expand Down Expand Up @@ -42,6 +43,8 @@ class BertForClassification(Model):
Otherwise, they will be frozen and only the final linear layer will be trained.
initializer : ``InitializerApplicator``, optional
If provided, will be used to initialize the final linear layer *only*.
regularizer : ``RegularizerApplicator``, optional (default=``None``)
If provided, will be used to calculate the regularization penalty during training.
"""
def __init__(self,
vocab: Vocabulary,
Expand All @@ -51,8 +54,9 @@ def __init__(self,
index: str = "bert",
label_namespace: str = "labels",
trainable: bool = True,
initializer: InitializerApplicator = InitializerApplicator()) -> None:
super().__init__(vocab)
initializer: InitializerApplicator = InitializerApplicator(),
regularizer: Optional[RegularizerApplicator] = None,) -> None:
super().__init__(vocab, regularizer)

if isinstance(bert_model, str):
self.bert_model = PretrainedBertModel.load(bert_model)
Expand Down
12 changes: 9 additions & 3 deletions allennlp/models/bidirectional_lm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Optional

from allennlp.data.vocabulary import Vocabulary
from allennlp.models.language_model import LanguageModel
from allennlp.models.model import Model
from allennlp.modules.text_field_embedders import TextFieldEmbedder
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder
from allennlp.nn import InitializerApplicator
from allennlp.nn import InitializerApplicator, RegularizerApplicator


@Model.register('bidirectional-language-model')
Expand Down Expand Up @@ -37,6 +39,8 @@ class BidirectionalLanguageModel(LanguageModel):
the full ``_SoftmaxLoss`` defined above.
sparse_embeddings: ``bool``, optional (default: False)
Passed on to ``SampledSoftmaxLoss`` if True.
regularizer : ``RegularizerApplicator``, optional (default=``None``)
If provided, will be used to calculate the regularization penalty during training.
"""
def __init__(self,
vocab: Vocabulary,
Expand All @@ -45,12 +49,14 @@ def __init__(self,
dropout: float = None,
num_samples: int = None,
sparse_embeddings: bool = False,
initializer: InitializerApplicator = None) -> None:
initializer: InitializerApplicator = None,
regularizer: Optional[RegularizerApplicator] = None) -> None:
super().__init__(vocab=vocab,
text_field_embedder=text_field_embedder,
contextualizer=contextualizer,
dropout=dropout,
num_samples=num_samples,
sparse_embeddings=sparse_embeddings,
bidirectional=True,
initializer=initializer)
initializer=initializer,
regularizer=regularizer)
8 changes: 6 additions & 2 deletions allennlp/models/event2mind.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from allennlp.models.model import Model
from allennlp.nn.beam_search import BeamSearch
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
from allennlp.nn import RegularizerApplicator
from allennlp.training.metrics import UnigramRecall


Expand Down Expand Up @@ -55,6 +56,8 @@ class Event2Mind(Model):
target_embedding_dim : int, optional (default = source_embedding_dim)
You can specify an embedding dimensionality for the target side. If not, we'll use the same
value as the source embedder's.
regularizer : ``RegularizerApplicator``, optional (default=``None``)
If provided, will be used to calculate the regularization penalty during training.
"""
def __init__(self,
vocab: Vocabulary,
Expand All @@ -65,8 +68,9 @@ def __init__(self,
beam_size: int = 10,
target_names: List[str] = None,
target_namespace: str = "tokens",
target_embedding_dim: int = None) -> None:
super().__init__(vocab)
target_embedding_dim: int = None,
regularizer: Optional[RegularizerApplicator] = None) -> None:
super().__init__(vocab, regularizer)
target_names = target_names or ["xintent", "xreact", "oreact"]

# Note: The original tweaks the embeddings for "personx" to be the mean
Expand Down
11 changes: 7 additions & 4 deletions allennlp/models/language_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Tuple, Union, Optional

import torch
import numpy as np
Expand All @@ -10,7 +10,7 @@
from allennlp.modules.sampled_softmax_loss import SampledSoftmaxLoss
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder
from allennlp.nn.util import get_text_field_mask
from allennlp.nn import InitializerApplicator
from allennlp.nn import InitializerApplicator, RegularizerApplicator
from allennlp.training.metrics import Perplexity


Expand Down Expand Up @@ -88,6 +88,8 @@ class LanguageModel(Model):
Train a bidirectional language model, where the contextualizer
is used to predict the next and previous token for each input token.
This must match the bidirectionality of the contextualizer.
regularizer : ``RegularizerApplicator``, optional (default=``None``)
If provided, will be used to calculate the regularization penalty during training.
"""
def __init__(self,
vocab: Vocabulary,
Expand All @@ -97,8 +99,9 @@ def __init__(self,
num_samples: int = None,
sparse_embeddings: bool = False,
bidirectional: bool = False,
initializer: InitializerApplicator = None) -> None:
super().__init__(vocab)
initializer: InitializerApplicator = None,
regularizer: Optional[RegularizerApplicator] = None) -> None:
super().__init__(vocab, regularizer)
self._text_field_embedder = text_field_embedder

if contextualizer.is_bidirectional() is not bidirectional:
Expand Down

0 comments on commit 0bd3319

Please sign in to comment.