In [None]:
from transformers import AutoTokenizer, BertConfig, TrainingArguments, Trainer
from CustomBertModel import fixed_predict, DataCollatorForMultiMask
from MoELayer import BertWwmMoE
from datasets import Dataset
from ltp import LTP

# https://github.com/huggingface/transformers/blob/main/examples/research_projects/mlm_wwm/run_chinese_ref.py
from bert.run_chinese_ref import prepare_ref

import random
import torch


In [None]:
random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ltp = LTP()

tokenizer = AutoTokenizer.from_pretrained("Midsummra/CNMBert-MoE")
config = BertConfig.from_pretrained('Midsummra/CNMBert-MoE')
model = BertWwmMoE.from_pretrained('Midsummra/CNMBert-MoE', config=config).to('cuda')

In [None]:
# 数据预处理

text = set()
with open('./webtext/train.csv', mode='r', encoding='utf-8') as file:
    line = file.readline()
    while True:
        if not line:
            break
        text.add(line)
        line = file.readline()

text = [t.replace('\n', '') for t in list(text)]
random.shuffle(text)

train_data = {'text': text[1000:]}
eval_data = {'text': text[:1000]}


In [None]:
def tokenize_func(dataset):
    tokens = tokenizer(dataset['text'],
                       max_length=64,
                       padding='max_length',
                       truncation=True,
                       return_tensors='pt'
                       )
    ref = prepare_ref(dataset['text'], ltp, tokenizer)
    features = {'input_ids': tokens['input_ids'], 'chinese_ref': ref, 'attention_mask': tokens['attention_mask']}
    return features

data_collator = DataCollatorForMultiMask(tokenizer,
                                             mlm_probability=0.15,
                                             mlm=True)

train_dataset = train_data.map(tokenize_func, batched=True, remove_columns=["text"])
eval_dataset = eval_data.map(tokenize_func, batched=True, remove_columns=["text"])


In [None]:
# 训练

torch.manual_seed(42)

model = model.to(device)
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"Trainable layer: {name}")
    param.data = param.data.contiguous()


training_args = TrainingArguments(
    output_dir='./model/checkpoints/',
    num_train_epochs=20,
    per_device_train_batch_size=256,
    eval_strategy='steps',
    eval_steps=500,
    learning_rate=2e-5,  #学习率建议给1e-5~2e-5
    weight_decay=1e-5,
    logging_dir='./model/logs/',
    logging_steps=100,
    logging_first_step=True,
    save_strategy='steps',
    save_steps=500,
    save_total_limit=4,
    max_grad_norm=1.0,
    warmup_ratio=1 / 20
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
)


In [None]:
trainer.train()
trainer.save_model('./model/cnmbert-ft')
eval_results = trainer.evaluate()
print(f"Evaluation cnmbert-ft: {eval_results}")