In [1]:
from transformers import AutoTokenizer, BertConfig, TrainingArguments, Trainer
from bert.CustomBertModel import DataCollatorForMultiMask
from MoELayer import BertWwmMoE
from datasets import Dataset
from ltp import LTP

# https://github.com/huggingface/transformers/blob/main/examples/research_projects/mlm_wwm/run_chinese_ref.py
from bert.run_chinese_ref import prepare_ref

import random
import torch


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
random.seed(123)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ltp = LTP().to(device=device)

tokenizer = AutoTokenizer.from_pretrained("Midsummra/CNMBert-MoE")
config = BertConfig.from_pretrained('Midsummra/CNMBert-MoE')
model = BertWwmMoE.from_pretrained('Midsummra/CNMBert-MoE', config=config).to('cuda')

  state_dict = torch.load(model_file, map_location=map_location)
BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.
Some weights of the model checkpoint at Midsummra/CNMBert-MoE were not used when initializing BertWwmMoE: ['bert.encoder.layer.0.intermediate.dense.sparse_moe.bias', 'bert.encoder.layer.10.intermediate.dense.sparse_moe.bias', 'bert.encoder.layer.12.intermediate.dense.sparse_moe.bias', 'bert.en

In [3]:
# 数据预处理

text = set()
bilibili = set()
with open('../webtext/train.csv', mode='r', encoding='utf-8') as file:
    line = file.readline()
    while True:
        if not line:
            break
        text.add(line)
        line = file.readline()
with open('../webtext/bilibili.csv', mode='r', encoding='utf-8') as file:
    line = file.readline()
    while True:
        if not line:
            break
        bilibili.add(line)
        line = file.readline()

text = [t.replace('\n', '') for t in list(text)]
bilibili = [t.replace('\n', '') for t in list(bilibili)]
random.shuffle(text)
random.shuffle(bilibili)

train_data = {'text': text[:750000] + bilibili[:750000]}
eval_data = {'text': text[len(text) - 20000:] + bilibili[len(bilibili) - 20000:]}

train_data = Dataset.from_dict(train_data)
eval_data = Dataset.from_dict(eval_data)

In [4]:
def tokenize_func(dataset):
    tokens = tokenizer(dataset['text'],
                       max_length=64,
                       padding='max_length',
                       truncation=True,
                       return_tensors='pt'
                       )
    ref = prepare_ref(dataset['text'], ltp, tokenizer)
    features = {'input_ids': tokens['input_ids'], 'chinese_ref': ref, 'attention_mask': tokens['attention_mask']}
    return features

data_collator = DataCollatorForMultiMask(tokenizer,
                                             mlm_probability=0.15,
                                             mlm=True,
                                             pad_to_multiple_of=64)

train_dataset = train_data.map(tokenize_func, batched=True, remove_columns=["text"])
eval_dataset = eval_data.map(tokenize_func, batched=True, remove_columns=["text"])


Map: 100%|██████████| 1500000/1500000 [17:24<00:00, 1435.58 examples/s]
Map: 100%|██████████| 40000/40000 [00:26<00:00, 1508.86 examples/s]


In [5]:
for val in eval_dataset.__iter__():
    if len(val['input_ids']) == 64:
        print(val)

{'input_ids': [101, 2792, 809, 1920, 2157, 833, 1355, 4385, 8024, 2769, 812, 4495, 3833, 704, 6432, 671, 702, 782, 758, 7509, 679, 1059, 4638, 3198, 952, 8024, 1071, 2141, 1920, 1914, 3221, 2501, 2159, 671, 702, 782, 4638, 4028, 1548, 8024, 5445, 7478, 727, 2697, 8024, 1316, 2575, 4526, 749, 6821, 702, 782, 1377, 5543, 5688, 1941, 2697, 6820, 679, 7231, 102, 0, 0, 0], 'chinese_ref': [2, 4, 7, 10, 12, 16, 19, 20, 21, 24, 27, 29, 32, 34, 38, 43, 47, 50, 53, 55, 56, 59], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]}
{'input_ids': [101, 2791, 2094, 1297, 749, 8024, 6756, 6158, 965, 712, 2458, 6624, 749, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'chinese_ref': [2, 9], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [6]:
for name, param in model.named_parameters():
    if name.startswith('bert.embeddings.'):
        param.requires_grad = True
    else:
        param.requires_grad = False
    if param.requires_grad:
        print(name)

bert.embeddings.word_embeddings.weight
bert.embeddings.position_embeddings.weight
bert.embeddings.token_type_embeddings.weight
bert.embeddings.LayerNorm.weight
bert.embeddings.LayerNorm.bias


In [7]:
# 训练

torch.manual_seed(42)

model = model.to(device)
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"Trainable layer: {name}")
    param.data = param.data.contiguous()


training_args = TrainingArguments(
    output_dir='./model/checkpoints/',
    num_train_epochs=20,
    per_device_train_batch_size=128,
    eval_strategy='steps',
    eval_steps=500,
    learning_rate=1e-5,  #学习率建议给1e-5~2e-5
    weight_decay=1e-5,
    logging_dir='./model/logs/',
    logging_steps=100,
    logging_first_step=True,
    save_strategy='steps',
    save_steps=100,
    save_total_limit=4,
    max_grad_norm=1.0,
    warmup_ratio=1 / 20,
    disable_tqdm=True
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
)


Trainable layer: bert.embeddings.word_embeddings.weight
Trainable layer: bert.embeddings.position_embeddings.weight
Trainable layer: bert.embeddings.token_type_embeddings.weight
Trainable layer: bert.embeddings.LayerNorm.weight
Trainable layer: bert.embeddings.LayerNorm.bias


In [8]:
trainer.train()
trainer.save_model('./model/cnmbert-ft')
eval_results = trainer.evaluate()
print(f"Evaluation cnmbert-ft: {eval_results}")

  attn_output = torch.nn.functional.scaled_dot_product_attention(


{'loss': 0.8008, 'grad_norm': 1.6644216775894165, 'learning_rate': 8.533151292772422e-10, 'epoch': 8.53315129277242e-05}


KeyboardInterrupt: 