This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
bidirectional_lm.py
54 lines (44 loc) · 2.33 KB
/
bidirectional_lm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from typing import Tuple
from allennlp.modules.token_embedders.token_embedder import TokenEmbedder
from allennlp_models.lm.modules.token_embedders.language_model import LanguageModelTokenEmbedder
@TokenEmbedder.register("bidirectional_lm_token_embedder")
class BidirectionalLanguageModelTokenEmbedder(LanguageModelTokenEmbedder):
"""
Compute a single layer of representations from a bidirectional language model. This is done
by computing a learned scalar average of the layers from the LM. Typically the LM's weights
will be fixed, but they can be fine tuned by setting `requires_grad`.
# Parameters
archive_file : `str`, required
An archive file, typically model.tar.gz, from a BidirectionalLanguageModel. The
contextualizer used by the LM must satisfy two requirements:
1. It must have a num_layers field.
2. It must take a boolean return_all_layers parameter in its constructor.
See BidirectionalLanguageModelTransformer for their definitions.
dropout : `float`, optional.
The dropout value to be applied to the representations.
bos_eos_tokens : `Tuple[str, str]`, optional (default=`("<S>", "</S>")`)
These will be indexed and placed around the indexed tokens. Necessary if the language model
was trained with them, but they were injected external to an indexer.
remove_bos_eos : `bool`, optional (default = `True`)
Typically the provided token indexes will be augmented with begin-sentence and end-sentence
tokens. (Alternatively, you can pass bos_eos_tokens.) If this flag is True the
corresponding embeddings will be removed from the return values.
Warning: This only removes a single start and single end token!
requires_grad : `bool`, optional (default=`False`)
If True, compute gradient of bidirectional language model parameters for fine tuning.
"""
def __init__(
self,
archive_file: str,
dropout: float = None,
bos_eos_tokens: Tuple[str, str] = ("<S>", "</S>"),
remove_bos_eos: bool = True,
requires_grad: bool = False,
) -> None:
super().__init__(
archive_file=archive_file,
dropout=dropout,
bos_eos_tokens=bos_eos_tokens,
remove_bos_eos=remove_bos_eos,
requires_grad=requires_grad,
)