### 准备数据

In [None]:
import pandas as pd
# 读取.parquet文件
parquet_file = '/path/file_name.parquet'
df = pd.read_parquet(parquet_file)

# 获取text列的前1万条数据，只用10000条来做测试
text_col = df['text'][:10000]

# 指定要写入的txt文件
txt_file = '/path/file_name.txt'

# 将数据追加写入txt文件
with open(txt_file, 'a') as file:
    content_col.to_csv(file, sep='\t', index=False, header=False)
print(f'前1万条content数据已写入到 {txt_file}')

### 开始训练

In [None]:
pip install sentencepiece

nohup spm_train --input '/path/file_name.txt' \
--input_format text \
--model_prefix bpe_test \
--model_type bpe \
--vocab_size 10000 \
--character_coverage 0.9995 \
--num_threads 32 \
--split_digits True \
--byte_fallback True \
--max_sentence_length 24000 > bpe_test.log &

### 开始使用

In [None]:
import sentencepiece as spm
sp_bpe = spm.SentencePieceProcessor() 
sp_bpe.load('bpe_test.model')
print('*** BPE ***')
print(sp_bpe.encode_as_pieces('The excellence of a translation can only be judged by noting'))
print(len(sp_bpe.encode_as_pieces('The excellence of a translation can only be judged by noting')))
print(sp_bpe.encode_as_pieces('麒麟，是中国古代神话中的一种瑞兽'))
print(len(sp_bpe.encode_as_pieces('麒麟，是中国古代神话中的一种瑞兽')))

### 合并LLaMa词表

In [None]:
import os
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"]="python"
from transformers import LlamaTokenizer
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
import sentencepiece as spm

# 位置
llama_tokenizer_dir = "/path/llama-2-7b-hf" # 换成你自己模型的位置
chinese_sp_model_file ="/path/bpe_test.model" # 刚才训练的模型

# 加载
llama_tokenizer = LlamaTokenizer.from_pretrained(llama_tokenizer_dir)
chinese_sp_model = spm.SentencePieceProcessor()
chinese_sp_model.Load(chinese_sp_model_file)
llama_spm = sp_pb2_model.ModelProto()
llama_spm.ParseFromString(llama_tokenizer.sp_model.serialized_model_proto())
chinese_spm = sp_pb2_model.ModelProto()
chinese_spm.ParseFromString(chinese_sp_model.serialized_model_proto())


# 打印两个词表的大小和原llama的特殊token
print(len(llama_tokenizer),len(chinese_sp_model))
print(llama_tokenizer.all_special_tokens)
print(llama_tokenizer.all_special_ids)
print(llama_tokenizer.special_tokens_map)

# 结果
32000 10000
['<s>', '</s>', '<unk>']
[1, 2, 0]
{'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>'}

# 开始往llama词表里添加，这里你也可以直接加入你想要加入词表的词，或者是领域内的特殊词
llama_spm_tokens_set=set(p.piece for p in llama_spm.pieces)
print(len(llama_spm_tokens_set))
print(f"Before:{len(llama_spm_tokens_set)}")
for p in chinese_spm.pieces:
    piece = p.piece
    if piece not in llama_spm_tokens_set:
        new_p = sp_pb2_model.ModelProto().SentencePiece()
        new_p.piece = piece
        new_p.score = 0
        llama_spm.pieces.append(new_p)
print(f"New model pieces: {len(llama_spm.pieces)}")

# 结果
32000
Before:32000
New model pieces: 40114
# 我们中文词表原来有1万，去重添加后，添加了8114个词。

# 保存合并后的模型
output_sp_dir = 'merged_tokenizer_sp_test'
output_hf_dir = 'merged_tokenizer_hf_test'
os.makedirs(output_sp_dir,exist_ok=True)
with open(output_sp_dir+'/chinese_llama.model', 'wb') as f:
    f.write(llama_spm.SerializeToString())
tokenizer = LlamaTokenizer(vocab_file=output_sp_dir+'/chinese_llama.model')

tokenizer.save_pretrained(output_hf_dir)
print(f"Chinese-LLaMA tokenizer has been saved to {output_hf_dir}")

# 看一下效果
llama_tokenizer = LlamaTokenizer.from_pretrained(llama_tokenizer_dir)
chinese_llama_tokenizer = LlamaTokenizer.from_pretrained(output_hf_dir)


text = "The excellence of a translation can only be judged by noting"
print("Test text:\n",text)
print(f"Tokenized by LLaMA tokenizer:{llama_tokenizer.tokenize(text)}")
print(f"Tokenized length by LLaMA tokenizer:{len(llama_tokenizer.tokenize(text))}")
print(f"Tokenized by chinese_llama tokenizer:{chinese_llama_tokenizer.tokenize(text)}")
print(f"Tokenized length by LLaMA-extent-1 tokenizer:{len(chinese_llama_tokenizer.tokenize(text))}")


text = "麒麟，是中国古代神话中的一种瑞兽"
print("Test text:\n",text)
print(f"Tokenized by LLaMA tokenizer:{llama_tokenizer.tokenize(text)}")
print(f"Tokenized length by LLaMA tokenizer:{len(llama_tokenizer.tokenize(text))}")
print(f"Tokenized by chinese_llama tokenizer:{chinese_llama_tokenizer.tokenize(text)}")
print(f"Tokenized length by chinese_llama tokenizer:{len(chinese_llama_tokenizer.tokenize(text))}")
