# 使用预训练的BERT模型解决文本二分类问题

## 导入前置依赖

In [1]:
import os
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
# 用于加载bert模型的分词器
from transformers import AutoTokenizer
# 用于加载bert模型
from transformers import BertModel
from pathlib import Path

## 设置全局配置

In [2]:
batch_size = 16
# 文本的最大长度
text_max_length = 128
# 总训练的epochs数，我只是随便定义了个数
epochs = 100
# 学习率
lr = 3e-5
# 取多少训练集的数据作为验证集
validation_ratio = 0.1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 每多少步，打印一次loss
log_per_step = 50

# 数据集所在位置
dataset_dir = Path("./基于论文摘要的文本分类与关键词抽取挑战赛公开数据")
os.makedirs(dataset_dir) if not os.path.exists(dataset_dir) else ''

# 模型存储路径
model_dir = Path("./model/bert_checkpoints")
# 如果模型目录不存在，则创建一个
os.makedirs(model_dir) if not os.path.exists(model_dir) else ''

print("Device:", device)

Device: cuda


## 进行数据读取与数据预处理

In [3]:
# 读取数据集，进行数据处理

pd_train_data = pd.read_csv('./基于论文摘要的文本分类与关键词抽取挑战赛公开数据/train.csv')
pd_train_data['title'] = pd_train_data['title'].fillna('')
pd_train_data['abstract'] = pd_train_data['abstract'].fillna('')

test_data = pd.read_csv('./基于论文摘要的文本分类与关键词抽取挑战赛公开数据/test.csv')
test_data['title'] = test_data['title'].fillna('')
test_data['abstract'] = test_data['abstract'].fillna('')
pd_train_data['text'] = pd_train_data['title'].fillna('') + ' ' +  pd_train_data['author'].fillna('') + ' ' + pd_train_data['abstract'].fillna('')+ ' ' + pd_train_data['Keywords'].fillna('')
test_data['text'] = test_data['title'].fillna('') + ' ' +  test_data['author'].fillna('') + ' ' + test_data['abstract'].fillna('')+ ' ' + pd_train_data['Keywords'].fillna('')


In [4]:
pd_train_data

Unnamed: 0,uuid,title,author,abstract,Keywords,label,text
0,0,Accessible Visual Artworks for Blind and Visua...,"Quero, Luis Cavazos; Bartolome, Jorge Iranzo; ...",Despite the use of tactile graphics and audio ...,accessibility technology; multimodal interacti...,0,Accessible Visual Artworks for Blind and Visua...
1,1,Seizure Detection and Prediction by Parallel M...,"Li, Chenqi; Lammie, Corey; Dong, Xuening; Amir...","During the past two decades, epileptic seizure...",CNN; Seizure Detection; Seizure Prediction; EE...,1,Seizure Detection and Prediction by Parallel M...
2,2,Fast ScanNet: Fast and Dense Analysis of Multi...,"Lin, Huangjing; Chen, Hao; Graham, Simon; Dou,...",Lymph node metastasis is one of the most impor...,Histopathology image analysis; computational p...,1,Fast ScanNet: Fast and Dense Analysis of Multi...
3,3,Long-Term Effectiveness of Antiretroviral Ther...,"Huang, Peng; Tan, Jingguang; Ma, Wenzhe; Zheng...",In order to assess the effectiveness of the Ch...,HIV; ART; mortality; observational cohort stud...,0,Long-Term Effectiveness of Antiretroviral Ther...
4,4,Real-Time Facial Affective Computing on Mobile...,"Guo, Yuanyuan; Xia, Yifan; Wang, Jing; Yu, Hui...",Convolutional Neural Networks (CNNs) have beco...,facial affective computing; convolutional neur...,0,Real-Time Facial Affective Computing on Mobile...
...,...,...,...,...,...,...,...
5995,5995,ADAPT: A Write Disturbance-Aware Programming T...,"Imran, Muhammad; Kwon, Taehyun; Yang, Joon-Sung","Phase change memory (PCM) is an emerging, resi...",Phase change materials; Programming; Encoding;...,0,ADAPT: A Write Disturbance-Aware Programming T...
5996,5996,Oxy-fuel combustion technology for cement prod...,"Carrasco-Maldonado, Francisco; Spoerl, Reinhol...",Oxy-fuel combustion stands as a promising carb...,Oxy-fuel; Cement sector; Carbon capture; Indus...,0,Oxy-fuel combustion technology for cement prod...
5997,5997,Ultrasonic‑assisted molten salt hydrates pretr...,"Soo Ling Chong,Inn Shi Tan,Henry Chee Yew Foo,...",This study aims to establish an efficient pret...,Eucheuma cottonii residues; L-lactic acid; dil...,1,Ultrasonic‑assisted molten salt hydrates pretr...
5998,5998,Transcriptional Profiling of Pseudomonas aerug...,"Janne G Thöming,Susanne Häussler,Janne G Thömi...",Pseudomonas aeruginosa is an opportunistic pat...,Adaptation; Biofilm formation; Genotype–phenot...,1,Transcriptional Profiling of Pseudomonas aerug...


In [5]:
# 从训练集中随机采样测试集
validation_data = pd_train_data.sample(frac=validation_ratio)
train_data = pd_train_data[~pd_train_data.index.isin(validation_data.index)]

查看训练数据

In [6]:
train_data

Unnamed: 0,uuid,title,author,abstract,Keywords,label,text
0,0,Accessible Visual Artworks for Blind and Visua...,"Quero, Luis Cavazos; Bartolome, Jorge Iranzo; ...",Despite the use of tactile graphics and audio ...,accessibility technology; multimodal interacti...,0,Accessible Visual Artworks for Blind and Visua...
1,1,Seizure Detection and Prediction by Parallel M...,"Li, Chenqi; Lammie, Corey; Dong, Xuening; Amir...","During the past two decades, epileptic seizure...",CNN; Seizure Detection; Seizure Prediction; EE...,1,Seizure Detection and Prediction by Parallel M...
3,3,Long-Term Effectiveness of Antiretroviral Ther...,"Huang, Peng; Tan, Jingguang; Ma, Wenzhe; Zheng...",In order to assess the effectiveness of the Ch...,HIV; ART; mortality; observational cohort stud...,0,Long-Term Effectiveness of Antiretroviral Ther...
4,4,Real-Time Facial Affective Computing on Mobile...,"Guo, Yuanyuan; Xia, Yifan; Wang, Jing; Yu, Hui...",Convolutional Neural Networks (CNNs) have beco...,facial affective computing; convolutional neur...,0,Real-Time Facial Affective Computing on Mobile...
5,5,A phenotype-based forward genetic screen ident...,"Yonghe Ding,Di Lang,Jianhua Yan,Haisong Bu,Hon...",Previously we showed the generation of a prote...,Dnajb6; electrocardiogram; genetic diseases; g...,1,A phenotype-based forward genetic screen ident...
...,...,...,...,...,...,...,...
5994,5994,Dual-level attention-aware network for tempora...,"Sun, Bo; Guo, Meng; Cao, Siming; He, Jun; Yu, ...",Human emotions are known to always have four p...,attention mechanism; feature dynamical fusion;...,0,Dual-level attention-aware network for tempora...
5996,5996,Oxy-fuel combustion technology for cement prod...,"Carrasco-Maldonado, Francisco; Spoerl, Reinhol...",Oxy-fuel combustion stands as a promising carb...,Oxy-fuel; Cement sector; Carbon capture; Indus...,0,Oxy-fuel combustion technology for cement prod...
5997,5997,Ultrasonic‑assisted molten salt hydrates pretr...,"Soo Ling Chong,Inn Shi Tan,Henry Chee Yew Foo,...",This study aims to establish an efficient pret...,Eucheuma cottonii residues; L-lactic acid; dil...,1,Ultrasonic‑assisted molten salt hydrates pretr...
5998,5998,Transcriptional Profiling of Pseudomonas aerug...,"Janne G Thöming,Susanne Häussler,Janne G Thömi...",Pseudomonas aeruginosa is an opportunistic pat...,Adaptation; Biofilm formation; Genotype–phenot...,1,Transcriptional Profiling of Pseudomonas aerug...


## 构建训练所需的dataloader与dataset

In [7]:
# 构建Dataset
class MyDataset(Dataset):

    def __init__(self, mode='train'):
        super(MyDataset, self).__init__()
        self.mode = mode
        # 拿到对应的数据
        if mode == 'train':
            self.dataset = train_data
        elif mode == 'validation':
            self.dataset = validation_data
        elif mode == 'test':
            # 如果是测试模式，则返回内容和uuid。拿uuid做target主要是方便后面写入结果。
            self.dataset = test_data
        else:
            raise Exception("Unknown mode {}".format(mode))

    def __getitem__(self, index):
        # 取第index条
        data = self.dataset.iloc[index]
        # 取其内容
        text = data['text']
        # 根据状态返回内容
        if self.mode == 'test':
            # 如果是test，将uuid做为target
            label = data['uuid']
        else:
            label = data['label']
        # 返回内容和label
        return text, label

    def __len__(self):
        return len(self.dataset)


In [8]:
train_dataset = MyDataset('train')
validation_dataset = MyDataset('validation')

In [9]:
train_dataset.__getitem__(0)

('Accessible Visual Artworks for Blind and Visually Impaired People: Comparing a Multimodal Approach with Tactile Graphics Quero, Luis Cavazos; Bartolome, Jorge Iranzo; Cho, Jundong Despite the use of tactile graphics and audio guides, blind and visually impaired people still face challenges to experience and understand visual artworks independently at art exhibitions. Art museums and other art places are increasingly exploring the use of interactive guides to make their collections more accessible. In this work, we describe our approach to an interactive multimodal guide prototype that uses audio and tactile modalities to improve the autonomous access to information and experience of visual artworks. The prototype is composed of a touch-sensitive 2.5D artwork relief model that can be freely explored by touch. Users can access localized verbal descriptions and audio by performing touch gestures on the surface while listening to themed background music along. We present the design requi

获取Bert预训练模型

In [10]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

接着构造我们的Dataloader。我们需要定义一下collate_fn，在其中完成对句子进行编码、填充、组装batch等动作：

In [11]:
def collate_fn(batch):
    """
    将一个batch的文本句子转成tensor，并组成batch。
    :param batch: 一个batch的句子，例如: [('推文', target), ('推文', target), ...]
    :return: 处理后的结果，例如：
             src: {'input_ids': tensor([[ 101, ..., 102, 0, 0, ...], ...]), 'attention_mask': tensor([[1, ..., 1, 0, ...], ...])}
             target：[1, 1, 0, ...]
    """
    text, label = zip(*batch)
    text, label = list(text), list(label)

    # src是要送给bert的，所以不需要特殊处理，直接用tokenizer的结果即可
    # padding='max_length' 不够长度的进行填充
    # truncation=True 长度过长的进行裁剪
    src = tokenizer(text, padding='max_length', max_length=text_max_length, return_tensors='pt', truncation=True)

    return src, torch.LongTensor(label)


In [12]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
validation_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

![](https://pic3.zhimg.com/v2-698a3f82f8c1b8be42ab82a4e8472216_r.jpg)
![](https://pic3.zhimg.com/80/v2-b4af6f8d2d27c74b473015aadde6ebea_720w.webp)

In [13]:
inputs, targets = next(iter(train_loader))
print("inputs:", inputs)
print("targets:", targets)

inputs: {'input_ids': tensor([[  101,  8382, 11268,  ...,  2824, 25505,   102],
        [  101,  5884, 19293,  ...,  2013,  2367,   102],
        [  101, 23208,  1996,  ...,  1011, 22307,   102],
        ...,
        [  101,  3435,  1998,  ...,  2109,  2000,   102],
        [  101,  5415,  8332,  ...,  1010,  2005,   102],
        [  101,  6260, 16066,  ...,  1037,  2555,   102]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]])}
targets: tensor([1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1])


定义预测模型，该模型由bert模型加上最后的预测层组成

In [14]:
class MyModel(nn.Module):

    def __init__(self):
        super(MyModel, self).__init__()

        # 加载bert模型
        self.bert = BertModel.from_pretrained('bert-base-uncased', mirror='tuna')

        # 最后的预测层
        self.predictor = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, src):
        """
        :param src: 分词后的推文数据
        """

        # 将src直接序列解包传入bert，因为bert和tokenizer是一套的，所以可以这么做。
        # 得到encoder的输出，用最前面[CLS]的输出作为最终线性层的输入
        outputs = self.bert(**src).last_hidden_state[:, 0, :]

        # 使用线性层来做最终的预测
        return self.predictor(outputs)


BERT模型在文本前插入一个[CLS]符号，并将该符号对应的输出向量作为整篇文本的语义表示，用于文本分类，如下图所示。可以理解为：与文本中已有的其它字/词相比，这个无明显语义信息的符号会更“公平”地融合文本中各个字/词的语义信息。

In [15]:
model = MyModel()
model = model.to(device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


定义出损失函数和优化器。这里使用Binary Cross Entropy：

In [16]:
criteria = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)


In [17]:
# 由于inputs是字典类型的，定义一个辅助函数帮助to(device)
def to_device(dict_tensors):
    result_tensors = {}
    for key, value in dict_tensors.items():
        result_tensors[key] = value.to(device)
    return result_tensors


定义一个验证方法，获取到验证集的精准率和loss。

In [18]:
def validate():
    model.eval()
    total_loss = 0.
    total_correct = 0
    for inputs, targets in validation_loader:
        inputs, targets = to_device(inputs), targets.to(device)
        outputs = model(inputs)
        loss = criteria(outputs.view(-1), targets.float())
        total_loss += float(loss)

        correct_num = (((outputs >= 0.5).float() * 1).flatten() == targets).sum()
        total_correct += correct_num

    return total_correct / len(validation_dataset), total_loss / len(validation_dataset)


训练部分

In [19]:
# 首先将模型调成训练模式
model.train()

# 清空一下cuda缓存
if torch.cuda.is_available():
    torch.cuda.empty_cache()

# 定义几个变量，帮助打印loss
total_loss = 0.
# 记录步数
step = 0

# 记录在验证集上最好的准确率
best_accuracy = 0

# 开始训练
for epoch in range(epochs):
    model.train()
    for i, (inputs, targets) in enumerate(train_loader):
        # 从batch中拿到训练数据
        inputs, targets = to_device(inputs), targets.to(device)
        # 传入模型进行前向传递
        outputs = model(inputs)
        # 计算损失
        loss = criteria(outputs.view(-1), targets.float())
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        total_loss += float(loss)
        step += 1

        if step % log_per_step == 0:
            print("Epoch {}/{}, Step: {}/{}, total loss:{:.4f}".format(epoch+1, epochs, i, len(train_loader), total_loss))
            total_loss = 0

        del inputs, targets

    # 一个epoch后，使用过验证集进行验证
    accuracy, validation_loss = validate()
    print("Epoch {}, accuracy: {:.4f}, validation loss: {:.4f}".format(epoch+1, accuracy, validation_loss))
    torch.save(model, model_dir / f"model_{epoch}.pt")

    # 保存最好的模型
    if accuracy > best_accuracy:
        torch.save(model, model_dir / f"model_best.pt")
        best_accuracy = accuracy


KeyboardInterrupt: 

加载最好的模型，然后进行测试集的预测

In [None]:
model = torch.load(model_dir / f"model_best.pt")
model = model.eval()


In [None]:
test_dataset = MyDataset('test')
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)


将测试数据送入模型，得到结果，最后输出为要求格式

In [None]:
results = []
for inputs, ids in test_loader:
    outputs = model(inputs.to(device))
    outputs = (outputs >= 0.5).int().flatten().tolist()
    ids = ids.tolist()
    results = results + [(id, result) for result, id in zip(outputs, ids)]


In [None]:
test_label = [pair[1] for pair in results]
test_data['label'] = test_label
test_data[['uuid', 'Keywords', 'label']].to_csv('submit_task1.csv', index=None)