什么时候需要从新训练tokenizer
- 如果需要的语言中没有可用的模型
- 如果需要使用的语料与现有训练模型的语料有很大的不同

## 1、准备语料库

In [3]:
from datasets import load_dataset

# This can take a few minutes to load, so grab a coffee or tea while you wait!
raw_datasets = load_dataset(path="code_search_net", 
                            name="python", 
                            cache_dir="./data")

Downloading data:   0%|          | 0.00/941M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/412178 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/22176 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/23107 [00:00<?, ? examples/s]

In [4]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['repository_name', 'func_path_in_repository', 'func_name', 'whole_func_string', 'language', 'func_code_string', 'func_code_tokens', 'func_documentation_string', 'func_documentation_tokens', 'split_name', 'func_code_url'],
        num_rows: 412178
    })
    test: Dataset({
        features: ['repository_name', 'func_path_in_repository', 'func_name', 'whole_func_string', 'language', 'func_code_string', 'func_code_tokens', 'func_documentation_string', 'func_documentation_tokens', 'split_name', 'func_code_url'],
        num_rows: 22176
    })
    validation: Dataset({
        features: ['repository_name', 'func_path_in_repository', 'func_name', 'whole_func_string', 'language', 'func_code_string', 'func_code_tokens', 'func_documentation_string', 'func_documentation_tokens', 'split_name', 'func_code_url'],
        num_rows: 23107
    })
})

此处仅使用whole_func_string列训练tokenizer

In [5]:
print(raw_datasets["train"][123456]["whole_func_string"])

def updater():
    """Update the current installation.

    git clones the latest version and merges it with the current directory.
    """
    print('%s Checking for updates' % run)
    # Changes must be separated by ;
    changes = '''major bug fixes;removed ninja mode;dropped python < 3.2 support;fixed unicode output;proxy support;more intels'''
    latest_commit = requester('https://raw.githubusercontent.com/s0md3v/Photon/master/core/updater.py', host='raw.githubusercontent.com')
    # Just a hack to see if a new version is available
    if changes not in latest_commit:
        changelog = re.search(r"changes = '''(.*?)'''", latest_commit)
        # Splitting the changes to form a list
        changelog = changelog.group(1).split(';')
        print('%s A new version of Photon is available.' % good)
        print('%s Changes:' % info)
        for change in changelog: # print changes
            print('%s>%s %s' % (green, end, change))

        current_path = os.getcwd().split('/') #

使用生成器避免所有内容都加载到内存中

In [6]:
# 生成器函数
def get_training_corpus():
    return (
        raw_datasets["train"][i: i + 1000]["whole_func_string"]
        for i in range(0, len(raw_datasets["train"]), 1000)
    )


training_corpus = get_training_corpus()

In [7]:
def get_training_corpus():
    dataset = raw_datasets["train"]
    for start_idx in range(0, len(dataset), 1000):
        samples = dataset[start_idx: start_idx + 1000]
        yield samples["whole_func_string"]

## 2、训练一个新的标记器
从一个旧的标记器开始，然后使用新的语料库训练一个新的标记器

In [8]:
from transformers import AutoTokenizer

old_tokenizer = AutoTokenizer.from_pretrained("/Volumes/WD_BLACK/models/gpt2")

数据处理

In [9]:
example = '''def add_numbers(a, b):
    """Add the two numbers `a` and `b`."""
    return a + b'''

tokens = old_tokenizer.tokenize(example)
tokens

['def',
 'Ġadd',
 '_',
 'n',
 'umbers',
 '(',
 'a',
 ',',
 'Ġb',
 '):',
 'Ċ',
 'Ġ',
 'Ġ',
 'Ġ',
 'Ġ"""',
 'Add',
 'Ġthe',
 'Ġtwo',
 'Ġnumbers',
 'Ġ`',
 'a',
 '`',
 'Ġand',
 'Ġ`',
 'b',
 '`',
 '."',
 '""',
 'Ċ',
 'Ġ',
 'Ġ',
 'Ġ',
 'Ġreturn',
 'Ġa',
 'Ġ+',
 'Ġb']

In [10]:
tokenizer = old_tokenizer.train_new_from_iterator(training_corpus, 52000)  # 52000是词表大小






In [11]:
tokens = tokenizer.tokenize(example)
tokens

['def',
 'Ġadd',
 '_',
 'numbers',
 '(',
 'a',
 ',',
 'Ġb',
 '):',
 'ĊĠĠĠ',
 'Ġ"""',
 'Add',
 'Ġthe',
 'Ġtwo',
 'Ġnumbers',
 'Ġ`',
 'a',
 '`',
 'Ġand',
 'Ġ`',
 'b',
 '`."""',
 'ĊĠĠĠ',
 'Ġreturn',
 'Ġa',
 'Ġ+',
 'Ġb']

In [12]:
print(len(tokens))
print(len(old_tokenizer.tokenize(example)))

27
36


In [13]:
example = """class LinearLayer():
    def __init__(self, input_size, output_size):
        self.weight = torch.randn(input_size, output_size)
        self.bias = torch.zeros(output_size)

    def __call__(self, x):
        return x @ self.weights + self.bias
    """
tokenizer.tokenize(example)

['class',
 'ĠLinear',
 'Layer',
 '():',
 'ĊĠĠĠ',
 'Ġdef',
 'Ġ__',
 'init',
 '__(',
 'self',
 ',',
 'Ġinput',
 '_',
 'size',
 ',',
 'Ġoutput',
 '_',
 'size',
 '):',
 'ĊĠĠĠĠĠĠĠ',
 'Ġself',
 '.',
 'weight',
 'Ġ=',
 'Ġtorch',
 '.',
 'randn',
 '(',
 'input',
 '_',
 'size',
 ',',
 'Ġoutput',
 '_',
 'size',
 ')',
 'ĊĠĠĠĠĠĠĠ',
 'Ġself',
 '.',
 'bias',
 'Ġ=',
 'Ġtorch',
 '.',
 'zeros',
 '(',
 'output',
 '_',
 'size',
 ')',
 'ĊĊĠĠĠ',
 'Ġdef',
 'Ġ__',
 'call',
 '__(',
 'self',
 ',',
 'Ġx',
 '):',
 'ĊĠĠĠĠĠĠĠ',
 'Ġreturn',
 'Ġx',
 'Ġ@',
 'Ġself',
 '.',
 'weights',
 'Ġ+',
 'Ġself',
 '.',
 'bias',
 'ĊĠĠĠĠ']

## 3、保存标记器

In [14]:
tokenizer.save_pretrained("./model/tokenizer/code-search-net-tokenizer")

('./data/tokenizer/code-search-net-tokenizer/tokenizer_config.json',
 './data/tokenizer/code-search-net-tokenizer/special_tokens_map.json',
 './data/tokenizer/code-search-net-tokenizer/vocab.json',
 './data/tokenizer/code-search-net-tokenizer/merges.txt',
 './data/tokenizer/code-search-net-tokenizer/added_tokens.json',
 './data/tokenizer/code-search-net-tokenizer/tokenizer.json')

登录huggingface账号上传

In [None]:
from huggingface_hub import notebook_login

notebook_login()

# 命令行登录
# huggingface-cli login

In [23]:
tokenizer.push_to_hub("code-search-net-tokenizer")

CommitInfo(commit_url='https://huggingface.co/snowfly/code-search-net-tokenizer/commit/3d69cdc7e07b84aeddd710835b7ed6fb7eae149d', commit_message='Upload tokenizer', commit_description='', oid='3d69cdc7e07b84aeddd710835b7ed6fb7eae149d', pr_url=None, pr_revision=None, pr_num=None)

使用上传的tokenizer

In [None]:
tokenizer = AutoTokenizer.from_pretrained("snowfly/code-search-net-tokenizer")