以下部分为测试UE Rewriter在pretrained dialogue system上测试提供数据。rewriter_data.txt为将所有unseen entity 经UE Rewriter替换成了seen entity后形成的“新”数据集。

### 0. Data

In [1]:
#read all data
docs=[]
import json
f = open('../wizard_of_wikipedia/data.json')
data = json.load(f)
  
for i in data:
    docs.append(i)

f.close()

In [2]:
len(docs)

22311

### 1. whether inputs contain unseen entities
- 定位id

In [3]:
from tqdm import tqdm
import re
import numpy as np
from transformers import BertTokenizer
import nltk
import pandas as pd

all_data = []
unseen_dataset = pd.DataFrame()
sentences=[]
unseen_entities=[]
doc_nums=[]
dialog_indices=[]


for doc_num in tqdm(range(len(docs))):
    dialog = []
    for i in docs[doc_num]['dialog']:
        dialog.append(i['text'])
    dialog_lower = [text.lower() for text in dialog]
    all_data.append(dialog_lower)
    
    #build vocabulary
    text = ''.join(dialog)
    clean_text = re.sub(r"[,.;@#?!&$/]+\ *", " ", text)
    vocabulary = set(clean_text.lower().split())
    
    #BERT tokenization
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    indexed_tokens = []
    for text in dialog:   
        tokenized_text = tokenizer.tokenize(text)
        indexed_tokens.append(tokenizer.convert_tokens_to_ids(tokenized_text))
    
    #BERT tokenized vocabulary
    compact = []
    for i in indexed_tokens:
        compact.extend(i) 
    tokenized_vocab = set(compact)
    
    #BERT text vocabulary
    new_vocab = set(tokenizer.convert_ids_to_tokens(tokenized_vocab))
    
    #unseen words in BERT
    unseen = vocabulary.difference(new_vocab)
    unseen = list(unseen)

    #find sentences with unseen entities
    for word in unseen:
        indices = [i for i, x in enumerate([word in i for i in dialog_lower]) if x == True] 
        for index in indices:
            sentence = dialog_lower[index]
            result = nltk.pos_tag(nltk.word_tokenize(sentence))
            result = dict(result)
            if word in result:
                if result[word] in ['NN', 'NNS', 'NNP', 'NNPS']:
                    sentences.append(dialog[index])
                    unseen_entities.append(word)
                    doc_nums.append(doc_num)
                    dialog_indices.append(index)

unseen_dataset['unseen entity'] = unseen_entities
#unseen_dataset['sentence'] = sentences
unseen_dataset['doc number'] = doc_nums
unseen_dataset['dialog index'] = dialog_indices

  0%|                                        | 2/2312 [00:16<5:21:14,  8.34s/it]


KeyboardInterrupt: 

### 2. UE Rewriter
- window_size作为参数，用id实现
- 注意window size是re-writer的参数，而不是放入生成模型的参数
- mask预测的模型，目前只选取了概率最高的，是否要做多个实验需要考虑

In [None]:
import torch
from transformers import BertTokenizer, BertModel, BertForMaskedLM

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.eval()
# model.to('cuda')  # if you have gpu


def predict_masked_sent(text, top_k=5):
    # Tokenize input
    text = "[CLS] %s [SEP]"%text
    tokenized_text = tokenizer.tokenize(text)
    masked_index = tokenized_text.index("[MASK]")
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    tokens_tensor = torch.tensor([indexed_tokens])
    # tokens_tensor = tokens_tensor.to('cuda')    # if you have gpu

    # Predict all tokens
    with torch.no_grad():
        outputs = model(tokens_tensor)
        predictions = outputs[0]

    probs = torch.nn.functional.softmax(predictions[0, masked_index], dim=-1)
    top_k_weights, top_k_indices = torch.topk(probs, top_k, sorted=True)
    
    pred_with_prob = {}
    for i, pred_idx in enumerate(top_k_indices):
        predicted_token = tokenizer.convert_ids_to_tokens([pred_idx])[0]
        token_weight = top_k_weights[i]
        pred_with_prob[predicted_token] = float(token_weight)
        #print("[MASK]: '%s'"%predicted_token, " | weights:", float(token_weight))
    return pred_with_prob

In [None]:
#Replace all unseen entities with [MASK] and put the masked sentences into all_data
window_size = 0
ex = [] #long sentence cannot be tokenized

for i in tqdm(range(len(unseen_dataset))):
    doc_num = unseen_dataset['doc number'][i]
    dialog_num = unseen_dataset['dialog index'][i]
    unseen_entity = unseen_dataset['unseen entity'][i]
    
    unseen_sentence = all_data[doc_num][dialog_num]
    mask_sentence = unseen_sentence.replace(unseen_entity, "[MASK]")
    
    if window_size==0 or dialog_num < window_size:
        try:
            pred = predict_masked_sent(mask_sentence, top_k=1)
            UE_pred = list(pred.keys())[0]
        except ValueError:
            UE_pred = unseen_entity
        except RuntimeError:
            ex.append((doc_num, dialog_num))
    
    else:
        context = ""
        for sen in all_data[doc_num][dialog_num-window_size : dialog_num]:
            context = context+sen
        mask_sentence = context+mask_sentence
        try:
            pred = predict_masked_sent(mask_sentence, top_k=1)
            UE_pred = list(pred.keys())[0]
        except ValueError:
            UE_pred = unseen_entity
        except RuntimeError:
            ex.append((doc_num, dialog_num))
    
    rewrited_sentence = mask_sentence.replace("[MASK]", UE_pred)
    del all_data[doc_num][dialog_num]
    all_data[doc_num].insert(dialog_num, rewrited_sentence)

将经过UE-Rewriter的数据保存成txt

In [None]:
file = open('./rewrite_data_w0.txt','a')
for dialog in all_data:
    for sen in dialog:
        file.write(sen)
        file.write('\n')
    file.write('##')    
file.close()

### 3. To read the new txt
- 不同window size rewrite之后，保存下来，再放pretrain，以免重复工作

In [16]:
file = open('./rewrite_data_w0/rewrite_data_w0_A.txt','r')
file_data = file.read() 
file_data = file_data.split('##')

while "" in file_data:
    file_data.remove("")
    
data = []
for dialog in file_data:
    tep_list = dialog.split('\n')
    del(tep_list[-1])
    data.append(tep_list)

In [17]:
len(data)

4432

In [18]:
def has_duplicates(t):
    for i in t:
        if i in t.pop(t.index(i)):
            return True
    return False

In [19]:
has_duplicates(data)

False

In [21]:
data[9]

['sh! that was a super tough game. i am really looking forward to student tickets to all the home games!',
 'i live in dallas tx now so i only get to watch on tv.  two years ago alabama played usc out here in arlington, tx and i was able to go.  i hope you enjoy your senior year.  you only get to do it once so make the most of it.',
 'thank you. i am going to try.  my family will make sure of it, too - always. ',
 'i am sure you will have a blast.']