In [1]:
import os
import json
import torch
import base64
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
from utils import make_context

from langdetect import detect as langdetect
from langdetect import DetectorFactory
# 确保检测结果的一致性
DetectorFactory.seed = 0

In [2]:
########## Prepare Dataset ###########
data_path = "/home/z00533370/projects/VLMEvalKit/raw_data/"

def get_text_list(folder_path):
    query_list = []
    response_list = []
    for file_name in os.listdir(folder_path):
        if file_name.endswith('json'):
            file = json.load(open(os.path.join(folder_path, file_name)))
            query_list.append(file['query'])
            response_list.append(file['response'])
    return query_list, response_list

query_list, response_list = get_text_list(data_path)

In [3]:
########## Count Token Freqs ###########
model_path = "/home/z00533370/projects/MoH/exp0731_qwenvl_chat_moh_layer16-31_sigmoid_prob_no_norm/checkpoint-5197/"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
token_counts = [0 for _ in range(151936)]
assert len(query_list) == len(response_list)
for i in tqdm(range(len(query_list))):
    query, response = query_list[i], response_list[i]
    _, context_tokens = make_context(tokenizer, query, history=[], system="You are a helpful assistant.")
    for token in context_tokens:
        token_counts[token] += 1
    response_tokens = tokenizer.encode(response)
    for token in response_tokens:
        token_counts[token] += 1

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17532/17532 [00:17<00:00, 1015.58it/s]


In [58]:
########## update multi-byte to count to sub-token ###########
tiktoken_bpe_file = "/home/z00533370/projects/MoH/exp0731_qwenvl_chat_moh_layer16-31_sigmoid_prob_no_norm/checkpoint-5197/qwen.tiktoken"
with open(tiktoken_bpe_file, "rb") as f:
    contents = f.read()
old_token_list = [base64.b64decode(token) for token, rank in (line.split() for line in contents.splitlines() if line)]
old_bytes_list = [token for token, rank in (line.split() for line in contents.splitlines() if line)]
# generate a list to calculate inherit_counts
inherit_counts = [0 for _ in range(151936)]
for i in range(len(old_bytes_list)):
    t_count = token_counts[i]
    b_len = len(old_bytes_list[i]) // 4
    if t_count > 0 and b_len > 1:
        for j in range(b_len - 1):
            for k in range(b_len - 1):
                sub_token = old_bytes_list[i][j*4*(k+1):(j+1)*4*(k+1)]
                try:
                    inherit_counts[old_bytes_list.index(sub_token)] += t_count

In [77]:
base64.b64decode(b'dGhp').decode('utf-8')

'thi'

In [99]:
########## Dictionary Pruning ###########
def is_special_token(token):
    return ((token.startswith('<') and token.endswith('>') and len(token) > 2) or
            (token.startswith('[') and token.endswith(']') and len(token) > 2))

new_token_list = []
new_bytes_list = []
mapping_new2old = []
# detect language, only keep english and chinese
for i in tqdm(range(len(old_token_list))):
    token = old_token_list[i]
    try:
        # number and symbols cannot be detected by langdetect
        token_str = token.decode("utf-8")
        #if (langdetect(token_str) in ['zh-cn', 'en']) or (token_counts[i] > 0) or is_special_token(token_str):
        if (token_counts[i] + inherit_counts[i] > 0) or is_special_token(token_str):
            new_token_list.append(token)
            new_bytes_list.append(old_bytes_list[i])
            mapping_new2old.append(i)
    except:
        new_token_list.append(token)
        new_bytes_list.append(old_bytes_list[i])
        mapping_new2old.append(i)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 151643/151643 [00:00<00:00, 1225986.40it/s]


In [100]:
########## Add Special Token Mapping ###########
qwen_vocab_size = 151936
for i in range(len(old_token_list), qwen_vocab_size):
    mapping_new2old.append(i)

In [65]:
########## Save New Vocab ###########
new_tiktoken_bpe_file = "/home/z00533370/projects/MoH/exp0731_qwenvl_chat_moh_layer16-31_sigmoid_prob_no_norm/checkpoint-5197-new-vocab/qwen.tiktoken"
vocab_mapping_file = "/home/z00533370/projects/MoH/exp0731_qwenvl_chat_moh_layer16-31_sigmoid_prob_no_norm/checkpoint-5197-new-vocab/token_vocab_mapping.torch"
# saving new tiktoken_bpe_file
with open(new_tiktoken_bpe_file, "w", encoding="utf8") as w:
    for i, token in enumerate(new_token_list):
        line = base64.b64encode(token).decode("utf8") + " " + str(i) + "\n"
        w.write(line)
# saving mapping index
torch.save(torch.LongTensor(mapping_new2old), vocab_mapping_file)

In [66]:
########## update model ###########
old_model_path = "/home/z00533370/projects/MoH/exp0731_qwenvl_chat_moh_layer16-31_sigmoid_prob_no_norm/checkpoint-5197/"
new_model_path = "/home/z00533370/projects/MoH/exp0731_qwenvl_chat_moh_layer16-31_sigmoid_prob_no_norm/checkpoint-5197-new-vocab/"
model = AutoModelForCausalLM.from_pretrained(old_model_path, trust_remote_code=True)
# define new module
new_embeds = torch.nn.Embedding(len(mapping_new2old), model.config.hidden_size, dtype=model.dtype)
new_lm_head = torch.nn.Linear(model.config.hidden_size, len(mapping_new2old), bias=False, dtype=model.dtype)
# get new module parameter from the old
assert len(set(mapping_new2old)) == len(mapping_new2old)
new_embeds.weight.data = model.transformer.wte.weight.data[torch.LongTensor(mapping_new2old, device=model.device)]
new_lm_head.weight.data = model.lm_head.weight.data[torch.LongTensor(mapping_new2old, device=model.device)]
# update model
model.transformer.wte.weight = new_embeds.weight
model.lm_head.weight = new_lm_head.weight
model.transformer.wte.num_embeddings = len(mapping_new2old)
model.lm_head.out_features = len(mapping_new2old)
# update config
model.config.__dict__['vocab_size'] = len(mapping_new2old)
model.config.__dict__['_name_or_path'] = new_model_path
model.generation_config.__dict__['eos_token_id'] = mapping_new2old.index(model.generation_config.__dict__['eos_token_id'])
model.generation_config.__dict__['pad_token_id'] = mapping_new2old.index(model.generation_config.__dict__['pad_token_id'])
# save new model
model.save_pretrained(new_model_path)

[2024-08-09 16:51:06,050] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/home/z00533370/anaconda3/envs/litgpt/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status




Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [107]:
for i in old_tokenizer.encode(query_list[4]):
    if i not in mapping_new2old:
        print(i)

In [110]:
text = query_list[4].encode("utf-16", "surrogatepass").decode("utf-16", "replace")
text

"Picture 1: <img>/home/z00533370/projects/datasets/LMUData/images/MMBench/256.jpg</img>\nHint: People can use the engineering-design process to develop solutions to problems. One step in the process is testing if a potential solution meets the requirements of the design.\nThe passage below describes how the engineering-design process was used to test a solution to a problem. Read the passage. Then answer the question below.\n\nDevin was a mechanical engineer who was designing  to record temperature, precipitation, and wind speed. The weather station would be used in a town where the highest recorded temperature was 40¬∞C. Devin wanted to make sure the weather station would work even in unusually warm weather.\nSo, he set an indoor test chamber to 50¬∞C with low moisture and no wind. He left the weather station in the chamber overnight. The next day, he checked to see if the weather station displayed accurate measurements after 24 hours at 50¬∞C.\nFigure: a weather station.\nQuestion: W

In [4]:
########## Dictionary Pruning ###########
tiktoken_bpe_file = "/home/z00533370/projects/MoH/exp0731_qwenvl_chat_moh_layer16-31_sigmoid_prob_no_norm/checkpoint-5197/qwen.tiktoken"

def is_special_token(token):
    return ((token.startswith('<') and token.endswith('>') and len(token) > 2) or
            (token.startswith('[') and token.endswith(']') and len(token) > 2))

with open(tiktoken_bpe_file, "rb") as f:
    contents = f.read()
old_token_list = [base64.b64decode(token) for token, rank in (line.split() for line in contents.splitlines() if line)]

In [10]:
token_counts[0]

101

In [13]:
old_token_list[10000].decode('utf-8')

' rise'

In [148]:
old_tokenizer.encode("irteen")

[44904]

In [149]:
tokenizer.encode("irteen")

[12535]

In [88]:
model_path = "/home/z00533370/projects/MoH/exp0731_qwenvl_chat_moh_layer16-31_sigmoid_prob_no_norm/checkpoint-5197/"
old_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
old_result = old_tokenizer.encode(query_list[18])
print(len(old_result))

404


In [133]:
old_tokenizer.decode([44904])

'irteen'

In [159]:
b'aXJ0' in old_bytes_list

True

In [161]:
b'aXJ0' in new_bytes_list

True

In [120]:
old_bytes_list[44904]
#old_token_list[44904], old_token_list[30942]

b'aXJ0ZWVu'

In [106]:
model_path = "/home/z00533370/projects/MoH/exp0731_qwenvl_chat_moh_layer16-31_sigmoid_prob_no_norm/checkpoint-5197-new-vocab/"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
result = tokenizer.encode(query_list[18])
print(len(result))

408


In [122]:
old_tokenizer.decode([404, 665, 268])

'irteen'

In [94]:
[mapping_new2old[i] for i in result[294:]]

[663,
 404,
 665,
 268,
 30942,
 550,
 11,
 892,
 1033,
 21286,
 553,
 8513,
 13224,
 304,
 279,
 220,
 16,
 21,
 15,
 15,
 82,
 323,
 220,
 16,
 22,
 15,
 15,
 82,
 13,
 576,
 7042,
 315,
 279,
 16244,
 30942,
 550,
 5230,
 94289,
 323,
 1910,
 1251,
 315,
 11643,
 36952,
 11,
 16703,
 3693,
 5203,
 11,
 323,
 7513,
 60007,
 13,
 576,
 2415,
 3685,
 4933,
 279,
 663,
 404,
 665,
 268,
 30942,
 550,
 304,
 220,
 16,
 22,
 20,
 15,
 13,
 9192,
 518,
 279,
 2415,
 13,
 5005,
 4226,
 279,
 3405,
 3685,
 624,
 14582,
 25,
 15920,
 315,
 1493,
 7482,
 572,
 16244,
 30942,
 550,
 5267,
 3798,
 510,
 32,
 13,
 19771,
 198,
 33,
 13,
 22652,
 198,
 5501,
 3293,
 279,
 4396,
 4226,
 504,
 279,
 2606,
 3403,
 13,
 715]

In [162]:
a = {}
a[('c', 'd')] = 1
a[('c', 'd')] += 1
a[('c', 'd')]

2

In [123]:
old_bytes_list[404], old_bytes_list[665], old_bytes_list[268]
old_token_list[404], old_token_list[665], old_token_list[268]

(b'ir', b'te', b'en')

In [90]:
for i, item in enumerate([mapping_new2old[i] for i in result]):
    if old_result[i] != item:
        print(i)

295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
391
393
394
395
396
397
398
400
401
402
403


IndexError: list index out of range

In [69]:
count = 0
for i, item in enumerate(query_list):
    if len(old_tokenizer.encode(item)) != len(tokenizer.encode(item)):
        count += 1
print(count)

2089


In [70]:
len(tokenizer)

22270

In [82]:
tokenizer.encode("""Hint: People can use the engineering-design process to develop solutions to problems. One step in the process is testing if a potential solution meets the requirements of the design.
The passage below describes how the engineering-design process was used to test a solution to a problem. Read the passage. Then answer the question below.

Devin was a mechanical engineer who was designing  to record temperature, precipitation, and wind speed. The weather station would be used in a town where the highest recorded temperature was 40¬∞C. Devin wanted to make sure the weather station would work even in unusually warm weather.
So, he set an indoor test chamber to 50¬∞C with low moisture and no wind. He left the weather station in the chamber overnight. The next day, he checked to see if the weather station displayed accurate measurements after 24 hours at 50¬∞C.
Figure: a weather station.
Question: Which of the following could Devin's test show?
Options:
A. if the weather station would work when the temperature was 50¬∞C
B. how well the weather station would work when it was windy
Please select the correct answer from the options above.")

[12698,
 25,
 5421,
 563,
 832,
 215,
 8072,
 18641,
 1477,
 247,
 1701,
 5914,
 247,
 3548,
 13,
 2675,
 2189,
 240,
 215,
 1477,
 310,
 4752,
 356,
 200,
 3174,
 4086,
 10176,
 215,
 5238,
 251,
 215,
 2108,
 546,
 680,
 10557,
 2621,
 8822,
 1027,
 215,
 8072,
 18641,
 1477,
 499,
 1197,
 247,
 1047,
 200,
 4086,
 247,
 200,
 2503,
 13,
 3071,
 215,
 10557,
 13,
 3382,
 2938,
 215,
 2445,
 2621,
 318,
 1496,
 5780,
 499,
 200,
 10945,
 11457,
 751,
 499,
 13603,
 156,
 247,
 2344,
 5637,
 11,
 22387,
 11,
 259,
 5935,
 3163,
 13,
 503,
 5530,
 5020,
 866,
 323,
 1197,
 240,
 200,
 4085,
 1121,
 215,
 5194,
 7093,
 5637,
 499,
 13,
 27666,
 3274,
 247,
 1053,
 2000,
 215,
 5530,
 5020,
 866,
 822,
 1208,
 240,
 21545,
 5095,
 5530,
 546,
 3047,
 11,
 493,
 644,
 392,
 13527,
 1047,
 11708,
 247,
 156,
 382,
 2405,
 14083,
 259,
 769,
 5935,
 13,
 1037,
 1628,
 215,
 5530,
 5020,
 240,
 215,
 11708,
 11779,
 13,
 503,
 1407,
 1490,
 11,
 493,
 5996,
 247,
 1203,
 356,
 215,
 5530,
 50

In [21]:
for i, item in enumerate(old_token_list):
    try:
        #print(i, item.decode('utf-8'))
        item.decode('utf-8')
    except:
        print(i, item)

94 b'\xa1'
95 b'\xa2'
96 b'\xa3'
97 b'\xa4'
98 b'\xa5'
99 b'\xa6'
100 b'\xa7'
101 b'\xa8'
102 b'\xa9'
103 b'\xaa'
104 b'\xab'
105 b'\xac'
106 b'\xae'
107 b'\xaf'
108 b'\xb0'
109 b'\xb1'
110 b'\xb2'
111 b'\xb3'
112 b'\xb4'
113 b'\xb5'
114 b'\xb6'
115 b'\xb7'
116 b'\xb8'
117 b'\xb9'
118 b'\xba'
119 b'\xbb'
120 b'\xbc'
121 b'\xbd'
122 b'\xbe'
123 b'\xbf'
124 b'\xc0'
125 b'\xc1'
126 b'\xc2'
127 b'\xc3'
128 b'\xc4'
129 b'\xc5'
130 b'\xc6'
131 b'\xc7'
132 b'\xc8'
133 b'\xc9'
134 b'\xca'
135 b'\xcb'
136 b'\xcc'
137 b'\xcd'
138 b'\xce'
139 b'\xcf'
140 b'\xd0'
141 b'\xd1'
142 b'\xd2'
143 b'\xd3'
144 b'\xd4'
145 b'\xd5'
146 b'\xd6'
147 b'\xd7'
148 b'\xd8'
149 b'\xd9'
150 b'\xda'
151 b'\xdb'
152 b'\xdc'
153 b'\xdd'
154 b'\xde'
155 b'\xdf'
156 b'\xe0'
157 b'\xe1'
158 b'\xe2'
159 b'\xe3'
160 b'\xe4'
161 b'\xe5'
162 b'\xe6'
163 b'\xe7'
164 b'\xe8'
165 b'\xe9'
166 b'\xea'
167 b'\xeb'
168 b'\xec'
169 b'\xed'
170 b'\xee'
171 b'\xef'
172 b'\xf0'
173 b'\xf1'
174 b'\xf2'
175 b'\xf3'
176 b'\xf4'
177 b'\xf5