-
Notifications
You must be signed in to change notification settings - Fork 316
/
components.py
466 lines (410 loc) · 20.5 KB
/
components.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
from mimetypes import init
from typing import Callable, Union, List, Tuple, Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import einops
import logging
from functools import *
from easy_transformer.hook_points import HookPoint
from easy_transformer.utils import (
gelu_new,
solu,
)
from easy_transformer.EasyTransformerConfig import EasyTransformerConfig
from fancy_einsum import einsum
from easy_transformer.caching import (
EasyTransformerKeyValueCache,
EasyTransformerKeyValueCacheEntry,
)
# Embed & Unembed
class Embed(nn.Module):
def __init__(self, cfg: Union[Dict, EasyTransformerConfig]):
super().__init__()
if isinstance(cfg, Dict):
cfg = EasyTransformerConfig.from_dict(cfg)
self.cfg = cfg
self.W_E = nn.Parameter(torch.empty(self.cfg.d_vocab, self.cfg.d_model))
def forward(self, tokens):
# If A has shape [a, b] and B has shape [c, d], then A[:, B] has shape [a, c, d]
# B acts as a tensor of indices into the second dimension (so >=0 and <b)
return self.W_E[tokens, :] # Shape [batch pos d_model]
class Unembed(nn.Module):
def __init__(self, cfg: Union[Dict, EasyTransformerConfig]):
super().__init__()
if isinstance(cfg, Dict):
cfg = EasyTransformerConfig.from_dict(cfg)
self.cfg = cfg
# Note that there's a separate variable for d_vocab_out and d_vocab (the input vocab size). For language tasks these are always the same, but for algorithmic tasks we may want them to be different.
self.W_U = nn.Parameter(torch.empty(self.cfg.d_model, self.cfg.d_vocab_out))
self.b_U = nn.Parameter(torch.zeros(self.cfg.d_vocab_out))
def forward(self, residual):
return (
einsum("batch pos d_model, d_model vocab -> batch pos vocab",
residual, self.W_U) + self.b_U
) # [batch, pos, d_vocab]
# Positional Embeddings
class PosEmbed(nn.Module):
def __init__(self, cfg: Union[Dict, EasyTransformerConfig]):
super().__init__()
if isinstance(cfg, Dict):
cfg = EasyTransformerConfig.from_dict(cfg)
self.cfg = cfg
self.W_pos = nn.Parameter(torch.empty(self.cfg.n_ctx, self.cfg.d_model))
def forward(
self,
tokens: torch.Tensor,
past_kv_pos_offset: int = 0):
"""Tokens have shape [batch, pos]
past_kv_pos_offset is the length of tokens in the past_kv_cache (if used, defaults to zero if unused)
Output shape [pos, d_model] - will be broadcast along batch dim"""
tokens_length = tokens.size(-1)
return self.W_pos[past_kv_pos_offset:tokens_length + past_kv_pos_offset, :] # [pos, d_model]
# LayerNormPre
# I fold the LayerNorm weights and biases into later weights and biases.
# This is just the 'center and normalise' part of LayerNorm
# Centering is equivalent to just deleting one direction of residual space,
# and is equivalent to centering the weight matrices of everything writing to the residual stream
# Normalising is a funkier non-linear operation, that projects the residual stream onto the unit hypersphere
class LayerNormPre(nn.Module):
def __init__(self, cfg: Union[Dict, EasyTransformerConfig]):
"""LayerNormPre - the 'center and normalise' part of LayerNorm. Length is
normally d_model, but is d_mlp for softmax. Not needed as a parameter. This
should only be used in inference mode after folding in LayerNorm weights"""
super().__init__()
if isinstance(cfg, Dict):
cfg = EasyTransformerConfig.from_dict(cfg)
self.cfg = cfg
self.eps = self.cfg.eps
# Adds a hook point for the normalisation scale factor
self.hook_scale = HookPoint() # [batch, pos]
self.hook_normalized = HookPoint() # [batch, pos, length]
def forward(self, x):
x = x - x.mean(axis=-1, keepdim=True) # [batch, pos, length]
scale = self.hook_scale(
(
x.pow(2).mean(-1, keepdim=True)
+ self.eps
).sqrt()
) # [batch, pos, 1]
return self.hook_normalized(x / scale) # [batch, pos, length]
class LayerNorm(nn.Module):
def __init__(
self, cfg: Union[Dict, EasyTransformerConfig], length: Optional[int] = None
):
"""
LayerNorm with optional length parameter
length (Optional[int]): If the dimension of the LayerNorm. If not provided, assumed to be d_model
"""
super().__init__()
if isinstance(cfg, Dict):
cfg = EasyTransformerConfig.from_dict(cfg)
self.cfg = cfg
self.eps = self.cfg.eps
if length is None:
self.length = self.cfg.d_model
else:
self.length = length
self.w = nn.Parameter(torch.ones(self.length))
self.b = nn.Parameter(torch.zeros(self.length))
# Adds a hook point for the normalisation scale factor
self.hook_scale = HookPoint() # [batch, pos, 1]
self.hook_normalized = HookPoint() # [batch, pos, length]
def forward(self, x):
x = x - x.mean(axis=-1, keepdim=True) # [batch, pos, length]
scale = self.hook_scale(
(
x.pow(2).mean(-1, keepdim=True)
+ self.eps
).sqrt()
) # [batch, pos, 1]
x = self.hook_normalized(x / scale) # [batch, pos, length]
return x * self.w + self.b
# Attention
class Attention(nn.Module):
def __init__(self, cfg: Union[Dict, EasyTransformerConfig], attn_type="global", layer_id=None):
"""Attention Block - params have shape [head_index, d_model, d_head] (or [head_index, d_head, d_model] for W_O) and multiply on the right. attn_scores refers to query key dot product immediately before attention softmax
Convention: All attention pattern-style matrices have shape [batch, head_index, query_pos, key_pos]
Args:
cfg (Union[Dict, EasyTransformerConfig]): Config
attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global".
layer_id (int, optional): The index of the current layer. Used by the Mistal models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None.
"""
super().__init__()
if isinstance(cfg, Dict):
cfg = EasyTransformerConfig.from_dict(cfg)
self.cfg = cfg
self.W_Q = nn.Parameter(
torch.empty(self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head)
)
self.W_K = nn.Parameter(
torch.empty(self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head)
)
self.W_V = nn.Parameter(
torch.empty(self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head)
)
self.W_O = nn.Parameter(
torch.empty(self.cfg.n_heads, self.cfg.d_head, self.cfg.d_model)
)
self.b_Q = nn.Parameter(torch.zeros(self.cfg.n_heads, self.cfg.d_head))
self.b_K = nn.Parameter(torch.zeros(self.cfg.n_heads, self.cfg.d_head))
self.b_V = nn.Parameter(torch.zeros(self.cfg.n_heads, self.cfg.d_head))
self.b_O = nn.Parameter(torch.zeros(self.cfg.d_model))
self.attn_type = attn_type
# Create a max_ctx x max_ctx mask, with True iff that query position
# can attend to that key position (query is first axis, key is second axis)
causal_mask = torch.tril(torch.ones((self.cfg.n_ctx, self.cfg.n_ctx)).bool())
if self.attn_type == "global":
# For global attention, this is a lower triangular matrix - key <= query
self.register_buffer("mask", causal_mask)
elif self.attn_type == "local":
# For local, this is banded, query - window_size < key <= query
assert isinstance(self.cfg.window_size, int)
self.register_buffer(
"mask", torch.triu(causal_mask, 1 - self.cfg.window_size)
)
else:
raise ValueError(f"Invalid attention type: {self.attn_type}")
self.register_buffer("IGNORE", torch.tensor(-1e5))
self.layer_id = layer_id
# attn_scale is a constant that we divide the attention scores by pre-softmax. I'm not entirely sure why it matters, but it's probably a mix of softmax not being scale invariant and numerical stability?
if self.cfg.use_attn_scale:
self.attn_scale = np.sqrt(self.cfg.d_head)
else:
self.attn_scale = 1.0
if self.cfg.scale_attn_by_inverse_layer_idx:
self.attn_scale *= (self.layer_id + 1)
self.hook_k = HookPoint() # [batch, pos, head_index, d_head]
self.hook_q = HookPoint() # [batch, pos, head_index, d_head]
self.hook_v = HookPoint() # [batch, pos, head_index, d_head]
self.hook_z = HookPoint() # [batch, pos, head_index, d_head]
self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos]
self.hook_attn = HookPoint() # [batch, head_index, query_pos, key_pos]
self.hook_result = HookPoint() # [batch, head_index, head_index, d_model]
# See EasyTransformerConfig for more details.
if self.cfg.positional_embedding_type == "shortformer":
# This tracks the input to the keys and queries, which is resid_pre + pos_embeds
self.hook_attn_input = HookPoint() # [batch, pos, d_model]
def forward(self,
resid_pre: torch.Tensor,
shortformer_pos_embed: Optional[torch.Tensor] = None,
past_kv_cache_entry: Optional[EasyTransformerKeyValueCacheEntry] = None
):
"""
shortformer_pos_embed is only used if self.cfg.positional_embedding_type == "shortformer", else defaults to None and is irrelevant. See EasyTransformerConfig for more details
past_kv_cache_entry is an optional entry of past keys and values for this layer, only relevant if generating text. Defaults to None
"""
if self.cfg.positional_embedding_type != "shortformer":
# Normal attention
q = self.hook_q(
einsum("batch pos d_model, head_index d_model d_head \
-> batch pos head_index d_head",
resid_pre, self.W_Q) + self.b_Q
) # [batch, pos, head_index, d_head]
k = self.hook_k(
einsum("batch pos d_model, head_index d_model d_head \
-> batch pos head_index d_head",
resid_pre, self.W_K) + self.b_K
) # [batch, pos, head_index, d_head]
else:
# Weird shortformer attention see EasyTransformerConfig for details
q, k = self.shortformer_calculate_qk(resid_pre, shortformer_pos_embed)
v = self.hook_v(
einsum("batch pos d_model, head_index d_model d_head \
-> batch pos head_index d_head",
resid_pre, self.W_V) + self.b_V
) # [batch, pos, head_index, d_head]
if past_kv_cache_entry is not None:
# Appends the new keys and values to the cached values, and automatically updates the cache
kv_cache_pos_offset = past_kv_cache_entry.past_keys.size(1)
k, v = past_kv_cache_entry.append(k, v)
else:
# Not using a cache
kv_cache_pos_offset = 0
attn_scores = (
einsum("batch query_pos head_index d_head, \
batch key_pos head_index d_head \
-> batch head_index query_pos key_pos",
q, k) / self.attn_scale
) # [batch, head_index, query_pos, key_pos]
if self.cfg.attention_dir == 'causal':
# If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask.
attn_scores = self.apply_causal_mask(
attn_scores,
kv_cache_pos_offset
) # [batch, head_index, query_pos, key_pos]
attn_matrix = self.hook_attn(
F.softmax(attn_scores, dim=-1)
) # [batch, head_index, query_pos, key_pos]
z = self.hook_z(
einsum("batch key_pos head_index d_head, \
batch head_index query_pos key_pos -> \
batch query_pos head_index d_head",
v, attn_matrix)
) # [batch, pos, head_index, d_head]
if not self.cfg.use_attn_result:
out = (
einsum("batch pos head_index d_head, \
head_index d_head d_model -> \
batch pos d_model",
z,
self.W_O)
) + self.b_O # [batch, pos, d_model]
else:
# Explicitly calculate the attention result so it can be accessed by a hook
# This is off by default because it can easily eat through your GPU memory.
result = self.hook_result(
einsum("batch pos head_index d_head, \
head_index d_head d_model -> \
batch pos head_index d_model",
z,
self.W_O)
) # [batch, pos, head_index, d_model]
out = (
einops.reduce(
result, "batch position index model->batch position model", "sum"
)
+ self.b_O
) # [batch, pos, d_model]
return out
def apply_causal_mask(self, attn_scores, past_kv_pos_offset):
# The query context length is the number of positions we take queries from - if not using a past_kv_cache this is just the context length (for the current prompt), but if we're caching it's just a single token.
query_ctx_length = attn_scores.size(-2)
# The key context length is the number of positions in the past - this includes all positions in the cache
# If not caching, query_ctx_length == key_ctx_length
key_ctx_length = attn_scores.size(-1)
assert query_ctx_length + past_kv_pos_offset == key_ctx_length, f"query_ctx_length {query_ctx_length} + past_kv_pos_offset {past_kv_pos_offset} != key_ctx_length {key_ctx_length} - you likely have a bug."
return torch.where(
self.mask[
past_kv_pos_offset : past_kv_pos_offset + query_ctx_length,
: key_ctx_length,
],
attn_scores,
self.IGNORE,
)
def shortformer_calculate_qk(self, x, shortformer_pos_embed):
# We add on the positional encodings to the residual stream JUST for the keys and queries, it's not added to the normal residual stream.
attn_input = self.hook_attn_input(
x + shortformer_pos_embed
) # [batch, pos, d_model]
q = self.hook_q(
einsum("batch pos d_model, head_index d_model d_head \
-> batch pos head_index d_head",
attn_input, self.W_Q) + self.b_Q
) # [batch, pos, head_index, d_head]
k = self.hook_k(
einsum("batch pos d_model, head_index d_model d_head \
-> batch pos head_index d_head",
attn_input, self.W_K) + self.b_K
) # [batch, pos, head_index, d_head]
return (q, k)
# MLP Layers
class MLP(nn.Module):
def __init__(self, cfg: Union[Dict, EasyTransformerConfig]):
super().__init__()
if isinstance(cfg, Dict):
cfg = EasyTransformerConfig.from_dict(cfg)
self.cfg = cfg
self.W_in = nn.Parameter(torch.empty(self.cfg.d_model, self.cfg.d_mlp))
self.b_in = nn.Parameter(torch.zeros(self.cfg.d_mlp))
self.W_out = nn.Parameter(torch.empty(self.cfg.d_mlp, self.cfg.d_model))
self.b_out = nn.Parameter(torch.zeros(self.cfg.d_model))
self.hook_pre = HookPoint() # [batch, pos, d_mlp]
self.hook_post = HookPoint() # [batch, pos, d_mlp]
if self.cfg.act_fn == "relu":
self.act_fn = F.relu
elif self.cfg.act_fn == "gelu":
self.act_fn = F.gelu
elif self.cfg.act_fn == "silu":
self.act_fn = F.silu
elif self.cfg.act_fn == "gelu_new":
self.act_fn = gelu_new
elif self.cfg.act_fn == "solu_ln":
self.act_fn = solu
self.hook_post_ln = HookPoint() # [batch, pos, d_mlp]
self.ln = LayerNorm(self.cfg, self.cfg.d_mlp)
else:
raise ValueError(f"Invalid activation function name: {self.cfg.act_fn}")
def forward(self, x):
# Technically, all these einsums could be done with a single matmul, but this is more readable.
pre_act = self.hook_pre(
einsum("batch pos d_model, d_model d_mlp -> batch pos d_mlp", x, self.W_in) + self.b_in
) # [batch, pos, d_mlp]
post_act = self.hook_post(self.act_fn(pre_act)) # [batch, pos, d_mlp]
if self.cfg.act_fn.endswith("_ln"):
post_act = self.hook_post_ln(self.ln(post_act))
mlp_out = (
einsum("batch pos d_mlp, d_mlp d_model -> batch pos d_model", post_act, self.W_out) + self.b_out
) # [batch, pos, d_model]
return mlp_out
# Transformer Block
class TransformerBlock(nn.Module):
def __init__(self, cfg: Union[Dict, EasyTransformerConfig], block_index):
super().__init__()
if isinstance(cfg, Dict):
cfg = EasyTransformerConfig.from_dict(cfg)
self.cfg = cfg
if self.cfg.normalization_type == "LN":
self.ln1 = LayerNorm(cfg)
if not self.cfg.attn_only:
self.ln2 = LayerNorm(cfg)
elif self.cfg.normalization_type == "LNPre":
# We've folded in LayerNorm weights, so just need the center + scale parts
self.ln1 = LayerNormPre(cfg)
if not self.cfg.attn_only:
self.ln2 = LayerNormPre(cfg)
elif self.cfg.normalization_type is None:
self.ln1 = nn.Identity()
if not self.cfg.attn_only:
self.ln2 = nn.Identity()
else:
logging.warning(
f"Invalid normalization_type passed in {self.cfg.normalization_type}"
)
if not self.cfg.use_local_attn:
self.attn = Attention(cfg, "global", block_index)
else:
assert self.cfg.attn_types is not None
attn_type = self.cfg.attn_types[block_index]
self.attn = Attention(cfg, attn_type, block_index)
if not self.cfg.attn_only:
self.mlp = MLP(cfg)
self.hook_attn_out = HookPoint() # [batch, pos, d_model]
self.hook_mlp_out = HookPoint() # [batch, pos, d_model]
self.hook_resid_pre = HookPoint() # [batch, pos, d_model]
if not self.cfg.attn_only:
self.hook_resid_mid = HookPoint() # [batch, pos, d_model]
self.hook_resid_post = HookPoint() # [batch, pos, d_model]
def forward(
self,
resid_pre: torch.Tensor,
shortformer_pos_embed: Optional[torch.Tensor] = None,
past_kv_cache_entry: Optional[EasyTransformerKeyValueCacheEntry] = None,
):
"""A single Transformer block.
Args:
resid_pre (torch.Tensor): The residual stream - shape [batch, pos, d_model]
cache (EasyTransformerKeyValueCache): A cache of previous keys and values, used only when generating text. Defaults to None.
shortformer_pos_embed (torch.Tensor, optional): Only used for positional_embeddings_type == "shortformer". The positional embeddings. See EasyTransformerConfig for details. Defaults to None.
Returns:
_type_: _description_
"""
resid_pre = self.hook_resid_pre(resid_pre) # [batch, pos, d_model]
normalized_resid_pre = self.ln1(resid_pre)
attn_out = self.hook_attn_out(
self.attn(
normalized_resid_pre,
shortformer_pos_embed = shortformer_pos_embed,
past_kv_cache_entry = past_kv_cache_entry)
) # [batch, pos, d_model]
if not self.cfg.attn_only:
resid_mid = self.hook_resid_mid(resid_pre + attn_out) # [batch, pos, d_model]
normalized_resid_mid = self.ln2(resid_mid)
mlp_out = self.hook_mlp_out(
self.mlp(normalized_resid_mid)
) # [batch, pos, d_model]
resid_post = self.hook_resid_post(resid_mid + mlp_out) # [batch, pos, d_model]
else:
resid_post = self.hook_resid_post(resid_pre + attn_out) # [batch, pos, d_model]
return resid_post