In [None]:
from datasets import load_dataset
from transformers import BertTokenizer
from torch.utils.data import DataLoader

# 加载数据集
dataset = load_dataset(path="parquet", data_files="data/zh_bert_nsp.parquet")["train"]
dataset = dataset.shuffle().select(range(1000))
dataset = dataset.train_test_split(test_size=0.2, shuffle=True)

# 加载分词器
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")


# 定义预处理函数
def preprocess_func(example):
    encoding = tokenizer(
        example["sentence1"],
        example["sentence2"],
        padding="max_length",
        truncation=True,
        max_length=512,
        return_token_type_ids=True,
    )
    encoding["labels"] = example["label"]
    return encoding


encoded_dataset = dataset.map(preprocess_func, batched=True)
encoded_dataset.set_format("torch", columns=["input_ids", "token_type_ids", "attention_mask", "labels"])
# 实例化DataLoader
train_batch_size = 32
test_batch_size = 64
train_dataloader = DataLoader(encoded_dataset["train"], batch_size=train_batch_size, shuffle=True)
test_dataloader = DataLoader(encoded_dataset["test"], batch_size=test_batch_size, shuffle=True)

In [None]:
import torch
import torch.nn as nn
from transformers import BertModel


# 自定义模型
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        # 加载bert预训练模型
        self.bert = BertModel.from_pretrained("bert-base-chinese")
        self.linear = nn.Linear(768, 2)

        # 冻结bert的所有参数
        for param in self.bert.parameters():
            param.requires_grad = False

    def forward(self, input_ids, attention_mask, token_type_ids):
        output = self.bert(input_ids, attention_mask, token_type_ids)
        output = self.linear(output.last_hidden_state[:, 0])  # 获取[CLS]对应的隐状态
        return output


model = Model()

In [None]:
def train(model, train_dataloader, test_dataloader, lr, num_epoch, device):
    model.to(device)
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(num_epoch):
        model.train()
        for batch_count, batch in enumerate(train_dataloader):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            token_type_ids = batch["token_type_ids"].to(device)
            labels = batch["labels"].to(device)
            # 前向传播
            output = model(input_ids, attention_mask, token_type_ids)
            # 反向传播
            loss = criterion(output, labels)
            optimizer.zero_grad()
            loss.backward()
            #  梯度裁剪
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            if batch_count % 10 == 0:
                preds = torch.argmax(output, dim=1)
                accuracy = (preds == labels).sum().item() / len(labels)
                print(f"\repoch:{epoch:0>2}[{'='*(int((batch_count+1) / len(train_dataloader) * 50)):<50}]", end="")
                print(f" loss:{loss}, accuracy={accuracy}")
        # 模型评估
        model.eval()
        accuracy_accumulate = 0
        sample_count = 0
        for batch_count, batch in enumerate(test_dataloader):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            token_type_ids = batch["token_type_ids"].to(device)
            labels = batch["labels"].to(device)
            # 前向传播
            with torch.no_grad():
                output = model(input_ids, attention_mask, token_type_ids)
            # 计算准确率
            preds = torch.argmax(output, dim=1)
            this_accuracy = (preds == labels).sum().item()
            accuracy_accumulate += this_accuracy
            sample_count += len(labels)
            print(f"\r评估：epoch:{epoch:0>2}[{'='*(int((batch_count+1) / len(test_dataloader) * 50)):<50}]", end="")
            print(f" accuracy={this_accuracy/len(labels)}", end="")
        print(f"\naccuracy: {accuracy_accumulate/sample_count}")


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
lr = 5e-5
num_epoch = 3
train(model, train_dataloader, test_dataloader, lr, num_epoch, device)

In [None]:
def predict(model, input):
    pt_input = tokenizer(
        input[0],
        input[1],
        padding="max_length",
        truncation=True,
        max_length=512,
        return_tensors="pt",
        return_token_type_ids=True,
    )
    model.eval()
    with torch.no_grad():
        output = model(**pt_input)
    return output.argmax(dim=-1).item()


model.to("cpu")
text = ["春红落尽", "夏木成荫"]
res = predict(model, text)
print(res)