In [1]:
import einops
from fancy_einsum import einsum
from transformer_lens import EasyTransformer
from dataclasses import dataclass
from transformer_lens.hook_points import HookPoint, HookedRootModule

In [2]:
reference_gpt2 = EasyTransformer.from_pretrained("gpt2-small", fold_ln=True, center_unembed=False, center_writing_weights=False)

# Check which device the model is loaded onto
print(f"The model is loaded on: {reference_gpt2.cfg.device}")



Loaded pretrained model gpt2-small into HookedTransformer
The model is loaded on: cuda


In [3]:
reference_gpt2

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (h

In [4]:
attributes = dir(reference_gpt2)
print("All attributes:", attributes)

All attributes: ['OV', 'QK', 'T_destination', 'W_E', 'W_E_pos', 'W_K', 'W_O', 'W_Q', 'W_U', 'W_V', 'W_gate', 'W_in', 'W_out', 'W_pos', '__annotations__', '__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_apply', '_backward_hooks', '_backward_pre_hooks', '_buffers', '_call_impl', '_compiled_call_impl', '_enable_hook', '_enable_hook_with_name', '_enable_hooks_for_points', '_forward_hooks', '_forward_hooks_always_called', '_forward_hooks_with_kwargs', '_forward_pre_hooks', '_forward_pre_hooks_with_kwargs', '_get_backward_hooks', '_get_backward_pre_hooks', '_get_name', '_init_weights_gpt2', '_init_weights_kaiming', '_init_weights_muP', '_ini

In [5]:
total_params = sum(p.numel() for p in reference_gpt2.parameters())
print(f"Total number of parameters: {total_params:,}")

Total number of parameters: 163,049,041


In [6]:
print("Model configuration:", reference_gpt2.cfg)
print("Model type:", type(reference_gpt2))

Model configuration: HookedTransformerConfig:
{'act_fn': 'gelu_new',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_scale': 8.0,
 'attn_scores_soft_cap': -1.0,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 3072,
 'd_model': 768,
 'd_vocab': 50257,
 'd_vocab_out': 50257,
 'decoder_start_token_id': None,
 'default_prepend_bos': True,
 'device': device(type='cuda'),
 'dtype': torch.float32,
 'eps': 1e-05,
 'experts_per_token': None,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.02886751345948129,
 'load_in_4bit': False,
 'model_name': 'gpt2',
 'n_ctx': 1024,
 'n_devices': 1,
 'n_heads': 12,
 'n_key_value_heads': None,
 'n_layers': 12,
 'n_params': 84934656,
 'normalization_type': 'LNPre',
 'num_experts': None,
 'original_architecture': 'GPT2LMHeadModel',
 'output_logits_soft_cap': -1.0,
 'parallel_attn_m

### Important Info about how Transformers are trained

Transformers perdict the next token for *each* prefix. Makes training more efficient as for every 100 tokens, we get 100 bits of feedback, not just one.

The 99 previous predictions are not trivial, due to *causal attention*.

Dimensions are things that can vary independently. Each input initially (token) can be thought of independently, we don't bake in any relation.

Integers to vectors -> Create a lookup table. Embedding.

Lookup table <=> Multiply a fixed matrix with a one-hot encoded vector.

# Tokens - Transformer Inputs

Model needs to deal with arbitrary text. Convert to integers *in a bounded range* - needed for calculating probability distribution over all tokens.

Idea: Form vocabulary.

**Idea 1** -> Dictionary.

Problem: Can't cope with arbitary text like URLs, mispellings.

**Idea 2** -> Characters. 256 ASCII characters.

Problem: Some sequences of characters are more meaningful than others. We want to use things efficiently, and want things that appear often together to be a single token, and other things should be separate. 

**Idea 3** -> Byte-Pair encoding. 

Ġ - begins with a space.

Basically first tokenize the text with 256 meaningful ASCII characters, then what pair of tokens occur most frequently together? Merge them.

In [14]:
sorted_vocab = sorted(reference_gpt2.tokenizer.vocab.items(), key=lambda n:n[1])
sorted_vocab

[('!', 0),
 ('"', 1),
 ('#', 2),
 ('$', 3),
 ('%', 4),
 ('&', 5),
 ("'", 6),
 ('(', 7),
 (')', 8),
 ('*', 9),
 ('+', 10),
 (',', 11),
 ('-', 12),
 ('.', 13),
 ('/', 14),
 ('0', 15),
 ('1', 16),
 ('2', 17),
 ('3', 18),
 ('4', 19),
 ('5', 20),
 ('6', 21),
 ('7', 22),
 ('8', 23),
 ('9', 24),
 (':', 25),
 (';', 26),
 ('<', 27),
 ('=', 28),
 ('>', 29),
 ('?', 30),
 ('@', 31),
 ('A', 32),
 ('B', 33),
 ('C', 34),
 ('D', 35),
 ('E', 36),
 ('F', 37),
 ('G', 38),
 ('H', 39),
 ('I', 40),
 ('J', 41),
 ('K', 42),
 ('L', 43),
 ('M', 44),
 ('N', 45),
 ('O', 46),
 ('P', 47),
 ('Q', 48),
 ('R', 49),
 ('S', 50),
 ('T', 51),
 ('U', 52),
 ('V', 53),
 ('W', 54),
 ('X', 55),
 ('Y', 56),
 ('Z', 57),
 ('[', 58),
 ('\\', 59),
 (']', 60),
 ('^', 61),
 ('_', 62),
 ('`', 63),
 ('a', 64),
 ('b', 65),
 ('c', 66),
 ('d', 67),
 ('e', 68),
 ('f', 69),
 ('g', 70),
 ('h', 71),
 ('i', 72),
 ('j', 73),
 ('k', 74),
 ('l', 75),
 ('m', 76),
 ('n', 77),
 ('o', 78),
 ('p', 79),
 ('q', 80),
 ('r', 81),
 ('s', 82),
 ('t', 83),
 

In [16]:
sorted_vocab[-20:]

[('Revolution', 50237),
 ('Ġsnipers', 50238),
 ('Ġreverted', 50239),
 ('Ġconglomerate', 50240),
 ('Terry', 50241),
 ('794', 50242),
 ('Ġharsher', 50243),
 ('Ġdesolate', 50244),
 ('ĠHitman', 50245),
 ('Commission', 50246),
 ('Ġ(/', 50247),
 ('âĢ¦."', 50248),
 ('Compar', 50249),
 ('Ġamplification', 50250),
 ('ominated', 50251),
 ('Ġregress', 50252),
 ('ĠCollider', 50253),
 ('Ġinformants', 50254),
 ('Ġgazed', 50255),
 ('<|endoftext|>', 50256)]

In [17]:
print(reference_gpt2.to_str_tokens("Aritra"))
print(reference_gpt2.to_str_tokens(" Aritra"))
print(reference_gpt2.to_str_tokens(" aritra"))
print(reference_gpt2.to_str_tokens("aritra"))

['<|endoftext|>', 'A', 'rit', 'ra']
['<|endoftext|>', ' A', 'rit', 'ra']
['<|endoftext|>', ' ar', 'it', 'ra']
['<|endoftext|>', 'ar', 'it', 'ra']


Arithmetic is fucked due to tokenization: Length is consistent, common numbers bundled together more.

In [18]:
reference_gpt2.to_str_tokens("78265+894758=18927489-2139298479")

['<|endoftext|>',
 '78',
 '265',
 '+',
 '89',
 '475',
 '8',
 '=',
 '189',
 '27',
 '489',
 '-',
 '2',
 '139',
 '298',
 '479']

Conversion of text to numbers

In [19]:
print(reference_gpt2.to_tokens("Whether a word begings with a capital or space matters!"))
print(reference_gpt2.to_tokens("Whether a word begings with a capital or space matters!", prepend_bos=False))

tensor([[50256, 15354,   257,  1573,  4123,   654,   351,   257,  3139,   393,
          2272,  6067,     0]], device='cuda:0')
tensor([[15354,   257,  1573,  4123,   654,   351,   257,  3139,   393,  2272,
          6067,     0]], device='cuda:0')


## Key Takeaways:

We learn a dictionary vocab of tokens (sub-words).

We (apprx.) losslessly convert language to integers by tokenizing it.

We convert language to integers via a lookup table.

Input to the transformer - sequence of tokens (not vectors).

# Logits - Transformer Outputs