In [None]:
from datasets import load_dataset
from transformers import BertTokenizer, DataCollatorForLanguageModeling
from torch.utils.data import DataLoader

# 加载数据集
dataset = load_dataset(path="csv", data_files="data/news_categorize.csv")["train"]
dataset = dataset.shuffle().select(range(50000))
dataset = dataset.train_test_split(test_size=0.2, shuffle=True)

# 加载分词器
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")


# 定义预处理函数
def preprocess_func(example):
    return tokenizer(
        example["title"],
        truncation=True,
        max_length=64,
        return_special_tokens_mask=True,
    )


encoded_dataset = dataset.map(preprocess_func, batched=True)
encoded_dataset.set_format("torch", columns=["input_ids", "token_type_ids", "attention_mask"])
# 创建MLM数据整理器，动态生成掩码和填充
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=True,
    mlm_probability=0.15,  # 掩码比例
)
# 实例化DataLoader
train_batch_size = 32
test_batch_size = 64
train_dataloader = DataLoader(
    encoded_dataset["train"],
    batch_size=train_batch_size,
    shuffle=True,
    collate_fn=data_collator,
)
test_dataloader = DataLoader(
    encoded_dataset["test"],
    batch_size=test_batch_size,
    shuffle=True,
    collate_fn=data_collator,
)
# 查看掩蔽后的数据
example_batch = next(iter(train_dataloader))
for i in range(5):
    masked_tokens = tokenizer.convert_ids_to_tokens(example_batch["input_ids"][i])
    print("".join(masked_tokens))

In [None]:
import torch
import torch.nn as nn
from transformers import BertModel

# 加载预训练模型
bert_model = BertModel.from_pretrained("bert-base-chinese")


# 自定义模型
class Model(nn.Module):
    def __init__(self, output_size):
        super().__init__()
        self.bert = bert_model
        self.linear = nn.Linear(768, output_size)

    def forward(self, input_ids, attention_mask, token_type_ids):
        # 关闭bert的梯度计算
        with torch.no_grad():
            output = self.bert(input_ids, attention_mask, token_type_ids)
        output = self.linear(output.last_hidden_state)
        return output


model = Model(tokenizer.vocab_size)

In [None]:
def train(model, train_dataloader, test_dataloader, lr, num_epoch, device):
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(num_epoch):
        model.train()
        for batch_count, batch in enumerate(train_dataloader):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            token_type_ids = batch["token_type_ids"].to(device)
            labels = batch["labels"].to(device)
            # 前向传播
            output = model(input_ids, attention_mask, token_type_ids)
            # 反向传播
            loss = criterion(output.view(-1, tokenizer.vocab_size), labels.view(-1))
            optimizer.zero_grad()
            loss.backward()
            #  梯度裁剪
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            if batch_count % 10 == 0:
                preds = torch.argmax(output, dim=-1)
                # labels != -100 为被掩蔽的位置
                mask = labels != -100
                accuracy = (preds[mask] == labels[mask]).sum().item() / mask.sum().item()
                print(f"\repoch:{epoch:0>2}[{'='*(int((batch_count+1) / len(train_dataloader) * 50)):<50}]", end="")
                print(f" loss:{loss}, accuracy={accuracy}")
        # 模型评估
        model.eval()
        accuracy_accumulate = 0
        sample_count = 0
        for batch_count, batch in enumerate(test_dataloader):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            token_type_ids = batch["token_type_ids"].to(device)
            labels = batch["labels"].to(device)
            # 前向传播
            with torch.no_grad():
                output = model(input_ids, attention_mask, token_type_ids)
            # 计算准确率
            preds = torch.argmax(output, dim=-1)
            mask = labels != -100
            this_accuracy = (preds[mask] == labels[mask]).sum().item()
            accuracy_accumulate += this_accuracy
            sample_count += mask.sum().item()
            print(f"\r评估：epoch:{epoch:0>2}[{'='*(int((batch_count+1) / len(test_dataloader) * 50)):<50}]", end="")
            print(f" accuracy={this_accuracy/mask.sum().item()}", end="")
        print(f"\naccuracy: {accuracy_accumulate/sample_count}")


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
lr = 5e-5
num_epoch = 5
train(model, train_dataloader, test_dataloader, lr, num_epoch, device)

In [None]:
def predict(model, input):
    pt_input = tokenizer(
        input,
        padding=True,
        truncation=True,
        max_length=64,
        return_tensors="pt",
        return_attention_mask=True,
    )

    # 获取掩码位置
    mask_token_id = tokenizer.mask_token_id
    mask_positions = (pt_input["input_ids"] == mask_token_id).nonzero(as_tuple=True)

    model.eval()
    with torch.no_grad():
        output = model(**pt_input)
    print(output.shape)
    output = output.argmax(dim=1)
    print(output)
    # 获取输出中概率最大的词的索引
    predicted_idx = output.argmax(dim=-1)
    # 将索引转换为词
    predicted_tokens = tokenizer.convert_ids_to_tokens(predicted_idx)
    # 输出拼接后的句子
    return [sentence.replace("[MASK]", predicted_token) for sentence, predicted_token in zip(text, predicted_tokens)]


model.to("cpu")
text = ["立夏：春[MASK]落尽，夏木成荫", "一个人去爬[MASK]合适吗？"]
res = predict(model, text)
print(res)

In [None]:
def predict(model, text, tokenizer, device="cpu", top_k=5):
    # 处理输入并定位掩码位置
    inputs = tokenizer(
        text,
        padding=True,
        truncation=True,
        max_length=64,
        return_tensors="pt",
        return_attention_mask=True,
        return_token_type_ids=True,
    )

    # 获取掩码位置
    mask_token_id = tokenizer.mask_token_id
    mask_positions = (inputs["input_ids"] == mask_token_id).nonzero(as_tuple=True)

    # 模型预测
    model.to(device)
    model.eval()
    with torch.no_grad():
        outputs = model(
            inputs["input_ids"].to(device), inputs["attention_mask"].to(device), inputs["token_type_ids"].to(device)
        )

    # 处理预测结果
    modified_input_ids = inputs["input_ids"].clone()
    for batch_idx, seq_idx in zip(*mask_positions):
        # 获取对应位置的logits
        logits = outputs[batch_idx, seq_idx]

        # 取topk概率
        topk = torch.topk(logits, top_k)
        probabilities = torch.softmax(topk.values, dim=-1)

        # 按概率随机选择
        selected_idx = torch.multinomial(probabilities, 1)
        selected_token_id = topk.indices[selected_idx]

        # 替换掩码位置
        modified_input_ids[batch_idx, seq_idx] = selected_token_id

    # 解码生成文本
    return tokenizer.batch_decode(modified_input_ids, skip_special_tokens=True)


# 使用示例
model.to("cpu")
text = ["立夏：春[MASK]落尽，夏木成荫", "一个人去爬[MASK]合适吗？"]
res = predict(model, text, tokenizer, device="cpu")
print(res)