In [None]:
from datasets import load_dataset
from transformers import BertTokenizer, DataCollatorForLanguageModeling
from torch.utils.data import DataLoader

# 加载数据集
dataset = load_dataset(path="csv", data_files="data/news_categorize.csv")["train"]
# 过滤掉过短文本
dataset = dataset.filter(lambda x: len(x["title"]) >= 20)
dataset = dataset.shuffle().select(range(20000))
dataset = dataset.train_test_split(test_size=0.2, shuffle=True)

# 加载分词器
tokenizer = BertTokenizer.from_pretrained("D:/Code/PythonProject/model/bert-base-chinese")


# 定义预处理函数
def preprocess_func(example):
    return tokenizer(
        example["title"],
        truncation=True,
        max_length=64,
        return_special_tokens_mask=True,
    )


encoded_dataset = dataset.map(preprocess_func, batched=True)
encoded_dataset.set_format("torch", columns=["input_ids", "token_type_ids", "attention_mask"])
# 创建MLM数据整理器，动态生成掩码和填充
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=True,
    mlm_probability=0.15,  # 掩码比例
)
# 实例化DataLoader
train_dataloader = DataLoader(
    encoded_dataset["train"],
    batch_size=16,
    shuffle=True,
    collate_fn=data_collator,
)
test_dataloader = DataLoader(
    encoded_dataset["test"],
    batch_size=64,
    shuffle=True,
    collate_fn=data_collator,
)
# 查看掩蔽后的数据
example_batch = next(iter(train_dataloader))
for i in range(4):
    masked_tokens = tokenizer.convert_ids_to_tokens(example_batch["input_ids"][i])
    print("".join(masked_tokens))

In [None]:
import torch
import torch.nn as nn
from transformers import BertModel


# 自定义模型
class Model(nn.Module):
    def __init__(self, output_size):
        super().__init__()
        # 加载bert预训练模型
        self.bert = BertModel.from_pretrained("D:/Code/PythonProject/model/bert-base-chinese")
        self.linear = nn.Linear(768, output_size)

        # 冻结bert所有参数
        for param in self.bert.parameters():
            param.requires_grad = False

    def forward(self, input_ids, attention_mask, token_type_ids):
        output = self.bert(input_ids, attention_mask, token_type_ids)
        output = self.linear(output.last_hidden_state)
        return output


model = Model(tokenizer.vocab_size)

In [None]:
def train(model, train_dataloader, test_dataloader, lr, num_epoch, device):
    model.to(device)
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    # CrossEntropyLoss 默认 ignore_index: int = -100，所以不用再设置 ignore_index 参数
    criterion = nn.CrossEntropyLoss()
    for epoch in range(num_epoch):
        model.train()
        for batch_count, batch in enumerate(train_dataloader):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            token_type_ids = batch["token_type_ids"].to(device)
            labels = batch["labels"].to(device)
            # 前向传播
            output = model(input_ids, attention_mask, token_type_ids)
            # 反向传播
            loss = criterion(output.view(-1, tokenizer.vocab_size), labels.view(-1))
            optimizer.zero_grad()
            loss.backward()
            #  梯度裁剪
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            if batch_count % 10 == 0:
                # labels != -100 为被掩蔽的位置
                mask = labels != -100
                preds = torch.argmax(output, dim=-1)
                accuracy = (preds[mask] == labels[mask]).sum().item() / mask.sum().item()
                print(f"\repoch:{epoch:0>2}[{'='*(int((batch_count+1) / len(train_dataloader) * 50)):<50}]", end="")
                print(f" loss:{loss}, accuracy={accuracy}")
        # 模型评估
        model.eval()
        accuracy_accumulate = 0
        sample_count = 0
        for batch_count, batch in enumerate(test_dataloader):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            token_type_ids = batch["token_type_ids"].to(device)
            labels = batch["labels"].to(device)
            # 前向传播
            with torch.no_grad():
                output = model(input_ids, attention_mask, token_type_ids)
            # 计算准确率
            mask = labels != -100
            preds = torch.argmax(output, dim=-1)
            this_accuracy = (preds[mask] == labels[mask]).sum().item()
            accuracy_accumulate += this_accuracy
            sample_count += mask.sum().item()
            print(f"\r评估：epoch:{epoch:0>2}[{'='*(int((batch_count+1) / len(test_dataloader) * 50)):<50}]", end="")
            print(f" accuracy={this_accuracy/mask.sum().item()}", end="")
        print(f"\naccuracy: {accuracy_accumulate/sample_count}")


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
lr = 5e-5
num_epoch = 2
train(model, train_dataloader, test_dataloader, lr, num_epoch, device)