In [1]:
from transformers import BertTokenizer,BertForPreTraining
import torch
import json
def writeToJsonFile(path: str, obj,indent=2):
    with open(path, "w", encoding="utf-8") as f:
        f.write(json.dumps(obj, ensure_ascii=False,indent=indent,sort_keys=True))
def readFromJsonFile(path: str):
    with open(path, "r", encoding="utf-8") as f:
        return json.loads(f.read())
def saveVocab(path:str,obj,sortVocab=False):
    obj=list(obj)
    if sortVocab:
        obj.sort()
    with open(path, "w", encoding="utf-8") as f:
        for i in obj:
            f.write(i+'\n')

In [2]:
modelPath='../bert/chinese_wobert_plus_L-12_H-768_A-12/'#旧模型目录，包括config、词表、模型文件
newModelPath='./new/'#这个目录最初只需放置新的词表，修改后的模型、config文件将保存到此目录
bert=BertForPreTraining.from_pretrained(modelPath)
bert.eval()
oldTk=BertTokenizer.from_pretrained(modelPath)
newTk=BertTokenizer.from_pretrained(newModelPath)

### 一、记录找得到的，和找不到只能随机或平均初始化的token，生成新config文件

In [6]:
deled,added=list(oldTk.vocab.keys()-newTk.vocab.keys()),list(newTk.vocab.keys()-oldTk.vocab.keys())
deled.sort(),added.sort()
#记录新旧vocab的token变化
print(f"原vocab删除{len(deled)}个，新vocab增加{len(added)}个")
saveVocab(newModelPath+'原vocab已删除.json',deled,True)
saveVocab(newModelPath+'新vocab已增加.json',added,True)
print(f"删除与新增token、新的config已写入路径：{newModelPath}",)
oldConfig=readFromJsonFile(modelPath+'config.json')
oldConfig['vocab_size']=newTk.vocab_size
#写入新的config文件，变化的也就vocab大小
writeToJsonFile(newModelPath+'config.json',oldConfig)

原vocab删除31828个，新vocab增加3696个
删除与新增token、新的config已写入路径：./new/


### 二、共有的token直接复制权重

In [8]:
common=oldTk.vocab.keys()&newTk.vocab.keys()

In [9]:
#对于旧词表中能找到的，记录下每个key的embedding和mlm bias的权重
key2embedding,key2mlmBias=dict(),dict()
for key in common:
    idx=oldTk.vocab[key]
    val=bert.bert.embeddings.word_embeddings.weight.data[idx]
    key2embedding[key]=val
    val=bert.cls.predictions.decoder.bias.data[idx]
    key2mlmBias[key]=val

### 三、旧词表模型中找不到的token，尝试用旧词表细粒度拆分后取平均权重

In [46]:
for key in added:
    idx=oldTk.encode(key,add_special_tokens=False)#切分后还找不到的就变成了unk
    val=bert.bert.embeddings.word_embeddings.weight.data[idx].mean(dim=0)
    key2embedding[key]=val
    val=bert.cls.predictions.decoder.bias.data[idx].mean(dim=0)
    key2mlmBias[key]=val

### 四、开始恢复和保存

In [47]:
#新模型调整下embedding层，mlm层大小
bert.resize_token_embeddings(newTk.vocab_size)

Embedding(21868, 768)

In [49]:
#根据dict恢复权重
for key in newTk.vocab.keys():
    idx=newTk.vocab[key]#找到在新模型里的index
    val=key2embedding[key]
    bert.bert.embeddings.word_embeddings.weight.data[idx]=val
    val=key2mlmBias[key]
    bert.cls.predictions.decoder.bias.data[idx]=val

In [50]:
#保存新模型
torch.save(bert.state_dict(),newModelPath+'pytorch_model.bin')

### 五、检查权重是否符合预期

In [4]:
old=torch.load(open(modelPath+'pytorch_model.bin',"rb"))
new=torch.load(open(newModelPath+'pytorch_model.bin',"rb"))

In [31]:
for key in common:
    oldIdx=oldTk.vocab[key]
    newIdx=newTk.vocab[key]
    notEqual=old['bert.embeddings.word_embeddings.weight'][oldIdx]!=new['bert.embeddings.word_embeddings.weight'][newIdx]
    assert notEqual.sum().item()==0
for key in common:
    oldIdx=oldTk.vocab[key]
    newIdx=newTk.vocab[key]
    notEqual=old['cls.predictions.decoder.bias'][oldIdx]!=new['cls.predictions.decoder.bias'][newIdx]
    assert notEqual.sum().item()==0