### 1.动机
- donut使用了多语言的词表, 实际使用中大多情况下都是只有中文和英文词表

### 2.bloom模型

In [1]:
from tqdm import tqdm
from transformers import AutoTokenizer
from vocabulary_pruner import VocabularyPruner, DonutVocabularyPruner

In [2]:
MP1 = "J:/model/bigscience-bloom-560m"
MP2 = "J:/model/YeungNLP-bloom-396m-zh"

In [3]:
class BloomVocabularyPruner(VocabularyPruner):

    def update_ebeddings(self, model, new2old_token_id, new_embeds, new_lm_head):
        for token_id, old_token_id in tqdm(new2old_token_id.items()):
            new_embeds.weight.data[token_id] = model.transformer.word_embeddings.weight.data[old_token_id]
            new_lm_head.weight.data[token_id] = model.lm_head.weight.data[old_token_id]
        model.transformer.word_embeddings.weight = new_embeds.weight
        model.lm_head.weight = new_lm_head.weight

In [4]:
# save_path = 'bloom-396m-zh'
# pruner = BloomVocabularyPruner()
# # 裁剪
# pruner.prune(MP1, MP2, save_path)
# # 检查裁剪的模型与原模型是否一致
# pruner.check(MP1, save_path, text='长风破浪会有时')

### 3.donut模型
- step1: 构建删减后的tokenizer.json

In [5]:
import langid
MP3 = "J:/model/mllm-model/donut-pretrain/20240102/pl-checkpoint-232000-ned-0.8460975410122905"

In [6]:
tokenizer3 = AutoTokenizer.from_pretrained(MP3)

In [7]:
tokenizer3.get_vocab()

{'て': 35872,
 '▁appreciate': 51578,
 '▁slow': 50091,
 'kse': 9617,
 '▁befinner': 11291,
 '▁سۈرەت': 10351,
 'vol': 38763,
 '天堂': 54898,
 'NAR': 54461,
 'writer': 49237,
 '一同': 47221,
 '絹': 31154,
 '▁amp': 41746,
 '袖': 6643,
 '▁fai': 16102,
 '귀': 49708,
 'さ': 47530,
 'DB': 39373,
 '▁MAX': 53631,
 'dd': 39629,
 '▁krásný': 20112,
 '▁вираз': 18635,
 '知って': 4541,
 '▁gördük': 12530,
 'liste': 7204,
 'സിനെ': 15622,
 '▁UU': 55440,
 'head': 36294,
 '▁Rigtig': 24053,
 'jeva': 39096,
 '▁Low': 46577,
 '储': 52028,
 '囝': 31440,
 '▁mei': 44125,
 '▁Mode': 38105,
 'mish': 51645,
 '▁நட': 3151,
 '▁істеу': 4285,
 'RS': 12790,
 '▁акцыі': 15719,
 '▁Astro': 53615,
 '▁coach': 44204,
 '▁Arra': 49014,
 '▁слот': 5726,
 '사업': 40664,
 '11)': 42340,
 '秧': 30912,
 'வில': 12612,
 '国务院': 43580,
 '▁разве': 2432,
 '▁콘텐츠': 51960,
 '47)': 50535,
 '霄': 24437,
 '▁proběhl': 25881,
 '▁Green': 12992,
 'ées': 53519,
 '▁tart': 12471,
 '第': 41750,
 '▁дисциплін': 10655,
 '▁Lisa': 36882,
 '지고': 42519,
 'pedi': 46480,
 '経': 40283,
 '

In [8]:
## 统计词表中的语言分布
lang_dist, lang_dist_list = {}, {}
for idx, voc in enumerate(tqdm(tokenizer3.get_vocab())):
    r = langid.classify(voc)[0]
    if r not in lang_dist:
        lang_dist_list[r] = [0, []]
        lang_dist[r] = 0
    lang_dist_list[r][0] += 1
    lang_dist[r] += 1
    lang_dist_list[r][1].append(voc)
    if idx <= 10:
        print(voc, r)


  0%|          | 72/58891 [00:02<20:22, 48.13it/s]  

て ja
▁appreciate en
▁slow lv
kse fi
▁befinner en
▁سۈرەت ug
vol en
天堂 zh
NAR en
writer en
一同 zh


100%|██████████| 58891/58891 [01:29<00:00, 659.87it/s]


In [10]:
import json
lang_dist = dict(sorted(lang_dist.items(), key=lambda x:x[1], reverse=True))
lang_dist_list = dict(sorted(lang_dist_list.items(), key=lambda x:x[1], reverse=True))

lang_dist.keys(), lang_dist

with open("lang_dist.json", "w", encoding="utf-8") as f1, \
    open("lang_dist_list.json", "w", encoding="utf-8") as f2:
    f1.write(json.dumps(lang_dist, ensure_ascii=False, indent=2))
    f2.write(json.dumps(lang_dist_list, ensure_ascii=False, indent=2))

In [11]:
lang_dist["zh"]

12734

In [12]:
"飞" in lang_dist_list["zh"][1]

True

In [34]:
# 重新保存tokenzier.json
tokenizer3_dict = tokenizer3.__dict__
tokenizer3_dict, type(tokenizer3_dict)

({'_tokenizer': <tokenizers.Tokenizer at 0x1da82c3ebe0>,
  '_decode_use_source_tokenizer': False,
  'init_inputs': (),
  'init_kwargs': {'bos_token': '<s>',
   'eos_token': '</s>',
   'sep_token': '</s>',
   'cls_token': '<s>',
   'unk_token': '<unk>',
   'pad_token': '<pad>',
   'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=True),
   'added_tokens_decoder': {'0': {'content': '<s>',
     'lstrip': False,
     'normalized': False,
     'rstrip': False,
     'single_word': False,
     'special': True},
    '1': {'content': '<pad>',
     'lstrip': False,
     'normalized': False,
     'rstrip': False,
     'single_word': False,
     'special': True},
    '2': {'content': '</s>',
     'lstrip': False,
     'normalized': False,
     'rstrip': False,
     'single_word': False,
     'special': True},
    '3': {'content': '<unk>',
     'lstrip': False,
     'normalized': False,
     'rstrip': False,
     'single_word': False,
     'special': True},

In [39]:
tokenizer3

XLMRobertaTokenizerFast(name_or_path='J:/model/mllm-model/donut-pretrain/20240102/pl-checkpoint-232000-ned-0.8460975410122905', vocab_size=57522, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': '<mask>', 'additional_special_tokens': ['<s_iitcdip>', '<s_synthdog>']}, clean_up_tokenization_spaces=True)

In [131]:
tokenizer3_file = "J:/model/mllm-model/donut-pretrain/20240102/pl-checkpoint-232000-ned-0.8460975410122905/tokenizer.json"
tokenizer3_config_file = "J:/model/mllm-model/donut-pretrain/20240102/pl-checkpoint-232000-ned-0.8460975410122905/tokenizer_config.json"

with open(tokenizer3_file, "r", encoding="utf-8") as f1,\
        open(tokenizer3_config_file, "r", encoding="utf-8") as f2:
    tokenizer_data = json.load(f1)
    tokenizer3_config_data = json.load(f2)

In [132]:
tokenizer_data

{'version': '1.0',
 'truncation': {'direction': 'Right',
  'max_length': 2560,
  'strategy': 'LongestFirst',
  'stride': 0},
 'padding': {'strategy': {'Fixed': 2560},
  'direction': 'Right',
  'pad_to_multiple_of': None,
  'pad_id': 1,
  'pad_type_id': 0,
  'pad_token': '<pad>'},
 'added_tokens': [{'id': 0,
   'content': '<s>',
   'single_word': False,
   'lstrip': False,
   'rstrip': False,
   'normalized': False,
   'special': True},
  {'id': 1,
   'content': '<pad>',
   'single_word': False,
   'lstrip': False,
   'rstrip': False,
   'normalized': False,
   'special': True},
  {'id': 2,
   'content': '</s>',
   'single_word': False,
   'lstrip': False,
   'rstrip': False,
   'normalized': False,
   'special': True},
  {'id': 3,
   'content': '<unk>',
   'single_word': False,
   'lstrip': False,
   'rstrip': False,
   'normalized': False,
   'special': True},
  {'id': 57521,
   'content': '<mask>',
   'single_word': False,
   'lstrip': True,
   'rstrip': False,
   'normalized': True,

In [133]:
vocab = tokenizer_data["model"]["vocab"]
len(vocab), vocab[100]

(57522, ['เดอร์', -13.401432037353516])

In [134]:
import re

In [135]:
def is_chinese_or_english(token):
    pattern = re.compile("^[\u4e00-\u9fa5a-zA-Z]+$")
    return bool(pattern.match(token))

def is_valid_token(token):
    """
    中文、英文、数字、符号
    :param token:
    :return:
    """
    #pattern = re.compile("^[\u4e00-\u9fa5a-zA-Z0-9\s`~!@#$%^&Ω🔞*()-_+={}[\]:;\"'<>,.?/]+$")
    pattern = re.compile("^[\u4e00-\u9fa5a-zA-Z0-9\s`~!@#$%^&Ω🔞*()≈≡≠＝<>＜＞≮≯∷±＋－×÷／∫∮∝∞∧∨∑∏∪∩∈∵∴⊥‖∠⌒≌∽√（）【】｛｝ⅠⅡ⊕⊙∥αβγδεζηθΔ+={}[\]:;\"'<>,.?/]+$")

    return bool(pattern.match(token))

input_str = "3.2"
is_chinese_or_english(input_str), is_valid_token(input_str), langid.classify(input_str)

(False, True, ('en', 9.061840057373047))

In [136]:
new_vocab, filtered_vocab = [], []
for v in tqdm(vocab):
    token = v[0].replace("▁", "")
    lang = langid.classify(token)
    is_ch_or_en = is_chinese_or_english(token)
    is_valid = is_chinese_or_english(token)
    flag = True if lang[0] in ["en", "zh"] else False
    if is_ch_or_en or is_valid or flag:
        new_vocab.append(v)
    else:
        filtered_vocab.append(v)


100%|██████████| 57522/57522 [01:22<00:00, 694.19it/s]


In [137]:
len(new_vocab), len(filtered_vocab)
with open("new_vocab.txt", "w", encoding="utf-8") as f1,\
    open("filtered_vocab.txt", "w", encoding="utf-8") as f2:
    vocab_1 = [v[0] for v in new_vocab]
    vocab_2 = [v[0] for v in filtered_vocab]
    f1.write("\n".join(vocab_1))
    f2.write("\n".join(vocab_2))

In [138]:
new_vocab[:-100], len(new_vocab)

([['<s>', 0.0],
  ['<pad>', 0.0],
  ['</s>', 0.0],
  ['<unk>', 0.0],
  ['a', -5.5477118492126465],
  ['▁_', -7.502312660217285],
  ['▁plus', -9.463733673095703],
  ['II', -11.439068794250488],
  ['▁own', -11.439139366149902],
  ['enseignement', -13.39828395843506],
  ['ective', -13.398374557495115],
  ['▁varianta', -13.398406982421877],
  ['▁bl', -11.439388275146484],
  ['▁hakka', -13.398545265197754],
  ['得到', -11.439473152160645],
  ['開啟', -13.398969650268556],
  ['▁Profi', -13.399009704589844],
  ['▁spel', -11.43959140777588],
  ['▁tranz', -13.399097442626951],
  ['ky', -9.46375560760498],
  ['hir', -11.43960189819336],
  ['▁euros', -11.439615249633787],
  ['▁Day', -11.439632415771484],
  ['▁zame', -13.399542808532717],
  ['▁primo', -11.439697265625],
  ['mark', -11.439749717712402],
  ['nni', -11.43976879119873],
  ['▁powinny', -13.399717330932615],
  ['▁esquerda', -13.39972686767578],
  ['▁Dal', -11.439814567565918],
  ['环境', -11.43984317779541],
  ['▁tegutse', -13.399785995483398

In [139]:
added_tokens = tokenizer_data["added_tokens"]
new_add_tokens = []
new_added_tokens_start_id = len(new_vocab) - 1
for t in added_tokens:
    if t["id"] in [0, 1, 2, 3]:
        new_add_tokens.append(t)
        continue
    t["id"] = new_added_tokens_start_id
    new_add_tokens.append(t)
    new_added_tokens_start_id += 1



In [140]:
new_add_tokens[:100]

[{'id': 0,
  'content': '<s>',
  'single_word': False,
  'lstrip': False,
  'rstrip': False,
  'normalized': False,
  'special': True},
 {'id': 1,
  'content': '<pad>',
  'single_word': False,
  'lstrip': False,
  'rstrip': False,
  'normalized': False,
  'special': True},
 {'id': 2,
  'content': '</s>',
  'single_word': False,
  'lstrip': False,
  'rstrip': False,
  'normalized': False,
  'special': True},
 {'id': 3,
  'content': '<unk>',
  'single_word': False,
  'lstrip': False,
  'rstrip': False,
  'normalized': False,
  'special': True},
 {'id': 40238,
  'content': '<mask>',
  'single_word': False,
  'lstrip': True,
  'rstrip': False,
  'normalized': True,
  'special': True},
 {'id': 40239,
  'content': '<sep/>',
  'single_word': False,
  'lstrip': False,
  'rstrip': False,
  'normalized': True,
  'special': False},
 {'id': 40240,
  'content': '<s_iitcdip>',
  'single_word': False,
  'lstrip': False,
  'rstrip': False,
  'normalized': False,
  'special': True},
 {'id': 40241,
  'c

In [156]:
import copy
new_tokenizer_data = copy.deepcopy(tokenizer_data)
new_tokenizer_data["model"]["vocab"] = new_vocab
new_tokenizer_data["added_tokens"] = new_add_tokens

new_tokenizer_config_data = copy.deepcopy(tokenizer3_config_data)
new_tokenizer_config_data["added_tokens"] = new_add_tokens
new_tokenizer_config_data["added_tokens_decoder"] = new_add_tokens

with open("./donut-zh/tokenizer.json", "w", encoding="utf-8") as f1,\
        open("./donut-zh/tokenizer_config.json", "w", encoding="utf-8") as f2:
    f1.write(json.dumps(new_tokenizer_data, ensure_ascii=False, indent=2))
    f2.write(json.dumps(new_tokenizer_config_data, ensure_ascii=False, indent=2))

In [157]:
MP4 = "./donut-zh"
tokenizer4 = AutoTokenizer.from_pretrained(MP4)

In [158]:
text = "中国平安是世界500强"
tokenizer3.tokenize(text), tokenizer4.tokenize(text)

(['中国', '平安', '是', '世界', '500', '强'], ['中国', '平安', '是', '世界', '500', '强'])

In [159]:
tokenizer3.convert_tokens_to_ids(tokenizer3.tokenize(text)), tokenizer4.convert_tokens_to_ids(tokenizer4.tokenize(text))

([20741, 53214, 37830, 36669, 40556, 48372],
 [11153, 36718, 23175, 22148, 25606, 32483])

In [160]:
len(tokenizer4.get_vocab())

41608

In [1]:
save_path = 'donut-zh-model'
# pruner = BloomVocabularyPruner()
pruner = DonutVocabularyPruner()
# 裁剪
pruner.prune(MP3, MP4, save_path)

NameError: name 'DonutVocabularyPruner' is not defined