# 『2022 CCF BDCI』- 基于TrustAI的阅读理解可解释性评测基线
## 1、项目介绍
深度学习模型在很多NLP任务上已经取得巨大成功，但其常被当作一个黑盒使用，内部预测机制对使用者是不透明的。这使得深度学习模型结果不被使用者信任，增加了落地难度，尤其在医疗、法律等特殊领域。同时，当模型出现效果不好或鲁棒性差等问题时，由于不了解其内部机制，很难对模型进行改进优化。
近期，深度学习模型的可解释性被越来越多的人关注。但模型的可解释性评估还不够完善，本基线提供了阅读理解任务的评测数据和相关评测指标，旨在评估模型的可解释性。
近期百度发布了一款集可信分析和增强于一体的可信AI工具集TrustAI，旨在探索模型预测机制并增强模型效果。本次基线基于TrustAI搭建。
## 2、基线运行



### 依赖安装
安装一些必须的依赖包。

In [3]:
# !mkdir /home/aistudio/external-libraries
!pip3 install -U paddlepaddle-gpu==2.3.2 
# !pip3 install -U paddlenlp==2.4.0 -t /home/aistudio/external-libraries
# !pip3 install trustai==0.1.5 -t /home/aistudio/external-libraries

^C
Traceback (most recent call last):
  File "/opt/conda/envs/python35-paddle120-env/bin/pip3", line 5, in <module>
    from pip._internal.cli.main import main
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/pip/_internal/cli/main.py", line 9, in <module>
    from pip._internal.cli.autocompletion import autocomplete
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/pip/_internal/cli/autocompletion.py", line 10, in <module>
    from pip._internal.cli.main_parser import create_main_parser
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/pip/_internal/cli/main_parser.py", line 8, in <module>
    from pip._internal.cli import cmdoptions
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/pip/_internal/cli/cmdoptions.py", line 24, in <module>
    from pip._internal.cli.parser import ConfigOptionParser
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/pip/_internal/cl

### 数据准备
#### 1）模型训练数据
我们推荐使用DuReader-robust数据集训练中文相似度计算模型。Paddlenlp框架会自动下载及缓存训练数据集，默认缓存存储路径为"~/.paddlenlp/datasets"。如需修改训练数据，请参考『初始化工作』中DATASET_NAME的修改。
#### 2）下载预训练模型
基线使用了ERNIE-3.0-base预训练模型。Paddlenlp框架自动缓存模型文件，默认缓存存储路径为"~/.paddlenlp/models"。如需修改依赖的预训练模型，请在『初始化工作』中修改MODEL_NAME。

### 初始化工作
初始化工作包括了模型选择及加载、训练数据集选择、模型存储路径设定、抽取证据的长度占原文本长度的比例设定等。可按需更改。

In [3]:
import sys
sys.path.append('/home/aistudio/external-libraries')
import json
import numpy as np
import paddle
import paddlenlp
from paddlenlp.transformers import ErnieForQuestionAnswering, ErnieTokenizer
from paddlenlp.datasets import load_dataset

from mrc_utils import *
print(paddle.__version__)
print(paddlenlp.__version__)
# Select pre-trained model
MODEL_NAME = "ernie-3.0-base-zh" # choose from ["ernie-1.0", "ernie-1.0-base-zh", "ernie-1.0-large-zh-cw", "ernie-2.0-base-zh", "ernie-2.0-large-zh", "ernie-3.0-xbase-zh", "ernie-3.0-base-zh", "ernie-3.0-medium-zh", "ernie-3.0-mini-zh", "ernie-3.0-micro-zh", "ernie-3.0-nano-zh"]
# Select dataset for model training
DATASET_NAME = 'dureader_robust'
# Set the path to save the trained model
MODEL_SAVE_PATH = f'save_model/{DATASET_NAME}-{MODEL_NAME}'
# MODEL_SAVE_PATH = f'save_model/model_state.pdparams'
RATIONALE_RATIO = 0.096 # 0.096 for Chinese dataset, 0.102 for English dataset

# Init model and tokenizer
model = ErnieForQuestionAnswering.from_pretrained(MODEL_NAME, num_classes=2)
tokenizer = ErnieTokenizer.from_pretrained(MODEL_NAME)

2.3.2
2.3.4


[2022-10-27 21:55:36,205] [    INFO] - Already cached /home/aistudio/.paddlenlp/models/ernie-3.0-base-zh/ernie_3.0_base_zh.pdparams
[2022-10-27 21:55:38,109] [    INFO] - Already cached /home/aistudio/.paddlenlp/models/ernie-3.0-base-zh/ernie_3.0_base_zh_vocab.txt
[2022-10-27 21:55:38,135] [    INFO] - tokenizer config file saved in /home/aistudio/.paddlenlp/models/ernie-3.0-base-zh/tokenizer_config.json
[2022-10-27 21:55:38,139] [    INFO] - Special tokens file saved in /home/aistudio/.paddlenlp/models/ernie-3.0-base-zh/special_tokens_map.json


In [4]:
from paddlenlp.datasets import load_dataset
# Hyperparameters
batch_size = 56
max_seq_length = 512
epochs = 1  #3
warmup_proportion = 0.1
weight_decay = 0.01
doc_stride = 512
learning_rate = 1e-5

# Load dataset
train_ds, dev_ds, test_ds = load_dataset(DATASET_NAME, splits=["train", "dev", "test"])

# Start training
training_mrc_model(model, 
                tokenizer,
                train_ds, 
                dev_ds,
                batch_size=batch_size,
                epochs=epochs,
                learning_rate=learning_rate,
                warmup_proportion=warmup_proportion,
                max_seq_length=max_seq_length,
                doc_stride=doc_stride, 
                weight_decay=weight_decay,
                save_dir=MODEL_SAVE_PATH)

 37%|███▋      | 7407/20038 [00:04<00:07, 1659.77it/s]


KeyboardInterrupt: 

### 模型训练
这里以ERNIE-3.0为例训练一个阅读理解模型。

In [8]:
from paddlenlp.datasets import load_dataset
# Hyperparameters
batch_size = 12
max_seq_length = 512
epochs = 3  #3
warmup_proportion = 0.1
weight_decay = 0.01
doc_stride = 512
learning_rate = 1e-5

# Load dataset
train_ds, dev_ds, test_ds = load_dataset(DATASET_NAME, splits=["train", "dev", "test"])

# Start training
training_mrc_model(model, 
                tokenizer,
                train_ds, 
                dev_ds,
                batch_size=batch_size,
                epochs=epochs,
                learning_rate=learning_rate,
                warmup_proportion=warmup_proportion,
                max_seq_length=max_seq_length,
                doc_stride=doc_stride, 
                weight_decay=weight_decay,
                save_dir=MODEL_SAVE_PATH)

100%|██████████| 20038/20038 [00:02<00:00, 8724.16it/s] 
Process ForkPoolWorker-13:


KeyboardInterrupt: 

Process ForkPoolWorker-14:
Process ForkPoolWorker-16:
Process ForkPoolWorker-15:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/multiprocess/process.py", line 297, in _bootstrap
    self.run()
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/multiprocess/process.py", line 297, in _bootstrap
    self.run()
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/multiprocess/process.py", line 297, in _bootstrap
    self.run()
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/multiprocess/process.py", line 297, in _bootstrap
    self.run()
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/multiprocess/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/envs/python35-paddle120-env/lib/pytho

### 重要度分数获取
该步为输入中每个词赋一个重要度分数，表示该词对预测的影响度。重要度分数获取共分三步。
#### 1）加载模型和评测数据集
更改模型以及评估数据的存储路径（MODEL_PATH和DATA_PATH），完成模型和评测数据集的加载。赛段一数据量为1855条，赛段二数据量为4366条，请确认评测数据集完整。

In [4]:
from utils import load_data
from functools import partial
from paddlenlp.datasets import load_dataset
# Correct MODEL_PATH and DATA_PATH before executing
MODEL_PATH = MODEL_SAVE_PATH + '/model_state.pdparams'
DATA_PATH = 'mrc_test.txt'

# Load the trained parameters
state_dict = paddle.load(MODEL_PATH)
model.set_dict(state_dict)

# Load test data
# global data_ds
# data_ds = DuReader().read(DATA_PATH)
# data = load_data(DATA_PATH)
# print("Num of data:", len(data))
# print(data)


#### 2）数据预处理

a) 输入格式化：将输入的两个文本组织成模型预测所需格式，如对于Ernie3.0-base模型，其输入形式为[CLS]question[SEP]context[SEP]

b) 分词位置索引：计算每个分词结果对应的原文位置索引，这里的分词包括模型分词和标准分词

In [5]:
# from paddlenlp.datasets import load_dataset
from mrc_utils import*
# Hyperparameters
batch_size = 1
max_seq_length = 512
epochs = 3  #3
warmup_proportion = 0.1
weight_decay = 0.01
doc_stride = 512
learning_rate = 1e-5
train_ds, dev_ds, test_ds = load_dataset(DATASET_NAME, splits=["train", "dev", "test"])
dev_context=[]
dev_question=[]
for data in dev_ds[:3]:
    dev_context.append(data['context'])
    dev_question.append(data['question'])
print(dev_context)
dev_trans_func = partial(prepare_validation_features, 
                            max_seq_length=max_seq_length, 
                            doc_stride=doc_stride,
                            tokenizer=tokenizer)
dev_ds.map(dev_trans_func, batched=True, num_workers=4)
# print(dev_ds[0])
    # 定义BatchSampler

dev_batch_sampler = paddle.io.BatchSampler(
    dev_ds, batch_size=batch_size, shuffle=False)

dev_batchify_fn = lambda samples, fn=Dict({
        "input_ids": Pad(axis=0, pad_val=tokenizer.pad_token_id),
        "token_type_ids": Pad(axis=0, pad_val=tokenizer.pad_token_type_id)
    }): fn(samples)
dev_data_loader = paddle.io.DataLoader(
        dataset=dev_ds,
        batch_sampler=dev_batch_sampler,
        collate_fn=dev_batchify_fn,
        return_list=True)

100%|██████████| 20038/20038 [00:00<00:00, 35325.44it/s]


['爬行垫根据中间材料的不同可以分为:XPE爬行垫、EPE爬行垫、EVA爬行垫、PVC爬行垫;其中XPE爬行垫、EPE爬行垫都属于PE材料加保鲜膜复合而成,都是无异味的环保材料,但是XPE爬行垫是品质较好的爬行垫,韩国进口爬行垫都是这种爬行垫,而EPE爬行垫是国内厂家为了减低成本,使用EPE(珍珠棉)作为原料生产的一款爬行垫,该材料弹性差,易碎,开孔发泡防水性弱。EVA爬行垫、PVC爬行垫是用EVA或PVC作为原材料与保鲜膜复合的而成的爬行垫,或者把图案转印在原材料上,这两款爬行垫通常有异味,如果是图案转印的爬行垫,油墨外露容易脱落。 当时我儿子爬的时候,我们也买了垫子,但是始终有味。最后就没用了,铺的就的薄毯子让他爬。', '真实情况是160-162。她平时谎报的168是因为不离脚穿高水台恨天高(15厘米) 图1她穿着高水台恨天高和刘亦菲一样高,(刘亦菲对外报身高172)范冰冰礼服下厚厚的高水台暴露了她的心机,对比一下两者的鞋子吧 图2 穿着高水台恨天高才和刘德华谢霆锋持平,如果她真的有168,那么加上鞋高,刘和谢都要有180?明显是不可能的。所以刘德华对外报的身高174减去10-15厘米才是范冰冰的真实身高 图3,范冰冰有一次脱鞋上场,这个最说明问题了,看看她的身体比例吧。还有目测一下她手上鞋子的鞋跟有多高多厚吧,至少超过10厘米。', '防水作为目前高端手机的标配,特别是苹果也支持防水之后,国产大多数高端旗舰手机都已经支持防水。虽然我们真的不会故意把手机放入水中,但是有了防水之后,用户心里会多一重安全感。那么近日最为火热的小米6防水吗?小米6的防水级别又是多少呢? 小编查询了很多资料发现,小米6确实是防水的,但是为了保持低调,同时为了不被别人说防水等级不够,很多资料都没有标注小米是否防水。根据评测资料显示,小米6是支持IP68级的防水,是绝对能够满足日常生活中的防水需求的。']


In [5]:
# import jieba
import jieba.posseg as pseg
import numpy as np
def MSE(y,t):
    #形参t代表训练数据（监督数据）（真实）
    #y代表预测数据
    return 0.5*np.sum((y-t)**2)

for step, batch in enumerate(dev_data_loader, start=1):
    if step>1:
        break
    # item=dev_ds[step]
    context=dev_context[step]
    # seg_list=jieba.cut(context,cut_all=False)
    seg_list =pseg.cut(context)
    ner_list=[]
    for w in seg_list:
        # print(w.word)
        # print(w.flag)
        ner_list.append(w)
    # print(ner_list)
    # print(dev_question[step])
    # global_step += 1
    input_ids, segment_ids= batch
    real_logits = model(input_ids=input_ids, token_type_ids=segment_ids)
    start_logits,end_logits=real_logits
    # print(start_logits.numpy())
    # print('\n',end_logits)
    input_ids=input_ids.numpy()
    start=0
    mse_losses={}
    mse_losses_flag={}
    for w in ner_list:
        word=w.word
        if word in[',','.','，','。','-','+','?','!']:
            continue
        flag=w.flag
        change_input_ids=input_ids[0]
        change_input_ids[start:start+len(word)]=0
        new_input_ids=[]
        new_input_ids.append(change_input_ids)
        new_input_ids=paddle.to_tensor(new_input_ids)
        change_logits=model(input_ids=new_input_ids, token_type_ids=segment_ids)
        new_start_logits,new_end_logits=change_logits
        # print(new_start_logits.numpy())
        mse_loss_start=MSE(start_logits.numpy()[0],new_start_logits.numpy()[0])
        mse_loss_end=MSE(end_logits.numpy()[0],new_end_logits.numpy()[0])
        mse_losses[word+flag]=(mse_loss_start+mse_loss_end)/2
        mse_losses_flag[word+flag]=(mse_loss_start+mse_loss_end)/2
        if(flag=='n' or flag=='v'):
            mse_losses_flag[word+flag]*=1.1

    a = sorted(mse_losses.items(), key=lambda x: x[1],reverse = True)
    print(a[:20])
    b = sorted(mse_losses_flag.items(), key=lambda x: x[1],reverse = True)
    print(b[:20])
    # print('logits',logits)
    # loss = criterion(logits, (start_positions, end_positions))

[('加上v', 1195.7999877929688), ('恨v', 1033.4662475585938), ('报n', 990.5422668457031), ('和c', 968.9105834960938), ('超过v', 921.5376586914062), ('真实d', 905.7653503417969), ('刘德华nr', 893.1786804199219), (')x', 877.47119140625), ('平时t', 849.399169921875), ('图n', 849.0646057128906), ('这个r', 821.3957214355469), ('168m', 806.4478759765625), ('水台n', 797.172119140625), ('脱鞋v', 781.5652465820312), ('高a', 756.5150146484375), ('都d', 739.8418579101562), ('身体n', 736.9400634765625), ('谢霆锋nr', 731.70849609375), ('3m', 728.1162109375), ('10m', 725.2652587890625)]
[('加上v', 1315.3799865722658), ('恨v', 1136.8128723144532), ('报n', 1089.5964935302736), ('超过v', 1013.6914245605469), ('和c', 968.9105834960938), ('图n', 933.9710662841798), ('真实d', 905.7653503417969), ('刘德华nr', 893.1786804199219), (')x', 877.47119140625), ('水台n', 876.8893310546875), ('脱鞋v', 859.7217712402345), ('平时t', 849.399169921875), ('这个r', 821.3957214355469), ('身体n', 810.6340698242188), ('168m', 806.4478759765625), ('高a', 756.5150146484375), ('

In [5]:
from mrc_utils import *
from paddlenlp.datasets import load_dataset
DATA_PATH = 'mrc_interpretation_A.txt'
data_ds = DuReader().read(DATA_PATH)
data = load_data(DATA_PATH)
# print(data[:1]))
# Hyperparameters
batch_size = 12
max_seq_length = 512
epochs = 3  #3
warmup_proportion = 0.1
weight_decay = 0.01
doc_stride = 512

# Prepare dataloader
test_trans_func = partial(prepare_validation_features, 
                            max_seq_length=max_seq_length, 
                            doc_stride=doc_stride,
                            tokenizer=tokenizer)
                            
data_ds.map(test_trans_func, batched=True, num_workers=4)
test_batch_sampler = paddle.io.DistributedBatchSampler(
        data_ds, batch_size=batch_size, shuffle=False)

test_batchify_fn = lambda samples, fn=Dict({
    "input_ids": Pad(axis=0, pad_val=tokenizer.pad_token_id),
    "token_type_ids": Pad(axis=0, pad_val=tokenizer.pad_token_type_id)
}): fn(samples)
test_data_loader = paddle.io.DataLoader(
    dataset=data_ds,
    batch_sampler=test_batch_sampler,
    collate_fn=test_batchify_fn,
    return_list=True)

# Get offset maps which will be used for score alignment
contexts, standard_split, ori_offset_maps, standard_split_offset_maps = pre_process(data, data_ds, tokenizer)
print(contexts[:1])
print(standard_split[:1])
print(ori_offset_maps[:1])
print(standard_split_offset_maps[:1])

['[CLS]地瓜不是红薯。地瓜一般生吃或者凉拌，外形是纺锤型的，有明显的瓣状结构，内里的肉是白色的，有清淡的药香味，生吃又脆又甜，常食用可以预防肝癌、胃癌，营养价值非常高。红薯是粗粮，也叫番薯山芋。它是一种属管状花目，旋花科一年生的草本植物，富含丰富的矿物质和维生素，而且非常耐饱。地瓜是红薯吗[SEP]']
[['[CLS]', '地', '瓜', '不', '是', '红', '薯', '。', '地', '瓜', '一', '般', '生', '吃', '或', '者', '凉', '拌', '，', '外', '形', '是', '纺', '锤', '型', '的', '，', '有', '明', '显', '的', '瓣', '状', '结', '构', '，', '内', '里', '的', '肉', '是', '白', '色', '的', '，', '有', '清', '淡', '的', '药', '香', '味', '，', '生', '吃', '又', '脆', '又', '甜', '，', '常', '食', '用', '可', '以', '预', '防', '肝', '癌', '、', '胃', '癌', '，', '营', '养', '价', '值', '非', '常', '高', '。', '红', '薯', '是', '粗', '粮', '，', '也', '叫', '番', '薯', '山', '芋', '。', '它', '是', '一', '种', '属', '管', '状', '花', '目', '，', '旋', '花', '科', '一', '年', '生', '的', '草', '本', '植', '物', '，', '富', '含', '丰', '富', '的', '矿', '物', '质', '和', '维', '生', '素', '，', '而', '且', '非', '常', '耐', '饱', '。', '地', '瓜', '是', '红', '薯', '吗', '[SEP]']]
[[(0, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 14), (14, 15), (15, 16), (16, 17

In [24]:
print(len(test_data_loader))

1855


#### 3）重要度分数获取
我们提供attention和IG两种解释方法，可根据实际实验结果选取最有效的一种方法。

##### a） Attention-based Interpreter

In [6]:
from trustai.interpretation.token_level import AttentionInterpreter
from utils import create_dataloader_from_scratch
import paddle
# Hyperparameters
BATCH_SIZE = 64

# Init an attention interpreter and get the importance scores
att = AttentionInterpreter(model, device="gpu", predict_fn=attention_predict_fn)

# Use attention interpreter to get the importance scores for all data
interp_results = None
for batch in test_data_loader:
    if interp_results:
        interp_results += att(batch)
        # print(interp_results)
    else:
        interp_results = att(batch)

# Trim the output to get scores only for context
interp_results = trim_output(interp_results, data_ds, tokenizer)

# Align the results back to the standard splited tokens so that it can be evaluated correctly later
align_res = att.alignment(interp_results, contexts, standard_split, standard_split_offset_maps, ori_offset_maps, special_tokens=["[CLS]", '[SEP]'])

In [11]:
print(align_res[0])

InterpretResult(words=['[CLS]', '地', '瓜', '不', '是', '红', '薯', '。', '地', '瓜', '一', '般', '生', '吃', '或', '者', '凉', '拌', '，', '外', '形', '是', '纺', '锤', '型', '的', '，', '有', '明', '显', '的', '瓣', '状', '结', '构', '，', '内', '里', '的', '肉', '是', '白', '色', '的', '，', '有', '清', '淡', '的', '药', '香', '味', '，', '生', '吃', '又', '脆', '又', '甜', '，', '常', '食', '用', '可', '以', '预', '防', '肝', '癌', '、', '胃', '癌', '，', '营', '养', '价', '值', '非', '常', '高', '。', '红', '薯', '是', '粗', '粮', '，', '也', '叫', '番', '薯', '山', '芋', '。', '它', '是', '一', '种', '属', '管', '状', '花', '目', '，', '旋', '花', '科', '一', '年', '生', '的', '草', '本', '植', '物', '，', '富', '含', '丰', '富', '的', '矿', '物', '质', '和', '维', '生', '素', '，', '而', '且', '非', '常', '耐', '饱', '。', '地', '瓜', '是', '红', '薯', '吗', '[SEP]'], word_attributions=[26.672456741333008, -58.07738494873047, -64.54483795166016, -11.655533790588379, -23.530738830566406, -33.82665252685547, -39.29488754272461, -7.500601291656494, -56.386775970458984, -66.5095443725586, -77.31051635742188, -74.48218536

##### b）IG-based Interpreter

In [9]:
from trustai.interpretation.token_level import IntGradInterpreter
from utils import create_dataloader_from_scratch
# Hyperparameters
IG_STEP = 100
BATCH_SIZE = 64

# Init an IG interpreter
ig = IntGradInterpreter(model, predict_fn=IG_predict_fn, device="gpu")

# Use IG interpreter to get the importance scores for all data
interp_results = None
for batch in test_data_loader:
    if interp_results:
        interp_results += ig(batch, steps=IG_STEP)
    else:
        interp_results = ig(batch, steps=IG_STEP)

# trim the output to get scores only for context
interp_results = trim_output(interp_results, data_ds, tokenizer)

# Align the results back to the standard splited tokens so that it can be evaluated correctly later
align_res = ig.alignment(interp_results, contexts, standard_split, standard_split_offset_maps, ori_offset_maps, special_tokens=["[CLS]", '[SEP]'])

SystemError: (Fatal) Operator elementwise_add raises an paddle::memory::allocation::BadAlloc exception.
The exception content is
:ResourceExhaustedError: 

Out of memory error on GPU 0. Cannot allocate 69.000000MB memory on GPU 0, 15.757812GB memory has been allocated and available memory is only 24.500000MB.

Please check whether there is any other process using GPU 0.
1. If yes, please stop them, or start PaddlePaddle on another GPU.
2. If no, please decrease the batch size of your model. 
If the above ways do not solve the out of memory problem, you can try to use CUDA managed memory. The command is `export FLAGS_use_cuda_managed_memory=false`.
 (at /paddle/paddle/fluid/memory/allocation/cuda_allocator.cc:87)
. (at /paddle/paddle/fluid/imperative/tracer.cc:307)


### 生成用于评估的数据
评估文件格式要求是3列数据：编号\t预测答案\t证据，我们提供了脚本将模型输出结果转成评估所需格式。

In [9]:
import math
# Re-sort the token index according to their importance scores
def resort(index_array, importance_score):
    res = sorted([[idx, importance_score[idx]] for idx in index_array], key=lambda x:x[1], reverse=True)
    res = [n[0] for n in res]
    return res

# Post-prepare the result data so that it can be used for the evaluation directly
def prepare_eval_data(data, results, paddle_model):
    res = {}
    idx = 0
    step=0
    for data_id, inter_res in zip(data, results):
        # Split importance score vectors for query and title from inter_res.word_attributions
        step+=1
        if(step==2):
            break
        # print(data_id,'\n\n\n',inter_res)
        importance_score = np.array(inter_res.word_attributions[1:-1])
        print(importance_score)
        # Extract topK importance scores
        topk = math.ceil(len(data[data_id]['sent_token'])*RATIONALE_RATIO)
        
        eval_data = {}        
        eval_data['id'] = data_id
        label = list(inter_res.pred_label)
        # print('\n\n',label[0])
        if int(label[0])>=int(label[1])+1:
            eval_data['pred_label'] = ''
        else:
            eval_data['pred_label'] = ''.join(tokenizer.convert_ids_to_tokens(data_ds[idx]['input_ids'][int(label[0]):int(label[1])+1]))
            print(eval_data['pred_label'])
        # Find the token index of the topK importance scores
        eval_data['rationale'] = np.argpartition(importance_score, -topk)[-topk:]
        # Re-sort the token index according to their importance scores
        eval_data['rationale'] = resort(eval_data['rationale'], importance_score)
        print(eval_data['rationale'])
        res[data_id] = eval_data
        idx += 1
    return res

# Generate results for evaluation
predicts = prepare_eval_data(data, align_res, model)

[-58.07738495 -64.54483795 -11.65553379 -23.53073883 -33.82665253
 -39.29488754  -7.50060129 -56.38677597 -66.50954437 -77.31051636
 -74.48218536 -80.5030365  -80.91278839 -77.15781403 -69.86212158
 -88.68054199 -87.16474915 -51.09682846 -79.56904602 -82.26538849
 -67.66261292 -85.58882904 -91.9704361  -87.64743805 -72.3473053
 -69.404953   -70.7878418  -81.68314362 -85.27710724 -80.21749878
 -93.42695618 -92.4311676  -90.18592834 -96.14303589 -66.19949341
 -79.55653381 -80.49518585 -73.23722839 -87.37515259 -74.4006958
 -86.96891022 -88.02120972 -81.38883209 -68.48487091 -76.5506134
 -88.89331055 -92.04889679 -80.76564789 -90.63497162 -91.16900635
 -95.46302795 -72.02256775 -89.48639679 -93.03535461 -81.85414124
 -90.78614807 -83.58494568 -96.93214417 -66.24763489 -69.9414444
 -86.49354553 -85.08138275 -79.68699646 -75.93650818 -86.87537384
 -85.18067169 -90.90517426 -93.41381073 -76.50401306 -92.61473083
 -96.92289734 -74.45556641 -86.21369171 -91.52531433 -88.31045532
 -90.39032745 

In [8]:
print(predicts)

{1: {'id': 1, 'pred_label': '不是红薯', 'rationale': [134, 6, 2, 140, 79, 137, 3, 4, 138, 92, 5, 139, 80, 135]}}


In [22]:
import jieba.posseg as pseg
import numpy as np
import math
from tqdm import tqdm
def MSE(y,t):
    #形参t代表训练数据（监督数据）（真实）
    #y代表预测数据
    return 0.5*np.sum((y-t)**2)
step=0
# out_file1 = open('./mrc_rationale_word.txt', 'w')
for step, batch in tqdm(enumerate(test_data_loader, start=0)):
    if step==3:
        break
    data_id=data_ds[step]['example_id']
    results=align_res[step]
    label = list(results.pred_label)
    if int(label[0])>=int(label[1])+1:
        answer = ''
    else:
        answer= ''.join(tokenizer.convert_ids_to_tokens(data_ds[step]['input_ids'][int(label[0]):int(label[1])+1]))
    
    # out_file1.write(str(data_id)+'\t'+answer+'\t')
    # out_file2.write(str(data_id)+'\t'+answer+'\t')
    # out_file3.write(str(data_id)+'\t'+answer+'\t')
    topk = math.ceil(len(data[data_id]['sent_token'])*RATIONALE_RATIO)
    test_context=data[data_id]
    context=test_context['context']
    seg_list =pseg.cut(context)
    ner_list=[]
    for w in seg_list:
        ner_list.append(w)
    # print(ner_list)
    # print(dev_question[step])
    # global_step += 1
    input_ids, segment_ids= batch
    real_logits = model(input_ids=input_ids, token_type_ids=segment_ids)
    start_logits,end_logits=real_logits
    # print(start_logits.numpy())
    # print('\n',end_logits)
    input_ids=input_ids.numpy()
    start=0
    mse_losses={}
    word_len={}
    for w in ner_list:
        word=w.word
        if word in[',','.','，','。','-','+','?','!']:
            continue
        flag=w.flag
        change_input_ids=input_ids[0]
        change_input_ids[start:start+len(word)]=0
        new_input_ids=[]
        new_input_ids.append(change_input_ids)
        new_input_ids=paddle.to_tensor(new_input_ids)
        change_logits=model(input_ids=new_input_ids, token_type_ids=segment_ids)
        new_start_logits,new_end_logits=change_logits
        mse_loss_start=MSE(start_logits.numpy()[0],new_start_logits.numpy()[0])
        mse_loss_end=MSE(end_logits.numpy()[0],new_end_logits.numpy()[0])
        
        if word in word_len.keys():
            word_len[word]+=1
            mse_losses[word]+=(mse_loss_start+mse_loss_end)/2
        else:
            word_len[word]=1
            mse_losses[word]=(mse_loss_start+mse_loss_end)/2
        if(flag=='n' or flag=='v'):
            mse_losses[word]*=10
        word_len[word]+=1
    a = sorted(mse_losses.items(), key=lambda x: x[1],reverse = True)
    print(a[:20])
    for k in mse_losses.keys():
        mse_losses[k]=mse_losses[k]/word_len[k]
    # print(word_len)
    # print(mse_losses)
        # if(flag=='n' or flag=='v'):
        #     mse_losses_flag1[word]*=10
    a = sorted(mse_losses.items(), key=lambda x: x[1],reverse = True)
    print(a[:20])
    b1 = sorted(mse_losses_flag1.items(), key=lambda x: x[1],reverse = True)

    for idx in range(len(b1)):
        element=b1[idx]
        word=element[0]
        # print(word)
        topk-=len(word)
        if(topk<=0):
            break
        out_file1.write(word+',')
    out_file1.write('\n')
        for e in range(len(word)):
            if topk<=len(b1[idx+1][0]) and e==len(word)-1:
                out_file1.write(str(context.find(word)+e)+'\n')
            else:
                out_file1.write(str(context.find(word)+e)+',')
    

out_file1.close()


0it [00:00, ?it/s]

[('是', 17911358.315734863), ('红薯', 184779.06158447266), ('地瓜', 178871.92962646484), ('生', 17667.507934570312), ('吃', 17667.507934570312), ('有', 17667.507934570312), ('粗粮', 2196.8502807617188), ('叫', 2196.8502807617188), ('番薯', 2196.8502807617188), ('山芋', 2196.8502807617188), ('属', 2196.8502807617188), ('管状花', 2196.8502807617188), ('旋', 2196.8502807617188), ('富含', 2196.8502807617188), ('矿物质', 2196.8502807617188), ('凉拌', 1606.1370849609375), ('外形', 1606.1370849609375), ('纺锤', 1606.1370849609375), ('瓣状', 1606.1370849609375), ('结构', 1606.1370849609375)]
[('是', 1791135.8315734863), ('红薯', 30796.510264078777), ('地瓜', 29811.988271077473), ('生', 4416.876983642578), ('吃', 4416.876983642578), ('有', 4416.876983642578), ('粗粮', 1098.4251403808594), ('叫', 1098.4251403808594), ('番薯', 1098.4251403808594), ('山芋', 1098.4251403808594), ('属', 1098.4251403808594), ('管状花', 1098.4251403808594), ('旋', 1098.4251403808594), ('富含', 1098.4251403808594), ('矿物质', 1098.4251403808594), ('凉拌', 803.0685424804688), ('外形

1it [00:03,  3.64s/it]

[('人', 196783.05130004883), ('应当', 19501.02310180664), ('负', 19501.02310180664), ('贩卖毒品', 19501.02310180664), ('罪', 1950.102310180664), ('犯', 1772.8202819824219), ('杀人', 1772.8202819824219), ('故意伤害', 1772.8202819824219), ('致', 1772.8202819824219), ('死亡', 1772.8202819824219), ('强奸', 1772.8202819824219), ('抢劫', 1772.8202819824219), ('放火', 1772.8202819824219), ('爆炸', 1772.8202819824219), ('投放', 1772.8202819824219), ('物质', 1772.8202819824219), ('人犯', 1772.8202819824219), ('应负', 1772.8202819824219), ('、', 1240.9741973876953), ('的', 709.1281127929688)]
[('人', 32797.175216674805), ('应当', 4875.25577545166), ('负', 4875.25577545166), ('贩卖毒品', 4875.25577545166), ('犯', 886.4101409912109), ('杀人', 886.4101409912109), ('故意伤害', 886.4101409912109), ('致', 886.4101409912109), ('死亡', 886.4101409912109), ('强奸', 886.4101409912109), ('抢劫', 886.4101409912109), ('放火', 886.4101409912109), ('爆炸', 886.4101409912109), ('投放', 886.4101409912109), ('物质', 886.4101409912109), ('人犯', 886.4101409912109), ('应负', 886.41014

2it [00:06,  3.35s/it]

[('会', 2261.52446269989), ('有', 2245.594825744629), ('作用', 2245.594825744629), ('促进', 2245.594825744629), ('产生', 2086.2984561920166), ('酸性', 205.59313297271729), ('物质', 205.59313297271729), ('可能', 205.59313297271729), ('导致', 205.59313297271729), ('肌肤', 205.59313297271729), ('PH值', 205.59313297271729), ('改变', 205.59313297271729), ('释放', 205.59313297271729), ('大量', 205.59313297271729), ('烟酸', 205.59313297271729), ('皮肤', 205.59313297271729), ('刺激', 205.59313297271729), ('使用', 205.59313297271729), ('是', 205.59313297271729), ('抗', 205.59313297271729)]
[('会', 565.3811156749725), ('有', 561.3987064361572), ('作用', 561.3987064361572), ('促进', 561.3987064361572), ('产生', 521.5746140480042), ('酸性', 102.79656648635864), ('物质', 102.79656648635864), ('可能', 102.79656648635864), ('导致', 102.79656648635864), ('肌肤', 102.79656648635864), ('PH值', 102.79656648635864), ('改变', 102.79656648635864), ('释放', 102.79656648635864), ('大量', 102.79656648635864), ('烟酸', 102.79656648635864), ('皮肤', 102.79656648635864), ('刺激

3it [00:12,  4.03s/it]
