In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2Tokenizer
from fancy_einsum import einsum

#### GPT2-medium

In [3]:
model_name = "gpt2-medium"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [4]:
token_embeds = model.transformer.wte.weight
value_vectors = torch.cat(
    [
        model.transformer.h[layer_idx].mlp.c_proj.weight
        for layer_idx in range(model.config.num_hidden_layers)
    ],
    dim=0,
)
print(value_vectors.shape)

torch.Size([98304, 1024])


In [5]:
seed_token_toxic = ["fuck", "shit", "crap"]
seed_token_non_toxic = ["hello", "thanks", "friend", "peace", "welcome"]

toxic_token_id = [
    tokenizer(tok, add_special_tokens=False)["input_ids"][0]
    for tok in seed_token_toxic
    if len(tokenizer(tok, add_special_tokens=False)["input_ids"]) == 1
]

non_toxic_token_id = [
    tokenizer(tok, add_special_tokens=False)["input_ids"][0]
    for tok in seed_token_non_toxic
    if len(tokenizer(tok, add_special_tokens=False)["input_ids"]) == 1
]

print("Toxic token IDs:", toxic_token_id)
print("Non-toxic token IDs:", non_toxic_token_id)

toxic_embed = token_embeds[toxic_token_id].mean(dim=0)
non_toxic_embed = token_embeds[non_toxic_token_id].mean(dim=0)

Toxic token IDs: [31699, 16211]
Non-toxic token IDs: [31373, 27547, 6726, 22988]


In [6]:

def unembed_to_text(vector, model, tokenizer, k=10):
    norm = model.transformer.ln_f
    lm_head = model.lm_head.weight
    dots = einsum("vocab d_model, d_model -> vocab", lm_head, norm(vector))
    top_k = dots.topk(k).indices
    return tokenizer.batch_decode(top_k, skip_special_tokens=True)

In [7]:

k = 20
norm = model.transformer.ln_f

target_vec = toxic_embed - non_toxic_embed
dot_prods = einsum("value_vecs d_model, d_model -> value_vecs", norm(value_vectors), target_vec)
top_value_vecs = dot_prods.topk(k).indices
for vec_idx in top_value_vecs:
    print(f"Value vec: Layer {vec_idx // 4096}, index {vec_idx % 4096}")
    print(unembed_to_text(value_vectors[vec_idx], model, tokenizer))

Value vec: Layer 19, index 770
[' shit', ' ass', ' crap', ' fuck', ' garbage', ' asses', ' cunt', ' trash', ' dick', 'shit']
Value vec: Layer 12, index 882
['fuck', ' shit', ' piss', 'Fuck', ' hilar', 'shit', ' stupidity', ' poop', ' shitty', ' stupid']
Value vec: Layer 15, index 659
[' dudes', ' stuff', ' dude', ' shit', ' kinda', ' fuckin', ' goddamn', ' badass', ' blah', ' pretty']
Value vec: Layer 17, index 2877
[' kinda', ' stuff', ' fuckin', ' guys', ' yeah', ' gonna', ' dudes', ' crap', ' gotta', ' guy']
Value vec: Layer 13, index 4065
[' fuck', ' fucking', ' piss', ' goddamn', ' shit', ' godd', ' damned', ' damn', ' crap', ' shri']
Value vec: Layer 19, index 1767
[' fucking', ' dudes', ' fuckin', ' goddamn', ' shit', ' shitty', ' dude', ' kinda', ' guys', ' gotta']
Value vec: Layer 7, index 3358
[' crap', ' shri', ' shit', ' whine', ' Godd', ' bullshit', ' gigg', ' euphem', ' goddamn', 'uphem']
Value vec: Layer 8, index 1079
[' crap', ' dudes', ' kinda', ' dude', ' crappy', ' g

In [8]:
print(unembed_to_text(target_vec, model, tokenizer))

['shit', 'fuck', 'Fuck', ' fuck', ' shit', ' Fuck', ' Shit', ' fucking', ' FUCK', ' fucked']


In [9]:
torch.save(target_vec, 'gpt2_toxic_embed.pt')

#### Llama3

In [122]:
model_name = "meta-llama/Llama-3.1-8B"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [123]:
token_embeds = model.model.embed_tokens.weight
value_vectors = torch.cat(
    [
        model.model.layers[layer_idx].mlp.down_proj.weight.T
        for layer_idx in range(model.config.num_hidden_layers)
    ],
    dim=0,
)
print(value_vectors.shape)

torch.Size([458752, 4096])


In [132]:
seed_token_toxic = ["fuck", "shit", "crap"]
seed_token_non_toxic = ["hello", "thanks", "friend", "peace", "welcome"]

toxic_token_id = [
    tokenizer(tok, add_special_tokens=False)["input_ids"][0]
    for tok in seed_token_toxic
    if len(tokenizer(tok, add_special_tokens=False)["input_ids"]) == 1
]

non_toxic_token_id = [
    tokenizer(tok, add_special_tokens=False)["input_ids"][0]
    for tok in seed_token_non_toxic
    if len(tokenizer(tok, add_special_tokens=False)["input_ids"]) == 1
]

print("Toxic token IDs:", toxic_token_id)
print("Non-toxic token IDs:", non_toxic_token_id)

toxic_embed = token_embeds[toxic_token_id].mean(dim=0)
non_toxic_embed = token_embeds[non_toxic_token_id].mean(dim=0)


Toxic token IDs: [71574, 41153, 99821]
Non-toxic token IDs: [15339, 46593, 10931, 55225, 35184]


In [133]:
def unembed_to_text(vector, model, tokenizer, k=10):
    norm = model.model.norm  
    lm_head = model.lm_head.weight
    dots = torch.einsum("vd,d->v", lm_head, norm(vector))
    top_k = dots.topk(k).indices
    return tokenizer.batch_decode(top_k, skip_special_tokens=True)


In [134]:
k = 20
norm = model.model.norm  

target_vec = toxic_embed - non_toxic_embed
dot_prods = torch.einsum("nd,d->n", norm(value_vectors), target_vec)
top_value_vecs = dot_prods.topk(k).indices

for vec_idx in top_value_vecs:
    print(f"Value vec: Layer {vec_idx // 4096}, index {vec_idx % 4096}")
    print(unembed_to_text(value_vectors[vec_idx], model, tokenizer))


Value vec: Layer 7, index 1612
['atur', 'odom', 'olang', 'kur', 'ーラ', ' Franken', 'ourmet', 'بو', '-prepend', '-folder']
Value vec: Layer 1, index 3677
['angen', 'rego', 'aji', ' Brun', 'icha', 'ifi', 'andan', '(Gravity', 'strup', '坪']
Value vec: Layer 7, index 1261
['хов', '363', 'entai', '_salt', 'ДА', '場', 'asaki', 'ンガ', 'ledge', 'イク']
Value vec: Layer 74, index 2602
['anj', 'δι', 'aint', 'inflate', 'ource', '�', 'andom', 'VAS', 'ewire', 'bury']
Value vec: Layer 110, index 370
['rego', 'еро', 'cing', 'RATE', 'ames', 'PageRoute', 'riott', 'öt', 'agi', 'AMES']
Value vec: Layer 45, index 604
['ira', 'lesi', 'orse', 'ulg', ' burg', ' current', 'mnt', 'quer', 'igan', ' бух']
Value vec: Layer 34, index 1478
[' tables', 'eah', ' publicity', ' deb', ' Hast', 'ighton', ' liv', ' Min', '幕', ' Governors']
Value vec: Layer 75, index 3246
[' me', ' him', ' them', ' us', 'NTAX', 'them', ' lui', ' Him', ' Dud', ' ihm']
Value vec: Layer 66, index 1808
[' lobby', ' Lobby', ' Rooms', 'hosts', ' hosts

In [135]:
print(unembed_to_text(target_vec, model, tokenizer))

['antal', 'ę', 'irse', 'llll', ' tiener', 'ington', 'obus', ' eldre', ' weg', 'dik']


In [None]:
torch.save(target_vec, '/data/kebl6672/dpo-toxic-general/checkpoints/llama3_toxic_embed.pt')

#### Gemma-2

In [10]:
model_name = "google/gemma-2-2b"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [11]:
token_embeds = model.model.embed_tokens.weight
value_vectors = torch.cat(
    [
        model.model.layers[layer_idx].mlp.down_proj.weight.T
        for layer_idx in range(model.config.num_hidden_layers)
    ],
    dim=0,
)
print(value_vectors.shape)

torch.Size([239616, 2304])


In [12]:
seed_token_toxic = ["fuck", "shit", "crap"]
seed_token_non_toxic = ["hello", "thanks", "friend", "peace", "welcome"]

toxic_token_id = [
    tokenizer(tok, add_special_tokens=False)["input_ids"][0]
    for tok in seed_token_toxic
    if len(tokenizer(tok, add_special_tokens=False)["input_ids"]) == 1
]

non_toxic_token_id = [
    tokenizer(tok, add_special_tokens=False)["input_ids"][0]
    for tok in seed_token_non_toxic
    if len(tokenizer(tok, add_special_tokens=False)["input_ids"]) == 1
]

print("Toxic token IDs:", toxic_token_id)
print("Non-toxic token IDs:", non_toxic_token_id)

toxic_embed = token_embeds[toxic_token_id].mean(dim=0)
non_toxic_embed = token_embeds[non_toxic_token_id].mean(dim=0)


Toxic token IDs: [34024, 31947, 101886]
Non-toxic token IDs: [17534, 12203, 9141, 44209, 28583]


In [13]:
def unembed_to_text(vector, model, tokenizer, k=10):
    norm = model.model.norm  
    lm_head = model.lm_head.weight
    dots = torch.einsum("vd,d->v", lm_head, norm(vector))
    top_k = dots.topk(k).indices
    return tokenizer.batch_decode(top_k, skip_special_tokens=True)


In [14]:
k = 20
norm = model.model.norm  

target_vec = toxic_embed - non_toxic_embed
dot_prods = torch.einsum("nd,d->n", norm(value_vectors), target_vec)
top_value_vecs = dot_prods.topk(k).indices

for vec_idx in top_value_vecs:
    print(f"Value vec: Layer {vec_idx // 2304}, index {vec_idx % 2304}")
    print(unembed_to_text(value_vectors[vec_idx], model, tokenizer))


Value vec: Layer 87, index 1892
['HSSF', 'sptr', ' umge', ' siihen', '例句', ' advoc', 'Computed', ' riten', 'subpackage', 'glieder']
Value vec: Layer 14, index 119
[' shit', ' Shit', 'shit', 'Shit', ' SHIT', ' crap', ' shits', 'Crap', ' shite', ' shitty']
Value vec: Layer 79, index 385
['esModule', 'migrationBuilder', 'celot', ' pinulongan', 'RectangleBorder', 'hoeddwyd', 'oa̍t', 'WireFormatLite', ' fourrure', 'fillType']
Value vec: Layer 79, index 1454
[' dudes', ' dude', ' stuff', ' guys', ' kinda', ' shit', ' guy', ' crap', ' thingy', ' hella']
Value vec: Layer 95, index 195
[' fuck', ' fucks', ' fucking', ' fucked', ' shit', 'fuck', 'Fuck', ' Fucking', 'fucking', 'Fucking']
Value vec: Layer 102, index 143
['convertView', 'NavController', 'ClassNotFound', 'cellulose', ' defaultstate', ' Chuk', ' Vikipedi', 'queryInterface', 'دانشنامهٔ', ' PopupWindow']
Value vec: Layer 76, index 1704
['ValueStyle', 'GenerationType', 'BeginContext', 'InjectAttribute', ' мәкал', 'enumi', 'IntoConstrain

In [15]:
print(unembed_to_text(target_vec, model, tokenizer))

['shit', ' shit', 'fuck', 'Shit', ' SHIT', 'crap', ' Shit', ' fuck', ' crap', ' Fuck']


In [16]:
torch.save(target_vec, 'gemma2_toxic_embed.pt')

#### Mistral

In [17]:
model_name = "mistralai/Mistral-7B-v0.1"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [18]:
token_embeds = model.model.embed_tokens.weight
value_vectors = torch.cat(
    [
        model.model.layers[layer_idx].mlp.down_proj.weight.T
        for layer_idx in range(model.config.num_hidden_layers)
    ],
    dim=0,
)
print(value_vectors.shape)

torch.Size([458752, 4096])


In [19]:
seed_token_toxic = ['shit', ' shit', 'fuck', 'Shit', ' SHIT', 'crap', ' Shit', ' fuck', ' crap', ' Fuck']
seed_token_non_toxic = [
    "hello", "thanks", "please", "friend", "peace", "smile", 
    "happy", "welcome", "kind", "goodbye"
]

toxic_token_id = [
    tokenizer(tok, add_special_tokens=False)["input_ids"][0]
    for tok in seed_token_toxic
    if len(tokenizer(tok, add_special_tokens=False)["input_ids"]) == 1
]

non_toxic_token_id = [
    tokenizer(tok, add_special_tokens=False)["input_ids"][0]
    for tok in seed_token_non_toxic
    if len(tokenizer(tok, add_special_tokens=False)["input_ids"]) == 1
]

print("Toxic token IDs:", toxic_token_id)
print("Non-toxic token IDs:", non_toxic_token_id)

toxic_embed = token_embeds[toxic_token_id].mean(dim=0)
non_toxic_embed = token_embeds[non_toxic_token_id].mean(dim=0)


Toxic token IDs: [5492, 5492, 4159, 21849, 4159, 21849, 19449]
Non-toxic token IDs: [8196, 4665, 1832, 6405, 6458, 4610, 10058, 2112]


In [20]:
def unembed_to_text(vector, model, tokenizer, k=10):
    norm = model.model.norm  
    lm_head = model.lm_head.weight
    dots = torch.einsum("vd,d->v", lm_head, norm(vector))
    top_k = dots.topk(k).indices
    return tokenizer.batch_decode(top_k, skip_special_tokens=True)


In [21]:
k = 20
norm = model.model.norm  

target_vec = toxic_embed - non_toxic_embed
dot_prods = torch.einsum("nd,d->n", norm(value_vectors), target_vec)
top_value_vecs = dot_prods.topk(k).indices

for vec_idx in top_value_vecs:
    print(f"Value vec: Layer {vec_idx // 4096}, index {vec_idx % 4096}")
    print(unembed_to_text(value_vectors[vec_idx], model, tokenizer))


Value vec: Layer 0, index 1046
['iska', 'Leon', 'Pir', 'Kre', 'uper', 'TA', 'Aut', 'Cast', 'own', 'im']
Value vec: Layer 2, index 84
['üng', 'Ship', 'ơ', 'ớ', 'redu', 'esch', 'uffer', 'scar', 'zeichnet', 'º']
Value vec: Layer 9, index 3295
['bounds', 'season', 'acc', 'Box', 'Box', 'lengths', 'ala', 'tol', 'esse', 'Morning']
Value vec: Layer 9, index 3846
['dern', 'igin', '#!', 'citiz', 'ulin', 'zung', 'stadt', 'osi', 'urm', 'olu']
Value vec: Layer 0, index 3499
['stick', 'Perry', 'tack', 'Peru', '�', 'iere', 'жно', 'ië', 'icus', 'edia']
Value vec: Layer 7, index 3295
['istan', 'aku', 'signal', 'label', 'meantime', 'cia', 'station', 'Array', '야', 'Ces']
Value vec: Layer 3, index 1257
['iera', 'même', 'mutable', 'iesa', 'ele', 'Tube', 'havet', 'chter', 'Lad', 'nero']
Value vec: Layer 10, index 1287
['iw', 'Tow', 'slide', 'Reserve', 'ಠ', 'reserv', 'ety', 'PL', 'SC', 'Ki']
Value vec: Layer 4, index 1351
['pte', 'pha', 'wig', 'mina', 'phrase', 'isson', 'astr', 'controvers', 'thumb', 'aml']


In [22]:
print(unembed_to_text(target_vec, model, tokenizer))

['ery', 'azar', 'eda', 'tere', 'oren', 'Roose', 'izen', 'ntil', 'ment', 'asso']


In [None]:
torch.save(target_vec, 'mistral_toxic_embed.pt')