In [1]:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2Tokenizer
from fancy_einsum import einsum

In [7]:
model_name = "meta-llama/Llama-3.1-8B"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [4]:
print(model)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
    (rotary_

In [6]:
print(model)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 1024)
    (wpe): Embedding(1024, 1024)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-23): 24 x GPT2Block(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=3072, nx=1024)
          (c_proj): Conv1D(nf=1024, nx=1024)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=4096, nx=1024)
          (c_proj): Conv1D(nf=1024, nx=4096)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1024, out_features=50257, bias=False)
)


In [11]:
token_embeds = model.model.embed_tokens.weight
value_vectors = torch.cat(
    [
        model.model.h[layer_idx].mlp.down_proj.weight
        for layer_idx in range(model.config.num_hidden_layers)
    ],
    dim=0,
)

AttributeError: 'LlamaForCausalLM' object has no attribute 'transformer'

In [10]:

seed_token_pos = ["happy", "joy", " happy", " joy", " smile"]
seed_token_neg = [" sad", " angry", " disgust"]

pos_token_id = [tokenizer.encode(tok)[0] for tok in seed_token_pos]
neg_token_id = [tokenizer.encode(tok)[0] for tok in seed_token_neg]

print(pos_token_id)
print(neg_token_id)

pos_embed = token_embeds[pos_token_id].mean(dim=0)
neg_embed = token_embeds[neg_token_id].mean(dim=0)

[34191, 2633, 3772, 8716, 8212]
[6507, 7954, 16234]


In [12]:


def unembed_to_text(vector, model, tokenizer, k=10):
    norm = model.transformer.ln_f
    lm_head = model.lm_head.weight
    dots = einsum("vocab d_model, d_model -> vocab", lm_head, norm(vector))
    top_k = dots.topk(k).indices
    return tokenizer.batch_decode(top_k, skip_special_tokens=True)

In [21]:

k = 20
norm = model.transformer.ln_f

target_vec = pos_embed - neg_embed
dot_prods = einsum("value_vecs d_model, d_model -> value_vecs", norm(value_vectors), target_vec)
top_value_vecs = dot_prods.topk(k=num_tp_vecs).indices
for vec_idx in top_value_vecs:
    print(f"Value vec: Layer {vec_idx // 4096}, index {vec_idx % 4096}")
    print(unembed_to_text(value_vectors[vec_idx], model, tokenizer))

Value vec: Layer 20, index 988
[' secure', ' successful', ' successfully', ' satisfactory', ' optimal', ' improved', ' optim', ' perfected', ' excellent', ' efficient']
Value vec: Layer 16, index 3208
[' peaceful', ' stable', ' satisfactory', ' good', ' trustworthy', ' safe', ' reassured', ' credibility', 'Safe', ' impartial']
Value vec: Layer 13, index 2434
[' positives', ' advant', ' blessed', ' mirac', ' upl', ' pristine', ' bright', ' smiles', ' buoy', ' boon']
Value vec: Layer 17, index 3439
[' Congratulations', 'osponsors', 'Congratulations', ' Honor', 'aug', ' enriched', ' Excellence', ' rewarded', ' Blossom', ' Celebr']
Value vec: Layer 15, index 301
[' collaborations', ' achievements', ' excellence', ' breakthrough', ' Inspired', ' Excellence', ' inspired', ' Aub', ' amazing', ' Citation']
Value vec: Layer 23, index 1438
['hari', ' externalToEVAOnly', 'perm', 'ifty', 'kin', ' Atomic', 'ourced', 'pac', '=-=-=-=-=-=-=-=-', 'allows']
Value vec: Layer 13, index 2488
[' unaffected'

[' Deluxe', ' Sapp', ' Hendricks', ' Vlad', ' Particip', 'cko', 'Particip', ' Rus', ' Haw', 'Advanced']
Value vec: Layer 20, index 1211
['Advertisement', 'advertising', 'Spoiler', 'favorite', 'Advertisements', 'Premium', 'Contribut', 'Cert', 'Offline', 'Lind']
Value vec: Layer 20, index 3784
['SIGN', 'ーティ', '________________________________________________________________', 'SPONSORED', '________________________________', 'RPG', 'Facebook', 'Maps', 'CBC', '¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯¯']
Value vec: Layer 14, index 3573
['kins', ' innocent', ' sanity', 'forth', 'change', ' Turtles', 'Save', 'OSP', 'mb', ' Nile']
Value vec: Layer 23, index 1269
['release', 'pass', 'state', 'tap', 'slot', 'cur', 'travel', 'share', 'battle', 'strength']
Value vec: Layer 20, index 654
['matic', 'bats', 'lins', 'ead', 'hop', ' Cind', 'th', 'question', 'bilt', 'mented']
Value vec: Layer 17, index 2886
[' Sacrament', ' sacrament', ' Christ', ' Holy', ' Incarn', ' Sacred', ' sac', ' baptism', 'Holy', ' Blessed']
Value vec: 

In [22]:


target_vec = neg_embed - pos_embed

dot_prods = einsum("value_vecs d_model, d_model -> value_vecs", norm(value_vectors), target_vec)
top_value_vecs = dot_prods.topk(k=num_tp_vecs).indices

for vec_idx in top_value_vecs:
    print(f"Value vec: Layer {vec_idx // 4096}, index {vec_idx % 4096}")
    print(unembed_to_text(value_vectors[vec_idx], model, tokenizer))

Value vec: Layer 17, index 2953
[' hate', ' hated', ' negativity', ' bad', ' dreaded', ' harmful', ' adversaries', ' enemies', ' harsh', 'enemy']
Value vec: Layer 13, index 3214
[' burdens', ' troubled', ' risks', ' misfortune', ' headache', ' trouble', ' risk', ' nightmare', ' adverse', ' toxicity']
Value vec: Layer 17, index 1591
[' inability', ' failed', ' unable', ' inadequate', ' lack', ' failing', ' lacking', ' failure', ' insufficient', ' fail']
Value vec: Layer 17, index 443
[' burdens', ' worst', ' worse', ' toxic', ' humiliating', ' waste', ' nightmare', ' pests', ' wasting', ' protracted']
Value vec: Layer 16, index 974
[' inappropriately', ' prejud', ' unnecessarily', ' improperly', ' unchecked', ' incorrectly', ' inefficient', ' miscon', ' arrogance', ' excessively']
Value vec: Layer 20, index 1786
[' problems', ' malfunction', ' failure', ' failures', ' damage', ' woes', ' dysfunction', ' trouble', ' injuries', ' damaged']
Value vec: Layer 16, index 1643
[' inability', ' 

[' Shape', ' Gors', ' Beir', ' mater', ' Pigs', ' pree', ' Grande', ' questionnaire', ' Triangle', ' subclass']
Value vec: Layer 22, index 3227
[' march', ' Citizens', ' Greenpeace', ' groups', ' peaceful', ' boycot', ' outraged', ' pro', ' activist', ' protester']
Value vec: Layer 22, index 586
[' P', ' Acc', ' Port', ' B', ' Q', ' Hub', ' Sm', ' L', ' Pic', ' PS']
Value vec: Layer 14, index 1671
['itia', 'iann', ' ethics', ' Carbuncle', 'icative', ' Chao', ' ├', ' Leilan', ' Recomm', 'ettes']
Value vec: Layer 15, index 3244
[' Peb', ' Buildings', ' Implement', ' Presidential', 'OWS', 'アル', ' FW', '三', 'ARB', '�']
Value vec: Layer 16, index 2908
[' limitation', ' barrier', ' prohib', ' limit', ' restrictions', ' limiting', ' imped', ' restrictive', ' restriction', ' barriers']
Value vec: Layer 17, index 287
['avier', 'kHz', ' kHz', '00', ' Opp', ' paced', '999', '�', '4000', ' opposite']
Value vec: Layer 22, index 3047
['itone', 'itar', 'itu', '��', 'itary', 'itarian', 'igun', 'ibi', 