## Bert-for-WebQA

 sformers 包从零开始搭建了一个中文阅读问答训练测试
框架 ，数据选择百度的 WebQA 问答数据集， 类似于斯坦福智能问答数据集，
使用 Bert-base-chinese 和 CRF 模型 做基础，模型可以根据需要持续更新。

## 模型

输入：[‘CLS’]+Question+[‘SEP’]+Evidence 字符串

模型框架：采用多任务联合训练的方式，共两个任务：

           任务1. 使用"[CLS]"来判断两个句子是否是Quesntion-Evidence的关系；

           任务2. 使用Question+[‘SEP’]+Evidence的BERT表达 + CRF模型 进行序列标注，找出Evidence中的答案。

输出：

           任务1. [batch_size,1] 的0-1 序列，1表示对应的文章中含有问题答案，0表示没有；
           
           任务2. [batch_size, seq_len] 的0-1 序列, Evidence 中出现答案的位置为 1 ，其余为 0。

备注： 选择使用"[CLS]"做Quesntion-Evidence关系判断的原因是，做大规模文档检索时，通常回返回一些带有迷惑性的负样本，用"[CLS]"可以进行二次过滤。

#### 训练精度     

           Eval On TestData   Eval-Loss: 15.383  Eval-Result（召回）: R = 0.796

           Eval On DevData    Eval-Loss: 13.986  Eval-Result（召回）: R = 0.795

数据集来自：https://pan.baidu.com/s/1QUsKcFWZ7Tg1dk_AbldZ1A 提取码：2dva

BaseLine论文：https://arxiv.org/abs/1607.06275

模型的谷歌云共享连接(训练好的模型)：https://drive.google.com/open?id=1KHlCnT6VEpDCvtJp8FfwMtU5_ABrYzH9

==================== 超参 ====================

           early_stop = 1
                   lr = 1e-05
                   l2 = 1e-05
             n_epochs = 5
            Negweight = 0.01
             trainset = data/me_train.json
               devset = data/me_validation.ann.json
              testset = data/me_test.ann.json
       knowledge_path = data/me_test.ann.json
        Stopword_path = data/stop_words.txt
               device = cuda
                 mode = train
           model_path = save_model/latest_model.pt
           model_back = save_model/back_model.pt
           batch_size = 16


说明：上面效果只训练了半个epoch 因为疫情在家没有服务器，用谷歌云训练的，设备是tesla-P100，回答一个问题平均耗时40ms。

## 问答模块

问答模块设计了两种功能：

1.带有文章的阅读问答；

2.根据问题从知识库中快速检索文章，再进行阅读问答的智能问答，问题的答案要在知识库里面有才行！
 

## 文档检索

           步骤-0 准备知识库 

           步骤-1 jieba分词 

           步骤-2 去停用词 

           步骤-3 基于分词和二元语法词袋，使用sklearn计算TF-IDF矩阵 

           步骤-4 根据Query和知识库的TF-IDF矩阵计算排序出相关度较高的10篇文章。
           
用测试集数据搭建的知识库，文章检索精度 89%，其中文章数为3024，根据一个Query一次筛选出15篇文章，有89%的概率包含正确Evidene。


## 运行

           训练 %run TrainAndEval.py --batch_size=8 --mode="train" --model_path='save_model/latest_model.pt'

           评估 %run TrainAndEval.py --mode="eval" --model_path='save_model/latest_model.pt'

           阅读问答 %run TrainAndEval.py  --mode="demo" --model_path='save_model/latest_model.pt'

           智能问答 %run TrainAndEval.py  --mode="QA" --model_path='save_model/latest_model.pt'

## 不足

1. 目前模型对正确的Evidence能高准确度识别出正确答案，但是很难分辨有迷惑性的错误Evidence，下一步需要对"[CLS]"识别错误Evidence进行提升。这会导致在智能问答模块，识别出多个包含正确答案的候选答案，却无法确定哪一个是唯一正确答案。

2. 大规模文档检索时，因词袋较大，TF-IDF矩阵计算会很慢，下一步会根据FaceBook/DrQA文档检索模块，使用稀疏矩阵和哈希特征进行改进。



In [1]:
"""
    FILE :  CRF.py
    FUNCTION : None
    REFERENCE : https://github.com/jiesutd/NCRFpp/blob/master/model/crf.py
"""
import torch
from torch.autograd.variable import Variable
import torch.nn as nn


def log_sum_exp(vec, m_size):
    """
    Args:
        vec: size=(batch_size, vanishing_dim, hidden_dim)
        m_size: hidden_dim
    Returns:
        size=(batch_size, hidden_dim)
    """
    _, idx = torch.max(vec, 1)  # B * 1 * M
    max_score = torch.gather(vec, 1, idx.view(-1, 1, m_size)).view(-1, 1, m_size)  # B * M
    return max_score.view(-1, m_size) + torch.log(torch.sum(
        torch.exp(vec - max_score.expand_as(vec)), 1)).view(-1, m_size)


class CRF(nn.Module):
    """
        CRF
    """
    def __init__(self, **kwargs):
        """
        kwargs:
            target_size: int, target size
            device: str, device
        """
        super(CRF, self).__init__()
        for k in kwargs:
            self.__setattr__(k, kwargs[k])
        device = self.device

        # init transitions
        self.START_TAG, self.STOP_TAG = -2, -1
        init_transitions = torch.zeros(self.target_size + 2, self.target_size + 2, device=device)
        init_transitions[:, self.START_TAG] = -10000.0
        init_transitions[self.STOP_TAG, :] = -10000.0
        self.transitions = nn.Parameter(init_transitions)

    def _forward_alg(self, feats, mask):
        """
        Do the forward algorithm to compute the partition function (batched).
        Args:
            feats: size=(batch_size, seq_len, self.target_size+2)
            mask: size=(batch_size, seq_len)
        Returns:
            xxx
        """
        #print(feats.shape)
        batch_size = feats.size(0)
        seq_len = feats.size(1)
        tag_size = feats.size(2)
        mask = mask.transpose(1, 0).contiguous().bool()
        ins_num = seq_len * batch_size
        """ be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1) """
        feats = feats.transpose(1,0).contiguous().view(ins_num,1, tag_size).expand(ins_num, tag_size, tag_size)
        """ need to consider start """
        scores = feats + self.transitions.view(1, tag_size, tag_size).expand(ins_num, tag_size, tag_size)
        scores = scores.view(seq_len, batch_size, tag_size, tag_size)
        # build iter
        seq_iter = enumerate(scores)
        _, inivalues = next(seq_iter)  # bat_size * from_target_size * to_target_size
        """ only need start from start_tag """
        partition = inivalues[:, self.START_TAG, :].clone().view(batch_size, tag_size, 1)  # bat_size * to_target_size

        """
        add start score (from start to all tag, duplicate to batch_size)
        partition = partition + self.transitions[START_TAG,:].view(1, tag_size, 1).expand(batch_size, tag_size, 1)
        iter over last scores
        """
        for idx, cur_values in seq_iter:
            """
            previous to_target is current from_target
            partition: previous results log(exp(from_target)), #(batch_size * from_target)
            cur_values: bat_size * from_target * to_target
            """
            cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
            cur_partition = log_sum_exp(cur_values, tag_size)

            mask_idx = mask[idx, :].view(batch_size, 1).expand(batch_size, tag_size)

            """ effective updated partition part, only keep the partition value of mask value = 1 """
            masked_cur_partition = cur_partition.masked_select(mask_idx)
            """ let mask_idx broadcastable, to disable warning """
            mask_idx = mask_idx.contiguous().view(batch_size, tag_size, 1)

            """ replace the partition where the maskvalue=1, other partition value keeps the same """
            partition.masked_scatter_(mask_idx, masked_cur_partition)
        """ 
        until the last state, add transition score for all partition (and do log_sum_exp) 
        then select the value in STOP_TAG 
        """
        cur_values = self.transitions.view(1, tag_size, tag_size).expand(batch_size, tag_size, tag_size) + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
        cur_partition = log_sum_exp(cur_values, tag_size)
        final_partition = cur_partition[:, self.STOP_TAG]
        return final_partition.sum(), scores

    def _viterbi_decode(self, feats, mask):
        """
            input:
                feats: (batch, seq_len, self.tag_size+2)
                mask: (batch, seq_len)
            output:
                decode_idx: (batch, seq_len) decoded sequence
                path_score: (batch, 1) corresponding score for each sequence (to be implementated)
        """
        # print(feats.size())
        batch_size = feats.size(0)
        seq_len = feats.size(1)
        tag_size = feats.size(2)
        # assert(tag_size == self.tagset_size+2)
        """ calculate sentence length for each sentence """
        length_mask = torch.sum(mask.long(), dim=1).view(batch_size, 1).long()
        """ mask to (seq_len, batch_size) """
        mask = mask.transpose(1, 0).contiguous().bool()
        ins_num = seq_len * batch_size
        """ be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1) """
        feats = feats.transpose(1,0).contiguous().view(ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size)
        """ need to consider start """
        scores = feats + self.transitions.view(1, tag_size, tag_size).expand(ins_num, tag_size, tag_size)
        scores = scores.view(seq_len, batch_size, tag_size, tag_size)

        # build iter
        seq_iter = enumerate(scores)
        # record the position of best score
        back_points = list()
        partition_history = list()
        ##  reverse mask (bug for mask = 1- mask, use this as alternative choice)
        # mask = 1 + (-1)*mask
        mask = (1 - mask.long()).byte().bool()
        _, inivalues = next(seq_iter)  # bat_size * from_target_size * to_target_size
        """ only need start from start_tag """
        partition = inivalues[:, self.START_TAG, :].clone().view(batch_size, tag_size)  # bat_size * to_target_size
        partition_history.append(partition)
        # iter over last scores
        for idx, cur_values in seq_iter:
            """
            previous to_target is current from_target
            partition: previous results log(exp(from_target)), #(batch_size * from_target)
            cur_values: batch_size * from_target * to_target
            """
            cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
            """ forscores, cur_bp = torch.max(cur_values[:,:-2,:], 1) # do not consider START_TAG/STOP_TAG """
            partition, cur_bp = torch.max(cur_values, 1)
            partition_history.append(partition)
            """
            cur_bp: (batch_size, tag_size) max source score position in current tag
            set padded label as 0, which will be filtered in post processing
            """
            cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0)
            back_points.append(cur_bp)
        """ add score to final STOP_TAG """
        partition_history = torch.cat(partition_history, 0).view(seq_len, batch_size, -1).transpose(1, 0).contiguous() ## (batch_size, seq_len. tag_size)
        """ get the last position for each setences, and select the last partitions using gather() """
        last_position = length_mask.view(batch_size,1,1).expand(batch_size, 1, tag_size) -1
        last_partition = torch.gather(partition_history, 1, last_position).view(batch_size,tag_size,1)
        """ calculate the score from last partition to end state (and then select the STOP_TAG from it) """
        last_values = last_partition.expand(batch_size, tag_size, tag_size) + self.transitions.view(1,tag_size, tag_size).expand(batch_size, tag_size, tag_size)
        _, last_bp = torch.max(last_values, 1)
        pad_zero = torch.zeros(batch_size, tag_size, device=self.device, requires_grad=True).long()
        back_points.append(pad_zero)
        back_points = torch.cat(back_points).view(seq_len, batch_size, tag_size)

        """ elect end ids in STOP_TAG """
        pointer = last_bp[:, self.STOP_TAG]
        insert_last = pointer.contiguous().view(batch_size,1,1).expand(batch_size,1, tag_size)
        back_points = back_points.transpose(1,0).contiguous()
        """move the end ids(expand to tag_size) to the corresponding position of back_points to replace the 0 values """
        back_points.scatter_(1, last_position, insert_last)
        back_points = back_points.transpose(1,0).contiguous()
        """ decode from the end, padded position ids are 0, which will be filtered if following evaluation """
        # decode_idx = Variable(torch.LongTensor(seq_len, batch_size))
        decode_idx = torch.empty(seq_len, batch_size, device=self.device, requires_grad=True).long()
        decode_idx[-1] = pointer.detach()
        for idx in range(len(back_points)-2, -1, -1):
            pointer = torch.gather(back_points[idx], 1, pointer.contiguous().view(batch_size, 1))
            decode_idx[idx] = pointer.detach().view(batch_size)
        path_score = None
        decode_idx = decode_idx.transpose(1, 0)
        return path_score, decode_idx

    def forward(self, feats, mask):
        """
        :param feats:
        :param mask:
        :return:
        """
        path_score, best_path = self._viterbi_decode(feats, mask)
        return path_score, best_path

    def _score_sentence(self, scores, mask, tags):
        """
        Args:
            scores: size=(seq_len, batch_size, tag_size, tag_size)
            mask: size=(batch_size, seq_len)
            tags: size=(batch_size, seq_len)
        Returns:
            score:
        """
        # print(scores.size())
        batch_size = scores.size(1)
        seq_len = scores.size(0)
        tag_size = scores.size(-1)
        tags = tags.view(batch_size, seq_len)
        """ convert tag value into a new format, recorded label bigram information to index """
        # new_tags = Variable(torch.LongTensor(batch_size, seq_len))
        new_tags = torch.empty(batch_size, seq_len, device=self.device, requires_grad=True).long()
        for idx in range(seq_len):
            if idx == 0:
                new_tags[:, 0] = (tag_size - 2) * tag_size + tags[:, 0]
            else:
                new_tags[:, idx] = tags[:, idx-1] * tag_size + tags[:, idx]

        """ transition for label to STOP_TAG """
        end_transition = self.transitions[:, self.STOP_TAG].contiguous().view(1, tag_size).expand(batch_size, tag_size)
        """ length for batch,  last word position = length - 1 """
        length_mask = torch.sum(mask, dim=1).view(batch_size, 1).long()
        """ index the label id of last word """
        end_ids = torch.gather(tags, 1, length_mask-1)

        """ index the transition score for end_id to STOP_TAG """
        end_energy = torch.gather(end_transition, 1, end_ids)

        """ convert tag as (seq_len, batch_size, 1) """
        new_tags = new_tags.transpose(1, 0).contiguous().view(seq_len, batch_size, 1)
        """ need convert tags id to search from 400 positions of scores """
        tg_energy = torch.gather(scores.view(seq_len, batch_size, -1), 2, new_tags).view(seq_len, batch_size)
        tg_energy = tg_energy.masked_select(mask.transpose(1, 0))

        """
        add all score together
        gold_score = start_energy.sum() + tg_energy.sum() + end_energy.sum()
        """
        gold_score = tg_energy.sum() + end_energy.sum()

        return gold_score

    def neg_log_likelihood_loss(self, feats, mask, tags):
        """
        Args:
            feats: size=(batch_size, seq_len, tag_size)
            mask: size=(batch_size, seq_len)
            tags: size=(batch_size, seq_len)
        """
        batch_size = feats.size(0)
        forward_score, scores = self._forward_alg(feats, mask)
        gold_score = self._score_sentence(scores, mask, tags)
        return forward_score - gold_score

## Bert-for-WebQA   
使用 torch 和 transformers 包从零开始搭建了一个中文阅读问答训练测试 框架 ，数据选择百度的 WebQA 问答数据集， 类似于斯坦福智能问答数据集， 使用 Bert-base-chinese 和 CRF 模型 做基础，模型可以根据需要持续更新。

In [2]:
import os
import json
from torch.utils import data
from transformers import BertTokenizer , BertModel
import numpy as np
import argparse
import time
import torch
import torch.nn as nn
import torch.optim as optim
import warnings
warnings.filterwarnings('ignore')
from tqdm import tqdm
#from DocumentRetrieval import Knowledge
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer, TfidfTransformer
import random
import jieba
import collections
# 字符ID化
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

In [3]:
q = '毛主席出生在哪一年？'
e = '毛主席于1896年出生在湖南长沙韶山冲'

tokens = tokenizer.tokenize('[CLS]'+q+'[SEP]'+e)# list
tokens_id = tokenizer.convert_tokens_to_ids(tokens)
token_type_ids = [0 if i <= tokens_id.index(102) else 1 for i in range(len(tokens_id ))]

In [4]:
tokenizer.convert_ids_to_tokens(102)

'[SEP]'

In [5]:
def pad(batch):
    tokens_l, tokens_id_l, token_type_ids_l, answer_offset_l, answer_seq_label_l, IsQA_l= list(map(list, zip(*batch)))
    maxlen = np.array([len(sen) for sen in tokens_l]).max()
    ### pad和截断
    for i in range(len(tokens_l)):
        tokens = tokens_l[i]
        tokens_id= tokens_id_l[i]
        # answer_offset = answer_offset_l[i]
        answer_seq_label = answer_seq_label_l[i]
        token_type_ids = token_type_ids_l[i]
        tokens_l[i] = tokens + (maxlen - len(tokens))*['[PAD]']
        token_type_ids_l[i] = token_type_ids + (maxlen - len(tokens))*[1]
        tokens_id_l[i] =tokens_id + (maxlen - len(tokens))*tokenizer.convert_tokens_to_ids(['[PAD]'])
        answer_seq_label_l[i] = answer_seq_label + [0]*(maxlen - len(tokens))
    return tokens_l, tokens_id_l, token_type_ids_l, answer_offset_l, answer_seq_label_l, IsQA_l

def result_metric(prediction_all, y_2d_all):
    total_num=0
    toral_cur=0
    for prediction, y_2d in zip(prediction_all, y_2d_all):
        batch_size,seq_len = prediction.size()
        currect = torch.sum(torch.sum(prediction == y_2d, dim=1)==seq_len).to("cpu").item()
        toral_cur = toral_cur + currect
        total_num = total_num + batch_size
    return toral_cur/total_num


In [6]:
class WebQADataset(data.Dataset):
    def __init__(self, fpath):
        self.hp = hp
        self.questions, self.evidences, self.answer= [], [], []
        with open(fpath, 'r',encoding='utf-8') as f:
            data = json.load(f)#读取json文件内容
            for key in data:
                item = data[key]
                question = item['question']
                evidences = item['evidences']
                for evi_key in evidences:
                    evi_item = evidences[evi_key]
                    self.questions.append(question)
                    self.evidences.append(evi_item['evidence'])
                    self.answer.append(evi_item['answer'][0])
        shuffled_l = list(zip(self.questions, self.evidences, self.answer))
        random.shuffle(shuffled_l)
        self.questions[:], self.evidences[:], self.answer[:] = zip(*shuffled_l)

    def __len__(self):
        return len(self.answer)

    def FindOffset(self, tokens_id, answer_id):
        n = len(tokens_id)
        m = len(answer_id)
        if n < m:
            return False
        for i in range(n - m + 1):
            if tokens_id[i:i + m] == answer_id:
                return (i, i + m)
        return False

    def __getitem__(self, idx):
        # We give credits only to the first piece.
        q, e, a = self.questions[idx], self.evidences[idx], self.answer[idx]
        tokens = tokenizer.tokenize('[CLS]'+q+'[SEP]'+e)# list
        if len(tokens)>256:
            tokens=tokens[:256]
        tokens_id = tokenizer.convert_tokens_to_ids(tokens)
        token_type_ids = [0 if i <= tokens_id.index(102) else 1 for i in range(len(tokens_id ))]
        answer_offset = (-1, -1)
        IsQA = 0
        answer_seq_label = len(tokens_id) * [0]
        if a != 'no_answer':
            answer_tokens = tokenizer.tokenize(a)
            answer_offset = self.FindOffset(tokens, answer_tokens)#有肯能返回False
            if answer_offset:#在原文中找到答案
                answer_seq_label[answer_offset[0]:answer_offset[1]] = [1]*(len(answer_tokens))
                IsQA = 1
            else:# self.FindOffset 返回False
                answer_offset = (-1, -1)
        return tokens, tokens_id, token_type_ids, answer_offset, answer_seq_label, IsQA
    
    def get_samples_weight(self,Negweight):
        samples_weight = []
        for ans in self.answer:
            if ans != 'no_answer':
                samples_weight.append(1.0)
            else:
                samples_weight.append(Negweight)
        return np.array(samples_weight)

In [20]:
class Prepare_Train_Features_For_CRF:
    def __init__(self,tokenizer,max_length = 384,stride = 128,pad_on_right = "right"):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.stride = stride
        self.pad_on_right = pad_on_right

    def prepare_train_features(self,examples):
        # Tokenize our examples with truncation and padding, but keep the overflows using a stride. This results
        # in one example possible giving several features when a context is long, each of those features having a
        # context that overlaps a bit the context of the previous feature.
        tokenized_examples = self.tokenizer(
            examples["question" if self.pad_on_right else "context"],
            examples["context" if self.pad_on_right else "question"],
            truncation="only_second" if self.pad_on_right else "only_first",
            max_length= self.max_length,
            stride= self.stride,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            padding="max_length",
        )
        
        # Since one example might give us several features if it has a long context, we need a map from a feature to
        # its corresponding example. This key gives us just that.
        sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
        # The offset mappings will give us a map from token to character position in the original context. This will
        # help us compute the start_positions and end_positions.
        offset_mapping = tokenized_examples.pop("offset_mapping")

        # Let's label those examples!
        tokenized_examples["start_positions"] = []
        tokenized_examples["end_positions"] = []
        tokenized_examples["answer_offset"] = []
        tokenized_examples["answer_seq_label"] = []
        tokenized_examples["labels"] = []
        for i, offsets in enumerate(offset_mapping):
            # We will label impossible answers with the index of the CLS token.
            input_ids = tokenized_examples["input_ids"][i]
            answer_seq_label = len(input_ids) * [0]
            cls_index = input_ids.index(self.tokenizer.cls_token_id)

            # Grab the sequence corresponding to that example (to know what is the context and what is the question).
            sequence_ids = tokenized_examples.sequence_ids(i)

            # One example can give several spans, this is the index of the example containing this span of text.
            sample_index = sample_mapping[i]
            answers = examples["answers"]
            # If no answers are given, set the cls_index as answer.
            if len(answers["answer_start"]) == 0:
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
                tokenized_examples["answer_offset"].append((-1,-1))
                tokenized_examples["labels"].append(0)
            else:
                # Start/end character index of the answer in the text.
                start_char = answers["answer_start"][0]
                end_char = start_char + len(answers["text"][0])

                # Start token index of the current span in the text.
                token_start_index = 0
                while sequence_ids[token_start_index] != (1 if self.pad_on_right else 0):
                    token_start_index += 1

                # End token index of the current span in the text.
                token_end_index = len(input_ids) - 1
                while sequence_ids[token_end_index] != (1 if self.pad_on_right else 0):
                    token_end_index -= 1

                # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
                if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                    tokenized_examples["start_positions"].append(cls_index)
                    tokenized_examples["end_positions"].append(cls_index)
                    tokenized_examples["answer_offset"].append((-1,-1))
                    tokenized_examples["labels"].append(0)

                else:
                    tokenized_examples["labels"].append(1)
                    # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                    # Note: we could go after the last offset if the answer is the last word (edge case).
                    while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                        token_start_index += 1
                    tokenized_examples["start_positions"].append(token_start_index - 1)
                    while offsets[token_end_index][1] >= end_char:
                        token_end_index -= 1
                    tokenized_examples["end_positions"].append(token_end_index + 1)
                    tokenized_examples["answer_offset"].append((token_start_index - 1,token_end_index + 1))
                    answer_tokens = self.tokenizer.tokenize(answers["text"][0])
                    answer_seq_label[token_start_index - 1:token_end_index + 1] = [1]*(len(answer_tokens))
                    print(len(answer_seq_label))
                    tokenized_examples["answer_seq_label"].append(answer_seq_label)
               

        return tokenized_examples
    
    

In [21]:
from transformers import AutoTokenizer

model_checkpoint = r'C:\Users\Administrator\Desktop\2021.02.08 multi_choice型阅读理解\bert-base-chinese'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [22]:
prepare_features = Prepare_Train_Features_For_CRF(tokenizer,max_length = 384,stride = 128,pad_on_right = "right")

In [23]:
t = {'answers': {'answer_start': [102],
  'text': ['国家公安部国家工商总局国家科学技术委员会科技部卫生部国家发展改革委员会等部委均接受并采纳过的我的建议']},
 'context': '当我看到建议被采纳，部委领导写给我的回信时，我知道我正在为这个国家的发展尽着一份力量，27日，河北省邢台钢铁有限公司的普通工人白金跃，拿着历年来国家各部委反馈给他的感谢信，激动地对中新网记者说，27年来，国家公安部国家工商总局国家科学技术委员会科技部卫生部国家发展改革委员会等部委均接受并采纳过的我的建议',
 'id': 's000011',
 'question': '激动的原因是什么？',
 'title': ''}

prepare_features.prepare_train_features(t)

385


{'input_ids': [[101, 4080, 1220, 4638, 1333, 1728, 3221, 784, 720, 8043, 102, 2496, 2769, 4692, 1168, 2456, 6379, 6158, 7023, 5287, 8024, 6956, 1999, 7566, 2193, 1091, 5314, 2769, 4638, 1726, 928, 3198, 8024, 2769, 4761, 6887, 2769, 3633, 1762, 711, 6821, 702, 1744, 2157, 4638, 1355, 2245, 2226, 4708, 671, 819, 1213, 7030, 8024, 8149, 3189, 8024, 3777, 1266, 4689, 6928, 1378, 7167, 7188, 3300, 7361, 1062, 1385, 4638, 3249, 6858, 2339, 782, 4635, 7032, 6645, 8024, 2897, 4708, 1325, 2399, 3341, 1744, 2157, 1392, 6956, 1999, 1353, 7668, 5314, 800, 4638, 2697, 6468, 928, 8024, 4080, 1220, 1765, 2190, 704, 3173, 5381, 6381, 5442, 6432, 8024, 8149, 2399, 3341, 8024, 1744, 2157, 1062, 2128, 6956, 1744, 2157, 2339, 1555, 2600, 2229, 1744, 2157, 4906, 2110, 2825, 3318, 1999, 1447, 833, 4906, 2825, 6956, 1310, 4495, 6956, 1744, 2157, 1355, 2245, 3121, 7484, 1999, 1447, 833, 5023, 6956, 1999, 1772, 2970, 1358, 2400, 7023, 5287, 6814, 4638, 2769, 4638, 2456, 6379, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

In [12]:
dir_path = r'WebQA.v1.0'
trainset_path = 'me_train.json'
dev_path = 'me_validation.ann.json'
test_path = 'me_test.ann.json'
mode_path = r'bert-base-chinese'

In [11]:
train_dataset = WebQADataset(os.path.join(dir_path ,trainset_path))

In [14]:
# 返回结果：tokens, tokens_id, token_type_ids, answer_offset, answer_seq_label, IsQA
print(train_dataset[12])

NameError: name 'train_dataset' is not defined

In [15]:
print(train_dataset[9])

NameError: name 'train_dataset' is not defined

In [7]:
#model = BertModel.from_pretrained(r'bert-base-chinese')

In [8]:
#model.config.num_labels

In [9]:
#input_ids = torch.tensor([1,55,55,55,66,4]).view(1,-1)
#token_type_ids = torch.tensor([0,0,0,0,0,0]).view(1,-1)

#outputs = model(input_ids,token_type_ids)

In [10]:
#emb, _ = self.PreModel(input_ids=tokens_x_2d, token_type_ids=token_type_ids_2d) #[batch_size, seq_len, hidden_size]
#last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)) 
#pooler_output (torch.FloatTensor of shape (batch_size, hidden_size))

In [11]:
#outputs 

In [22]:
import torch
import torch.nn as nn
from transformers import AutoModelForQuestionAnswering, BertPreTrainedModel,BertModel
import numpy as np

class BertForQuestionAnsweringWithCRF(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.bert = BertModel(config)
        self.hidden_size = self.bert.config.hidden_size
        self.CRF_fc1 = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(self.hidden_size, config.num_labels + 2, bias=True),
        )
        self.CRF = CRF(target_size = self.bert.config.num_labels,device= torch.device("cuda"))
        self.CrossEntropyLoss = nn.CrossEntropyLoss()
        self.fc2 = nn.Linear(self.hidden_size, 2, bias=True)

    def forward(self,tokens_id_l, token_type_ids_l, answer_offset_l, answer_seq_label_l, IsQA_l):


        ## 字符ID [batch_size, seq_length]
        tokens_x_2d = torch.LongTensor(tokens_id_l).to(self.device)
        token_type_ids_2d = torch.LongTensor(token_type_ids_l).to(self.device)

        # 计算sql_len 不包含[CLS]
        batch_size, seq_length = tokens_x_2d[:,1:].size()

        ## CRF答案ID [batch_size, seq_length]
        y_2d = torch.LongTensor(answer_seq_label_l).to(self.device)[:,1:]
        ## (batch_size,)
        y_IsQA_2d = torch.LongTensor(IsQA_l).to(self.device)


        if self.training:    # self.training基层的外部类
            self.bert.train()
            output = self.bert(input_ids=tokens_x_2d, token_type_ids=token_type_ids_2d, output_hidden_states= True,return_dict= True)  #[batch_size, seq_len, hidden_size]
        else:
            self.bert.eval()
            with torch.no_grad():
                output = self.bert(input_ids=tokens_x_2d, token_type_ids=token_type_ids_2d, output_hidden_states= True,return_dict= True)

        ## [CLS] for IsQA  [batch_size, hidden_size]
        cls_emb = output.last_hidden_state[:,0,:] 
        
        IsQA_logits = self.fc2(cls_emb) ## [batch_size, 2]
        IsQA_loss = self.CrossEntropyLoss.forward(IsQA_logits,y_IsQA_2d)

        ## [batch_size, 1]
        IsQA_prediction = IsQA_logits.argmax(dim=-1).unsqueeze(dim=-1)

        # CRF mask
        mask = np.ones(shape=[batch_size, seq_length], dtype=np.uint8)
        mask = torch.ByteTensor(mask).bool().to(self.device)  # [batch_size, seq_len, 4]
      

        # No [CLS]
        crf_logits = self.CRF_fc1(output.last_hidden_state[:,1:,:] )
        crf_loss = self.CRF.neg_log_likelihood_loss(feats=crf_logits, mask=mask, tags=y_2d )
        _, CRFprediction = self.CRF.forward(feats=crf_logits, mask=mask)

        return IsQA_prediction, CRFprediction, IsQA_loss, crf_loss, y_2d, y_IsQA_2d.unsqueeze(dim=-1)# (batch_size,) -> (batch_size, 1)
    

In [10]:
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--early_stop", type=int, default=1)
parser.add_argument("--lr", type=float, default=1e-5)
parser.add_argument("--l2", type=float, default=1e-5)
parser.add_argument("--n_epochs", type=int, default=5)
parser.add_argument("--Negweight", type=float, default=0.01)

parser.add_argument("--knowledge_path", type=str, default=r"WebQA.v1.0\me_test.ann.json")
parser.add_argument("--Stopword_path",type=str, default= 'data/stop_words.txt')
parser.add_argument("--device", type=str, default='cuda')
parser.add_argument("--mode", type=str, default='train')  # eval / demo / train / QA
parser.add_argument("--model_path", type=str, default="save_model/latest_model.bin")
parser.add_argument("--model_back", type=str, default="save_model/back_model.bin")
parser.add_argument("--batch_size", type=int, default=12)

hp = parser.parse_args([])


In [23]:
def TrainOneEpoch(model, train_iter, dev_iter, test_iter, optimizer, hp):
    model.train()
    CRFprediction_all, CRFloss_all, IsQAloss_all, y_CRF_all, IsQA_prediction_all, y_IsQA_all= [],[],[],[],[],[]

    best_acc = 0
    for i, batch in enumerate(tqdm(train_iter)):
        _, tokens_id_l, token_type_ids_l, answer_offset_l, answer_seq_label_l, IsQA_l = batch
        optimizer.zero_grad()
        IsQA_prediction, CRFprediction, IsQA_loss, crf_loss, y_2d, y_IsQA_2d  = model.forward(tokens_id_l, token_type_ids_l, answer_offset_l, answer_seq_label_l, IsQA_l)
        ## CRF
        CRFprediction_all.append(CRFprediction)
        y_CRF_all.append(y_2d)

        ## IsQA
        IsQA_prediction_all.append(IsQA_prediction)

        y_IsQA_all.append(y_IsQA_2d)

        # loss
        CRFloss_all.append(crf_loss.to("cpu").item())
        IsQAloss_all.append(IsQA_loss.to("cpu").item())
        loss = 0.2*IsQA_loss + 0.8*crf_loss

        nn.utils.clip_grad_norm_(model.parameters(), 3.0)#设置梯度截断阈值
        loss.backward()## 计算梯度
        optimizer.step()## 根据计算的梯度更新网络参数


        if i % 1000 == 0 and i > 0:
            accCRF = result_metric(CRFprediction_all, y_CRF_all)
            accIsQA = result_metric(IsQA_prediction_all, y_IsQA_all)

            print("<Last 100 Steps MeanValue> Setp-{} IsQA-Loss: {:.3f} CRF-Loss: {:.3f}  "
                  "CRF-Result: accCRF = {:.3f}  IsQA-Result: accIsQA = {:.3f}"
                  .format(i,np.mean(IsQAloss_all), np.mean(CRFloss_all),accCRF,accIsQA))

        if i % 2000 == 0:
            print("Eval on Devset...")
            accIsQA, accCRF = Eval(model, dev_iter)
            if accIsQA * accCRF > best_acc:
                best_acc = accIsQA * accCRF
                if i>0:
                    print("Devdata 精度提升 备份模型至{}".format(hp.model_back))
                    model.save_pretrained(hp.model_back)
            model.train()

In [24]:
def Eval(model, iterator):

    model.eval()
    CRFprediction_all, CRFloss_all, IsQAloss_all, y_CRF_all, IsQA_prediction_all, y_IsQA_all= [],[],[],[],[],[]
    final_pred_all = []
    for i, batch in enumerate(iterator):
        _, tokens_id_l, token_type_ids_l, answer_offset_l, answer_seq_label_l, IsQA_l = batch
        IsQA_prediction, CRFprediction, IsQA_loss, crf_loss, y_2d, y_IsQA_2d  = model.forward(tokens_id_l, token_type_ids_l, answer_offset_l, answer_seq_label_l,IsQA_l)

        ## CRF
        CRFprediction_all.append(CRFprediction)
        y_CRF_all.append(y_2d)

        ## IsQA
        IsQA_prediction_all.append(IsQA_prediction)
        y_IsQA_all.append(y_IsQA_2d)

        ## 综合预测
        # [batch_size,seq_len]
        final_pred = torch.LongTensor(np.zeros(CRFprediction.size())).to("cuda")
        final_pred[IsQA_prediction.squeeze(dim=-1)==1] = CRFprediction[IsQA_prediction.squeeze(dim=-1)==1]
        final_pred_all.append(final_pred)

        CRFloss_all.append(crf_loss.to("cpu").item())
        IsQAloss_all.append(IsQA_loss.to("cpu").item())


    accCRF = result_metric(CRFprediction_all, y_CRF_all)
    accIsQA = result_metric(IsQA_prediction_all, y_IsQA_all)
    accFinal = result_metric(final_pred_all, y_CRF_all)

    print("<本次评估结果> IsQA-Loss: {:.3f} CRF-Loss: {:.3f} CRF-Result: accCRF = {:.3f} "
          "IsQA-Result: accIsQA = {:.3f} Final-Result：accFinal = {:.3f}".
          format(np.mean(IsQAloss_all), np.mean(CRFloss_all), accCRF, accIsQA, accFinal))

    return accIsQA, accCRF



In [17]:
dir_path = r'WebQA.v1.0'
trainset_path = 'me_train.json'
dev_path = 'me_validation.ann.json'
test_path = 'me_test.ann.json'
mode_path = r'bert-base-chinese'


train_dataset = WebQADataset(os.path.join(dir_path ,trainset_path))
dev_dataset = WebQADataset(os.path.join(dir_path ,dev_path))
test_dataset = WebQADataset(os.path.join(dir_path ,test_path))

samples_weight = train_dataset.get_samples_weight(hp.Negweight)
sampler = torch.utils.data.WeightedRandomSampler(samples_weight, len(samples_weight))

train_iter = data.DataLoader(dataset=train_dataset,
                                     batch_size=hp.batch_size,
                                     shuffle=False,
                                     sampler=sampler,
                                     num_workers=0,
                                     collate_fn=pad
                                     )
dev_iter = data.DataLoader(dataset=dev_dataset,
                                   batch_size=hp.batch_size,
                                   shuffle=False,
                                   num_workers=0,
                                   collate_fn=pad
                                   )
test_iter = data.DataLoader(dataset=test_dataset,
                                    batch_size=hp.batch_size,
                                    shuffle=False,
                                    num_workers=0,
                                    collate_fn=pad
                                    )


In [25]:
import gc
import torch



gc.collect()
torch.cuda.empty_cache()


In [26]:
if os.path.exists(hp.model_path):
    print('=======载入模型=======')
    model = torch.load(hp.model_path)
else:
    print("=======初始化模型======")
    model = BertForQuestionAnsweringWithCRF.from_pretrained(mode_path )
    if hp.device == 'cuda':
        model = model.cuda()
    #model = nn.DataParallel(model)

optimizer = optim.Adam(model.parameters(), lr=hp.lr, weight_decay=hp.l2)

if not os.path.exists(os.path.split(hp.model_path)[0]):
    os.makedirs(os.path.split(hp.model_path)[0])

print("First Eval On TestData")
accIsQA, accCRF = Eval(model, test_iter)

best_acc = max(0, accIsQA*accCRF )
no_gain_rc = 0    #效果不增加代数

for epoch in range(1, hp.n_epochs + 1):
    print(f"=========TRAIN and EVAL at epoch={epoch}=========")
    TrainOneEpoch(model, train_iter, dev_iter,test_iter, optimizer, hp)

    # print(f"=========eval dev at epoch={epoch}=========")
    # dev_acc = eval(model, dev_iter)
    print(f"=========eval test at epoch={epoch}=========")
    accIsQA, accCRF = Eval(model, test_iter)

    if accIsQA*accCRF >best_acc:
        print("精度值由 {:.3f} 更新至 {:.3f} ".format(best_acc, accIsQA*accCRF))
        best_acc = accIsQA*accCRF
        print("=======保存模型=======")
        torch.save(model, hp.model_path)
        no_gain_rc = 0
    else:
        no_gain_rc = no_gain_rc+1

        # 提前终止
    if no_gain_rc > hp.early_stop:
        print("连续{}个epoch没有提升，在epoch={}提前终止".format(no_gain_rc,epoch))
        break
        



Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForQuestionAnsweringWithCRF: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForQuestionAnsweringWithCRF from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnsweringWithCRF from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnsweringWithCRF were not initialized from the model check

First Eval On TestData


  0%|                                                                                        | 0/37371 [00:00<?, ?it/s]

<本次评估结果> IsQA-Loss: 0.619 CRF-Loss: 901.782 CRF-Result: accCRF = 0.000 IsQA-Result: accIsQA = 0.733 Final-Result：accFinal = 0.001
Eval on Devset...


  0%|                                                                            | 1/37371 [00:45<475:22:49, 45.80s/it]

<本次评估结果> IsQA-Loss: 0.646 CRF-Loss: 570.588 CRF-Result: accCRF = 0.000 IsQA-Result: accIsQA = 0.667 Final-Result：accFinal = 0.002


  0%|▏                                                                           | 101/37371 [01:50<7:00:31,  1.48it/s]

<Last 100 Steps MeanValue> Setp-100 IsQA-Loss: 0.667 CRF-Loss: 176.391  CRF-Result: accCRF = 0.035  IsQA-Result: accIsQA = 0.610


  1%|▍                                                                           | 201/37371 [02:56<7:17:40,  1.42it/s]

<Last 100 Steps MeanValue> Setp-200 IsQA-Loss: 0.505 CRF-Loss: 137.478  CRF-Result: accCRF = 0.062  IsQA-Result: accIsQA = 0.788


  1%|▌                                                                           | 301/37371 [04:00<6:48:55,  1.51it/s]

<Last 100 Steps MeanValue> Setp-300 IsQA-Loss: 0.412 CRF-Loss: 121.576  CRF-Result: accCRF = 0.087  IsQA-Result: accIsQA = 0.846


  1%|▊                                                                           | 401/37371 [05:05<7:12:35,  1.42it/s]

<Last 100 Steps MeanValue> Setp-400 IsQA-Loss: 0.351 CRF-Loss: 111.794  CRF-Result: accCRF = 0.112  IsQA-Result: accIsQA = 0.877


  1%|█                                                                           | 501/37371 [06:10<7:30:05,  1.37it/s]

<Last 100 Steps MeanValue> Setp-500 IsQA-Loss: 0.314 CRF-Loss: 104.357  CRF-Result: accCRF = 0.139  IsQA-Result: accIsQA = 0.895


  2%|█▏                                                                          | 601/37371 [07:15<7:11:51,  1.42it/s]

<Last 100 Steps MeanValue> Setp-600 IsQA-Loss: 0.290 CRF-Loss: 98.636  CRF-Result: accCRF = 0.161  IsQA-Result: accIsQA = 0.906


  2%|█▍                                                                          | 701/37371 [08:19<7:43:13,  1.32it/s]

<Last 100 Steps MeanValue> Setp-700 IsQA-Loss: 0.273 CRF-Loss: 94.103  CRF-Result: accCRF = 0.179  IsQA-Result: accIsQA = 0.913


  2%|█▋                                                                          | 801/37371 [09:23<6:52:00,  1.48it/s]

<Last 100 Steps MeanValue> Setp-800 IsQA-Loss: 0.261 CRF-Loss: 90.363  CRF-Result: accCRF = 0.197  IsQA-Result: accIsQA = 0.919


  2%|█▊                                                                          | 901/37371 [10:28<7:25:56,  1.36it/s]

<Last 100 Steps MeanValue> Setp-900 IsQA-Loss: 0.250 CRF-Loss: 86.513  CRF-Result: accCRF = 0.219  IsQA-Result: accIsQA = 0.924


  3%|██                                                                         | 1000/37371 [11:33<6:42:58,  1.50it/s]

<Last 100 Steps MeanValue> Setp-1000 IsQA-Loss: 0.242 CRF-Loss: 83.563  CRF-Result: accCRF = 0.236  IsQA-Result: accIsQA = 0.927
Eval on Devset...
<本次评估结果> IsQA-Loss: 0.066 CRF-Loss: 42.741 CRF-Result: accCRF = 0.486 IsQA-Result: accIsQA = 0.992 Final-Result：accFinal = 0.486
Devdata 精度提升 备份模型至save_model/back_model.bin


  3%|██▏                                                                        | 1101/37371 [13:26<7:51:12,  1.28it/s]

<Last 100 Steps MeanValue> Setp-1100 IsQA-Loss: 0.235 CRF-Loss: 80.781  CRF-Result: accCRF = 0.252  IsQA-Result: accIsQA = 0.931


  3%|██▍                                                                        | 1201/37371 [14:34<8:13:01,  1.22it/s]

<Last 100 Steps MeanValue> Setp-1200 IsQA-Loss: 0.229 CRF-Loss: 78.461  CRF-Result: accCRF = 0.265  IsQA-Result: accIsQA = 0.933


  3%|██▌                                                                        | 1301/37371 [15:38<7:13:01,  1.39it/s]

<Last 100 Steps MeanValue> Setp-1300 IsQA-Loss: 0.226 CRF-Loss: 76.469  CRF-Result: accCRF = 0.277  IsQA-Result: accIsQA = 0.935


  4%|██▊                                                                        | 1401/37371 [16:42<8:09:22,  1.23it/s]

<Last 100 Steps MeanValue> Setp-1400 IsQA-Loss: 0.220 CRF-Loss: 74.367  CRF-Result: accCRF = 0.291  IsQA-Result: accIsQA = 0.937


  4%|███                                                                        | 1501/37371 [17:47<7:56:36,  1.25it/s]

<Last 100 Steps MeanValue> Setp-1500 IsQA-Loss: 0.216 CRF-Loss: 72.611  CRF-Result: accCRF = 0.303  IsQA-Result: accIsQA = 0.939


  4%|███▏                                                                       | 1601/37371 [18:53<8:40:10,  1.15it/s]

<Last 100 Steps MeanValue> Setp-1600 IsQA-Loss: 0.210 CRF-Loss: 71.191  CRF-Result: accCRF = 0.312  IsQA-Result: accIsQA = 0.941


  5%|███▍                                                                       | 1701/37371 [20:00<8:43:13,  1.14it/s]

<Last 100 Steps MeanValue> Setp-1700 IsQA-Loss: 0.207 CRF-Loss: 69.994  CRF-Result: accCRF = 0.321  IsQA-Result: accIsQA = 0.942


  5%|███▌                                                                       | 1801/37371 [21:06<8:23:45,  1.18it/s]

<Last 100 Steps MeanValue> Setp-1800 IsQA-Loss: 0.203 CRF-Loss: 68.775  CRF-Result: accCRF = 0.329  IsQA-Result: accIsQA = 0.944


  5%|███▊                                                                       | 1901/37371 [22:12<8:24:07,  1.17it/s]

<Last 100 Steps MeanValue> Setp-1900 IsQA-Loss: 0.202 CRF-Loss: 67.510  CRF-Result: accCRF = 0.338  IsQA-Result: accIsQA = 0.944


  5%|████                                                                       | 2000/37371 [23:17<6:24:16,  1.53it/s]

<Last 100 Steps MeanValue> Setp-2000 IsQA-Loss: 0.199 CRF-Loss: 66.417  CRF-Result: accCRF = 0.347  IsQA-Result: accIsQA = 0.945
Eval on Devset...
<本次评估结果> IsQA-Loss: 0.062 CRF-Loss: 32.140 CRF-Result: accCRF = 0.626 IsQA-Result: accIsQA = 0.992 Final-Result：accFinal = 0.626
Devdata 精度提升 备份模型至save_model/back_model.bin


  6%|████▏                                                                      | 2101/37371 [25:11<8:40:12,  1.13it/s]

<Last 100 Steps MeanValue> Setp-2100 IsQA-Loss: 0.197 CRF-Loss: 65.453  CRF-Result: accCRF = 0.354  IsQA-Result: accIsQA = 0.946


  6%|████▍                                                                      | 2201/37371 [26:18<8:26:33,  1.16it/s]

<Last 100 Steps MeanValue> Setp-2200 IsQA-Loss: 0.195 CRF-Loss: 64.593  CRF-Result: accCRF = 0.361  IsQA-Result: accIsQA = 0.947


  6%|████▌                                                                      | 2301/37371 [27:25<8:56:47,  1.09it/s]

<Last 100 Steps MeanValue> Setp-2300 IsQA-Loss: 0.193 CRF-Loss: 63.665  CRF-Result: accCRF = 0.368  IsQA-Result: accIsQA = 0.948


  6%|████▊                                                                      | 2401/37371 [28:31<8:49:17,  1.10it/s]

<Last 100 Steps MeanValue> Setp-2400 IsQA-Loss: 0.191 CRF-Loss: 62.854  CRF-Result: accCRF = 0.374  IsQA-Result: accIsQA = 0.949


  7%|█████                                                                      | 2501/37371 [29:36<8:46:08,  1.10it/s]

<Last 100 Steps MeanValue> Setp-2500 IsQA-Loss: 0.189 CRF-Loss: 61.922  CRF-Result: accCRF = 0.381  IsQA-Result: accIsQA = 0.949


  7%|█████▏                                                                     | 2601/37371 [30:43<8:54:25,  1.08it/s]

<Last 100 Steps MeanValue> Setp-2600 IsQA-Loss: 0.187 CRF-Loss: 61.117  CRF-Result: accCRF = 0.387  IsQA-Result: accIsQA = 0.950


  7%|█████▍                                                                     | 2701/37371 [31:48<9:05:09,  1.06it/s]

<Last 100 Steps MeanValue> Setp-2700 IsQA-Loss: 0.187 CRF-Loss: 60.324  CRF-Result: accCRF = 0.393  IsQA-Result: accIsQA = 0.950


  7%|█████▌                                                                     | 2801/37371 [32:53<9:00:49,  1.07it/s]

<Last 100 Steps MeanValue> Setp-2800 IsQA-Loss: 0.186 CRF-Loss: 59.729  CRF-Result: accCRF = 0.398  IsQA-Result: accIsQA = 0.951


  8%|█████▊                                                                     | 2901/37371 [33:59<9:23:12,  1.02it/s]

<Last 100 Steps MeanValue> Setp-2900 IsQA-Loss: 0.184 CRF-Loss: 59.074  CRF-Result: accCRF = 0.404  IsQA-Result: accIsQA = 0.951


  8%|██████                                                                     | 3000/37371 [35:04<6:19:57,  1.51it/s]

<Last 100 Steps MeanValue> Setp-3000 IsQA-Loss: 0.184 CRF-Loss: 58.487  CRF-Result: accCRF = 0.408  IsQA-Result: accIsQA = 0.952
Eval on Devset...
<本次评估结果> IsQA-Loss: 0.060 CRF-Loss: 28.914 CRF-Result: accCRF = 0.673 IsQA-Result: accIsQA = 0.992 Final-Result：accFinal = 0.673
Devdata 精度提升 备份模型至save_model/back_model.bin


  8%|██████▏                                                                    | 3101/37371 [36:59<8:25:02,  1.13it/s]

<Last 100 Steps MeanValue> Setp-3100 IsQA-Loss: 0.184 CRF-Loss: 57.869  CRF-Result: accCRF = 0.413  IsQA-Result: accIsQA = 0.952


  9%|██████▍                                                                    | 3201/37371 [38:07<9:28:21,  1.00it/s]

<Last 100 Steps MeanValue> Setp-3200 IsQA-Loss: 0.183 CRF-Loss: 57.329  CRF-Result: accCRF = 0.417  IsQA-Result: accIsQA = 0.952


  9%|██████▌                                                                    | 3301/37371 [39:14<8:53:01,  1.07it/s]

<Last 100 Steps MeanValue> Setp-3300 IsQA-Loss: 0.182 CRF-Loss: 56.724  CRF-Result: accCRF = 0.422  IsQA-Result: accIsQA = 0.952


  9%|██████▊                                                                    | 3401/37371 [40:19<9:53:42,  1.05s/it]

<Last 100 Steps MeanValue> Setp-3400 IsQA-Loss: 0.181 CRF-Loss: 56.196  CRF-Result: accCRF = 0.426  IsQA-Result: accIsQA = 0.953


  9%|███████                                                                    | 3501/37371 [41:25<9:25:04,  1.00s/it]

<Last 100 Steps MeanValue> Setp-3500 IsQA-Loss: 0.180 CRF-Loss: 55.685  CRF-Result: accCRF = 0.430  IsQA-Result: accIsQA = 0.953


 10%|███████▏                                                                   | 3601/37371 [42:32<9:51:00,  1.05s/it]

<Last 100 Steps MeanValue> Setp-3600 IsQA-Loss: 0.180 CRF-Loss: 55.123  CRF-Result: accCRF = 0.435  IsQA-Result: accIsQA = 0.953


 10%|███████▎                                                                  | 3701/37371 [43:38<10:19:28,  1.10s/it]

<Last 100 Steps MeanValue> Setp-3700 IsQA-Loss: 0.180 CRF-Loss: 54.635  CRF-Result: accCRF = 0.439  IsQA-Result: accIsQA = 0.953


 10%|███████▋                                                                   | 3801/37371 [44:46<9:22:13,  1.00s/it]

<Last 100 Steps MeanValue> Setp-3800 IsQA-Loss: 0.179 CRF-Loss: 54.214  CRF-Result: accCRF = 0.442  IsQA-Result: accIsQA = 0.953


 10%|███████▊                                                                   | 3901/37371 [45:53<9:59:45,  1.08s/it]

<Last 100 Steps MeanValue> Setp-3900 IsQA-Loss: 0.179 CRF-Loss: 53.817  CRF-Result: accCRF = 0.446  IsQA-Result: accIsQA = 0.953


 11%|████████                                                                   | 4000/37371 [46:57<6:12:14,  1.49it/s]

<Last 100 Steps MeanValue> Setp-4000 IsQA-Loss: 0.178 CRF-Loss: 53.330  CRF-Result: accCRF = 0.449  IsQA-Result: accIsQA = 0.953
Eval on Devset...
<本次评估结果> IsQA-Loss: 0.059 CRF-Loss: 26.917 CRF-Result: accCRF = 0.706 IsQA-Result: accIsQA = 0.992 Final-Result：accFinal = 0.706
Devdata 精度提升 备份模型至save_model/back_model.bin


 11%|████████▏                                                                  | 4101/37371 [48:52<9:29:53,  1.03s/it]

<Last 100 Steps MeanValue> Setp-4100 IsQA-Loss: 0.177 CRF-Loss: 52.841  CRF-Result: accCRF = 0.454  IsQA-Result: accIsQA = 0.954


 11%|████████▎                                                                 | 4201/37371 [49:59<10:12:21,  1.11s/it]

<Last 100 Steps MeanValue> Setp-4200 IsQA-Loss: 0.176 CRF-Loss: 52.368  CRF-Result: accCRF = 0.457  IsQA-Result: accIsQA = 0.954


 12%|████████▌                                                                 | 4301/37371 [51:04<10:01:06,  1.09s/it]

<Last 100 Steps MeanValue> Setp-4300 IsQA-Loss: 0.175 CRF-Loss: 51.950  CRF-Result: accCRF = 0.461  IsQA-Result: accIsQA = 0.954


 12%|████████▋                                                                 | 4401/37371 [52:11<10:25:49,  1.14s/it]

<Last 100 Steps MeanValue> Setp-4400 IsQA-Loss: 0.174 CRF-Loss: 51.517  CRF-Result: accCRF = 0.465  IsQA-Result: accIsQA = 0.955


 12%|████████▉                                                                 | 4501/37371 [53:18<10:38:05,  1.16s/it]

<Last 100 Steps MeanValue> Setp-4500 IsQA-Loss: 0.173 CRF-Loss: 51.121  CRF-Result: accCRF = 0.469  IsQA-Result: accIsQA = 0.955


 12%|█████████▏                                                                 | 4601/37371 [54:25<9:52:45,  1.09s/it]

<Last 100 Steps MeanValue> Setp-4600 IsQA-Loss: 0.171 CRF-Loss: 50.801  CRF-Result: accCRF = 0.472  IsQA-Result: accIsQA = 0.955


 13%|█████████▎                                                                | 4701/37371 [55:33<10:25:30,  1.15s/it]

<Last 100 Steps MeanValue> Setp-4700 IsQA-Loss: 0.171 CRF-Loss: 50.489  CRF-Result: accCRF = 0.474  IsQA-Result: accIsQA = 0.955


 13%|█████████▋                                                                 | 4801/37371 [56:41<9:56:56,  1.10s/it]

<Last 100 Steps MeanValue> Setp-4800 IsQA-Loss: 0.171 CRF-Loss: 50.169  CRF-Result: accCRF = 0.477  IsQA-Result: accIsQA = 0.955


 13%|█████████▋                                                                | 4901/37371 [57:48<10:31:04,  1.17s/it]

<Last 100 Steps MeanValue> Setp-4900 IsQA-Loss: 0.170 CRF-Loss: 49.769  CRF-Result: accCRF = 0.480  IsQA-Result: accIsQA = 0.955


 13%|██████████                                                                 | 5000/37371 [58:52<6:11:39,  1.45it/s]

<Last 100 Steps MeanValue> Setp-5000 IsQA-Loss: 0.170 CRF-Loss: 49.456  CRF-Result: accCRF = 0.483  IsQA-Result: accIsQA = 0.956
Eval on Devset...
<本次评估结果> IsQA-Loss: 0.059 CRF-Loss: 26.119 CRF-Result: accCRF = 0.713 IsQA-Result: accIsQA = 0.992 Final-Result：accFinal = 0.713
Devdata 精度提升 备份模型至save_model/back_model.bin


 14%|█████████▊                                                              | 5101/37371 [1:00:49<10:51:54,  1.21s/it]

<Last 100 Steps MeanValue> Setp-5100 IsQA-Loss: 0.169 CRF-Loss: 49.122  CRF-Result: accCRF = 0.486  IsQA-Result: accIsQA = 0.956


 14%|██████████                                                              | 5201/37371 [1:01:57<10:37:05,  1.19s/it]

<Last 100 Steps MeanValue> Setp-5200 IsQA-Loss: 0.168 CRF-Loss: 48.813  CRF-Result: accCRF = 0.489  IsQA-Result: accIsQA = 0.956


 14%|██████████▏                                                             | 5301/37371 [1:03:02<10:40:44,  1.20s/it]

<Last 100 Steps MeanValue> Setp-5300 IsQA-Loss: 0.168 CRF-Loss: 48.515  CRF-Result: accCRF = 0.491  IsQA-Result: accIsQA = 0.956


 14%|██████████▍                                                             | 5401/37371 [1:04:09<10:42:30,  1.21s/it]

<Last 100 Steps MeanValue> Setp-5400 IsQA-Loss: 0.167 CRF-Loss: 48.255  CRF-Result: accCRF = 0.494  IsQA-Result: accIsQA = 0.956


 15%|██████████▌                                                             | 5501/37371 [1:05:16<11:03:47,  1.25s/it]

<Last 100 Steps MeanValue> Setp-5500 IsQA-Loss: 0.166 CRF-Loss: 47.938  CRF-Result: accCRF = 0.496  IsQA-Result: accIsQA = 0.956


 15%|██████████▊                                                             | 5601/37371 [1:06:22<11:15:29,  1.28s/it]

<Last 100 Steps MeanValue> Setp-5600 IsQA-Loss: 0.166 CRF-Loss: 47.626  CRF-Result: accCRF = 0.499  IsQA-Result: accIsQA = 0.956


 15%|██████████▉                                                             | 5701/37371 [1:07:29<11:01:29,  1.25s/it]

<Last 100 Steps MeanValue> Setp-5700 IsQA-Loss: 0.165 CRF-Loss: 47.305  CRF-Result: accCRF = 0.502  IsQA-Result: accIsQA = 0.956


 16%|███████████▏                                                            | 5801/37371 [1:08:34<10:03:29,  1.15s/it]

<Last 100 Steps MeanValue> Setp-5800 IsQA-Loss: 0.164 CRF-Loss: 47.004  CRF-Result: accCRF = 0.504  IsQA-Result: accIsQA = 0.957


 16%|███████████▎                                                            | 5901/37371 [1:09:41<10:39:25,  1.22s/it]

<Last 100 Steps MeanValue> Setp-5900 IsQA-Loss: 0.164 CRF-Loss: 46.772  CRF-Result: accCRF = 0.506  IsQA-Result: accIsQA = 0.957


 16%|███████████▋                                                             | 6000/37371 [1:10:45<5:37:29,  1.55it/s]

<Last 100 Steps MeanValue> Setp-6000 IsQA-Loss: 0.163 CRF-Loss: 46.499  CRF-Result: accCRF = 0.509  IsQA-Result: accIsQA = 0.957
Eval on Devset...
<本次评估结果> IsQA-Loss: 0.053 CRF-Loss: 24.179 CRF-Result: accCRF = 0.738 IsQA-Result: accIsQA = 0.992 Final-Result：accFinal = 0.738
Devdata 精度提升 备份模型至save_model/back_model.bin


 16%|███████████▊                                                            | 6101/37371 [1:12:44<11:27:51,  1.32s/it]

<Last 100 Steps MeanValue> Setp-6100 IsQA-Loss: 0.163 CRF-Loss: 46.232  CRF-Result: accCRF = 0.511  IsQA-Result: accIsQA = 0.957


 17%|███████████▉                                                            | 6201/37371 [1:13:52<11:21:15,  1.31s/it]

<Last 100 Steps MeanValue> Setp-6200 IsQA-Loss: 0.162 CRF-Loss: 45.991  CRF-Result: accCRF = 0.513  IsQA-Result: accIsQA = 0.957


 17%|████████████▏                                                           | 6301/37371 [1:14:59<11:38:02,  1.35s/it]

<Last 100 Steps MeanValue> Setp-6300 IsQA-Loss: 0.162 CRF-Loss: 45.729  CRF-Result: accCRF = 0.515  IsQA-Result: accIsQA = 0.957


 17%|████████████▎                                                           | 6401/37371 [1:16:08<11:36:05,  1.35s/it]

<Last 100 Steps MeanValue> Setp-6400 IsQA-Loss: 0.162 CRF-Loss: 45.480  CRF-Result: accCRF = 0.518  IsQA-Result: accIsQA = 0.957


 17%|████████████▌                                                           | 6501/37371 [1:17:16<12:08:23,  1.42s/it]

<Last 100 Steps MeanValue> Setp-6500 IsQA-Loss: 0.161 CRF-Loss: 45.234  CRF-Result: accCRF = 0.520  IsQA-Result: accIsQA = 0.957


 18%|████████████▋                                                           | 6601/37371 [1:18:23<11:43:17,  1.37s/it]

<Last 100 Steps MeanValue> Setp-6600 IsQA-Loss: 0.161 CRF-Loss: 44.976  CRF-Result: accCRF = 0.522  IsQA-Result: accIsQA = 0.957


 18%|████████████▉                                                           | 6701/37371 [1:19:31<11:32:19,  1.35s/it]

<Last 100 Steps MeanValue> Setp-6700 IsQA-Loss: 0.161 CRF-Loss: 44.726  CRF-Result: accCRF = 0.524  IsQA-Result: accIsQA = 0.957


 18%|█████████████                                                           | 6801/37371 [1:20:39<11:31:42,  1.36s/it]

<Last 100 Steps MeanValue> Setp-6800 IsQA-Loss: 0.160 CRF-Loss: 44.487  CRF-Result: accCRF = 0.526  IsQA-Result: accIsQA = 0.957


 18%|█████████████▎                                                          | 6901/37371 [1:21:46<11:52:07,  1.40s/it]

<Last 100 Steps MeanValue> Setp-6900 IsQA-Loss: 0.160 CRF-Loss: 44.262  CRF-Result: accCRF = 0.528  IsQA-Result: accIsQA = 0.957


 19%|█████████████▋                                                           | 7000/37371 [1:22:50<5:05:05,  1.66it/s]

<Last 100 Steps MeanValue> Setp-7000 IsQA-Loss: 0.159 CRF-Loss: 44.064  CRF-Result: accCRF = 0.530  IsQA-Result: accIsQA = 0.957
Eval on Devset...
<本次评估结果> IsQA-Loss: 0.058 CRF-Loss: 23.928 CRF-Result: accCRF = 0.742 IsQA-Result: accIsQA = 0.992 Final-Result：accFinal = 0.742
Devdata 精度提升 备份模型至save_model/back_model.bin


 19%|█████████████▋                                                          | 7101/37371 [1:24:47<11:19:29,  1.35s/it]

<Last 100 Steps MeanValue> Setp-7100 IsQA-Loss: 0.159 CRF-Loss: 43.892  CRF-Result: accCRF = 0.531  IsQA-Result: accIsQA = 0.957


 19%|█████████████▊                                                          | 7201/37371 [1:25:54<11:56:40,  1.43s/it]

<Last 100 Steps MeanValue> Setp-7200 IsQA-Loss: 0.158 CRF-Loss: 43.686  CRF-Result: accCRF = 0.533  IsQA-Result: accIsQA = 0.958


 20%|██████████████                                                          | 7301/37371 [1:27:02<11:50:34,  1.42s/it]

<Last 100 Steps MeanValue> Setp-7300 IsQA-Loss: 0.157 CRF-Loss: 43.478  CRF-Result: accCRF = 0.535  IsQA-Result: accIsQA = 0.958


 20%|██████████████▎                                                         | 7401/37371 [1:28:07<11:30:37,  1.38s/it]

<Last 100 Steps MeanValue> Setp-7400 IsQA-Loss: 0.157 CRF-Loss: 43.246  CRF-Result: accCRF = 0.537  IsQA-Result: accIsQA = 0.958


 20%|██████████████▍                                                         | 7501/37371 [1:29:15<13:23:39,  1.61s/it]

<Last 100 Steps MeanValue> Setp-7500 IsQA-Loss: 0.156 CRF-Loss: 43.065  CRF-Result: accCRF = 0.539  IsQA-Result: accIsQA = 0.958


 20%|██████████████▋                                                         | 7601/37371 [1:30:22<12:17:30,  1.49s/it]

<Last 100 Steps MeanValue> Setp-7600 IsQA-Loss: 0.156 CRF-Loss: 42.865  CRF-Result: accCRF = 0.540  IsQA-Result: accIsQA = 0.958


 21%|██████████████▊                                                         | 7701/37371 [1:31:31<12:01:02,  1.46s/it]

<Last 100 Steps MeanValue> Setp-7700 IsQA-Loss: 0.156 CRF-Loss: 42.667  CRF-Result: accCRF = 0.542  IsQA-Result: accIsQA = 0.958


 21%|███████████████                                                         | 7801/37371 [1:32:42<13:11:58,  1.61s/it]

<Last 100 Steps MeanValue> Setp-7800 IsQA-Loss: 0.156 CRF-Loss: 42.489  CRF-Result: accCRF = 0.544  IsQA-Result: accIsQA = 0.958


 21%|███████████████▏                                                        | 7901/37371 [1:33:51<11:54:49,  1.46s/it]

<Last 100 Steps MeanValue> Setp-7900 IsQA-Loss: 0.155 CRF-Loss: 42.312  CRF-Result: accCRF = 0.546  IsQA-Result: accIsQA = 0.958


 21%|███████████████▋                                                         | 8000/37371 [1:34:53<4:51:19,  1.68it/s]

<Last 100 Steps MeanValue> Setp-8000 IsQA-Loss: 0.155 CRF-Loss: 42.143  CRF-Result: accCRF = 0.547  IsQA-Result: accIsQA = 0.958
Eval on Devset...


 21%|███████████████▏                                                       | 8001/37371 [1:35:43<124:25:39, 15.25s/it]

<本次评估结果> IsQA-Loss: 0.060 CRF-Loss: 22.856 CRF-Result: accCRF = 0.737 IsQA-Result: accIsQA = 0.992 Final-Result：accFinal = 0.737


 22%|███████████████▌                                                        | 8101/37371 [1:36:49<11:32:20,  1.42s/it]

<Last 100 Steps MeanValue> Setp-8100 IsQA-Loss: 0.155 CRF-Loss: 41.979  CRF-Result: accCRF = 0.549  IsQA-Result: accIsQA = 0.958


 22%|███████████████▊                                                        | 8201/37371 [1:38:02<13:21:24,  1.65s/it]

<Last 100 Steps MeanValue> Setp-8200 IsQA-Loss: 0.154 CRF-Loss: 41.803  CRF-Result: accCRF = 0.551  IsQA-Result: accIsQA = 0.958


 22%|███████████████▉                                                        | 8301/37371 [1:39:13<12:07:15,  1.50s/it]

<Last 100 Steps MeanValue> Setp-8300 IsQA-Loss: 0.154 CRF-Loss: 41.626  CRF-Result: accCRF = 0.552  IsQA-Result: accIsQA = 0.958


 22%|████████████████▏                                                       | 8401/37371 [1:40:22<12:44:02,  1.58s/it]

<Last 100 Steps MeanValue> Setp-8400 IsQA-Loss: 0.154 CRF-Loss: 41.477  CRF-Result: accCRF = 0.554  IsQA-Result: accIsQA = 0.958


 23%|████████████████▍                                                       | 8501/37371 [1:41:31<11:47:26,  1.47s/it]

<Last 100 Steps MeanValue> Setp-8500 IsQA-Loss: 0.153 CRF-Loss: 41.288  CRF-Result: accCRF = 0.556  IsQA-Result: accIsQA = 0.958


 23%|████████████████▌                                                       | 8601/37371 [1:42:39<12:40:48,  1.59s/it]

<Last 100 Steps MeanValue> Setp-8600 IsQA-Loss: 0.153 CRF-Loss: 41.143  CRF-Result: accCRF = 0.557  IsQA-Result: accIsQA = 0.958


 23%|████████████████▊                                                        | 8626/37371 [1:42:55<5:43:00,  1.40it/s]


KeyboardInterrupt: 

In [None]:
 def predict(self,tokens_id_l, token_type_ids_l):
        tokens_x_2d = torch.LongTensor(tokens_id_l).to(self.device)
        token_type_ids_2d = torch.LongTensor(token_type_ids_l).to(self.device)

        batch_size, seq_length = tokens_x_2d[:,1:].size()
        self.PreModel.eval()
        with torch.no_grad():
            emb, _ = self.PreModel(tokens_x_2d,token_type_ids=token_type_ids_2d)

        ## [CLS] for IsQA  [batch_size, hidden_size]
        cls_emb = emb[:,0,:]
        ## [batch_size, 2]
        IsQA_logits = self.fc2(cls_emb)
        ## [batch_size, 1]
        IsQA_prediction = IsQA_logits.argmax(dim=-1)

        # CRF mask
        mask = np.ones(shape=[batch_size, seq_length], dtype=np.uint8)
        mask = torch.ByteTensor(mask).to(self.device)
        # [batch_size, seq_len, 4]
        crf_logits = self.CRF_fc1(emb[:,1:,:])
        _, CRFprediction = self.CRF.forward(feats=crf_logits, mask=mask)

        return IsQA_prediction.to("cpu"), CRFprediction.to("cpu")

In [None]:
def Demo(model, q, e):
    tokens = tokenizer.tokenize('[CLS]' + q + '[SEP]' + e)  # list
    if len(tokens) > 512:
        tokens = tokens[:512]
    tokens_id = [tokenizer.convert_tokens_to_ids(tokens)]#[[101,...,102,...]]
    token_type_ids = [0 if i <= tokens_id[0].index(102) else 1 for i in range(len(tokens_id))]
    IsQA_prediction, CRFprediction = model.module.predict(tokens_id,token_type_ids)
    CRFprediction = CRFprediction.numpy()[0]
    IsQA_prediction = IsQA_prediction.numpy()[0]
    answer = ""
    if IsQA_prediction==1:
        for i in range(len(tokens[1:])):
            if CRFprediction[i].item()==1:
                answer = answer + tokens[1:][i]
    return answer

def prepare_knowledge(knowledge_path,Stopword_path):
    def del_stopword(line, Stopword, ngram=False):
        line = list(jieba.cut(line))
        new = [word for word in line if word not in Stopword]
        if ngram:  # 返回2元语法
            N = len(line)
            for i, word_i in enumerate(line):
                for j in range(min(i + 1, N - 1), N):
                    word_j = line[j]
                    if word_i not in Stopword and word_j not in Stopword:
                        new.append(word_i + ' ' + word_j)
        return new  # [w1,w2,...]
    print("正在准备知识库...")
    dataset = Knowledge(knowledge_path)
    with open(Stopword_path, "r", encoding="gbk") as f:
        Stopword = set(f.read().splitlines())
    documents = dataset.evidences
    corpus = [" ".join(del_stopword(e, Stopword)) for e in documents]
    vectorizer = CountVectorizer()  # ngram_range=(1,2)
    count = vectorizer.fit_transform(corpus)
    # 计算TF-IDF向量
    TFIDF = TfidfTransformer()
    tfidf_matrix = TFIDF.fit_transform(count)
    d_matrix = np.array(tfidf_matrix.toarray())
    vocabulary_ = vectorizer.vocabulary_
    return d_matrix, vocabulary_, del_stopword, Stopword, dataset

def QA(model, question, xu, knowledge):
    # 按相关度从大到小 #[1,2,3,...]
    xu.reverse()
    xu = [item[0] for item in xu]
    ## 有序字典 按相关度从大到小插入key
    result =collections.OrderedDict()
    q = question
    tokens_id_l = []
    token_type_ids_l = []
    tokens_l = []
    for index in xu:
        e=knowledge.evidences[index]
        tokens = tokenizer.tokenize('[CLS]' + q + '[SEP]' + e)  # list
        if len(tokens) > 512:
            tokens = tokens[:512]
        tokens_id = tokenizer.convert_tokens_to_ids(tokens)  # [101,...,102,...]
        token_type_ids = [0 if i <= tokens_id.index(102) else 1 for i in range(len(tokens_id))]
        tokens_id_l.append(tokens_id)
        token_type_ids_l.append(token_type_ids)
        tokens_l.append(tokens[1:])
    ## pad
    max_len = max([len(x) for x in tokens_id_l])
    tokens_id_l = [x+(max_len - len(x))*tokenizer.convert_tokens_to_ids(['[PAD]']) for x in tokens_id_l ]
    token_type_ids_l = [x+(max_len - len(x))*[1] for x in token_type_ids_l ]

    ## 批预测
    IsQA_prediction, CRFprediction = model.module.predict(tokens_id_l,token_type_ids_l)
    CRFprediction = CRFprediction.numpy() #[batch_size, max_len]
    IsQA_prediction = IsQA_prediction.numpy()#[batch_size, 1]
    for k in range(len(xu)):
        answer = ""
        tokens = tokens_l[k]
        if IsQA_prediction[k]==1:#[cls]判断
            for i in range(len(tokens)):
                if CRFprediction[k,i].item()==1:
                    answer = answer + tokens[i]
        # 记录answer
        if answer in result:
            result[answer] = result[answer] + 1
        else:
            if answer:
                result[answer] = 1
    return result


In [None]:
model = torch.load(hp.model_path)
            ques_num=1
            while True:
                print("请输入问题-{}:".format(ques_num))
                question = input()
                if question == "OVER":
                    print("问答结束！")
                    break
                print("请输入文章：")
                evidence = input()
                # print("正在解析...")
                start = time.time()
                answer = Demo(model,question,evidence)
                end = time.time()
                if answer:
                    print("问题-{}的答案是：{}".format(ques_num,answer))
                    print("耗时:{:.2f}毫秒".format((end-start)*1e3))
                else:
                    print("文章中没有答案")
                ques_num = ques_num + 1
        else:
            print("没有可用模型！")

In [None]:
model = torch.load(hp.model_path)
            ques_num=1
            #准备知识库
            d_matrix, vocabulary_, del_stopword, Stopword, knowledge = prepare_knowledge(hp.knowledge_path,hp.Stopword_path)
            while True:
                # try:
                print("请输入问题-{}:".format(ques_num))
                question = input()
                if question == "OVER":
                    print("问答结束！")
                    break

                # 创建问句tf-idf向量
                q_vector = np.zeros([1, d_matrix.shape[1]])
                q_list = del_stopword(question,Stopword,ngram=False)
                for word in q_list:
                    if word in vocabulary_:
                        q_vector[0,vocabulary_[word]] = 1.
                dot = (np.mat(d_matrix))*(np.mat(q_vector.T))
                xu=dot.argsort(0)[-15:].tolist()# [[12], [37], [10]] 最大的15个索引

                start = time.time()
                answer = QA(model, question, xu, knowledge)
                end = time.time()

                print("问题-{}的答案是:".format(ques_num))
                for ai,a in enumerate(answer):
                    print("推荐答案No.{}: {}".format(ai+1,a))
                print("耗时:{:.2f}毫秒".format((end-start)*1e3))

                ques_num = ques_num + 1
        else:
            print("没有可用模型！")

## 参考资料：https://github.com/Hanlard/Bert-for-WebQA/blob/master/model.py