In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
def fix_untrained_tokens(model, eps=1e-16):
    """
    Llama-3 for eg has untrained vectors in the base model.
    These include <|eot_id|>, <|start_header_id|>, <|end_header_id|>
    We reset them to the mean of the rest of the tokens
    """
    embedding_matrix = model.get_input_embeddings().weight.data

    # Get untrained tokens
    indicator_untrained = torch.amax(embedding_matrix, axis=1) <= eps
    where_untrained = torch.where(indicator_untrained)[0]
    n_untrained = where_untrained.shape[0]
    n_trained = embedding_matrix.shape[0] - n_untrained

    # First set untrained to all 0s - sometimes it's not! 1e-23 for bfloat16
    embedding_matrix[where_untrained] = 0

    # Find sum
    sum_embedding = torch.sum(embedding_matrix, dtype=torch.float32, axis=0)

    # Find correct average by dividing by sum of trained tokens
    mean_embedding = (sum_embedding / n_trained).to(embedding_matrix.dtype)
    # Set them to the mean
    embedding_matrix[where_untrained] = mean_embedding

    return mean_embedding

In [3]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

models = {
    "b": "/inspire/hdd/ws-8207e9e2-e733-4eec-a475-cfa1c36480ba/embodied-multimodality/public/zfhe/jiaxing_projects/Language-Model-SAEs/exp/diff_models/Llama-3.1-8B",
    "i": "/inspire/hdd/ws-8207e9e2-e733-4eec-a475-cfa1c36480ba/embodied-multimodality/public/zfhe/jiaxing_projects/Language-Model-SAEs/exp/diff_models/Llama-3.1-8B-Instruct",
    "o": "/inspire/hdd/ws-8207e9e2-e733-4eec-a475-cfa1c36480ba/embodied-multimodality/public/zfhe/jiaxing_projects/Language-Model-SAEs/exp/diff_models/DeepSeek-R1-Distill-Llama-8B",
}
tokenizers = {k: AutoTokenizer.from_pretrained(v) for k, v in models.items()}
models = {k: AutoModelForCausalLM.from_pretrained(v) for k, v in models.items()}
models

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

{'b': LlamaForCausalLM(
   (model): LlamaModel(
     (embed_tokens): Embedding(128256, 4096)
     (layers): ModuleList(
       (0-31): 32 x LlamaDecoderLayer(
         (self_attn): LlamaSdpaAttention(
           (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
           (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
           (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
           (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
           (rotary_emb): LlamaRotaryEmbedding()
         )
         (mlp): LlamaMLP(
           (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
           (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
           (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
           (act_fn): SiLU()
         )
         (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
         (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e

In [4]:
bmodel = models["b"]
fix_untrained_tokens(bmodel)
# embeddings = bmodel.get_input_embeddings().weight.data
# mean_embedding = embeddings[:128000].mean(dim=0)

tensor([-0.0024, -0.0040,  0.0021,  ...,  0.0003, -0.0012,  0.0008])

In [5]:
model_names = {
    "b": "/inspire/hdd/ws-8207e9e2-e733-4eec-a475-cfa1c36480ba/embodied-multimodality/public/zfhe/jiaxing_projects/Language-Model-SAEs/exp/diff_models/Llama-3.1-8B",
    "i": "/inspire/hdd/ws-8207e9e2-e733-4eec-a475-cfa1c36480ba/embodied-multimodality/public/zfhe/jiaxing_projects/Language-Model-SAEs/exp/diff_models/Llama-3.1-8B-Instruct",
    "o": "/inspire/hdd/ws-8207e9e2-e733-4eec-a475-cfa1c36480ba/embodied-multimodality/public/zfhe/jiaxing_projects/Language-Model-SAEs/exp/diff_models/DeepSeek-R1-Distill-Llama-8B",
}
bmodel.save_pretrained(model_names["b"])

In [21]:
# bnowtokenizer = AutoTokenizer.from_pretrained("/inspire/hdd/ws-8207e9e2-e733-4eec-a475-cfa1c36480ba/embodied-multimodality/public/zfhe/jiaxing_projects/Language-Model-SAEs/exp/diff_models/Llama-3.1-8B")
# bnowmodel = AutoModelForCausalLM.from_pretrained("/inspire/hdd/ws-8207e9e2-e733-4eec-a475-cfa1c36480ba/embodied-multimodality/public/zfhe/jiaxing_projects/Language-Model-SAEs/exp/diff_models/Llama-3.1-8B")
# bnowembeddings = bnowmodel.get_input_embeddings().weight.data
# for token_id in range(128016,128022):
#     print(bnowembeddings[token_id] == mean_embedding)

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

tensor([True, True, True,  ..., True, True, True])
tensor([True, True, True,  ..., True, True, True])
tensor([True, True, True,  ..., True, True, True])
tensor([True, True, True,  ..., True, True, True])
tensor([True, True, True,  ..., True, True, True])
tensor([True, True, True,  ..., True, True, True])


In [25]:
print(tokenizers["b"].convert_ids_to_tokens(range(128016, 128022)))

['<|q_start|>', '<|q_end|>', '<|a_start|>', '<|a_end|> ', '<think>', '</think>']


In [6]:
model = models["i"]
fix_untrained_tokens(model)
# embeddings = model.get_input_embeddings().weight.data
# mean_embedding = embeddings[:128000].mean(dim=0)
# mean_embedding

tensor([-0.0024, -0.0040,  0.0021,  ...,  0.0003, -0.0012,  0.0008])

In [28]:
# tokenizers['i'] = AutoTokenizer.from_pretrained('/inspire/hdd/ws-8207e9e2-e733-4eec-a475-cfa1c36480ba/embodied-multimodality/public/zfhe/jiaxing_projects/Language-Model-SAEs/exp/diff_models/Llama-3.1-8B-Instruct')
# print(tokenizers['i'].convert_ids_to_tokens(range(128016,128022)))
# for token_id in range(128016,128022):
#     embeddings[token_id] = mean_embedding

['<|reserved_special_token_8|>', '<|reserved_special_token_9|>', '<|reserved_special_token_10|>', '<|reserved_special_token_11|>', '<think>', '</think>']


In [28]:
model.save_pretrained(model_names["i"])

In [29]:
model = models["o"]
fix_untrained_tokens(model)
model.save_pretrained(model_names["o"])

In [44]:
models["o"].get_input_embeddings().weight.data.mean(dim=0)

tensor([-0.0024, -0.0040,  0.0021,  ...,  0.0004, -0.0012,  0.0008])

In [43]:
models["i"].get_input_embeddings().weight.data[range(128016, 128022)]

tensor([[-0.0024, -0.0040,  0.0021,  ...,  0.0003, -0.0012,  0.0008],
        [-0.0024, -0.0040,  0.0021,  ...,  0.0003, -0.0012,  0.0008],
        [-0.0024, -0.0040,  0.0021,  ...,  0.0003, -0.0012,  0.0008],
        [-0.0024, -0.0040,  0.0021,  ...,  0.0003, -0.0012,  0.0008],
        [-0.0024, -0.0040,  0.0021,  ...,  0.0003, -0.0012,  0.0008],
        [-0.0024, -0.0040,  0.0021,  ...,  0.0003, -0.0012,  0.0008]])

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizers = AutoTokenizer.from_pretrained(
    "/inspire/hdd/ws-8207e9e2-e733-4eec-a475-cfa1c36480ba/embodied-multimodality/public/zfhe/models/Llama-3.1-8B"
)
models = AutoModelForCausalLM.from_pretrained(
    "/inspire/hdd/ws-8207e9e2-e733-4eec-a475-cfa1c36480ba/embodied-multimodality/public/zfhe/models/Llama-3.1-8B"
)
models

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (n

In [5]:
embeddings = models.get_input_embeddings().weight.data
embeddings.shape

torch.Size([128256, 4096])

In [8]:
# print(tokenizers.convert_ids_to_tokens(range(128016,128022)))
for token_id in range(128016, 128022):
    print(f"{token_id} - {embeddings[token_id]}")

128016 - tensor([-6.4520e-23, -8.0650e-24, -2.1713e-23,  ...,  8.2718e-23,
         4.3427e-23,  4.0945e-23])
128017 - tensor([ 2.0841e-25, -3.1226e-23,  3.0812e-23,  ...,  3.1019e-23,
         5.2733e-23, -2.3058e-23])
128018 - tensor([-2.2127e-23, -4.4461e-23,  3.3087e-23,  ..., -7.1861e-24,
        -4.5495e-23,  1.7267e-23])
128019 - tensor([ 1.0753e-23,  2.4609e-23,  1.0236e-23,  ..., -3.0761e-24,
         1.4993e-23,  2.2954e-23])
128020 - tensor([ 1.2253e-23, -6.2039e-23,  2.9572e-23,  ...,  3.8464e-23,
         3.8464e-23,  4.4875e-23])
128021 - tensor([ 2.9572e-23, -2.0266e-23,  2.7090e-23,  ...,  1.5199e-23,
         9.9779e-24,  1.6130e-23])


In [15]:
mean_embedding = embeddings.mean(dim=0)
print(mean_embedding.tolist())
mean_embedding1 = embeddings[:128000].mean(dim=0)
print(mean_embedding1.tolist())

((mean_embedding1 - mean_embedding) / mean_embedding).tolist()

[-0.0024028841871768236, -0.003996074199676514, 0.0021069259382784367, -0.00017608772031962872, -0.005583043675869703, -0.0001565763377584517, -0.00018419056141283363, 0.0009738958906382322, -0.0005562104051932693, 0.0007188466261141002, -9.571220289217308e-05, 0.0016472636489197612, 0.0005365521064959466, 0.0011423995019868016, 0.0004992990870960057, 0.00017972767818719149, -0.0010319710709154606, -0.000571000506170094, 0.0012551489053294063, 0.0014975698431953788, -0.0017496770014986396, -4.076978439115919e-05, -0.000958869291935116, 9.000254794955254e-05, 0.0009330619359388947, -0.0002409383305348456, 0.0009877526899799705, 0.000976781127974391, 0.0020677002612501383, 0.003261370351538062, -0.0012123682536184788, -0.0009383632568642497, -0.0014888278674334288, -0.0008309604017995298, 0.0002617652644403279, -2.776472501864191e-05, 0.0006166173261590302, 8.464302663924173e-05, 0.0008375857141800225, -0.004613897763192654, 0.0005804368993267417, -0.00024644864606671035, -0.001258353935

[0.002001199871301651,
 0.0019994163885712624,
 0.001995319267734885,
 0.002024353016167879,
 0.0019956636242568493,
 0.0018823692807927728,
 0.002251080237329006,
 0.0020083789713680744,
 0.0019554980099201202,
 0.0019848269876092672,
 0.0014610114740207791,
 0.0020302636548876762,
 0.0020373414736241102,
 0.0019845846109092236,
 0.0019406863721087575,
 0.0020568722393363714,
 0.002048267750069499,
 0.0020748821552842855,
 0.002022232860326767,
 0.0019764418248087168,
 0.00199712417088449,
 0.0025913931895047426,
 0.001976112835109234,
 0.0018488493515178561,
 0.001994337886571884,
 0.0019895927980542183,
 0.001489970600232482,
 0.001977120293304324,
 0.002005133545026183,
 0.0020088553428649902,
 0.001981340115889907,
 0.0019485194934532046,
 0.00199148734100163,
 0.001961992820724845,
 0.002212651539593935,
 0.0018711568554863334,
 0.001989634009078145,
 0.0017450843006372452,
 0.0020277821458876133,
 0.001994192833080888,
 0.0020703321788460016,
 0.002127679530531168,
 0.0020384527