# 作业

- 补全程序中的代码，理解其含义，并跑通整个项目；
- 报名参加[千言数据集：信息抽取比赛](https://aistudio.baidu.com/aistudio/competition/detail/46)。

![](https://ai-studio-static-online.cdn.bcebos.com/8747d266229c4ed0a3d410ea47a082f2551a0a3364464d5aae02047f151fa7f9)


# 基于预训练模型完成实体关系抽取

信息抽取旨在从非结构化自然语言文本中提取结构化知识，如实体、关系、事件等。对于给定的自然语言句子，根据预先定义的schema集合，抽取出所有满足schema约束的SPO三元组。

例如，「妻子」关系的schema定义为：      
{      
    S_TYPE: 人物,        
    P: 妻子,      
    O_TYPE: {      
        @value: 人物       
    }       
}        

该示例展示了如何使用PaddleNLP快速完成实体关系抽取，参与[千言信息抽取-关系抽取比赛](https://aistudio.baidu.com/aistudio/competition/detail/46)打榜。




In [1]:
# 安装paddlenlp最新版本
!pip install --upgrade paddlenlp

%cd relation_extraction/

Looking in indexes: https://mirror.baidu.com/pypi/simple/
Collecting paddlenlp
[?25l  Downloading https://mirror.baidu.com/pypi/packages/b1/e9/128dfc1371db3fc2fa883d8ef27ab6b21e3876e76750a43f58cf3c24e707/paddlenlp-2.0.2-py3-none-any.whl (426kB)
[K     |████████████████████████████████| 430kB 15.7MB/s eta 0:00:01
Installing collected packages: paddlenlp
  Found existing installation: paddlenlp 2.0.1
    Uninstalling paddlenlp-2.0.1:
      Successfully uninstalled paddlenlp-2.0.1
Successfully installed paddlenlp-2.0.2
/home/aistudio/relation_extraction


## 关系抽取介绍

针对 DuIE2.0 任务中多条、交叠SPO这一抽取目标，比赛对标准的 'BIO' 标注进行了扩展。
对于每个 token，根据其在实体span中的位置（包括B、I、O三种），我们为其打上三类标签，并且根据其所参与构建的predicate种类，将 B 标签进一步区分。给定 schema 集合，对于 N 种不同 predicate，以及头实体/尾实体两种情况，我们设计对应的共 2*N 种 B 标签，再合并 I 和 O 标签，故每个 token 一共有 (2*N+2) 个标签，如下图所示。


<div align="center">
<img src="https://ai-studio-static-online.cdn.bcebos.com/f984664777b241a9b43ef843c9b752f33906c8916bc146a69f7270b5858bee63" width="500" height="400" alt="标注策略" align=center />
</div>

### 评价方法

对测试集上参评系统输出的SPO结果和人工标注的SPO结果进行精准匹配，采用F1值作为评价指标。注意，对于复杂O值类型的SPO，必须所有槽位都精确匹配才认为该SPO抽取正确。针对部分文本中存在实体别名的问题，使用百度知识图谱的别名词典来辅助评测。F1值的计算方式如下：

F1 = (2 * P * R) / (P + R)，其中

- P = 测试集所有句子中预测正确的SPO个数 / 测试集所有句子中预测出的SPO个数
- R = 测试集所有句子中预测正确的SPO个数 / 测试集所有句子中人工标注的SPO个数

### Step1：构建模型

该任务可以看作一个序列标注任务，所以基线模型采用的是ERNIE序列标注模型。

**PaddleNLP提供了ERNIE预训练模型常用序列标注模型，可以通过指定模型名字完成一键加载。PaddleNLP为了方便用户处理数据，内置了对于各个预训练模型对应的Tokenizer，可以完成文本token化，转token ID，文本长度截断等操作。**

文本数据处理直接调用tokenizer即可输出模型所需输入数据。



In [2]:
import os
import json
from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer

label_map_path = os.path.join('data', "predicate2id.json")

if not (os.path.exists(label_map_path) and os.path.isfile(label_map_path)):
    sys.exit("{} dose not exists or is not a file.".format(label_map_path))
with open(label_map_path, 'r', encoding='utf8') as fp:
    label_map = json.load(fp)
    
num_classes = (len(label_map.keys()) - 2) * 2 + 2
print(num_classes)
# 补齐代码，理解TokenClassification接口含义，理解关系抽取标注体系和类别数由来
model = ErnieForTokenClassification.from_pretrained("ernie-1.0", num_classes=(len(label_map.keys()) - 2) * 2 + 2)
tokenizer = ErnieTokenizer.from_pretrained("ernie-1.0")

inputs = tokenizer(text="请输入测试样例", max_seq_len=20)

112


[2021-06-14 21:27:23,336] [    INFO] - Downloading https://paddlenlp.bj.bcebos.com/models/transformers/ernie/ernie_v1_chn_base.pdparams and saved to /home/aistudio/.paddlenlp/models/ernie-1.0
[2021-06-14 21:27:23,339] [    INFO] - Downloading ernie_v1_chn_base.pdparams from https://paddlenlp.bj.bcebos.com/models/transformers/ernie/ernie_v1_chn_base.pdparams
  0%|          | 0/392507 [00:00<?, ?it/s]100%|██████████| 392507/392507 [00:09<00:00, 40076.24it/s]
[2021-06-14 21:27:42,130] [    INFO] - Downloading vocab.txt from https://paddlenlp.bj.bcebos.com/models/transformers/ernie/vocab.txt
100%|██████████| 90/90 [00:00<00:00, 5445.03it/s]


In [3]:
print(tokenizer(text="吴宗宪遭服务生种族歧视, 他气呛", max_seq_len=20))

{'input_ids': [1, 1167, 761, 2075, 1396, 231, 112, 21, 106, 495, 2752, 367, 30, 44, 266, 5706, 2], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}


### Step2：加载并处理数据


从比赛官网下载数据集，解压存放于data/目录下并重命名为train_data.json, dev_data.json, test_data.json.

我们可以加载自定义数据集。通过继承[`paddle.io.Dataset`](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/io/Dataset_cn.html#dataset)，自定义实现`__getitem__` 和 `__len__`两个方法。


In [4]:
from typing import Optional, List, Union, Dict

import numpy as np
import paddle
from tqdm import tqdm
from paddlenlp.utils.log import logger

from data_loader import parse_label, DataCollator, convert_example_to_feature
from extract_chinese_and_punct import ChineseAndPunctuationExtractor


class DuIEDataset(paddle.io.Dataset):
    """
    Dataset of DuIE.
    """

    def __init__(
            self,
            input_ids: List[Union[List[int], np.ndarray]],
            seq_lens: List[Union[List[int], np.ndarray]],
            tok_to_orig_start_index: List[Union[List[int], np.ndarray]],
            tok_to_orig_end_index: List[Union[List[int], np.ndarray]],
            labels: List[Union[List[int], np.ndarray, List[str], List[Dict]]]):
        super(DuIEDataset, self).__init__()

        self.input_ids = input_ids
        self.seq_lens = seq_lens
        self.tok_to_orig_start_index = tok_to_orig_start_index
        self.tok_to_orig_end_index = tok_to_orig_end_index
        self.labels = labels

    def __len__(self):
        if isinstance(self.input_ids, np.ndarray):
            return self.input_ids.shape[0]
        else:
            return len(self.input_ids)

    def __getitem__(self, item):
        return {
            "input_ids": np.array(self.input_ids[item]),
            "seq_lens": np.array(self.seq_lens[item]),
            "tok_to_orig_start_index":
            np.array(self.tok_to_orig_start_index[item]),
            "tok_to_orig_end_index": np.array(self.tok_to_orig_end_index[item]),
            # If model inputs is generated in `collate_fn`, delete the data type casting.
            "labels": np.array(
                self.labels[item], dtype=np.float32),
        }

    @classmethod
    def from_file(cls,
                  file_path: Union[str, os.PathLike],
                  tokenizer: ErnieTokenizer,
                  max_length: Optional[int]=512,
                  pad_to_max_length: Optional[bool]=None):
        assert os.path.exists(file_path) and os.path.isfile(
            file_path), f"{file_path} dose not exists or is not a file."
        label_map_path = os.path.join(
            os.path.dirname(file_path), "predicate2id.json")
        assert os.path.exists(label_map_path) and os.path.isfile(
            label_map_path
        ), f"{label_map_path} dose not exists or is not a file."
        with open(label_map_path, 'r', encoding='utf8') as fp:
            label_map = json.load(fp)
        chineseandpunctuationextractor = ChineseAndPunctuationExtractor()

        input_ids, seq_lens, tok_to_orig_start_index, tok_to_orig_end_index, labels = (
            [] for _ in range(5))
        dataset_scale = sum(1 for line in open(file_path, 'r'))
        logger.info("Preprocessing data, loaded from %s" % file_path)
        with open(file_path, "r", encoding="utf-8") as fp:
            lines = fp.readlines()
            for line in tqdm(lines):
                example = json.loads(line)
                input_feature = convert_example_to_feature(
                    example, tokenizer, chineseandpunctuationextractor,
                    label_map, max_length, pad_to_max_length)
                input_ids.append(input_feature.input_ids)
                seq_lens.append(input_feature.seq_len)
                tok_to_orig_start_index.append(
                    input_feature.tok_to_orig_start_index)
                tok_to_orig_end_index.append(
                    input_feature.tok_to_orig_end_index)
                labels.append(input_feature.labels)

        return cls(input_ids, seq_lens, tok_to_orig_start_index,
                   tok_to_orig_end_index, labels)


In [6]:
data_path = 'data'
batch_size = 64
max_seq_length = 128

train_file_path = os.path.join(data_path, 'train_data.json')
train_dataset = DuIEDataset.from_file(
    train_file_path, tokenizer, max_seq_length, True)
train_batch_sampler = paddle.io.BatchSampler(
    train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
collator = DataCollator()
train_data_loader = paddle.io.DataLoader(
    dataset=train_dataset,
    batch_sampler=train_batch_sampler,
    collate_fn=collator)

eval_file_path = os.path.join(data_path, 'dev_data.json')
test_dataset = DuIEDataset.from_file(
    eval_file_path, tokenizer, max_seq_length, True)
test_batch_sampler = paddle.io.BatchSampler(
    test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
test_data_loader = paddle.io.DataLoader(
    dataset=test_dataset,
    batch_sampler=test_batch_sampler,
    collate_fn=collator)

[2021-06-14 21:28:21,120] [    INFO] - Preprocessing data, loaded from data/train_data.json
100%|██████████| 10010/10010 [00:20<00:00, 493.08it/s]
[2021-06-14 21:28:41,485] [    INFO] - Preprocessing data, loaded from data/dev_data.json
100%|██████████| 1000/1000 [00:01<00:00, 571.23it/s]


### Step3：定义损失函数和优化器，开始训练

我们选择均方误差作为损失函数，使用[`paddle.optimizer.AdamW`](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/optimizer/adamw/AdamW_cn.html#adamw)作为优化器。



在训练过程中，模型保存在当前目录checkpoints文件夹下。同时在训练的同时使用官方评测脚本进行评估，输出P/R/F1指标。
在验证集上F1可以达到69.42。


In [7]:
import paddle.nn as nn

class BCELossForDuIE(nn.Layer):
    def __init__(self, ):
        super(BCELossForDuIE, self).__init__()
        self.criterion = nn.BCEWithLogitsLoss(reduction='none')

    def forward(self, logits, labels, mask):
        loss = self.criterion(logits, labels)
        mask = paddle.cast(mask, 'float32')
        loss = loss * mask.unsqueeze(-1)
        loss = paddle.sum(loss.mean(axis=2), axis=1) / paddle.sum(mask, axis=1)
        loss = loss.mean()
        return loss

In [8]:
from utils import write_prediction_results, get_precision_recall_f1, decoding

@paddle.no_grad()
def evaluate(model, criterion, data_loader, file_path, mode):
    """
    mode eval:
    eval on development set and compute P/R/F1, called between training.
    mode predict:
    eval on development / test set, then write predictions to \
        predict_test.json and predict_test.json.zip \
        under /home/aistudio/relation_extraction/data dir for later submission or evaluation.
    """
    example_all = []
    with open(file_path, "r", encoding="utf-8") as fp:
        for line in fp:
            example_all.append(json.loads(line))
    id2spo_path = os.path.join(os.path.dirname(file_path), "id2spo.json")
    with open(id2spo_path, 'r', encoding='utf8') as fp:
        id2spo = json.load(fp)

    model.eval()
    loss_all = 0
    eval_steps = 0
    formatted_outputs = []
    current_idx = 0
    for batch in tqdm(data_loader, total=len(data_loader)):
        eval_steps += 1
        input_ids, seq_len, tok_to_orig_start_index, tok_to_orig_end_index, labels = batch
        logits = model(input_ids=input_ids)
        mask = (input_ids != 0).logical_and((input_ids != 1)).logical_and((input_ids != 2))
        loss = criterion(logits, labels, mask)
        loss_all += loss.numpy().item()
        probs = F.sigmoid(logits)
        logits_batch = probs.numpy()
        seq_len_batch = seq_len.numpy()
        tok_to_orig_start_index_batch = tok_to_orig_start_index.numpy()
        tok_to_orig_end_index_batch = tok_to_orig_end_index.numpy()
        formatted_outputs.extend(decoding(example_all[current_idx: current_idx+len(logits)],
                                          id2spo,
                                          logits_batch,
                                          seq_len_batch,
                                          tok_to_orig_start_index_batch,
                                          tok_to_orig_end_index_batch))
        current_idx = current_idx+len(logits)
    loss_avg = loss_all / eval_steps
    print("eval loss: %f" % (loss_avg))

    if mode == "predict":
        predict_file_path = os.path.join("/home/aistudio/relation_extraction/data", 'predictions.json')
    else:
        predict_file_path = os.path.join("/home/aistudio/relation_extraction/data", 'predict_eval.json')

    predict_zipfile_path = write_prediction_results(formatted_outputs,
                                                    predict_file_path)

    if mode == "eval":
        precision, recall, f1 = get_precision_recall_f1(file_path,
                                                        predict_zipfile_path)
        os.system('rm {} {}'.format(predict_file_path, predict_zipfile_path))
        return precision, recall, f1
    elif mode != "predict":
        raise Exception("wrong mode for eval func")

In [9]:
from paddlenlp.transformers import LinearDecayWithWarmup

#learning_rate = 2e-5
learning_rate = 2e-4
num_train_epochs = 5
warmup_ratio = 0.06

criterion = BCELossForDuIE()
# Defines learning rate strategy.
steps_by_epoch = len(train_data_loader)
num_training_steps = steps_by_epoch * num_train_epochs
lr_scheduler = LinearDecayWithWarmup(learning_rate, num_training_steps, warmup_ratio)
optimizer = paddle.optimizer.AdamW(
    learning_rate=lr_scheduler,
    parameters=model.parameters(),
    apply_decay_param_fun=lambda x: x in [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])])

In [10]:
# 模型参数保存路径
!mkdir checkpoints

mkdir: cannot create directory ‘checkpoints’: File exists


### Step4：提交预测结果

加载训练保存的模型加载后进行预测。

**NOTE:** 注意设置用于预测的模型参数路径。

In [11]:
import time
import paddle.nn.functional as F

# Starts training.
global_step = 0
logging_steps = 50
save_steps = 10000
num_train_epochs = 2
output_dir = 'checkpoints'
tic_train = time.time()
model.train()
for epoch in range(num_train_epochs):
    print("\n=====start training of %d epochs=====" % epoch)
    tic_epoch = time.time()
    for step, batch in enumerate(train_data_loader):
        input_ids, seq_lens, tok_to_orig_start_index, tok_to_orig_end_index, labels = batch
        logits = model(input_ids=input_ids)
        mask = (input_ids != 0).logical_and((input_ids != 1)).logical_and(
            (input_ids != 2))
        loss = criterion(logits, labels, mask)
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.clear_gradients()
        loss_item = loss.numpy().item()

        if global_step % logging_steps == 0:
            print(
                "epoch: %d / %d, steps: %d / %d, loss: %f, speed: %.2f step/s"
                % (epoch, num_train_epochs, step, steps_by_epoch,
                    loss_item, logging_steps / (time.time() - tic_train)))
            tic_train = time.time()

        if global_step % save_steps == 0 and global_step != 0:
            print("\n=====start evaluating ckpt of %d steps=====" %
                    global_step)
            precision, recall, f1 = evaluate(
                model, criterion, test_data_loader, eval_file_path, "eval")
            print("precision: %.2f\t recall: %.2f\t f1: %.2f\t" %
                    (100 * precision, 100 * recall, 100 * f1))
            print("saving checkpoing model_%d.pdparams to %s " %
                    (global_step, output_dir))
            paddle.save(model.state_dict(),
                        os.path.join(output_dir, 
                                        "model_%d.pdparams" % global_step))
            model.train()

        global_step += 1
    tic_epoch = time.time() - tic_epoch
    print("epoch time footprint: %d hour %d min %d sec" %
            (tic_epoch // 3600, (tic_epoch % 3600) // 60, tic_epoch % 60))

# Does final evaluation.
print("\n=====start evaluating last ckpt of %d steps=====" %
        global_step)
precision, recall, f1 = evaluate(model, criterion, test_data_loader,
                                    eval_file_path, "eval")
print("precision: %.2f\t recall: %.2f\t f1: %.2f\t" %
        (100 * precision, 100 * recall, 100 * f1))
paddle.save(model.state_dict(),
            os.path.join(output_dir,
                            "model_%d.pdparams" % global_step))
print("\n=====training complete=====")


=====start training of 0 epochs=====
epoch: 0 / 2, steps: 0 / 156, loss: 0.708722, speed: 66.17 step/s
epoch: 0 / 2, steps: 50 / 156, loss: 0.051778, speed: 2.51 step/s
epoch: 0 / 2, steps: 100 / 156, loss: 0.021021, speed: 2.50 step/s
epoch: 0 / 2, steps: 150 / 156, loss: 0.013120, speed: 2.50 step/s
epoch time footprint: 0 hour 1 min 2 sec

=====start training of 1 epochs=====
epoch: 1 / 2, steps: 44 / 156, loss: 0.010394, speed: 2.50 step/s
epoch: 1 / 2, steps: 94 / 156, loss: 0.008290, speed: 2.49 step/s
epoch: 1 / 2, steps: 144 / 156, loss: 0.007733, speed: 2.40 step/s
epoch time footprint: 0 hour 1 min 3 sec

=====start evaluating last ckpt of 312 steps=====


  0%|          | 0/15 [00:00<?, ?it/s]

111 [[0.79292583 0.17623125 0.00260663 ... 0.00198923 0.00223831 0.00204311]
 [0.99067605 0.0077444  0.00211638 ... 0.00125189 0.0019287  0.002126  ]
 [0.27155313 0.01013863 0.0058095  ... 0.00530959 0.00711821 0.00541888]
 ...
 [0.99106324 0.00841384 0.00217154 ... 0.00125049 0.00195815 0.00208309]
 [0.90263796 0.00737613 0.00214114 ... 0.00155687 0.00224502 0.00205512]
 [0.79292583 0.17623118 0.00260663 ... 0.00198923 0.00223831 0.00204311]]
111 [[0.70018    0.23845881 0.00281556 ... 0.00217912 0.00226507 0.00212392]
 [0.06173953 0.03593894 0.00621556 ... 0.00615661 0.00819748 0.0050516 ]
 [0.10407628 0.8719081  0.00351327 ... 0.00189207 0.00216487 0.00274059]
 ...
 [0.9893385  0.00731503 0.00199978 ... 0.00126249 0.00192637 0.00193686]
 [0.98836225 0.00887519 0.0020442  ... 0.00121779 0.0018681  0.00192945]
 [0.9687479  0.00855871 0.00196069 ... 0.00130473 0.00180481 0.0016544 ]]
111 [[0.7893691  0.171614   0.00284542 ... 0.00213069 0.00239742 0.00215808]
 [0.9872048  0.00741154 0.0

  7%|▋         | 1/15 [00:00<00:06,  2.04it/s]

111 [[0.50220835 0.36554465 0.00310107 ... 0.00267788 0.00265643 0.00242989]
 [0.9907575  0.00684745 0.00205258 ... 0.0012896  0.00203062 0.00226444]
 [0.04276246 0.03435493 0.00849102 ... 0.00775787 0.0123292  0.00757471]
 ...
 [0.27557626 0.03331489 0.00311368 ... 0.00292639 0.00389953 0.00250655]
 [0.25603828 0.03189815 0.00323919 ... 0.00312049 0.00428811 0.00263663]
 [0.20104343 0.02998616 0.00392984 ... 0.00355753 0.00534273 0.00296495]]
111 [[0.7880434  0.17686534 0.00272727 ... 0.00204142 0.00226961 0.00208939]
 [0.12266693 0.01806202 0.00606914 ... 0.00600238 0.00831541 0.00557359]
 [0.15360548 0.83367443 0.00329433 ... 0.00165396 0.00204954 0.00274181]
 ...
 [0.98819935 0.0072803  0.002084   ... 0.0012587  0.00189557 0.00183696]
 [0.99027133 0.00704013 0.00213044 ... 0.00129768 0.00191647 0.00195252]
 [0.98609734 0.00660515 0.0020934  ... 0.0012969  0.00184603 0.00176986]]
111 [[0.6709841  0.25996974 0.00279194 ... 0.00221543 0.00234005 0.002176  ]
 [0.9919869  0.00645959 0.0

 13%|█▎        | 2/15 [00:01<00:06,  1.96it/s]

111 [[0.46405125 0.42581642 0.00302567 ... 0.00254829 0.00245117 0.00235814]
 [0.13527222 0.01669745 0.00573523 ... 0.0056034  0.00733606 0.00489084]
 [0.08644817 0.80694133 0.00318472 ... 0.00187384 0.0020608  0.0024653 ]
 ...
 [0.4943858  0.02861052 0.00245904 ... 0.00211371 0.00261967 0.00177874]
 [0.44286373 0.02448895 0.0026953  ... 0.00245251 0.003014   0.00210039]
 [0.87107885 0.1411377  0.0023219  ... 0.00118382 0.00167131 0.00180064]]
111 [[0.8270987  0.15294206 0.00270188 ... 0.00200731 0.0022539  0.00205692]
 [0.9889641  0.0071397  0.00203191 ... 0.00124815 0.00192537 0.00206069]
 [0.17387009 0.012257   0.0058717  ... 0.00613328 0.00833407 0.00581694]
 ...
 [0.990741   0.00792886 0.00215825 ... 0.00124304 0.0019698  0.00206989]
 [0.9904732  0.00757536 0.00213663 ... 0.00124413 0.00196354 0.00207436]
 [0.8270987  0.15294203 0.00270188 ... 0.00200731 0.00225391 0.00205692]]
111 [[0.59798443 0.29447573 0.00308349 ... 0.00248274 0.00259649 0.0024021 ]
 [0.99191964 0.00647881 0.0

 20%|██        | 3/15 [00:01<00:05,  2.14it/s]

111 [[0.56524366 0.31563455 0.00312459 ... 0.00247944 0.0025728  0.00236878]
 [0.99215204 0.00653744 0.00199581 ... 0.00126024 0.00204822 0.00224777]
 [0.9919538  0.0066506  0.00200315 ... 0.00127148 0.00201587 0.00219063]
 ...
 [0.86065686 0.00920796 0.00199015 ... 0.00175984 0.00232167 0.00178065]
 [0.75122714 0.00923121 0.00227447 ... 0.00221672 0.0028296  0.00213499]
 [0.2320307  0.01255578 0.0052983  ... 0.00568897 0.00738475 0.00446587]]
111 [[0.6700175  0.26027393 0.00292519 ... 0.00226399 0.00246998 0.00224709]
 [0.9916387  0.00667506 0.00207014 ... 0.00125134 0.00199743 0.0021806 ]
 [0.9915535  0.0067642  0.00207555 ... 0.0012709  0.0019919  0.00217281]
 ...
 [0.9844934  0.0098426  0.00197624 ... 0.00117222 0.00183438 0.00178999]
 [0.9848726  0.00976787 0.00200148 ... 0.00117202 0.001852   0.00180396]
 [0.98600733 0.00735678 0.00196318 ... 0.00120765 0.00188175 0.0018297 ]]
111 [[0.45103696 0.41213277 0.00317633 ... 0.00265018 0.00259714 0.00250585]
 [0.09036641 0.02005757 0.0

 33%|███▎      | 5/15 [00:02<00:04,  2.33it/s]

111 [[0.5237838  0.33613348 0.00309318 ... 0.00260778 0.00254091 0.00243784]
 [0.9922959  0.00636744 0.00203285 ... 0.0012811  0.0020365  0.00225984]
 [0.05942701 0.02660389 0.00863215 ... 0.00815622 0.01289714 0.00801058]
 ...
 [0.89986837 0.00712355 0.00197959 ... 0.00167439 0.00247575 0.00182541]
 [0.24118008 0.01220215 0.0052371  ... 0.00562332 0.00832221 0.00421769]
 [0.7625093  0.00883324 0.00244119 ... 0.00212889 0.00307094 0.00208888]]
111 [[0.5238809  0.3297268  0.00290094 ... 0.00249127 0.00258823 0.00234188]
 [0.9917854  0.00674096 0.00205839 ... 0.00130233 0.00200084 0.00227631]
 [0.05592259 0.02688539 0.00775338 ... 0.00755397 0.01192047 0.00712474]
 ...
 [0.95363617 0.01065554 0.0019775  ... 0.00137909 0.00199147 0.00165947]
 [0.9525257  0.00849476 0.0019638  ... 0.00145983 0.0021879  0.00176867]
 [0.95687014 0.0090041  0.0020587  ... 0.00142648 0.00211103 0.00175374]]
111 [[0.5378901  0.3302538  0.00318436 ... 0.00255837 0.00255006 0.002403  ]
 [0.9904287  0.00687727 0.0

 40%|████      | 6/15 [00:02<00:03,  2.38it/s]

111 [[0.57804805 0.31405163 0.00280478 ... 0.00237524 0.00248273 0.00221183]
 [0.0512816  0.03691683 0.00622459 ... 0.00652957 0.00980672 0.00570791]
 [0.05578055 0.90206796 0.00373153 ... 0.00216956 0.00231735 0.00288252]
 ...
 [0.9682216  0.01872103 0.00194389 ... 0.00115072 0.00175709 0.00160423]
 [0.9834066  0.00852289 0.00185965 ... 0.00123163 0.00196318 0.00178983]
 [0.97684973 0.00865798 0.00183556 ... 0.00124471 0.00194674 0.00169439]]
111 [[0.76481885 0.19213474 0.00281385 ... 0.00212795 0.00236628 0.0021232 ]
 [0.99044657 0.00685698 0.00204559 ... 0.00126985 0.00193948 0.00208744]
 [0.19352861 0.01276856 0.0063722  ... 0.0055507  0.00824318 0.00615137]
 ...
 [0.98667765 0.00881827 0.00204929 ... 0.00121257 0.00188531 0.00182526]
 [0.9792859  0.01010763 0.00192019 ... 0.00119738 0.00180897 0.00171921]
 [0.980075   0.0097084  0.00193334 ... 0.00117684 0.00184028 0.00174825]]
111 [[0.7482835  0.1985908  0.00298148 ... 0.00219528 0.00242982 0.00218938]
 [0.99113876 0.00629436 0.0

 47%|████▋     | 7/15 [00:03<00:03,  2.34it/s]

111 [[0.7609147  0.20503932 0.00264522 ... 0.0020477  0.00220764 0.00208925]
 [0.9247439  0.0111764  0.00194226 ... 0.00138564 0.00181075 0.00198451]
 [0.8838127  0.10198952 0.00217005 ... 0.00115238 0.00156804 0.00186131]
 ...
 [0.93134266 0.04895681 0.00206101 ... 0.00109052 0.00159683 0.00163657]
 [0.96866745 0.01824397 0.00200071 ... 0.0011426  0.00170556 0.0017134 ]
 [0.977022   0.0166527  0.00204455 ... 0.00113103 0.00174699 0.00179445]]
111 [[0.6774783  0.2450871  0.00298732 ... 0.00228206 0.00248383 0.00231136]
 [0.9913912  0.00693904 0.0020956  ... 0.00126415 0.00193202 0.00217323]
 [0.99106336 0.00772403 0.00209083 ... 0.00124482 0.00190768 0.00216777]
 ...
 [0.9837299  0.00877383 0.00196171 ... 0.00123328 0.00182857 0.00175168]
 [0.98319054 0.00913908 0.00196469 ... 0.00121851 0.00182569 0.00178745]
 [0.97994286 0.01010509 0.00193656 ... 0.00120199 0.00178636 0.00172478]]
111 [[0.7479992  0.2139761  0.00280311 ... 0.00209012 0.00226622 0.00214559]
 [0.99049324 0.00715385 0.0

 53%|█████▎    | 8/15 [00:03<00:02,  2.36it/s]

111 [[0.7679606  0.18308814 0.00283556 ... 0.00217973 0.0024111  0.00216407]
 [0.98340386 0.00763908 0.00189122 ... 0.00121376 0.00188723 0.00192417]
 [0.9830873  0.01371415 0.00203435 ... 0.00114172 0.00185122 0.00198284]
 ...
 [0.9892819  0.00794803 0.00204372 ... 0.00121578 0.00188027 0.00192822]
 [0.98932976 0.00831002 0.00205475 ... 0.00120209 0.00189068 0.00194397]
 [0.9895834  0.00786679 0.0020507  ... 0.00120915 0.00189061 0.00196211]]
111 [[0.735487   0.23056996 0.00272893 ... 0.00205886 0.0022149  0.0021049 ]
 [0.9885882  0.00784307 0.00208237 ... 0.00124661 0.00187886 0.00205319]
 [0.9888263  0.01006037 0.0021527  ... 0.00122752 0.00187905 0.00205458]
 ...
 [0.98558575 0.00946544 0.00204692 ... 0.00121287 0.00181129 0.00191823]
 [0.96266735 0.01267261 0.00188279 ... 0.00119919 0.00176415 0.00177265]
 [0.96920174 0.02273201 0.00198154 ... 0.001109   0.00167392 0.00177043]]
111 [[0.82669157 0.15569592 0.00270285 ... 0.00200026 0.00229127 0.00209439]
 [0.78066915 0.02529672 0.0

 67%|██████▋   | 10/15 [00:04<00:02,  2.44it/s]

111 [[0.76342803 0.20144269 0.00262304 ... 0.00200338 0.00222631 0.00206159]
 [0.21552506 0.01289899 0.00542833 ... 0.0051505  0.0064715  0.0052891 ]
 [0.30101186 0.646892   0.00278952 ... 0.00136548 0.00178956 0.00207887]
 ...
 [0.9759218  0.01269253 0.00195594 ... 0.00113625 0.00175363 0.00176692]
 [0.9708238  0.01314772 0.0019415  ... 0.00113088 0.0017701  0.00176605]
 [0.9673004  0.02045999 0.00198081 ... 0.00106696 0.00168434 0.0016916 ]]
111 [[0.7096971  0.23104078 0.00293604 ... 0.0022234  0.00243324 0.00224572]
 [0.2819469  0.00892312 0.00577507 ... 0.00585809 0.0075751  0.00539403]
 [0.41212344 0.4989883  0.00247497 ... 0.00128072 0.00175109 0.00204665]
 ...
 [0.9490538  0.00792006 0.001834   ... 0.00134334 0.00190741 0.00169246]
 [0.929054   0.01072429 0.00187258 ... 0.00131148 0.00193856 0.00169566]
 [0.9571158  0.00710505 0.00187377 ... 0.0013366  0.00197946 0.00175393]]
111 [[0.6613135  0.26286268 0.00285158 ... 0.00226794 0.00236908 0.00222882]
 [0.9918656  0.00648496 0.0

 73%|███████▎  | 11/15 [00:04<00:01,  2.46it/s]

111 [[0.7365628  0.21695836 0.002736   ... 0.00208407 0.00225055 0.00213452]
 [0.53392637 0.00637191 0.00418599 ... 0.00375966 0.00451351 0.00412008]
 [0.74405986 0.23077798 0.00234856 ... 0.001114   0.00149631 0.00184663]
 ...
 [0.92731476 0.06500837 0.00201817 ... 0.00107679 0.00155936 0.00172291]
 [0.96232253 0.02205708 0.00190827 ... 0.00111356 0.00167099 0.00176958]
 [0.9739698  0.01492426 0.00188772 ... 0.00113948 0.00171395 0.00182156]]
111 [[0.53135383 0.3410596  0.0029881  ... 0.00251244 0.00255102 0.00235988]
 [0.992176   0.00638815 0.00205623 ... 0.00129844 0.0020443  0.00224451]
 [0.07515816 0.02152232 0.00865058 ... 0.00837957 0.01312198 0.00773724]
 ...
 [0.9111393  0.07013964 0.00196989 ... 0.00111749 0.00163558 0.00157879]
 [0.9701807  0.01524006 0.00193204 ... 0.00116795 0.00185188 0.00168702]
 [0.9559506  0.01232431 0.00190516 ... 0.00126349 0.00193868 0.00164852]]
111 [[0.7120315  0.2272474  0.00278825 ... 0.00218526 0.00235628 0.00218015]
 [0.9916678  0.00703205 0.0

 80%|████████  | 12/15 [00:05<00:01,  2.45it/s]

111 [[0.6382986  0.2873754  0.00330661 ... 0.00252308 0.00256213 0.00248932]
 [0.11016931 0.01887749 0.0079676  ... 0.00756978 0.01136202 0.00686021]
 [0.08321656 0.8426114  0.00340342 ... 0.00208846 0.0022959  0.00298804]
 ...
 [0.69019884 0.01503864 0.0023854  ... 0.00220183 0.00265084 0.00187938]
 [0.83522    0.02269024 0.00204398 ... 0.00147861 0.00199904 0.00170385]
 [0.7688914  0.01864246 0.00234727 ... 0.00175902 0.00226668 0.0018378 ]]
111 [[0.7503191  0.20341346 0.0027105  ... 0.00205255 0.00218948 0.0021002 ]
 [0.99130124 0.00705684 0.00209319 ... 0.00125299 0.00195653 0.0020832 ]
 [0.99132204 0.00773947 0.00212502 ... 0.00123711 0.0019521  0.00210427]
 ...
 [0.9771311  0.00739866 0.00183275 ... 0.00125597 0.00186514 0.00184034]
 [0.9376717  0.01453853 0.00182149 ... 0.00122315 0.0018626  0.00174039]
 [0.97989756 0.01565906 0.0019619  ... 0.00112387 0.00174231 0.00179266]]
111 [[0.67616105 0.25634277 0.00293431 ... 0.00233392 0.00249518 0.00227216]
 [0.9916671  0.00627995 0.0

 93%|█████████▎| 14/15 [00:06<00:00,  2.26it/s]

111 [[0.8170255  0.15126964 0.00271387 ... 0.00204344 0.00234057 0.0020792 ]
 [0.99093527 0.00736508 0.00212708 ... 0.00124645 0.00196181 0.00216167]
 [0.99073225 0.00837762 0.00212315 ... 0.00121601 0.00197078 0.00218551]
 ...
 [0.99105775 0.00801605 0.00216083 ... 0.00122897 0.00196716 0.00211773]
 [0.99095917 0.00841486 0.00217468 ... 0.00122517 0.00198015 0.00211885]
 [0.8170255  0.15126961 0.00271387 ... 0.00204344 0.00234057 0.0020792 ]]
111 [[0.63756955 0.2890867  0.00309016 ... 0.00236088 0.00249877 0.00236419]
 [0.14109664 0.01245768 0.00794824 ... 0.00838922 0.01116418 0.00704399]
 [0.06061069 0.8713652  0.00362991 ... 0.00201384 0.00238953 0.00297756]
 ...
 [0.94873345 0.01858362 0.00191037 ... 0.00123476 0.00180628 0.00156302]
 [0.9504468  0.01698007 0.00191952 ... 0.0012501  0.00183756 0.00166341]
 [0.7627455  0.0073614  0.00242656 ... 0.00221645 0.00292437 0.00228634]]
111 [[0.59367883 0.28711414 0.00303777 ... 0.00245011 0.00254663 0.00233805]
 [0.99184024 0.0062494  0.0

100%|██████████| 15/15 [00:06<00:00,  2.33it/s]


eval loss: 0.007536
precision: 0.00	 recall: 0.00	 f1: 0.00	

=====training complete=====


In [12]:
!bash train.sh

INFO 2021-06-14 21:58:47,085 launch.py:266] Local processes completed.

In [13]:
!bash predict.sh

=====predicting complete=====

预测结果会被保存在data/predictions.json，data/predictions.json.zip，其格式与原数据集文件一致。

之后可以使用官方评估脚本评估训练模型在dev_data.json上的效果。如：

```shell
python re_official_evaluation.py --golden_file=dev_data.json  --predict_file=predicitons.json.zip [--alias_file alias_dict]
```
输出指标为Precision, Recall 和 F1，Alias file包含了合法的实体别名，最终评测的时候会使用，这里不予提供。

之后在test_data.json上预测，然后预测结果（.zip文件）至[千言评测页面](https://aistudio.baidu.com/aistudio/competition/detail/46)。





In [30]:
!python re_official_evaluation.py --golden_file=./data/dev_data.json  --predict_file=./data/duie.json.zip 

correct spo num = 0.0
submitted spo num = 0.0
golden set spo num = 1771.0
submitted recall spo num = 0.0
{"errorCode": 0, "errorMsg": "success", "data": [{"name": "precision", "value": 0.0}, {"name": "recall", "value": 0.0}, {"name": "f1-score", "value": 0.0}]}



## Tricks

### 尝试更多的预训练模型

基线采用的预训练模型为ERNIE，PaddleNLP提供了丰富的预训练模型，如BERT，RoBERTa，Electra，XLNet等
参考[预训练模型文档](https://paddlenlp.readthedocs.io/zh/latest/model_zoo/transformers.html)

如可以选择RoBERTa large中文模型优化模型效果，只需更换模型和tokenizer即可无缝衔接。

In [None]:
from paddlenlp.transformers import RobertaForTokenClassification, RobertaTokenizer

model = RobertaForTokenClassification.from_pretrained(
    "roberta-wwm-ext-large",
    num_classes=(len(label_map) - 2) * 2 + 2)
tokenizer = RobertaTokenizer.from_pretrained("roberta-wwm-ext-large")

### 模型集成

使用多个模型进行训练预测，将各个模型预测结果进行融合。

以上基线实现基于PaddleNLP，开源不易，希望大家多多支持~ 
**记得给[PaddleNLP](https://github.com/PaddlePaddle/PaddleNLP)点个小小的Star⭐，及时跟踪最新消息和功能哦**

GitHub地址：[https://github.com/PaddlePaddle/PaddleNLP](https://github.com/PaddlePaddle/PaddleNLP)
