-
Notifications
You must be signed in to change notification settings - Fork 269
/
components.py
1306 lines (1132 loc) · 54.4 KB
/
components.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""Hooked Transformer Components.
This module contains all the components (e.g. :class:`Attention`, :class:`MLP`, :class:`LayerNorm`)
needed to create many different types of generative language models. They are used by
:class:`transformer_lens.HookedTransformer`.
"""
import logging
from typing import Dict, Optional, Tuple, Union
import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from fancy_einsum import einsum
from jaxtyping import Float, Int
from transformer_lens.FactoredMatrix import FactoredMatrix
from transformer_lens.hook_points import HookPoint
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCacheEntry
from transformer_lens.utils import gelu_fast, gelu_new, get_offset_position_ids, solu
# Embed & Unembed
class Embed(nn.Module):
def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
super().__init__()
if isinstance(cfg, Dict):
cfg = HookedTransformerConfig.from_dict(cfg)
self.cfg = cfg
self.W_E: Float[torch.Tensor, "d_vocab d_model"] = nn.Parameter(
torch.empty(self.cfg.d_vocab, self.cfg.d_model, dtype=cfg.dtype)
)
# Some models (e.g. Bloom) need post embedding layer norm
if cfg.post_embedding_ln:
self.ln = LayerNorm(cfg)
def forward(
self, tokens: Int[torch.Tensor, "batch pos"]
) -> Float[torch.Tensor, "batch pos d_model"]:
# If A has shape [a, b] and B has shape [c, d], then A[:, B] has shape [a, c, d]
# B acts as a tensor of indices into the second dimension (so >=0 and <b)
if self.cfg.post_embedding_ln:
return self.ln(self.W_E[tokens, :])
return self.W_E[tokens, :]
class Unembed(nn.Module):
def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
super().__init__()
if isinstance(cfg, Dict):
cfg = HookedTransformerConfig.from_dict(cfg)
self.cfg = cfg
# Note that there's a separate variable for d_vocab_out and d_vocab (the input vocab size). For language tasks these are always the same, but for algorithmic tasks we may want them to be different.
self.W_U: Float[torch.Tensor, "d_model d_vocab_out"] = nn.Parameter(
torch.empty(self.cfg.d_model, self.cfg.d_vocab_out, dtype=cfg.dtype)
)
self.b_U: Float[torch.Tensor, "d_vocab_out"] = nn.Parameter(
torch.zeros(self.cfg.d_vocab_out, dtype=cfg.dtype)
)
def forward(
self, residual: Float[torch.Tensor, "batch pos d_model"]
) -> Float[torch.Tensor, "batch pos d_vocab_out"]:
return (
einsum(
"batch pos d_model, d_model vocab -> batch pos vocab",
residual,
self.W_U,
)
+ self.b_U
)
# Positional Embeddings
class PosEmbed(nn.Module):
def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
super().__init__()
if isinstance(cfg, Dict):
cfg = HookedTransformerConfig.from_dict(cfg)
self.cfg = cfg
self.W_pos = nn.Parameter(
torch.empty(self.cfg.n_ctx, self.cfg.d_model, dtype=cfg.dtype)
)
def forward(
self,
tokens: Int[torch.Tensor, "batch pos"],
past_kv_pos_offset: int = 0,
attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
) -> Float[torch.Tensor, "batch pos d_model"]:
"""
Forward pass for positional embeddings.
Args:
tokens (Int[torch.Tensor, "batch pos"]): Input tokens.
past_kv_pos_offset (int, optional): The length of tokens in the past_kv_cache. Defaults to 0.
attention_mask (Int[torch.Tensor, "batch pos"], optional): The attention mask for padded tokens.
Defaults to None.
Returns:
Float[torch.Tensor, "batch pos d_model"]: Absolute position embeddings.
"""
tokens_length = tokens.size(-1)
if attention_mask is None:
pos_embed = self.W_pos[
past_kv_pos_offset : tokens_length + past_kv_pos_offset, :
] # [pos, d_model]
batch_pos_embed = einops.repeat(
pos_embed, "pos d_model -> batch pos d_model", batch=tokens.size(0)
)
else:
# Separated from the no padding case for computational efficiency
# (this code is a bit slower than the code above)
offset_position_ids = get_offset_position_ids(
past_kv_pos_offset, attention_mask
)
pos_embed = self.W_pos[offset_position_ids] # [batch, pos, d_model]
# Set the position embeddings to 0 for pad tokens (this is an arbitrary choice)
padding_mask = ~attention_mask.bool() # [batch, tokens_length]
offset_padding_mask = padding_mask[
:, past_kv_pos_offset : tokens_length + past_kv_pos_offset
].unsqueeze(
-1
) # [batch, pos, 1]
batch_pos_embed = torch.where(offset_padding_mask, 0, pos_embed)
return batch_pos_embed.clone()
class TokenTypeEmbed(nn.Module):
"""
The token-type embed is a binary ids indicating whether a token belongs to sequence A or B. For example, for two sentences: "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be [0, 0, ..., 0, 1, ..., 1, 1]. `0` represents tokens from Sentence A, `1` from Sentence B. If not provided, BERT assumes a single sequence input. Typically, shape is (batch_size, sequence_length).
See the BERT paper for more information: https://arxiv.org/pdf/1810.04805.pdf
"""
def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
super().__init__()
if isinstance(cfg, Dict):
cfg = HookedTransformerConfig.from_dict(cfg)
self.cfg = cfg
self.W_token_type = nn.Parameter(
torch.empty(2, self.cfg.d_model, dtype=cfg.dtype)
)
def forward(self, token_type_ids: Int[torch.Tensor, "batch pos"]):
return self.W_token_type[token_type_ids, :]
class BertEmbed(nn.Module):
"""
Custom embedding layer for a BERT-like model. This module computes the sum of the token, positional and token-type embeddings and takes the layer norm of the result.
"""
def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
super().__init__()
if isinstance(cfg, Dict):
cfg = HookedTransformerConfig.from_dict(cfg)
self.cfg = cfg
self.embed = Embed(cfg)
self.pos_embed = PosEmbed(cfg)
self.token_type_embed = TokenTypeEmbed(cfg)
self.ln = LayerNorm(cfg)
self.hook_embed = HookPoint()
self.hook_pos_embed = HookPoint()
self.hook_token_type_embed = HookPoint()
def forward(
self,
input_ids: Int[torch.Tensor, "batch pos"],
token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
):
base_index_id = torch.arange(input_ids.shape[1], device=input_ids.device)
index_ids = einops.repeat(
base_index_id, "pos -> batch pos", batch=input_ids.shape[0]
)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
word_embeddings_out = self.hook_embed(self.embed(input_ids))
position_embeddings_out = self.hook_pos_embed(self.pos_embed(index_ids))
token_type_embeddings_out = self.hook_token_type_embed(
self.token_type_embed(token_type_ids)
)
embeddings_out = (
word_embeddings_out + position_embeddings_out + token_type_embeddings_out
)
layer_norm_out = self.ln(embeddings_out)
return layer_norm_out
class BertMLMHead(nn.Module):
"""
Transforms BERT embeddings into logits. The purpose of this module is to predict masked tokens in a sentence.
"""
def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
super().__init__()
if isinstance(cfg, Dict):
cfg = HookedTransformerConfig.from_dict(cfg)
self.cfg = cfg
self.W = nn.Parameter(torch.empty(cfg.d_model, cfg.d_model, dtype=cfg.dtype))
self.b = nn.Parameter(torch.zeros(cfg.d_model, dtype=cfg.dtype))
self.act_fn = nn.GELU()
self.ln = LayerNorm(cfg)
def forward(self, resid: Float[torch.Tensor, "batch pos d_model"]) -> torch.Tensor:
resid = (
einsum(
"batch pos d_model_in, d_model_out d_model_in -> batch pos d_model_out",
resid,
self.W,
)
+ self.b
)
resid = self.act_fn(resid)
resid = self.ln(resid)
return resid
# LayerNormPre
# I fold the LayerNorm weights and biases into later weights and biases.
# This is just the 'center and normalise' part of LayerNorm
# Centering is equivalent to just deleting one direction of residual space,
# and is equivalent to centering the weight matrices of everything writing to the residual stream
# Normalising is a funkier non-linear operation, that projects the residual stream onto the unit hypersphere
class LayerNormPre(nn.Module):
def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
"""LayerNormPre - the 'center and normalise' part of LayerNorm. Length is
normally d_model, but is d_mlp for softmax. Not needed as a parameter. This
should only be used in inference mode after folding in LayerNorm weights"""
super().__init__()
if isinstance(cfg, Dict):
cfg = HookedTransformerConfig.from_dict(cfg)
self.cfg = cfg
self.eps = self.cfg.eps
# Adds a hook point for the normalisation scale factor
self.hook_scale = HookPoint() # [batch, pos]
# Hook Normalized captures LN output - here it's a vector with std 1 and mean 0
self.hook_normalized = HookPoint() # [batch, pos, length]
def forward(
self,
x: Union[
Float[torch.Tensor, "batch pos d_model"],
Float[torch.Tensor, "batch pos head_index d_model"],
],
) -> Union[
Float[torch.Tensor, "batch pos d_model"],
Float[torch.Tensor, "batch pos head_index d_model"],
]:
if self.cfg.dtype not in [torch.float32, torch.float64]:
x = x.to(torch.float32)
x = x - x.mean(axis=-1, keepdim=True) # [batch, pos, length]
scale: Union[
Float[torch.Tensor, "batch pos 1"],
Float[torch.Tensor, "batch pos head_index 1"],
] = self.hook_scale((x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt())
return self.hook_normalized(x / scale).to(self.cfg.dtype)
class LayerNorm(nn.Module):
def __init__(
self, cfg: Union[Dict, HookedTransformerConfig], length: Optional[int] = None
):
"""
LayerNorm with optional length parameter
length (Optional[int]): If the dimension of the LayerNorm. If not provided, assumed to be d_model
"""
super().__init__()
if isinstance(cfg, Dict):
cfg = HookedTransformerConfig.from_dict(cfg)
self.cfg = cfg
self.eps = self.cfg.eps
if length is None:
self.length = self.cfg.d_model
else:
self.length = length
self.w = nn.Parameter(torch.ones(self.length, dtype=cfg.dtype))
self.b = nn.Parameter(torch.zeros(self.length, dtype=cfg.dtype))
# Adds a hook point for the normalisation scale factor
self.hook_scale = HookPoint() # [batch, pos, 1]
# Hook_normalized is on the LN output
self.hook_normalized = HookPoint() # [batch, pos, length]
def forward(
self,
x: Union[
Float[torch.Tensor, "batch pos d_model"],
Float[torch.Tensor, "batch pos head_index d_model"],
],
) -> Union[
Float[torch.Tensor, "batch pos d_model"],
Float[torch.Tensor, "batch pos head_index d_model"],
]:
if self.cfg.dtype not in [torch.float32, torch.float64]:
x = x.to(torch.float32)
x = x - x.mean(axis=-1, keepdim=True) # [batch, pos, length]
scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale(
(x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()
)
x = x / scale # [batch, pos, length]
return self.hook_normalized(x * self.w + self.b).to(self.cfg.dtype)
class RMSNormPre(nn.Module):
def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
"""RMSNormPre - LayerNormPre without the centering and bias (RMS = Root Mean Square)"""
super().__init__()
if isinstance(cfg, Dict):
cfg = HookedTransformerConfig.from_dict(cfg)
self.cfg = cfg
self.eps = self.cfg.eps
# Adds a hook point for the normalisation scale factor
self.hook_scale = HookPoint() # [batch, pos]
self.hook_normalized = HookPoint() # [batch, pos, length]
def forward(
self, x: Float[torch.Tensor, "batch pos length"]
) -> Float[torch.Tensor, "batch pos length"]:
if self.cfg.dtype not in [torch.float32, torch.float64]:
x = x.to(torch.float32)
scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale(
(x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()
)
return self.hook_normalized(x / scale).to(
self.cfg.dtype
) # [batch, pos, length]
class RMSNorm(nn.Module):
def __init__(
self, cfg: Union[Dict, HookedTransformerConfig], length: Optional[int] = None
):
"""
RMSNorm - LayerNorm without the centering and bias (RMS = Root Mean Square)
length (Optional[int]): If the dimension of the RMSNorm. If not provided, assumed to be d_model
"""
super().__init__()
if isinstance(cfg, Dict):
cfg = HookedTransformerConfig.from_dict(cfg)
self.cfg = cfg
self.eps = self.cfg.eps
if length is None:
self.length = self.cfg.d_model
else:
self.length = length
self.w = nn.Parameter(torch.ones(self.length, dtype=cfg.dtype))
# Adds a hook point for the normalisation scale factor
self.hook_scale = HookPoint() # [batch, pos, 1]
self.hook_normalized = HookPoint() # [batch, pos, length]
def forward(
self, x: Float[torch.Tensor, "batch pos length"]
) -> Float[torch.Tensor, "batch pos length"]:
if self.cfg.dtype not in [torch.float32, torch.float64]:
x = x.to(torch.float32)
scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale(
(x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()
)
x = self.hook_normalized(x / scale).to(self.cfg.dtype) # [batch, pos, length]
return x * self.w
# Attention
class Attention(nn.Module):
def __init__(
self,
cfg: Union[Dict, HookedTransformerConfig],
attn_type: str = "global",
layer_id: Optional[int] = None,
):
"""Attention Block - params have shape [head_index, d_model, d_head] (or [head_index, d_head, d_model] for W_O) and multiply on the right. attn_scores refers to query key dot product immediately before attention softmax
Convention: All attention pattern-style matrices have shape [batch, head_index, query_pos, key_pos]
Args:
cfg (Union[Dict, HookedTransformerConfig]): Config
attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global".
layer_id (int, optional): The index of the current layer. Used by the Mistal models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None.
"""
super().__init__()
if isinstance(cfg, Dict):
cfg = HookedTransformerConfig.from_dict(cfg)
self.cfg = cfg
self.W_Q = nn.Parameter(
torch.empty(
self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype
)
)
self.W_K = nn.Parameter(
torch.empty(
self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype
)
)
self.W_V = nn.Parameter(
torch.empty(
self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head, dtype=cfg.dtype
)
)
self.W_O = nn.Parameter(
torch.empty(
self.cfg.n_heads, self.cfg.d_head, self.cfg.d_model, dtype=cfg.dtype
)
)
self.b_Q = nn.Parameter(
torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype)
)
self.b_K = nn.Parameter(
torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype)
)
self.b_V = nn.Parameter(
torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=cfg.dtype)
)
self.b_O = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=cfg.dtype))
self.attn_type = attn_type
# Create a max_ctx x max_ctx mask, with True iff that query position
# can attend to that key position (query is first axis, key is second axis)
causal_mask = torch.tril(torch.ones((self.cfg.n_ctx, self.cfg.n_ctx)).bool())
if self.attn_type == "global":
# For global attention, this is a lower triangular matrix - key <= query
self.register_buffer("mask", causal_mask)
elif self.attn_type == "local":
# For local, this is banded, query - window_size < key <= query
assert isinstance(self.cfg.window_size, int)
self.register_buffer(
"mask", torch.triu(causal_mask, 1 - self.cfg.window_size)
)
else:
raise ValueError(f"Invalid attention type: {self.attn_type}")
self.register_buffer("IGNORE", torch.tensor(-torch.inf))
self.layer_id = layer_id
# attn_scale is a constant that we divide the attention scores by pre-softmax. I'm not entirely sure why it matters, but it's probably a mix of softmax not being scale invariant and numerical stability?
if self.cfg.use_attn_scale:
self.attn_scale = np.sqrt(self.cfg.d_head)
else:
self.attn_scale = 1.0
if self.cfg.scale_attn_by_inverse_layer_idx:
self.attn_scale *= self.layer_id + 1
self.hook_k = HookPoint() # [batch, pos, head_index, d_head]
self.hook_q = HookPoint() # [batch, pos, head_index, d_head]
self.hook_v = HookPoint() # [batch, pos, head_index, d_head]
self.hook_z = HookPoint() # [batch, pos, head_index, d_head]
self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos]
self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos]
self.hook_result = HookPoint() # [batch, pos, head_index, d_model]
# See HookedTransformerConfig for more details.
if self.cfg.positional_embedding_type == "shortformer":
# This tracks the input to the keys and queries, which is resid_pre + pos_embeds
self.hook_attn_input = HookPoint() # [batch, pos, d_model]
elif self.cfg.positional_embedding_type == "rotary":
# Applies a rotation to each two-element chunk of keys and queries pre dot producting to bake in relative position. See HookedTransformerConfig for details
self.hook_rot_k = HookPoint()
self.hook_rot_q = HookPoint()
sin, cos = self.calculate_sin_cos_rotary(
self.cfg.rotary_dim, self.cfg.n_ctx, dtype=self.cfg.dtype
)
self.register_buffer("rotary_sin", sin)
self.register_buffer("rotary_cos", cos)
elif self.cfg.positional_embedding_type == "alibi":
# ALiBi bias wil be constructed on the first forward pass.
# Note: While computationally efficient, initializing an bias with max n_ctx (16, 1024, 1024) of float32 will occupy ~256MiB of contiguous GPU memory, which may not be optimal for memory usage.
self.alibi = None
@property
def OV(self) -> FactoredMatrix:
"""
OV-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity between the value vector and the output of the layer, the output is purely determined by the matrix W_OV = W_V @ W_O, and not W_V or W_O individually. (Mathematically, for a single head, output == pattern @ residual @ W_V @ W_O, see the glossary for more)
Done in the order W_V, W_O because the paper uses left-multiplying weight matrices, and TransformerLens uses right-multiplying, sorry!
Returns a FactoredMatrix, with left matrix W_V [head_index, d_model, d_head] and right matrix W_O [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model]. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the OV circuit of a head k, attn.OV[k] works.
"""
return FactoredMatrix(self.W_V, self.W_O)
@property
def QK(self) -> FactoredMatrix:
"""
QK-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity in the key-query dot product, the output is purely determined by the matrix W_QK = W_Q.T @ W_K, and not W_Q or W_K individually. (Mathematically, for a single head, pattern = destination_residual.T @ W_Q.T @ W_K @ source-residual, see the glossary for more).
Done in the order Q on the left, K on the right, because the pattern has dimensions [destination_pos, source_pos]
Returns a FactoredMatrix, with left matrix W_Q [head_index, d_model, d_head] and right matrix W_K.T [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model] matrix. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the QK circuit of a head k, attn.QK[k] works.
"""
W_K_transpose = einops.rearrange(
self.W_K, "head_index d_model d_head -> head_index d_head d_model"
)
return FactoredMatrix(self.W_Q, W_K_transpose)
def forward(
self,
query_input: Union[
Float[torch.Tensor, "batch pos d_model"],
Float[torch.Tensor, "batch pos head_index d_model"],
],
key_input: Union[
Float[torch.Tensor, "batch pos d_model"],
Float[torch.Tensor, "batch pos head_index d_model"],
],
value_input: Union[
Float[torch.Tensor, "batch pos d_model"],
Float[torch.Tensor, "batch pos head_index d_model"],
],
past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None,
additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 pos"]] = None,
attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
) -> Float[torch.Tensor, "batch pos d_model"]:
"""
shortformer_pos_embed is only used if self.cfg.positional_embedding_type == "shortformer", else defaults to None and is irrelevant. See HookedTransformerConfig for more details
past_kv_cache_entry is an optional entry of past keys and values for this layer, only relevant if generating text. Defaults to None
additive_attention_mask is an optional mask to add to the attention weights. Defaults to None.
attention_mask is the attention mask for padded tokens. Defaults to None.
"""
if self.cfg.use_split_qkv_input or self.cfg.use_attn_in:
qkv_einops_string = "batch pos head_index d_model"
else:
qkv_einops_string = "batch pos d_model"
q = self.hook_q(
einsum(
f"{qkv_einops_string}, head_index d_model d_head \
-> batch pos head_index d_head",
query_input,
self.W_Q,
)
+ self.b_Q
) # [batch, pos, head_index, d_head]
k = self.hook_k(
einsum(
f"{qkv_einops_string}, head_index d_model d_head \
-> batch pos head_index d_head",
key_input,
self.W_K,
)
+ self.b_K
) # [batch, pos, head_index, d_head]
v = self.hook_v(
einsum(
f"{qkv_einops_string}, head_index d_model d_head \
-> batch pos head_index d_head",
value_input,
self.W_V,
)
+ self.b_V
) # [batch, pos, head_index, d_head]
if past_kv_cache_entry is not None:
# Appends the new keys and values to the cached values, and automatically updates the cache
kv_cache_pos_offset = past_kv_cache_entry.past_keys.size(1)
k, v = past_kv_cache_entry.append(k, v)
else:
# Not using a cache
kv_cache_pos_offset = 0
if self.cfg.positional_embedding_type == "rotary":
q = self.hook_rot_q(
self.apply_rotary(q, kv_cache_pos_offset, attention_mask)
)
k = self.hook_rot_k(
self.apply_rotary(k, 0, attention_mask)
) # keys are cached so no offset
if self.cfg.dtype not in [torch.float32, torch.float64]:
# If using 16 bits, increase the precision to avoid numerical instabilities
q = q.to(torch.float32)
k = k.to(torch.float32)
attn_scores = (
einsum(
"batch query_pos head_index d_head, \
batch key_pos head_index d_head \
-> batch head_index query_pos key_pos",
q,
k,
)
/ self.attn_scale
) # [batch, head_index, query_pos, key_pos]
if self.cfg.positional_embedding_type == "alibi":
query_ctx = attn_scores.size(-2)
# The key context length is the number of positions in the past - this includes all positions in the cache
key_ctx = attn_scores.size(-1)
# only recompute when necessary to increase efficiency.
if self.alibi is None or key_ctx > self.alibi.size(-1):
self.alibi = Attention.create_alibi_bias(
self.cfg.n_heads, key_ctx, self.cfg.device
)
attn_scores += self.alibi[
:, :query_ctx, :key_ctx
] # [batch, head_index, query_pos, key_pos]
if self.cfg.attention_dir == "causal":
# If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask.
attn_scores = self.apply_causal_mask(
attn_scores, kv_cache_pos_offset, attention_mask
) # [batch, head_index, query_pos, key_pos]
if additive_attention_mask is not None:
attn_scores += additive_attention_mask
attn_scores = self.hook_attn_scores(attn_scores)
pattern = F.softmax(attn_scores, dim=-1)
pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern)
pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos]
pattern = pattern.to(self.cfg.dtype)
z = self.hook_z(
einsum(
"batch key_pos head_index d_head, \
batch head_index query_pos key_pos -> \
batch query_pos head_index d_head",
v,
pattern,
)
) # [batch, pos, head_index, d_head]
if not self.cfg.use_attn_result:
out = (
(
einsum(
"batch pos head_index d_head, \
head_index d_head d_model -> \
batch pos d_model",
z,
self.W_O,
)
)
+ self.b_O
) # [batch, pos, d_model]
else:
# Explicitly calculate the attention result so it can be accessed by a hook
# This is off by default because it can easily eat through your GPU memory.
result = self.hook_result(
einsum(
"batch pos head_index d_head, \
head_index d_head d_model -> \
batch pos head_index d_model",
z,
self.W_O,
)
) # [batch, pos, head_index, d_model]
out = (
einops.reduce(
result, "batch position index model->batch position model", "sum"
)
+ self.b_O
) # [batch, pos, d_model]
return out
def apply_causal_mask(
self,
attn_scores: Float[
torch.Tensor, "batch head_index pos pos_plus_past_kv_pos_offset"
],
past_kv_pos_offset: int = 0,
attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
):
# The query context length is the number of positions we take queries from - if not using a past_kv_cache this is just the context length (for the current prompt), but if we're caching it can be different.
query_ctx_length = attn_scores.size(-2)
# The key context length is the number of positions in the past - this includes all positions in the cache
# If not caching, query_ctx_length == key_ctx_length
key_ctx_length = attn_scores.size(-1)
assert (
query_ctx_length + past_kv_pos_offset == key_ctx_length
), f"query_ctx_length {query_ctx_length} + past_kv_pos_offset {past_kv_pos_offset} != key_ctx_length {key_ctx_length} - you likely have a bug."
# Index back to front to ensure local attention works
final_mask = self.mask[
None, None, -query_ctx_length:, -key_ctx_length:
] # [1, 1, pos, pos]
if attention_mask is not None:
# Apply a causal mask to the attention scores considering the padding
einsum_str = "batch head pos offset_pos, batch offset_pos -> batch head pos offset_pos"
final_mask = einops.einsum(final_mask, attention_mask, einsum_str).bool()
return torch.where(final_mask, attn_scores, self.IGNORE)
def calculate_sin_cos_rotary(
self,
rotary_dim: int,
n_ctx: int,
base: int = 10000,
dtype: torch.dtype = torch.float32,
) -> Tuple[
Float[torch.Tensor, "n_ctx rotary_dim"], Float[torch.Tensor, "n_ctx rotary_dim"]
]:
"""
Calculate the sine and cosine waves to use in a rotary embedding. See https://blog.eleuther.ai/rotary-embeddings/ for details
Note: For some inexplicable reason, in GPT-J each ADJACENT pair of elements in k and q are rotated, in GPT-NeoX the pair of elements at k and k+n//2 are rotated (ie folding the full length in half, and then looking at pairs accordingly). I have absolutely no clue why, it should be completely equivalent.
To resolve this, I've coded it to default to the GPT-J mode, but to explicitly check whether it's GPT-NeoX and then do the GPT-NeoX thing if it is.
"""
high_precision = torch.float32 if dtype != torch.float64 else torch.float64
pos = torch.arange(n_ctx, dtype=high_precision)
dim = torch.arange(rotary_dim // 2, dtype=high_precision)
# A set of frequencies evenly spaced in log space
freq = base ** (dim / (rotary_dim / 2))
if self.cfg.original_architecture in ["GPTNeoXForCausalLM", "LlamaForCausalLM"]:
freq = einops.repeat(freq, "d -> (2 d)")
else:
freq = einops.repeat(freq, "d -> (d 2)")
# Create a n_ctx x rotary_dim tensor, where each column is an arithmetic sequence of angles in that frequency
angles = pos[:, None] / freq[None, :]
return torch.sin(angles).to(dtype), torch.cos(angles).to(dtype)
def rotate_every_two(
self, x: Float[torch.Tensor, "... rotary_dim"]
) -> Float[torch.Tensor, "... rotary_dim"]:
"""
Rotary helper function, splits x into blocks of size 2 along the final axis and maps [x0, x1] to [-x1, x0]
The final axis of x must have even length.
GPT-NeoX and GPT-J do rotary subtly differently, see calculate_sin_cos_rotary for details.
"""
rot_x = x.clone()
if self.cfg.original_architecture in ["GPTNeoXForCausalLM", "LlamaForCausalLM"]:
n = x.size(-1) // 2
rot_x[..., :n] = -x[..., n:]
rot_x[..., n:] = x[..., :n]
else:
rot_x[..., ::2] = -x[..., 1::2]
rot_x[..., 1::2] = x[..., ::2]
return rot_x
def apply_rotary(
self,
x: Float[torch.Tensor, "batch pos head_index d_head"],
past_kv_pos_offset=0,
attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
) -> Float[torch.Tensor, "batch pos head_index d_head"]:
# Only apply rotary to first rotary_dim dimensions (eg, if rotary_dim=64 and d_head=256, only apply to first 1/4 of dimensions)
x_pos = x.size(1)
x_rot = x[..., : self.cfg.rotary_dim]
x_pass = x[..., self.cfg.rotary_dim :]
x_flip = self.rotate_every_two(x_rot)
if attention_mask is None:
rotary_cos = self.rotary_cos[
None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, :
]
rotary_sin = self.rotary_sin[
None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, :
]
x_rotated = x_rot * rotary_cos + x_flip * rotary_sin
else:
offset_position_ids = get_offset_position_ids(
past_kv_pos_offset, attention_mask
)
mask_rotary_cos = self.rotary_cos[offset_position_ids, None, :]
mask_rotary_sin = self.rotary_sin[offset_position_ids, None, :]
x_rotated = x_rot * mask_rotary_cos + x_flip * mask_rotary_sin
return torch.cat([x_rotated, x_pass], dim=-1)
@staticmethod
def create_alibi_slope(
n_ctx: int, device: torch.device = None
) -> Float[torch.Tensor, "query key"]:
"""Create an ALiBi Slope Matrix.
Create the slope matrix used in ALiBi, before it is multiplied by the head-specific scalar.
See :meth:`create_alibi_bias` for the full ALiBi bias calculation.
Examples:
>>> Attention.create_alibi_slope(3)
tensor([[ 0., 0., 0.],
[-1., 0., 0.],
[-2., -1., 0.]])
>>> Attention.create_alibi_slope(4)
tensor([[ 0., 0., 0., 0.],
[-1., 0., 0., 0.],
[-2., -1., 0., 0.],
[-3., -2., -1., 0.]])
Args:
n_ctx: The maximum number of tokens in a prompt.
Returns:
A tensor of shape (n_ctx, n_ctx), where the upper triangle is zero and the lower
triangle is decreasing by a constant slope of 1 (towards the bottom left corner).
"""
# set rows as [[0,1,2...]]
rows = torch.arange(n_ctx, device=device).unsqueeze(0)
# Set cols as [[0],[1],[2]...]
cols = torch.arange(n_ctx, device=device).unsqueeze(1)
# Use broadcasting to create the desired lower triangular part of the matrix
slope_matrix = rows - cols
# Use the clamp method to set all positive values (upper right triangle) to
return slope_matrix.clamp(max=0).to(torch.float32)
@staticmethod
def create_alibi_multipliers(
n_heads: int, device: torch.device = None
) -> Float[torch.Tensor, "head_idx"]:
"""Create the ALiBi Scalar Multipliers for each Head.
For n heads, the set of multipliers (m) is the geometric sequence that starts at 2^(-8/n), and
uses that same value as its ratio. For example, with 8 heads the values would be [1/(2^1),
1/(2^2), ... , 1/(2^8)]. With 16 heads the values would be [1/(2^0.5), 1/(2^1), ... , 1/(2^8)].
See :meth:`create_alibi_bias` for the full ALiBi bias calculation.
Examples:
>>> Attention.create_alibi_multipliers(8)
tensor([0.5000, 0.2500, 0.1250, 0.0625, 0.0312, 0.0156, 0.0078, 0.0039])
>>> Attention.create_alibi_multipliers(16)
tensor([0.7071, 0.5000, 0.3536, 0.2500, 0.1768, 0.1250, 0.0884, 0.0625, 0.0442, 0.0312,
0.0221, 0.0156, 0.0110, 0.0078, 0.0055, 0.0039])
Args:
n_heads: The number of heads in a layer.
device: The device to create the tensor on.
Returns:
A tensor of shape (n_heads,) containing the scalar multiplier for each head.
"""
# Calculate the starting value
start = 2 ** (-8 / n_heads)
# Generate the indices [0, 1, ..., n_heads-1]
indices = torch.arange(n_heads, device=device)
# Compute the multipliers, with the starting value being the same as the ratio
multipliers = start * (start**indices)
return multipliers
@staticmethod
def create_alibi_bias(
n_heads: int, n_ctx: int, device: torch.device = None
) -> Float[torch.Tensor, "head_idx query key"]:
"""Create the ALiBi Bias for all Heads.
Calculate the ALiBi bias (https://arxiv.org/pdf/2108.12409.pdf) for all heads in a layer.
The broad idea behind ALiBi is to remove the positional encoding from the original transformer
model, and instead apply a bias to each attention score. This bias is proportional to the
distance between the query and key (i.e. it encourage paying less attention to more distant
tokens), and is added to the attention scores before the softmax. It is used in models such as
Bloom.
Examples:
>>> Attention.create_alibi_bias(2, 4, torch.device('cpu'))
tensor([[[ 0.0000, 0.0000, 0.0000, 0.0000],
[-0.0625, 0.0000, 0.0000, 0.0000],
[-0.1250, -0.0625, 0.0000, 0.0000],
[-0.1875, -0.1250, -0.0625, 0.0000]],
[[ 0.0000, 0.0000, 0.0000, 0.0000],
[-0.0039, 0.0000, 0.0000, 0.0000],
[-0.0078, -0.0039, 0.0000, 0.0000],
[-0.0117, -0.0078, -0.0039, 0.0000]]])
Args:
n_heads: The number of heads in a layer.
n_ctx: The maximum number of tokens in a prompt.
device: The device to create the tensor on.
Returns:
The ALiBi bias that should be added to the attention scores before the softmax.
"""
# Create the slope matrix
slope: Float[torch.Tensor, "query key"] = Attention.create_alibi_slope(
n_ctx, device
)
# Create the scalar multiplier for each head.
multipliers: Float[
torch.Tensor, "head_idx"
] = Attention.create_alibi_multipliers(n_heads, device)
# The ALiBi bias is then m * slope_matrix
alibi_bias = torch.einsum("ij,k->kij", slope, multipliers)
return alibi_bias
# MLP Layers
class MLP(nn.Module):
def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
super().__init__()
if isinstance(cfg, Dict):
cfg = HookedTransformerConfig.from_dict(cfg)
self.cfg = cfg
self.W_in = nn.Parameter(
torch.empty(self.cfg.d_model, self.cfg.d_mlp, dtype=cfg.dtype)
)
self.b_in = nn.Parameter(torch.zeros(self.cfg.d_mlp, dtype=cfg.dtype))
self.W_out = nn.Parameter(
torch.empty(self.cfg.d_mlp, self.cfg.d_model, dtype=cfg.dtype)
)
self.b_out = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=cfg.dtype))
self.hook_pre = HookPoint() # [batch, pos, d_mlp]
self.hook_post = HookPoint() # [batch, pos, d_mlp]
if self.cfg.act_fn == "relu":
self.act_fn = F.relu
elif self.cfg.act_fn == "gelu":
self.act_fn = F.gelu
elif self.cfg.act_fn == "silu":
self.act_fn = F.silu
elif self.cfg.act_fn == "gelu_new":
self.act_fn = gelu_new
elif self.cfg.act_fn == "gelu_fast":
self.act_fn = gelu_fast
elif self.cfg.act_fn == "solu_ln":
self.act_fn = solu
# Hook taken between activation and layer norm
self.hook_mid = HookPoint() # [batch, pos, d_mlp]
if self.cfg.normalization_type == "LN":
self.ln = LayerNorm(self.cfg, self.cfg.d_mlp)
else:
self.ln = LayerNormPre(self.cfg)
else:
raise ValueError(f"Invalid activation function name: {self.cfg.act_fn}")
def forward(
self, x: Float[torch.Tensor, "batch pos d_model"]
) -> Float[torch.Tensor, "batch pos d_model"]:
# Technically, all these einsums could be done with a single matmul, but this is more readable.
pre_act = self.hook_pre(
einsum("batch pos d_model, d_model d_mlp -> batch pos d_mlp", x, self.W_in)
+ self.b_in
) # [batch, pos, d_mlp]
if not self.cfg.act_fn.endswith("_ln"):
post_act = self.hook_post(self.act_fn(pre_act)) # [batch, pos, d_mlp]
else:
mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp]
post_act = self.hook_post(self.ln(mid_act))
return (
einsum(
"batch pos d_mlp, d_mlp d_model -> batch pos d_model",
post_act,
self.W_out,
)
+ self.b_out
)
# TODO
# not sure whether to fold this into MLP or not
class GatedMLP(nn.Module):
"""
The equation of a gated MLP:
pre = x @ W_gate
pre_linear = x @ W_in
post = Gelu(pre) * (pre_linear) + b_in
mlp_out = post @ W_out + b_out
In one equation, mlp_out = (Gelu(x @ W_gate) * (x @ W_in) + b_in) @ W_out + b_out
"""
def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
super().__init__()
if isinstance(cfg, Dict):
cfg = HookedTransformerConfig.from_dict(cfg)
self.cfg = cfg
self.W_in = nn.Parameter(
torch.empty(self.cfg.d_model, self.cfg.d_mlp, dtype=cfg.dtype)
)
self.W_gate = nn.Parameter(