/
attention.py
275 lines (227 loc) · 10.3 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
"""Modules and functions for building attention models.
References (used throughout the code):
[1]: https://arxiv.org/abs/1712.09763
[2]: https://arxiv.org/abs/2006.16236
[3]: https://arxiv.org/abs/1706.03762
"""
import functools
import math
import numpy as np
import torch
from torch import autograd, nn
from torch.nn import functional as F
def positional_encoding(d_model, max_len):
"""Generates the sinusoidal positional encodings introduced in [3].
Copied from https://pytorch.org/tutorials/beginner/transformer_tutorial.html.
Args:
d_model: Dimension of the model (i.e. embedding dimension).
max_len: Maximum possible sequence length.
Return:
Tensor of shape [max_len, 1, d_model] containing the positional encodings.
"""
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
positional_encoding = torch.zeros(max_len, 1, d_model)
positional_encoding[:, 0, 0::2] = torch.sin(position * div_term)
positional_encoding[:, 0, 1::2] = torch.cos(position * div_term)
return positional_encoding
@functools.lru_cache(maxsize=32)
def image_positional_encoding(shape):
"""Generates positional encodings for 2d images.
The positional encoding is a Tensor of shape (N, 2, H, W) of (x, y) pixel
coordinates scaled to be between -.5 and .5.
Args:
shape: NCHW shape of image for which to generate positional encodings.
Returns:
The positional encodings.
"""
n, c, h, w = shape
zeros = torch.zeros(n, 1, h, w)
return torch.cat(
(
(torch.arange(-0.5, 0.5, 1 / h)[None, None, :, None] + zeros),
(torch.arange(-0.5, 0.5, 1 / w)[None, None, None, :] + zeros),
),
dim=1,
)
@functools.lru_cache(maxsize=32)
def _get_causal_mask(size, mask_center):
"""Generates causal masks for attention weights."""
return torch.tril(torch.ones((size, size)), diagonal=-int(mask_center))
class CausalAttention(nn.Module):
"""Autoregresively masked, multihead self-attention layer.
Autoregressive masking means that the current pixel can only attend to itself,
pixels to the left, and pixels above. When mask_center=True, the current pixel does
not attend to itself.
This Module generalizes attention to use 2D convolutions instead of fully connected
layers. As such, the input is expected to be 4D image tensors.
"""
def __init__(
self,
in_channels,
n_heads=1,
embed_channels=None,
out_channels=None,
mask_center=False,
extra_input_channels=0,
):
"""Initializes a new CausalAttention instance.
Args:
in_channels: Number of input channels.
n_heads: Number of causal self-attention heads.
embed_channels: Number of embedding channels. Defaults to in_channels.
out_channels: Number of output channels. Defaults to in_channels.
extra_input_channels: Extra input channels which are only used to compute
the embeddings and not the attention weights since doing so may break
the autoregressive property. For example, in [1] these channels include
the original input image.
mask_center: Whether to mask the center pixel of the attention matrices.
"""
super().__init__()
self._n_heads = n_heads
self._embed_channels = embed_channels or in_channels
self._out_channels = out_channels or in_channels
self._mask_center = mask_center
self._q = nn.Conv2d(
in_channels=in_channels, out_channels=self._embed_channels, kernel_size=1
)
self._kv = nn.Conv2d(
in_channels=in_channels + extra_input_channels,
out_channels=self._embed_channels + self._out_channels,
kernel_size=1,
)
# TODO(eugenhotaj): Should we only project if n_heads > 1?
self._proj = nn.Conv2d(
in_channels=self._out_channels,
out_channels=self._out_channels,
kernel_size=1,
)
def forward(self, x, extra_x=None):
"""Computes the forward pass.
Args:
x: The input used to compute both embeddings and attention weights.
extra_x: Extra channels concatenated with 'x' only used to compute the
embeddings. See the 'extra_input_channels' argument for more info.
Returns:
The result of the forward pass.
"""
def _to_multihead(t):
"""Reshapes an (N, C, H, W) tensor into (N, n_heads, H * W, head_size)."""
c = t.shape[1]
t = t.view(n, self._n_heads, c // self._n_heads, -1)
return t.transpose(2, 3)
n, _, h, w = x.shape
# Compute the query, key, and value.
q = _to_multihead(self._q(x))
if extra_x is not None:
x = torch.cat((x, extra_x), dim=1)
k, v = self._kv(x).split([self._embed_channels, self._out_channels], dim=1)
k, v = _to_multihead(k), _to_multihead(v)
# Compute the causal attention weights.
mask = (
_get_causal_mask(h * w, self._mask_center)
.view(1, 1, h * w, h * w)
.to(next(self.parameters()).device)
)
attn = (q @ k.transpose(2, 3)) / np.sqrt(k.shape[-1])
attn = attn.masked_fill(mask == 0, -np.inf)
# NOTE: When self._mask_center is True, the first row of the attention matrix
# will be NaNs. We replace the NaNs with 0s here to prevent downstream issues.
attn = F.softmax(attn, dim=-1).masked_fill(mask == 0, 0)
# Attend to output for each head, stack, and project.
out = (attn @ v).transpose(2, 3).contiguous().view(n, -1, h, w)
return self._proj(out)
def _idx(i):
return (slice(None), slice(None), slice(i, i + 1, 1), slice(None))
class _UnnormalizedLinearCausalAttention(autograd.Function):
"""Computes unnormalized causal attention using only O(N*C) memory."""
@staticmethod
def forward(ctx, Q, K, V):
ctx.save_for_backward(Q, K, V)
Vnew, S = torch.zeros_like(V), 0
for i in range(V.shape[2]):
S = S + K[_idx(i)].transpose(2, 3) @ V[_idx(i)]
Vnew[_idx(i)] = Q[_idx(i)] @ S
return Vnew
@staticmethod
def backward(ctx, G):
Q, K, V = ctx.saved_tensors
dQ, S = torch.zeros_like(Q), 0
for i in range(V.shape[2]):
S = S + K[_idx(i)].transpose(2, 3) @ V[_idx(i)]
dQ[_idx(i)] = G[_idx(i)] @ S.transpose(2, 3)
dK, dV, S = torch.zeros_like(K), torch.zeros_like(V), 0
for i in range(V.shape[2] - 1, -1, -1):
S = S + Q[_idx(i)].transpose(2, 3) @ G[_idx(i)]
dV[_idx(i)] = K[_idx(i)] @ S
dK[_idx(i)] = V[_idx(i)] @ S.transpose(2, 3)
return dQ, dK, dV
# TODO(eugenhotaj): LinearCausalAttention currently does O(N) computations each
# time forward is called. During sampling, forward is called N times to generate
# N pixels. This means that during sampling LinearCausalAttention unnecessarily
# does O(N^2) computations, most of which are thrown away. Instead, we can do
# O(N) work during sampling by storing previous activations as proposed in [2].
# TODO(eugenhotaj): This API does not match the CausalAttention API. We need
# to add support for mask_center and extra_input. There is also a lot of shared
# code between the two which should be extracted. It's probably possible to
# have base class which does the bookkeeping and the subclasses implement
# the actual computations.
class LinearCausalAttention(nn.Module):
"""Memory efficient implementation of CausalAttention as introduced in [2].
NOTE: LinearCausalAttention is *much* slower than CausalAttention and should
only be used if your model cannot fit in memory.
This implementation only requires O(N) memory (instead of O(N^2)) for a
sequence of N elements (e.g. an image with N pixels). To achieve this memory
reduction, the implementation avoids storing the full attention matrix in
memory and instead computes the output directly as Q @ (K @ V). However, this
output cannot be vectorized and requires iterating over the sequence, which
drastically slows down the computation.
"""
def __init__(
self,
in_channels,
feature_fn=lambda x: F.elu(x) + 1,
n_heads=1,
embed_channels=None,
out_channels=None,
):
"""Initializes a new LinearCausalAttention instance.
Args:
in_channels: Number of input channels.
feature_fn: A kernel feature map applied to the Query and Key activations.
Defaults to lambda x: elu(x) + 1.
n_heads: Number of causal self-attention heads.
embed_channels: Number of embedding channels. Defaults to in_channels.
out_channels: Number of output channels. Defaults to in_channels.
"""
super().__init__()
self._feature_fn = feature_fn
self._n_heads = n_heads
self._embed_channels = embed_channels or in_channels
self._out_channels = out_channels or in_channels
self._query = nn.Conv2d(
in_channels=in_channels, out_channels=self._embed_channels, kernel_size=1
)
self._kv = nn.Conv2d(
in_channels=in_channels,
out_channels=self._embed_channels + self._out_channels,
kernel_size=1,
)
self._numerator = _UnnormalizedLinearCausalAttention.apply
def forward(self, x):
def _to_multihead(t):
"""Reshapes an (N, C, H, W) tensor into (N, n_heads, H * W, head_size)."""
c = t.shape[1]
t = t.view(n, self._n_heads, c // self._n_heads, -1)
return t.transpose(2, 3)
n, _, h, w = x.shape
# Compute the Query, Key, and Value.
Q = _to_multihead(self._query(x))
K, V = self._kv(x).split([self._embed_channels, self._out_channels], dim=1)
K, V = _to_multihead(K), _to_multihead(V)
# Compute the causal attention weights.
Q, K = self._feature_fn(Q), self._feature_fn(K)
den = 1 / (torch.einsum("nlhi,nlhi->nlh", Q, K.cumsum(1)) + 1e-10)
num = self._numerator(Q, K, V)
out = num * torch.unsqueeze(den, -1)
return out.transpose(2, 3).contiguous().view(n, -1, h, w)