# ERNIE原理

- 对于中文，bert使用的基于字的处理，在mask时掩盖的也仅仅是一个单字。对于中文语料，学习到的模型能很简单地推测出字搭配，但是并不会学习到短语或者实体的语义信息。
- ERNIE模型在BERT的基础上，加入了海量语料中的实体、短语等先验语义知识，建模真实世界的语义关系。在训练时将短语、实体等先验知识进行mask，强迫模型对其进行建模，学习它们的语义表示。               
            
    <img src="../images/ERNIE-mask.PNG" width="80%" alt="ERINE MASK">
       


具体来说， ERNIE 采用三种 masking 策略：

    Basic-Level Masking： 跟 bert 一样对单字进行 mask，很难学习到高层次的语义信息；
    Phrase-Level Masking： 输入仍然是单字级别的，mask连续短语；
    Entity-Level Masking： 首先进行实体识别，然后将识别出的实体进行 mask。


<img src="../images/ERNIE-masks.PNG" width="100%" alt="ERNIE masking">

预训练阶段时采取上述掩码策略进行训练，获得每个汉字的嵌入向量；优调时，使用方法和BERT一样，**文本分词时仍然是按字分词**

# 基于ERNIE模型的文本分类

In [1]:
import time
from tqdm import tqdm
from datetime import timedelta


import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import BertModel, BertTokenizer
from transformers import AdamW, get_linear_schedule_with_warmup

from sklearn import metrics

## 模型配置

In [2]:
class Config(object):
    """配置参数"""
    def __init__(self):
        self.model_name = 'ERNIE'
        dataset = "../../H/datasets/THUCNews/"
        self.train_path = dataset + '/train.txt'  # 训练集
        self.dev_path = dataset + '/dev.txt'  # 验证集
        self.test_path = dataset + '/test.txt'  # 测试集
        self.class_list = [
            x.strip() for x in open(dataset + '/class.txt').readlines()
        ]  # 类别名单
        self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt'  # 模型训练结果
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')  # 设备

        self.require_improvement = 1000  # 若超过1000batch效果还没提升，则提前结束训练
        self.num_classes = len(self.class_list)  # 类别数
        self.num_epochs = 3  # epoch数
        self.batch_size = 128  # mini-batch大小
        self.pad_size = 32  # 每句话处理成的长度(短填长切)
        self.learning_rate = 5e-5  # 学习率
        self.bert_path = '../../H/models/huggingface/ERNIE/'
        self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)
        self.hidden_size = 768

In [3]:
config = Config()

## 创建模型

In [4]:
class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.bert = BertModel.from_pretrained(config.bert_path)
        for param in self.bert.parameters():
            param.requires_grad = True
        self.fc = nn.Linear(config.hidden_size, config.num_classes)

    def forward(self, x):
        context = x[0]  # 输入的句子
        mask = x[
            2]  # 对padding部分进行mask，和句子一个size，padding部分用0表示，如：[1, 1, 1, 1, 0, 0]
        _, pooled = self.bert(context,
                              attention_mask=mask)
        out = self.fc(pooled)
        return out

In [5]:
model = Model(config)

In [6]:
model

Model(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(18000, 768, padding_idx=0)
      (position_embeddings): Embedding(513, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      

## 创建数据集

In [7]:
from tqdm import tqdm

def build_dataset(config):
    def load_dataset(path, pad_size=32):
        contents = []
        with open(path, 'r', encoding='UTF-8') as f:
            for line in tqdm(f):
                lin = line.strip()
                if not lin:
                    continue
                content, label = lin.split('\t')

                token_ids = config.tokenizer.encode(
                    content,
                    add_special_tokens=True,  # 添加特殊符号
                )

                mask = []
                seq_len = len(token_ids)

                if pad_size:
                    if seq_len < pad_size:
                        mask = [1] * len(token_ids) + [0
                                                       ] * (pad_size - seq_len)
                        token_ids += ([0] * (pad_size - seq_len))
                    else:
                        mask = [1] * pad_size
                        token_ids = token_ids[:pad_size]
                        seq_len = pad_size
                contents.append((token_ids, int(label), seq_len, mask))
        return contents

    train = load_dataset(config.train_path, config.pad_size)
    dev = load_dataset(config.dev_path, config.pad_size)
    test = load_dataset(config.test_path, config.pad_size)
    return train, dev, test


class DatasetIterater(object):
    def __init__(self, batches, batch_size, device):
        self.batch_size = batch_size
        self.batches = batches
        self.n_batches = len(batches) // batch_size
        self.residue = False  # 记录batch数量是否为整数
        if len(batches) % self.n_batches != 0:
            self.residue = True
        self.index = 0
        self.device = device

    def _to_tensor(self, datas):
        x = torch.LongTensor([_[0] for _ in datas]).to(self.device)
        y = torch.LongTensor([_[1] for _ in datas]).to(self.device)

        # pad前的长度(超过pad_size的设为pad_size)
        seq_len = torch.LongTensor([_[2] for _ in datas]).to(self.device)
        mask = torch.LongTensor([_[3] for _ in datas]).to(self.device)
        return (x, seq_len, mask), y

    def __next__(self):
        if self.residue and self.index == self.n_batches:
            batches = self.batches[self.index *
                                   self.batch_size:len(self.batches)]
            self.index += 1
            batches = self._to_tensor(batches)
            return batches

        elif self.index >= self.n_batches:
            self.index = 0
            raise StopIteration
        else:
            batches = self.batches[self.index *
                                   self.batch_size:(self.index + 1) *
                                   self.batch_size]
            self.index += 1
            batches = self._to_tensor(batches)
            return batches

    def __iter__(self):
        return self

    def __len__(self):
        if self.residue:
            return self.n_batches + 1
        else:
            return self.n_batches


def build_iterator(dataset, config):
    iter = DatasetIterater(dataset, config.batch_size, config.device)
    return iter


def get_time_dif(start_time):
    """获取已使用时间"""
    end_time = time.time()
    time_dif = end_time - start_time
    return timedelta(seconds=int(round(time_dif)))

In [8]:
%%time
train_data, dev_data, test_data = build_dataset(config)
train_iter = build_iterator(train_data, config)
dev_iter = build_iterator(dev_data, config)
test_iter = build_iterator(test_data, config)

180000it [00:17, 10078.27it/s]
10000it [00:00, 10299.09it/s]
10000it [00:00, 10382.67it/s]

CPU times: user 19.8 s, sys: 74.2 ms, total: 19.9 s
Wall time: 19.8 s





## 训练模型

In [9]:
def train(config, model, train_iter, dev_iter, test_iter):
    start = time.time()
    
    # 优化器
    optimizer = AdamW(model.parameters(), 
                      lr=config.learning_rate, 
                      correct_bias=False)  # To reproduce BertAdam specific behavior set correct_bias=False
    total_steps = len(train_iter) * config.num_epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, 
                                                num_warmup_steps=0, 
                                                num_training_steps=total_steps)  # PyTorch scheduler

    total_batch = 0  # 记录进行到多少batch
    dev_best_loss = float('inf')
    last_improve = 0  # 记录上次验证集loss下降的batch数
    flag = False  # 记录是否很久没有效果提升

    # 训练模式
    model.train()
    for epoch in range(config.num_epochs):
        print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs))

        for i, (trains, labels) in enumerate(train_iter):
            # 前向推理
            outputs = model(trains)

            # 梯度归零
            model.zero_grad()

            # 损失
            loss = F.cross_entropy(outputs, labels)

            # 反向传播
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            # 更新参数
            optimizer.step()
            scheduler.step()

            # 精度
            if total_batch % 100 == 0:
                # 每多少轮输出在训练集和验证集上的效果
                true = labels.data.cpu()
                predic = torch.max(outputs.data, 1)[1].cpu()
                train_acc = metrics.accuracy_score(true, predic)

                # 验证集上的精度和损失
                dev_acc, dev_loss = evaluate(config, model, dev_iter)
                if dev_loss < dev_best_loss:
                    dev_best_loss = dev_loss

                    # 保存最佳模型
                    torch.save(model.state_dict(), config.save_path)

                    improve = '*'
                    last_improve = total_batch
                else:
                    improve = ''

                time_dif = get_time_dif(start)

                msg = 'Iter: {0:>6},  Train Loss: {1:>5.2},  Train Acc: {2:>6.2%},  Val Loss: {3:>5.2},  Val Acc: {4:>6.2%},  Time: {5} {6}'
                print(
                    msg.format(total_batch, loss.item(), train_acc, dev_loss,
                               dev_acc, time_dif, improve))

                model.train()

            total_batch += 1
            if total_batch - last_improve > config.require_improvement:
                # 验证集loss超过1000batch没下降，结束训练
                print("No optimization for a long time, auto-stopping...")
                flag = True
                break
        if flag:
            break

    # 测试模型
    test(config, model, test_iter)

In [10]:
# 测试模型


def test(config, model, test_iter):
    # 加载模型
    model.load_state_dict(torch.load(config.save_path))

    # 验证模式
    model.eval()

    start_time = time.time()

    test_acc, test_loss, test_report, test_confusion = evaluate(config,
                                                                model,
                                                                test_iter,
                                                                test=True)
    msg = 'Test Loss: {0:>5.2},  Test Acc: {1:>6.2%}'
    print(msg.format(test_loss, test_acc))
    print("Precision, Recall and F1-Score...")
    print(test_report)
    print("Confusion Matrix...")
    print(test_confusion)
    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)


    # 验证模型
def evaluate(config, model, data_iter, test=False):
    model.eval()
    loss_total = 0
    predict_all = np.array([], dtype=int)
    labels_all = np.array([], dtype=int)
    with torch.no_grad():
        for texts, labels in data_iter:
            outputs = model(texts)
            loss = F.cross_entropy(outputs, labels)
            loss_total += loss
            labels = labels.data.cpu().numpy()
            predic = torch.max(outputs.data, 1)[1].cpu().numpy()
            labels_all = np.append(labels_all, labels)
            predict_all = np.append(predict_all, predic)

    acc = metrics.accuracy_score(labels_all, predict_all)
    if test:
        report = metrics.classification_report(labels_all,
                                               predict_all,
                                               target_names=config.class_list,
                                               digits=4)
        confusion = metrics.confusion_matrix(labels_all, predict_all)
        return acc, loss_total / len(data_iter), report, confusion
    return acc, loss_total / len(data_iter)

In [11]:
model = model.to(config.device)
train(config, model, train_iter, dev_iter, test_iter)

Epoch [1/3]
Iter:      0,  Train Loss:   2.4,  Train Acc:  5.47%,  Val Loss:   2.3,  Val Acc: 22.21%,  Time: 0:00:07 *
Iter:    100,  Train Loss:   0.3,  Train Acc: 92.19%,  Val Loss:  0.34,  Val Acc: 90.17%,  Time: 0:00:36 *
Iter:    200,  Train Loss:   0.3,  Train Acc: 89.84%,  Val Loss:  0.27,  Val Acc: 91.57%,  Time: 0:01:06 *
Iter:    300,  Train Loss:  0.23,  Train Acc: 92.19%,  Val Loss:  0.27,  Val Acc: 91.61%,  Time: 0:01:36 
Iter:    400,  Train Loss:  0.35,  Train Acc: 89.06%,  Val Loss:  0.24,  Val Acc: 92.15%,  Time: 0:02:06 *
Iter:    500,  Train Loss:  0.16,  Train Acc: 94.53%,  Val Loss:  0.23,  Val Acc: 92.63%,  Time: 0:02:36 *
Iter:    600,  Train Loss:  0.29,  Train Acc: 90.62%,  Val Loss:  0.22,  Val Acc: 92.83%,  Time: 0:03:07 *
Iter:    700,  Train Loss:  0.14,  Train Acc: 94.53%,  Val Loss:  0.21,  Val Acc: 93.00%,  Time: 0:03:37 *
Iter:    800,  Train Loss:  0.16,  Train Acc: 95.31%,  Val Loss:  0.21,  Val Acc: 93.12%,  Time: 0:04:07 *
Iter:    900,  Train Loss: