-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
multi_head_attention.py
292 lines (254 loc) · 11.3 KB
/
multi_head_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Part of this code is adopted from https://github.com/espnet/espnet
"""
import math
import numpy as np
import torch
import torch.nn as nn
__all__ = [
'RelPositionMultiHeadAttention',
'RelPositionalEncoding',
'PositionalEncoding',
]
class MultiHeadAttention(nn.Module):
"""Multi-Head Attention layer.
Args:
n_head (int): number of heads
n_feat (int): size of the features
dropout_rate (float): dropout rate
"""
def __init__(self, n_head, n_feat, dropout_rate):
"""Construct an MultiHeadedAttention object."""
super(MultiHeadAttention, self).__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k = nn.Linear(n_feat, n_feat)
self.linear_v = nn.Linear(n_feat, n_feat)
self.linear_out = nn.Linear(n_feat, n_feat)
self.attn = None
self.dropout = nn.Dropout(p=dropout_rate)
def forward_qkv(self, query, key, value):
"""Transforms query, key and value.
Args:
query (torch.Tensor): (batch, time1, size)
key (torch.Tensor): (batch, time2, size)
value (torch.Tensor): (batch, time2, size)
returns:
q (torch.Tensor): (batch, head, time1, size)
k (torch.Tensor): (batch, head, time2, size)
v (torch.Tensor): (batch, head, time2, size)
"""
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
return q, k, v
def forward_attention(self, value, scores, mask):
"""Compute attention context vector.
Args:
value (torch.Tensor): (batch, time2, size)
scores(torch.Tensor): (batch, time1, time2)
mask(torch.Tensor): (batch, time1, time2)
returns:
value (torch.Tensor): transformed `value` (batch, time2, d_model) weighted by the attention scores
"""
n_batch = value.size(0)
if mask is not None:
mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
min_value = float(np.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
scores = scores.masked_fill(mask, min_value)
self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
else:
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
p_attn = self.dropout(self.attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model)
def forward(self, query, key, value, mask):
"""Compute 'Scaled Dot Product Attention'.
Args:
query (torch.Tensor): (batch, time1, size)
key (torch.Tensor): (batch, time2, size)
value(torch.Tensor): (batch, time2, size)
mask (torch.Tensor): (batch, time1, time2)
returns:
output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention
"""
q, k, v = self.forward_qkv(query, key, value)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask)
class RelPositionMultiHeadAttention(MultiHeadAttention):
"""Multi-Head Attention layer with relative position encoding.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): number of heads
n_feat (int): size of the features
dropout_rate (float): dropout rate
"""
def __init__(self, n_head, n_feat, dropout_rate):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate)
# linear transformation for positional ecoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
torch.nn.init.xavier_uniform_(self.pos_bias_u)
torch.nn.init.xavier_uniform_(self.pos_bias_v)
def rel_shift(self, x, zero_triu=False):
"""Compute relative positinal encoding.
Args:
x (torch.Tensor): (batch, time, size)
zero_triu (bool): return the lower triangular part of the matrix
"""
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
x = x_padded[:, :, 1:].view_as(x)
if zero_triu:
ones = torch.ones((x.size(2), x.size(3)))
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
return x
def forward(self, query, key, value, mask, pos_emb):
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (torch.Tensor): (batch, time1, size)
key (torch.Tensor): (batch, time2, size)
value(torch.Tensor): (batch, time2, size)
mask (torch.Tensor): (batch, time1, time2)
pos_emb (torch.Tensor) : (batch, time1, size)
Returns:
output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention
"""
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2) # (batch, time1, head, d_k)
n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, time1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
matrix_bd = self.rel_shift(matrix_bd)
scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2)
return self.forward_attention(v, scores, mask)
class PositionalEncoding(torch.nn.Module):
"""Positional encoding.
Args:
d_model (int): embedding dim
dropout_rate (float): dropout rate
max_len (int): maximum input length
reverse (int): whether to reverse the input position
"""
def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False, xscale=None):
"""Construct an PositionalEncoding object."""
super(PositionalEncoding, self).__init__()
self.d_model = d_model
self.reverse = reverse
self.xscale = xscale # math.sqrt(self.d_model)
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x):
"""Reset the positional encodings."""
if self.pe is not None:
if self.pe.size(1) >= x.size(1):
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
pe = torch.zeros(x.size(1), self.d_model)
if self.reverse:
position = torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
else:
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor):
"""Add positional encoding.
Args:
x (torch.Tensor): Input. Its shape is (batch, time, ...)
Returns:
Encoded Output (torch.Tensor): Its shape is (batch, time, ...)
"""
self.extend_pe(x)
if self.xscale:
x = x * self.xscale
x = x + self.pe[:, : x.size(1)]
return self.dropout(x), None
class RelPositionalEncoding(PositionalEncoding):
"""Relitive positional encoding module.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): embedding dim
dropout_rate (float): dropout rate
max_len (int): maximum input length
"""
def __init__(self, d_model, dropout_rate, max_len=5000, dropout_emb_rate=0.0, xscale=None):
super().__init__(d_model, dropout_rate, max_len, reverse=True, xscale=xscale)
if dropout_emb_rate > 0:
self.dropout_emb = nn.Dropout(dropout_emb_rate)
else:
self.dropout_emb = None
def forward(self, x):
"""Compute positional encoding.
Args:
x (torch.Tensor): Input. Its shape is (batch, time, ...)
Returns:
x (torch.Tensor): Its shape is (batch, time, ...)
pos_emb (torch.Tensor): Its shape is (1, time, ...)
"""
self.extend_pe(x)
if self.xscale:
x = x * self.xscale
pos_emb = self.pe[:, : x.size(1)]
if self.dropout_emb:
pos_emb = self.dropout_emb(pos_emb)
return self.dropout(x), pos_emb