-
Notifications
You must be signed in to change notification settings - Fork 3
/
multihead_attention.py
280 lines (229 loc) · 10.8 KB
/
multihead_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import torch
import torch.nn as nn
import torch.nn.functional as F
class DotProductAttention(nn.Module):
def __init__(self, dropout=0.0):
super(DotProductAttention, self).__init__()
self.dropout = dropout
def forward(self, q, k, v, attn_mask=None, knn=False):
B, N1, N2 = q.shape[0], q.shape[-2], k.shape[-2]
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
if attn_mask is not None:
attn_output_weights += attn_mask
if knn:
mask=torch.zeros(B,N1,N2,device=q.device,requires_grad=False)
index=torch.topk(attn_output_weights,k=int(N2 * 3 // 4),dim=-1,largest=True)[1]
mask.scatter_(-1,index,1.)
# attn_output_weights = torch.where(mask>0,attn_output_weights,torch.full_like(attn_output_weights,-1e7))
attn_output_weights = torch.where(mask > 0, attn_output_weights, torch.full_like(attn_output_weights, float('-inf')))
attn_output_weights = F.softmax(attn_output_weights, dim=-1)
attn_output_weights = F.dropout(attn_output_weights,
p=self.dropout,
training=self.training)
attn_output = torch.bmm(attn_output_weights, v)
return attn_output
class DotProductAttentionStream(DotProductAttention):
def __init__(self, dropout=0.0):
super(DotProductAttentionStream, self).__init__(dropout)
############################
# Cache for stream inference
############################
self.k_weights_cache = None
self.k_pos_weights_cache = None
def stream_inference(self, q, k, v, k_pos, v_pos, attn_mask=None):
if self.k_weights_cache is not None:
k_weights_new = torch.bmm(q, k[:, [-1]].transpose(1, 2))
k_weights = torch.cat((self.k_weights_cache[:, :, 1:], k_weights_new), dim=-1)
self.k_weights_cache = k_weights
k_pos_weights = self.k_pos_weights_cache
else:
k_weights = torch.bmm(q, k.transpose(1, 2))
self.k_weights_cache = k_weights
k_pos_weights = torch.bmm(q, k_pos.transpose(1, 2))
self.k_pos_weights_cache = k_pos_weights
attn_output_weights = k_weights + k_pos_weights
if attn_mask is not None:
attn_output_weights += attn_mask
attn_output_weights = F.softmax(attn_output_weights, dim=-1)
attn_output_weights = F.dropout(attn_output_weights,
p=self.dropout,
training=self.training)
attn_output = torch.bmm(attn_output_weights, (v + v_pos))
return attn_output
class MultiheadAttention(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True, kdim=None, vdim=None):
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
if self._qkv_same_embed_dim:
self.in_proj_weight = nn.Parameter(torch.empty(3 * embed_dim, embed_dim))
else:
raise RuntimeError('Do not support q, k, v have different dimensions')
if bias:
self.in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim))
else:
self.register_parameter('in_proj_bias', None)
self.out_proj = nn.Linear(embed_dim, embed_dim)
if self._qkv_same_embed_dim:
nn.init.xavier_uniform_(self.in_proj_weight)
if self.in_proj_bias is not None:
nn.init.constant_(self.in_proj_bias, 0.)
nn.init.constant_(self.out_proj.bias, 0.)
self.dotproductattention = DotProductAttention(dropout)
def forward(self, q, k, v, attn_mask=None, key_padding_mask=None, knn=False):
tsz, bsz, embed_dim = q.shape[0], q.shape[1], q.shape[2]
head_dim = embed_dim // self.num_heads
assert head_dim * self.num_heads == embed_dim, \
'embed_dim must be divisible by num_heads'
scaling = float(head_dim) ** -0.5
_b = self.in_proj_bias
_start = None
_end = embed_dim
_w = self.in_proj_weight[:_end, :]
if _b is not None:
_b = _b[:_end]
q = F.linear(q, _w, _b)
_b = self.in_proj_bias
_start = embed_dim
_end = embed_dim * 2
_w = self.in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
k = F.linear(k, _w, _b)
_b = self.in_proj_bias
_start = embed_dim * 2
_end = None
_w = self.in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
v = F.linear(v, _w, _b)
q = q * scaling
q = q.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
k = k.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
v = v.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0).repeat(bsz, 1, 1)
attn_mask = attn_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
attn_mask = attn_mask.reshape(-1, *attn_mask.shape[2:])
if key_padding_mask is not None:
key_padding_mask = key_padding_mask.unsqueeze(1).repeat(1, tsz, 1)
key_padding_mask = key_padding_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
key_padding_mask = key_padding_mask.reshape(-1, *key_padding_mask.shape[2:])
if attn_mask is not None and key_padding_mask is not None:
mask = attn_mask + key_padding_mask
elif attn_mask is not None:
mask = attn_mask
elif key_padding_mask is not None:
mask = key_padding_mask
else:
mask = None
attn_output = self.dotproductattention(q, k, v, mask, knn=knn)
attn_output = attn_output.transpose(0, 1).contiguous().view(tsz, bsz,
self.embed_dim)
return self.out_proj(attn_output), None
class MultiheadAttentionStream(MultiheadAttention):
def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True, kdim=None, vdim=None):
super(MultiheadAttentionStream, self).__init__(embed_dim, num_heads, dropout, bias, kdim, vdim)
self.dotproductattention = DotProductAttentionStream(dropout)
############################
# Cache for stream inference
############################
self.q_cache = None
self.k_cache = None
self.v_cache = None
self.k_pos_cache = None
self.v_pos_cache = None
def stream_inference(self, q, k, v, pos, attn_mask=None, key_padding_mask=None):
tsz, bsz, embed_dim = q.shape[0], q.shape[1], q.shape[2]
head_dim = embed_dim // self.num_heads
assert head_dim * self.num_heads == embed_dim, \
'embed_dim must be divisible by num_heads'
scaling = float(head_dim) ** -0.5
if self.q_cache is not None:
q = self.q_cache
else:
_b = self.in_proj_bias
_start = None
_end = embed_dim
_w = self.in_proj_weight[:_end, :]
if _b is not None:
_b = _b[:_end]
q = F.linear(q, _w, _b)
self.q_cache = q
assert (self.k_cache is None) == (self.k_pos_cache is None)
if self.k_cache is not None:
_b = self.in_proj_bias
_start = embed_dim
_end = embed_dim * 2
_w = self.in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
k_new = F.linear(k[[-1]], _w, None)
k = torch.cat((self.k_cache[1:], k_new))
self.k_cache = k
k_pos = self.k_pos_cache
else:
_b = self.in_proj_bias
_start = embed_dim
_end = embed_dim * 2
_w = self.in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
k = F.linear(k, _w, None)
self.k_cache = k
k_pos = F.linear(pos, _w, _b)
self.k_pos_cache = k_pos
assert (self.v_cache is None) == (self.v_pos_cache is None)
if self.v_cache is not None:
_b = self.in_proj_bias
_start = embed_dim * 2
_end = None
_w = self.in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
v_new = F.linear(v[[-1]], _w, None)
v = torch.cat((self.v_cache[1:], v_new))
self.v_cache = v
v_pos = self.v_pos_cache
else:
_b = self.in_proj_bias
_start = embed_dim * 2
_end = None
_w = self.in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
v = F.linear(v, _w, None)
self.v_cache = v
v_pos = F.linear(pos, _w, _b)
self.v_pos_cache = v_pos
q = q * scaling
q = q.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
k = k.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
v = v.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
k_pos = k_pos.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
v_pos = v_pos.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0).repeat(bsz, 1, 1)
attn_mask = attn_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
attn_mask = attn_mask.reshape(-1, *attn_mask.shape[2:])
if key_padding_mask is not None:
key_padding_mask = key_padding_mask.unsqueeze(1).repeat(1, tsz, 1)
key_padding_mask = key_padding_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
key_padding_mask = key_padding_mask.reshape(-1, *key_padding_mask.shape[2:])
if attn_mask is not None and key_padding_mask is not None:
mask = attn_mask + key_padding_mask
elif attn_mask is not None:
mask = attn_mask
elif key_padding_mask is not None:
mask = key_padding_mask
else:
mask = None
attn_output = self.dotproductattention.stream_inference(q, k, v, k_pos, v_pos, mask)
attn_output = attn_output.transpose(0, 1).contiguous().view(tsz, bsz,
self.embed_dim)
return self.out_proj(attn_output), None