In [1]:
import torch
from transformers import AutoConfig, AutoModelForCausalLM, LlamaTokenizer

origin_model = AutoModelForCausalLM.from_pretrained(
    '/mnt/bn/pankeyu/mlx/users/pankeyu/playground/backbones/llama7b-v2',
    trust_remote_code=True,
    torch_dtype=torch.float16,
    device_map='auto'
)

origin_tokenizer = LlamaTokenizer.from_pretrained(
    '/mnt/bn/pankeyu/mlx/users/pankeyu/playground/backbones/llama7b-v2',
    trust_remote_code=True,
)

extend_tokenzier = LlamaTokenizer.from_pretrained(
    '/mnt/bn/pankeyu/mlx/users/pankeyu/playground/LLMsTrainer/configs/tokenizer_configs/llama_plus',
    trust_remote_code=True,
)

extend_model = AutoModelForCausalLM.from_pretrained(
    '/mnt/bn/pankeyu/mlx/users/pankeyu/playground/backbones/llama7b-v2-avg-plus',
    trust_remote_code=True,
    torch_dtype=torch.float16,
    device_map='auto'
)

print(extend_model.model.embed_tokens)
print(extend_model.lm_head)
print('Extend Vocab Size: ', len(extend_tokenzier.get_vocab()))

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 2/2 [00:09<00:00,  4.70s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:09<00:00,  4.94s/it]

Embedding(40419, 4096, padding_idx=0)
Linear(in_features=4096, out_features=40419, bias=False)
Extend Vocab Size:  40419





In [3]:
print('-' * 10 + ' Avg Extend ' + '-' * 10)
print('first token embedding: ', extend_model.model.embed_tokens.weight.data[0, :])         # 首 token 的 embedding 应该和随机扩词表后的模型值相同
print('last token embedding: ', extend_model.model.embed_tokens.weight.data[-1, :])         # 尾 token 的 embedding 应该和随机扩词表后的模型值不同
print('first lm-head embedding', extend_model.lm_head.weight.data[0, :])                    # 首 token 的 lm embedding 应该和随机扩词表后的模型值相同
print('last lm-head embedding', extend_model.lm_head.weight.data[-1, :])                    # 尾 token 的 lm embedding 应该和随机扩词表后的模型值不同

print('-' * 10 + ' Random Extend ' + '-' * 10)
print('first token embedding: ', origin_model.model.embed_tokens.weight.data[0, :])
print('last token embedding: ', origin_model.model.embed_tokens.weight.data[-1, :])
print('first lm-head embedding', origin_model.lm_head.weight.data[0, :])
print('last lm-head embedding', origin_model.lm_head.weight.data[-1, :])

---------- Avg Extend ----------
first token embedding:  tensor([ 1.2517e-06, -1.7881e-06, -4.3511e-06,  ...,  8.9407e-07,
        -6.5565e-06,  8.9407e-07], device='cuda:0', dtype=torch.float16)
last token embedding:  tensor([ 0.0015,  0.0055,  0.0001,  ...,  0.0054,  0.0036, -0.0061],
       device='cuda:0', dtype=torch.float16)
first lm-head embedding tensor([-0.0039,  0.0032, -0.0071,  ...,  0.0053, -0.0082,  0.0070],
       device='cuda:6', dtype=torch.float16)
last lm-head embedding tensor([-0.0056,  0.0034,  0.0096,  ...,  0.0020,  0.0051, -0.0054],
       device='cuda:6', dtype=torch.float16)
---------- Random Extend ----------
first token embedding:  tensor([ 1.2517e-06, -1.7881e-06, -4.3511e-06,  ...,  8.9407e-07,
        -6.5565e-06,  8.9407e-07], device='cuda:0', dtype=torch.float16)
last token embedding:  tensor([-0.0381,  0.0041,  0.0029,  ...,  0.0058, -0.0052,  0.0081],
       device='cuda:0', dtype=torch.float16)
first lm-head embedding tensor([-0.0039,  0.0032, -0.007

In [2]:

embed_tokens = model.model.embed_tokens                 # Embedding(32000, 4096, padding_idx=0), use `embed_tokens(torch.LongTensor([0, 1]))` to get embedding
embed_dim = embed_tokens.embedding_dim                  # 4096
lm_head = model.lm_head                                 # Linear(in_features=4096, out_features=32000, bias=False)
print(lm_head)
print(lm_head.weight.size())

tensor([ 0.0015,  0.0055,  0.0001,  ...,  0.0054,  0.0036, -0.0061],
       device='cuda:0', dtype=torch.float16)
tensor([-0.0056,  0.0034,  0.0096,  ...,  0.0020,  0.0051, -0.0054],
       device='cuda:6', dtype=torch.float16)


In [18]:
extend_token_embedding = []
origin_tokenizer_tokens = list(origin_tokenizer.get_vocab().keys())
extend_tokenizer_tokens = list(extend_tokenzier.get_vocab().keys())
assert len(extend_tokenizer_tokens) > len(origin_tokenizer), \
    'Extend vocab size should larger than Origin vocab.'

new_embeddings = torch.nn.Embedding(len(extend_tokenizer_tokens), embed_dim)              # allocate new matrix, shape: torch.Size([40419, 4096])
new_embeddings.to(embed_tokens.weight.device)

new_embeddings.weight.data.normal_(mean=0.0, std=1.0)                                     # initialize weights
if new_embeddings.padding_idx is not None:
    new_embeddings.weight.data[new_embeddings.padding_idx].zero_()

num_tokens_to_copy = min(
    len(origin_tokenizer_tokens), 
    len(extend_tokenizer_tokens)
)

new_embeddings.weight.data[:num_tokens_to_copy, :] = embed_tokens.weight.data[:num_tokens_to_copy, :]

extend_token_start_index = len(origin_tokenizer_tokens)                                   # get the first extend token index
print('extend_token_start_index: ', extend_token_start_index)
for i in range(extend_token_start_index, len(extend_tokenizer_tokens)):
    ext_token = extend_tokenizer_tokens[i]
    
    # print('Before: ', new_embeddings.weight.data[i, :])

    if '<' in ext_token and '>' in ext_token:                                             # don't initial sepcial token like: <pad>, <title>, ...
        pass
    else:
        sub_tokens = origin_tokenizer.encode(ext_token)                                   # e.g. [1, 259, 13]
        drop_tokens = [origin_tokenizer.bos_token_id, origin_tokenizer.eos_token_id]      # drop bos/eos token 
        sub_tokens = [
            sub_token for sub_token in sub_tokens if sub_token not in drop_tokens         # e.g. [259, 13]
            
        ]
        sub_tokens_embeddings = embed_tokens(torch.LongTensor(sub_tokens))                # all sub-token embedding in origin tokenizer, shape: (2, 4096)
        avg_sub_token_embedding = torch.mean(sub_tokens_embeddings, dim=0)                # averaage sub-token embedding as new extend token embedding, shpae: (4096,)
        new_embeddings.weight.data[i, :] = avg_sub_token_embedding                        # update average embedding in new embedding matrix
    
    # print('Before: ', new_embeddings.weight.data[i, :])

print(new_embeddings.weight.size())                                                       # torch.Size([40419, 4096])
    

extend_token_start_index:  32000
torch.Size([40419, 4096])


In [27]:
old_lm_head_output_features, old_lm_head_input_features = lm_head.weight.size()
new_lm_head = torch.nn.Linear(
    old_lm_head_input_features, 
    len(extend_tokenizer_tokens),
     bias=False
).to(lm_head.weight.device)

new_lm_head.weight.data[:num_tokens_to_copy, :] = lm_head.weight.data[:num_tokens_to_copy, :]

for i in range(extend_token_start_index, len(extend_tokenizer_tokens)):
    ext_token = extend_tokenizer_tokens[i]

    # print('Before: ', new_lm_head.weight.data[i, :])

    if '<' in ext_token and '>' in ext_token:                                             # don't initial sepcial token like: <pad>, <title>, ...
        pass
    else:
        sub_tokens = origin_tokenizer.encode(ext_token)                                   # e.g. [1, 259, 13]
        drop_tokens = [origin_tokenizer.bos_token_id, origin_tokenizer.eos_token_id]      # drop bos/eos token 
        sub_tokens = [
            sub_token for sub_token in sub_tokens if sub_token not in drop_tokens         # e.g. [259, 13]
            
        ]
        sub_lm_embeddings = lm_head.weight.data[sub_tokens, :]                            # all sub-token embedding in origin tokenizer, shape: (2, 4096)
        avg_sub_lm_embedding = torch.mean(sub_lm_embeddings, dim=0)                       # averaage sub-toke lm embedding as new extend token embedding, shpae: (4096,)
        new_lm_head.weight.data[i, :] = avg_sub_lm_embedding                              # update average embedding in new embedding matrix
    
    # print('After: ', new_lm_head.weight.data[i, :])

# print(new_lm_head.weight.size())                                                        # torch.Size([40419, 4096])

torch.Size([40419, 4096])
