This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
transformer_layer.py
255 lines (218 loc) · 9.06 KB
/
transformer_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
from typing import Union, Optional, TYPE_CHECKING
from dataclasses import dataclass
import torch
from allennlp.common import FromParams
from allennlp.modules.transformer.transformer_module import TransformerModule
from allennlp.modules.transformer.activation_layer import ActivationLayer
from allennlp.modules.transformer.attention_module import SelfAttention, AttentionOutput
from allennlp.modules.transformer.output_layer import OutputLayer
from allennlp.modules.transformer.util import FloatT
if TYPE_CHECKING:
from transformers.configuration_utils import PretrainedConfig
class AttentionLayer(TransformerModule, FromParams):
"""
This module wraps the self-attention with the output-layer, similar to the architecture in BERT.
Details in the paper:
[BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding, Devlin et al, 2019]
(https://api.semanticscholar.org/CorpusID:52967399)
# Parameters
hidden_size: `int`
num_attention_heads: `int`
attention_dropout: `float` (default = `0.0`)
Dropout probability for the `SelfAttention` layer.
hidden_dropout: `float` (default = `0.0`)
Dropout probability for the `OutputLayer`.
"""
_pretrained_relevant_module = "encoder.layer.0.attention"
_pretrained_mapping = {"layer": "layers"}
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
attention_dropout: float = 0.0,
hidden_dropout: float = 0.0,
is_cross_attention: bool = False,
is_decoder: bool = False,
):
super().__init__()
self.self = SelfAttention(
hidden_size,
num_attention_heads,
attention_dropout,
is_cross_attention=is_cross_attention,
is_decoder=is_decoder,
)
self.output = OutputLayer(hidden_size, hidden_size, hidden_dropout)
def forward(
self,
input_tensor: torch.Tensor,
attention_mask: torch.BoolTensor,
head_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.BoolTensor] = None,
output_attentions: bool = False,
):
"""
# Parameters
input_tensor : `torch.Tensor`
Shape `batch_size x seq_len x hidden_dim`
attention_mask : `torch.BoolTensor`, optional
Shape `batch_size x seq_len`
head_mask : `torch.BoolTensor`, optional
output_attentions : `bool`
Whether to also return the attention probabilities, default = `False`
"""
if encoder_hidden_states is not None:
attention_mask = encoder_attention_mask
self_output = self.self(
input_tensor,
source_states=encoder_hidden_states,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
)
attention_output = self.output(self_output.hidden_states, input_tensor)
outputs = AttentionOutput(
attention_output,
self_output.key_value_state,
self_output.position_bias,
self_output.attention_probs,
)
return outputs
@classmethod
def _from_config(cls, config: "PretrainedConfig", **kwargs):
final_kwargs = {}
final_kwargs["hidden_size"] = config.hidden_size
final_kwargs["num_attention_heads"] = config.num_attention_heads
final_kwargs["attention_dropout"] = config.attention_probs_dropout_prob
final_kwargs["hidden_dropout"] = config.hidden_dropout_prob
final_kwargs.update(**kwargs)
return cls(**final_kwargs)
@dataclass
class TransformerLayerOutput:
"""
Encapsulates the outputs of the `TransformerLayer` module.
"""
hidden_states: FloatT
self_attention_probs: Optional[FloatT] = None
cross_attention_probs: Optional[FloatT] = None
class TransformerLayer(TransformerModule, FromParams):
"""
This module is a single transformer layer, mapping to `BertLayer` in the architecture in BERT.
Details in the paper:
[BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding, Devlin et al, 2019]
(https://api.semanticscholar.org/CorpusID:52967399)
# Parameters
hidden_size : `int`
intermediate_size : `int`
num_attention_heads : `int`
attention_dropout : `float` (default = `0.0`)
Dropout probability for the `SelfAttention` layer.
hidden_dropout : `float` (default = `0.0`)
Dropout probability for the `OutputLayer`.
activation : `Union[str, torch.nn.Module]`
add_cross_attention : `bool` (default = `False`)
If True, an extra `AttentionLayer` is added for cross-attention.
This is helpful when using the layer in a decoder.
"""
_pretrained_relevant_module = "encoder.layer.0"
_pretrained_mapping = {
"layer": "layers",
"intermediate_act_fn": "act_fn",
"crossattention": "cross_attention",
}
def __init__(
self,
hidden_size: int,
intermediate_size: int,
num_attention_heads: int,
attention_dropout: float = 0.0,
hidden_dropout: float = 0.0,
activation: Union[str, torch.nn.Module] = "relu",
add_cross_attention: bool = False,
):
super().__init__()
self._hidden_size = hidden_size
self._add_cross_attention = add_cross_attention
self.attention = AttentionLayer(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_dropout=attention_dropout,
hidden_dropout=hidden_dropout,
)
if add_cross_attention:
self.cross_attention = AttentionLayer(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_dropout=attention_dropout,
hidden_dropout=hidden_dropout,
is_cross_attention=True,
is_decoder=True,
)
self.intermediate = ActivationLayer(
hidden_size=hidden_size, intermediate_size=intermediate_size, activation=activation
)
self.output = OutputLayer(
input_size=intermediate_size, hidden_size=hidden_size, dropout=hidden_dropout
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> TransformerLayerOutput:
"""
# Parameters
hidden_states : `torch.Tensor`
Shape `batch_size x seq_len x hidden_dim`
attention_mask : `torch.BoolTensor`, optional
Shape `batch_size x seq_len`
head_mask : `torch.BoolTensor`, optional
encoder_hidden_states : `torch.Tensor`, optional
encoder_attention_mask : `torch.Tensor`, optional
output_attentions : `bool`
Whether to also return the attention probabilities, default = `False`
"""
attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
)
attention_output = attention_outputs.hidden_states
self_attention_probs = attention_outputs.attention_probs
cross_attention_probs = None
if encoder_hidden_states is not None:
assert hasattr(
self, "cross_attention"
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated "
"with cross-attention layers by setting `config.add_cross_attention=True`"
cross_attention_outputs = self.cross_attention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
attention_output = cross_attention_outputs.hidden_states
cross_attention_probs = cross_attention_outputs.attention_probs
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
outputs = TransformerLayerOutput(layer_output, self_attention_probs, cross_attention_probs)
return outputs
@classmethod
def _from_config(cls, config: "PretrainedConfig", **kwargs):
final_kwargs = {}
final_kwargs["hidden_size"] = config.hidden_size
final_kwargs["num_attention_heads"] = config.num_attention_heads
final_kwargs["attention_dropout"] = config.attention_probs_dropout_prob
final_kwargs["hidden_dropout"] = config.hidden_dropout_prob
final_kwargs["intermediate_size"] = config.intermediate_size
final_kwargs["activation"] = config.hidden_act
final_kwargs["add_cross_attention"] = config.add_cross_attention
final_kwargs.update(**kwargs)
return cls(**final_kwargs)